diff --git a/include/linux/filter.h b/include/linux/filter.h
index d29e58fde364f01168c059f85faac389622a3fc3..818a0b26249ea2a836638204226cd73b106e7ccb 100644
--- a/include/linux/filter.h
+++ b/include/linux/filter.h
@@ -728,7 +728,7 @@ void xdp_do_flush_map(void);
 void bpf_warn_invalid_xdp_action(u32 act);
 void bpf_warn_invalid_xdp_redirect(u32 ifindex);
 
-struct sock *do_sk_redirect_map(void);
+struct sock *do_sk_redirect_map(struct sk_buff *skb);
 
 #ifdef CONFIG_BPF_JIT
 extern int bpf_jit_enable;
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 89974c5286d8ae4db63def6822ae5b0dc93cddf9..b1ef98ebce53cd259a2894f4575cb18cc7755331 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -840,6 +840,11 @@ struct tcp_skb_cb {
 			struct inet6_skb_parm	h6;
 #endif
 		} header;	/* For incoming skbs */
+		struct {
+			__u32 key;
+			__u32 flags;
+			struct bpf_map *map;
+		} bpf;
 	};
 };
 
diff --git a/kernel/bpf/devmap.c b/kernel/bpf/devmap.c
index 920428d84da23c224035bb82172c425e7aa0efae..52e0548ba548713b4e152b688f5597dcf91f2f88 100644
--- a/kernel/bpf/devmap.c
+++ b/kernel/bpf/devmap.c
@@ -78,6 +78,9 @@ static struct bpf_map *dev_map_alloc(union bpf_attr *attr)
 	int err = -EINVAL;
 	u64 cost;
 
+	if (!capable(CAP_NET_ADMIN))
+		return ERR_PTR(-EPERM);
+
 	/* check sanity of attributes */
 	if (attr->max_entries == 0 || attr->key_size != 4 ||
 	    attr->value_size != 4 || attr->map_flags & ~BPF_F_NUMA_NODE)
diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
index 6424ce0e49698abee1da6242a0a8494d8ba0f03f..2b6eb35ae5d39799a9ae33fe5bf07508d2b82a54 100644
--- a/kernel/bpf/sockmap.c
+++ b/kernel/bpf/sockmap.c
@@ -39,6 +39,7 @@
 #include <linux/workqueue.h>
 #include <linux/list.h>
 #include <net/strparser.h>
+#include <net/tcp.h>
 
 struct bpf_stab {
 	struct bpf_map map;
@@ -101,9 +102,16 @@ static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
 		return SK_DROP;
 
 	skb_orphan(skb);
+	/* We need to ensure that BPF metadata for maps is also cleared
+	 * when we orphan the skb so that we don't have the possibility
+	 * to reference a stale map.
+	 */
+	TCP_SKB_CB(skb)->bpf.map = NULL;
 	skb->sk = psock->sock;
 	bpf_compute_data_end(skb);
+	preempt_disable();
 	rc = (*prog->bpf_func)(skb, prog->insnsi);
+	preempt_enable();
 	skb->sk = NULL;
 
 	return rc;
@@ -114,17 +122,10 @@ static void smap_do_verdict(struct smap_psock *psock, struct sk_buff *skb)
 	struct sock *sk;
 	int rc;
 
-	/* Because we use per cpu values to feed input from sock redirect
-	 * in BPF program to do_sk_redirect_map() call we need to ensure we
-	 * are not preempted. RCU read lock is not sufficient in this case
-	 * with CONFIG_PREEMPT_RCU enabled so we must be explicit here.
-	 */
-	preempt_disable();
 	rc = smap_verdict_func(psock, skb);
 	switch (rc) {
 	case SK_REDIRECT:
-		sk = do_sk_redirect_map();
-		preempt_enable();
+		sk = do_sk_redirect_map(skb);
 		if (likely(sk)) {
 			struct smap_psock *peer = smap_psock_sk(sk);
 
@@ -141,8 +142,6 @@ static void smap_do_verdict(struct smap_psock *psock, struct sk_buff *skb)
 	/* Fall through and free skb otherwise */
 	case SK_DROP:
 	default:
-		if (rc != SK_REDIRECT)
-			preempt_enable();
 		kfree_skb(skb);
 	}
 }
@@ -487,6 +486,9 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
 	int err = -EINVAL;
 	u64 cost;
 
+	if (!capable(CAP_NET_ADMIN))
+		return ERR_PTR(-EPERM);
+
 	/* check sanity of attributes */
 	if (attr->max_entries == 0 || attr->key_size != 4 ||
 	    attr->value_size != 4 || attr->map_flags & ~BPF_F_NUMA_NODE)
@@ -840,6 +842,12 @@ static int sock_map_update_elem(struct bpf_map *map,
 		return -EINVAL;
 	}
 
+	if (skops.sk->sk_type != SOCK_STREAM ||
+	    skops.sk->sk_protocol != IPPROTO_TCP) {
+		fput(socket->file);
+		return -EOPNOTSUPP;
+	}
+
 	err = sock_map_ctx_update_elem(&skops, map, key, flags);
 	fput(socket->file);
 	return err;
diff --git a/net/core/filter.c b/net/core/filter.c
index 74b8c91fb5f4461da58c73568976bc9834c4612b..aa0265997f930c86229ee1658063ba9718c6dc79 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -1839,31 +1839,31 @@ static const struct bpf_func_proto bpf_redirect_proto = {
 	.arg2_type      = ARG_ANYTHING,
 };
 
-BPF_CALL_3(bpf_sk_redirect_map, struct bpf_map *, map, u32, key, u64, flags)
+BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
+	   struct bpf_map *, map, u32, key, u64, flags)
 {
-	struct redirect_info *ri = this_cpu_ptr(&redirect_info);
+	struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
 
 	if (unlikely(flags))
 		return SK_ABORTED;
 
-	ri->ifindex = key;
-	ri->flags = flags;
-	ri->map = map;
+	tcb->bpf.key = key;
+	tcb->bpf.flags = flags;
+	tcb->bpf.map = map;
 
 	return SK_REDIRECT;
 }
 
-struct sock *do_sk_redirect_map(void)
+struct sock *do_sk_redirect_map(struct sk_buff *skb)
 {
-	struct redirect_info *ri = this_cpu_ptr(&redirect_info);
+	struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
 	struct sock *sk = NULL;
 
-	if (ri->map) {
-		sk = __sock_map_lookup_elem(ri->map, ri->ifindex);
+	if (tcb->bpf.map) {
+		sk = __sock_map_lookup_elem(tcb->bpf.map, tcb->bpf.key);
 
-		ri->ifindex = 0;
-		ri->map = NULL;
-		/* we do not clear flags for future lookup */
+		tcb->bpf.key = 0;
+		tcb->bpf.map = NULL;
 	}
 
 	return sk;
@@ -1873,9 +1873,10 @@ static const struct bpf_func_proto bpf_sk_redirect_map_proto = {
 	.func           = bpf_sk_redirect_map,
 	.gpl_only       = false,
 	.ret_type       = RET_INTEGER,
-	.arg1_type      = ARG_CONST_MAP_PTR,
-	.arg2_type      = ARG_ANYTHING,
+	.arg1_type	= ARG_PTR_TO_CTX,
+	.arg2_type      = ARG_CONST_MAP_PTR,
 	.arg3_type      = ARG_ANYTHING,
+	.arg4_type      = ARG_ANYTHING,
 };
 
 BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb)
@@ -3683,7 +3684,6 @@ static bool sk_skb_is_valid_access(int off, int size,
 {
 	if (type == BPF_WRITE) {
 		switch (off) {
-		case bpf_ctx_range(struct __sk_buff, mark):
 		case bpf_ctx_range(struct __sk_buff, tc_index):
 		case bpf_ctx_range(struct __sk_buff, priority):
 			break;
@@ -3693,6 +3693,7 @@ static bool sk_skb_is_valid_access(int off, int size,
 	}
 
 	switch (off) {
+	case bpf_ctx_range(struct __sk_buff, mark):
 	case bpf_ctx_range(struct __sk_buff, tc_classid):
 		return false;
 	case bpf_ctx_range(struct __sk_buff, data):
diff --git a/samples/sockmap/sockmap_kern.c b/samples/sockmap/sockmap_kern.c
index f9b38ef82dc2449e56094b9c8c0c805673c0a7b0..52b0053274f425a6a07eb70dfbce89ca9c32d4fb 100644
--- a/samples/sockmap/sockmap_kern.c
+++ b/samples/sockmap/sockmap_kern.c
@@ -62,7 +62,7 @@ int bpf_prog2(struct __sk_buff *skb)
 		ret = 1;
 
 	bpf_printk("sockmap: %d -> %d @ %d\n", lport, bpf_ntohl(rport), ret);
-	return bpf_sk_redirect_map(&sock_map, ret, 0);
+	return bpf_sk_redirect_map(skb, &sock_map, ret, 0);
 }
 
 SEC("sockops")
diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h
index 43ab5c402f98f3c2dc15ccc49ec475530a39e401..be9a631a69f78ca03d8514a798484a62840712d2 100644
--- a/tools/include/uapi/linux/bpf.h
+++ b/tools/include/uapi/linux/bpf.h
@@ -569,9 +569,10 @@ union bpf_attr {
  *     @flags: reserved for future use
  *     Return: 0 on success or negative error code
  *
- * int bpf_sk_redirect_map(map, key, flags)
+ * int bpf_sk_redirect_map(skb, map, key, flags)
  *     Redirect skb to a sock in map using key as a lookup key for the
  *     sock in map.
+ *     @skb: pointer to skb
  *     @map: pointer to sockmap
  *     @key: key to lookup sock in map
  *     @flags: reserved for future use
diff --git a/tools/testing/selftests/bpf/bpf_helpers.h b/tools/testing/selftests/bpf/bpf_helpers.h
index 36fb9161b34acf8c14a585f4778905db9185b163..b2e02bdcd098f6380c80bc190e3986acb1007188 100644
--- a/tools/testing/selftests/bpf/bpf_helpers.h
+++ b/tools/testing/selftests/bpf/bpf_helpers.h
@@ -65,7 +65,7 @@ static int (*bpf_xdp_adjust_head)(void *ctx, int offset) =
 static int (*bpf_setsockopt)(void *ctx, int level, int optname, void *optval,
 			     int optlen) =
 	(void *) BPF_FUNC_setsockopt;
-static int (*bpf_sk_redirect_map)(void *map, int key, int flags) =
+static int (*bpf_sk_redirect_map)(void *ctx, void *map, int key, int flags) =
 	(void *) BPF_FUNC_sk_redirect_map;
 static int (*bpf_sock_map_update)(void *map, void *key, void *value,
 				  unsigned long long flags) =
diff --git a/tools/testing/selftests/bpf/sockmap_verdict_prog.c b/tools/testing/selftests/bpf/sockmap_verdict_prog.c
index 9b99bd10807d8990a82ba1a390e9f1304e04d214..2cd2d552938b19211da94c4b51ea7895408d35b0 100644
--- a/tools/testing/selftests/bpf/sockmap_verdict_prog.c
+++ b/tools/testing/selftests/bpf/sockmap_verdict_prog.c
@@ -61,8 +61,8 @@ int bpf_prog2(struct __sk_buff *skb)
 	bpf_printk("verdict: data[0] = redir(%u:%u)\n", map, sk);
 
 	if (!map)
-		return bpf_sk_redirect_map(&sock_map_rx, sk, 0);
-	return bpf_sk_redirect_map(&sock_map_tx, sk, 0);
+		return bpf_sk_redirect_map(skb, &sock_map_rx, sk, 0);
+	return bpf_sk_redirect_map(skb, &sock_map_tx, sk, 0);
 }
 
 char _license[] SEC("license") = "GPL";
diff --git a/tools/testing/selftests/bpf/test_maps.c b/tools/testing/selftests/bpf/test_maps.c
index fe3a443a110228efb99bc893d2f0820cd03e6763..50ce52d2013d6feddadcb6a2f61cdc4667b3d78d 100644
--- a/tools/testing/selftests/bpf/test_maps.c
+++ b/tools/testing/selftests/bpf/test_maps.c
@@ -466,7 +466,7 @@ static void test_sockmap(int tasks, void *data)
 	int one = 1, map_fd_rx, map_fd_tx, map_fd_break, s, sc, rc;
 	struct bpf_map *bpf_map_rx, *bpf_map_tx, *bpf_map_break;
 	int ports[] = {50200, 50201, 50202, 50204};
-	int err, i, fd, sfd[6] = {0xdeadbeef};
+	int err, i, fd, udp, sfd[6] = {0xdeadbeef};
 	u8 buf[20] = {0x0, 0x5, 0x3, 0x2, 0x1, 0x0};
 	int parse_prog, verdict_prog;
 	struct sockaddr_in addr;
@@ -548,6 +548,16 @@ static void test_sockmap(int tasks, void *data)
 		goto out_sockmap;
 	}
 
+	/* Test update with unsupported UDP socket */
+	udp = socket(AF_INET, SOCK_DGRAM, 0);
+	i = 0;
+	err = bpf_map_update_elem(fd, &i, &udp, BPF_ANY);
+	if (!err) {
+		printf("Failed socket SOCK_DGRAM allowed '%i:%i'\n",
+		       i, udp);
+		goto out_sockmap;
+	}
+
 	/* Test update without programs */
 	for (i = 0; i < 6; i++) {
 		err = bpf_map_update_elem(fd, &i, &sfd[i], BPF_ANY);
diff --git a/tools/testing/selftests/bpf/test_verifier.c b/tools/testing/selftests/bpf/test_verifier.c
index 3c7d3a45a3c551b291fb715354754504c62ab8c8..50e15cedbb7ffad813cfc737d7618f198357f3c8 100644
--- a/tools/testing/selftests/bpf/test_verifier.c
+++ b/tools/testing/selftests/bpf/test_verifier.c
@@ -1130,15 +1130,27 @@ static struct bpf_test tests[] = {
 		.errstr = "invalid bpf_context access",
 	},
 	{
-		"check skb->mark is writeable by SK_SKB",
+		"invalid access of skb->mark for SK_SKB",
+		.insns = {
+			BPF_LDX_MEM(BPF_W, BPF_REG_0, BPF_REG_1,
+				    offsetof(struct __sk_buff, mark)),
+			BPF_EXIT_INSN(),
+		},
+		.result =  REJECT,
+		.prog_type = BPF_PROG_TYPE_SK_SKB,
+		.errstr = "invalid bpf_context access",
+	},
+	{
+		"check skb->mark is not writeable by SK_SKB",
 		.insns = {
 			BPF_MOV64_IMM(BPF_REG_0, 0),
 			BPF_STX_MEM(BPF_W, BPF_REG_1, BPF_REG_0,
 				    offsetof(struct __sk_buff, mark)),
 			BPF_EXIT_INSN(),
 		},
-		.result = ACCEPT,
+		.result =  REJECT,
 		.prog_type = BPF_PROG_TYPE_SK_SKB,
+		.errstr = "invalid bpf_context access",
 	},
 	{
 		"check skb->tc_index is writeable by SK_SKB",