diff --git a/include/linux/cgroup-defs.h b/include/linux/cgroup-defs.h
index 504d8591b6d3c18ba1ca2f432af235739ac77276..ed128fed03356e48fac3ff5e1aab06ff58051540 100644
--- a/include/linux/cgroup-defs.h
+++ b/include/linux/cgroup-defs.h
@@ -542,4 +542,40 @@ static inline void cgroup_threadgroup_change_end(struct task_struct *tsk) {}
 
 #endif	/* CONFIG_CGROUPS */
 
+#ifdef CONFIG_SOCK_CGROUP_DATA
+
+struct sock_cgroup_data {
+	u16	prioidx;
+	u32	classid;
+};
+
+static inline u16 sock_cgroup_prioidx(struct sock_cgroup_data *skcd)
+{
+	return skcd->prioidx;
+}
+
+static inline u32 sock_cgroup_classid(struct sock_cgroup_data *skcd)
+{
+	return skcd->classid;
+}
+
+static inline void sock_cgroup_set_prioidx(struct sock_cgroup_data *skcd,
+					   u16 prioidx)
+{
+	skcd->prioidx = prioidx;
+}
+
+static inline void sock_cgroup_set_classid(struct sock_cgroup_data *skcd,
+					   u32 classid)
+{
+	skcd->classid = classid;
+}
+
+#else	/* CONFIG_SOCK_CGROUP_DATA */
+
+struct sock_cgroup_data {
+};
+
+#endif	/* CONFIG_SOCK_CGROUP_DATA */
+
 #endif	/* _LINUX_CGROUP_DEFS_H */
diff --git a/include/net/cls_cgroup.h b/include/net/cls_cgroup.h
index ccd6d8bffa4d8d0744c70c3591f948d6520b634a..c0a92e2c286d6cc0942591dfc02d9b485ae6810e 100644
--- a/include/net/cls_cgroup.h
+++ b/include/net/cls_cgroup.h
@@ -41,13 +41,12 @@ static inline u32 task_cls_classid(struct task_struct *p)
 	return classid;
 }
 
-static inline void sock_update_classid(struct sock *sk)
+static inline void sock_update_classid(struct sock_cgroup_data *skcd)
 {
 	u32 classid;
 
 	classid = task_cls_classid(current);
-	if (classid != sk->sk_classid)
-		sk->sk_classid = classid;
+	sock_cgroup_set_classid(skcd, classid);
 }
 
 static inline u32 task_get_classid(const struct sk_buff *skb)
@@ -64,17 +63,17 @@ static inline u32 task_get_classid(const struct sk_buff *skb)
 	 * softirqs always disables bh.
 	 */
 	if (in_serving_softirq()) {
-		/* If there is an sk_classid we'll use that. */
+		/* If there is an sock_cgroup_classid we'll use that. */
 		if (!skb->sk)
 			return 0;
 
-		classid = skb->sk->sk_classid;
+		classid = sock_cgroup_classid(&skb->sk->sk_cgrp_data);
 	}
 
 	return classid;
 }
 #else /* !CONFIG_CGROUP_NET_CLASSID */
-static inline void sock_update_classid(struct sock *sk)
+static inline void sock_update_classid(struct sock_cgroup_data *skcd)
 {
 }
 
diff --git a/include/net/netprio_cgroup.h b/include/net/netprio_cgroup.h
index f2a9597ff53c089cb3aa4a810e6286355c9f194c..604190596cde8c1b2cc719800b32ab98803df41d 100644
--- a/include/net/netprio_cgroup.h
+++ b/include/net/netprio_cgroup.h
@@ -25,8 +25,6 @@ struct netprio_map {
 	u32 priomap[];
 };
 
-void sock_update_netprioidx(struct sock *sk);
-
 static inline u32 task_netprioidx(struct task_struct *p)
 {
 	struct cgroup_subsys_state *css;
@@ -38,13 +36,25 @@ static inline u32 task_netprioidx(struct task_struct *p)
 	rcu_read_unlock();
 	return idx;
 }
+
+static inline void sock_update_netprioidx(struct sock_cgroup_data *skcd)
+{
+	if (in_interrupt())
+		return;
+
+	sock_cgroup_set_prioidx(skcd, task_netprioidx(current));
+}
+
 #else /* !CONFIG_CGROUP_NET_PRIO */
+
 static inline u32 task_netprioidx(struct task_struct *p)
 {
 	return 0;
 }
 
-#define sock_update_netprioidx(sk)
+static inline void sock_update_netprioidx(struct sock_cgroup_data *skcd)
+{
+}
 
 #endif /* CONFIG_CGROUP_NET_PRIO */
 #endif  /* _NET_CLS_CGROUP_H */
diff --git a/include/net/sock.h b/include/net/sock.h
index a95bcf7d6efaed67befba4b687111d347b3966fd..0ca22b014de1a0a31e539a89e35fa5cc14875384 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -59,6 +59,7 @@
 #include <linux/static_key.h>
 #include <linux/sched.h>
 #include <linux/wait.h>
+#include <linux/cgroup-defs.h>
 
 #include <linux/filter.h>
 #include <linux/rculist_nulls.h>
@@ -308,8 +309,7 @@ struct cg_proto;
   *	@sk_send_head: front of stuff to transmit
   *	@sk_security: used by security modules
   *	@sk_mark: generic packet mark
-  *	@sk_cgrp_prioidx: socket group's priority map index
-  *	@sk_classid: this socket's cgroup classid
+  *	@sk_cgrp_data: cgroup data for this cgroup
   *	@sk_cgrp: this socket's cgroup-specific proto data
   *	@sk_write_pending: a write to stream socket waits to start
   *	@sk_state_change: callback to indicate change in the state of the sock
@@ -443,12 +443,7 @@ struct sock {
 #ifdef CONFIG_SECURITY
 	void			*sk_security;
 #endif
-#if IS_ENABLED(CONFIG_CGROUP_NET_PRIO)
-	u16			sk_cgrp_prioidx;
-#endif
-#ifdef CONFIG_CGROUP_NET_CLASSID
-	u32			sk_classid;
-#endif
+	struct sock_cgroup_data	sk_cgrp_data;
 	struct cg_proto		*sk_cgrp;
 	void			(*sk_state_change)(struct sock *sk);
 	void			(*sk_data_ready)(struct sock *sk);
diff --git a/net/Kconfig b/net/Kconfig
index 127da94ae25eb73e8ffd45e7e7dbc9e07d937033..11f8c22af34d09f59755c8c87de137e1879e6845 100644
--- a/net/Kconfig
+++ b/net/Kconfig
@@ -250,9 +250,14 @@ config XPS
 	depends on SMP
 	default y
 
+config SOCK_CGROUP_DATA
+	bool
+	default n
+
 config CGROUP_NET_PRIO
 	bool "Network priority cgroup"
 	depends on CGROUPS
+	select SOCK_CGROUP_DATA
 	---help---
 	  Cgroup subsystem for use in assigning processes to network priorities on
 	  a per-interface basis.
@@ -260,6 +265,7 @@ config CGROUP_NET_PRIO
 config CGROUP_NET_CLASSID
 	bool "Network classid cgroup"
 	depends on CGROUPS
+	select SOCK_CGROUP_DATA
 	---help---
 	  Cgroup subsystem for use as general purpose socket classid marker that is
 	  being used in cls_cgroup and for netfilter matching.
diff --git a/net/core/dev.c b/net/core/dev.c
index e5c395473eba800fb9fe7839ffdc19756d898263..8f705fcedb94b0e4ef56cdf70b363fdd396a14ee 100644
--- a/net/core/dev.c
+++ b/net/core/dev.c
@@ -2929,7 +2929,8 @@ static void skb_update_prio(struct sk_buff *skb)
 	struct netprio_map *map = rcu_dereference_bh(skb->dev->priomap);
 
 	if (!skb->priority && skb->sk && map) {
-		unsigned int prioidx = skb->sk->sk_cgrp_prioidx;
+		unsigned int prioidx =
+			sock_cgroup_prioidx(&skb->sk->sk_cgrp_data);
 
 		if (prioidx < map->priomap_len)
 			skb->priority = map->priomap[prioidx];
diff --git a/net/core/netclassid_cgroup.c b/net/core/netclassid_cgroup.c
index 2e4df84c34a194ad61d818f282d1124521b9a828..e60ded46b3ac5fff089402807a4eddc77217611f 100644
--- a/net/core/netclassid_cgroup.c
+++ b/net/core/netclassid_cgroup.c
@@ -62,8 +62,8 @@ static int update_classid_sock(const void *v, struct file *file, unsigned n)
 	struct socket *sock = sock_from_file(file, &err);
 
 	if (sock)
-		sock->sk->sk_classid = (u32)(unsigned long)v;
-
+		sock_cgroup_set_classid(&sock->sk->sk_cgrp_data,
+					(unsigned long)v);
 	return 0;
 }
 
diff --git a/net/core/netprio_cgroup.c b/net/core/netprio_cgroup.c
index 2b9159b7a28a43f1772963a9df361f49cc7f6a63..de42aa7f6c7702fc50e36475544c4075b4f3c453 100644
--- a/net/core/netprio_cgroup.c
+++ b/net/core/netprio_cgroup.c
@@ -223,7 +223,8 @@ static int update_netprio(const void *v, struct file *file, unsigned n)
 	int err;
 	struct socket *sock = sock_from_file(file, &err);
 	if (sock)
-		sock->sk->sk_cgrp_prioidx = (u32)(unsigned long)v;
+		sock_cgroup_set_prioidx(&sock->sk->sk_cgrp_data,
+					(unsigned long)v);
 	return 0;
 }
 
diff --git a/net/core/scm.c b/net/core/scm.c
index 8a1741b14302bd0cecdc265848feba8222400d17..14596fb3717270d62fa70544b7ec2496de96e1ce 100644
--- a/net/core/scm.c
+++ b/net/core/scm.c
@@ -289,8 +289,8 @@ void scm_detach_fds(struct msghdr *msg, struct scm_cookie *scm)
 		/* Bump the usage count and install the file. */
 		sock = sock_from_file(fp[i], &err);
 		if (sock) {
-			sock_update_netprioidx(sock->sk);
-			sock_update_classid(sock->sk);
+			sock_update_netprioidx(&sock->sk->sk_cgrp_data);
+			sock_update_classid(&sock->sk->sk_cgrp_data);
 		}
 		fd_install(new_fd, get_file(fp[i]));
 	}
diff --git a/net/core/sock.c b/net/core/sock.c
index 7965ef487375631035d9b5e72324a320a39a8c5f..947741dc43fa68cc68549347a6428bbc7fb1e1be 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -1393,17 +1393,6 @@ static void sk_prot_free(struct proto *prot, struct sock *sk)
 	module_put(owner);
 }
 
-#if IS_ENABLED(CONFIG_CGROUP_NET_PRIO)
-void sock_update_netprioidx(struct sock *sk)
-{
-	if (in_interrupt())
-		return;
-
-	sk->sk_cgrp_prioidx = task_netprioidx(current);
-}
-EXPORT_SYMBOL_GPL(sock_update_netprioidx);
-#endif
-
 /**
  *	sk_alloc - All socket objects are allocated here
  *	@net: the applicable net namespace
@@ -1432,8 +1421,8 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority,
 		sock_net_set(sk, net);
 		atomic_set(&sk->sk_wmem_alloc, 1);
 
-		sock_update_classid(sk);
-		sock_update_netprioidx(sk);
+		sock_update_classid(&sk->sk_cgrp_data);
+		sock_update_netprioidx(&sk->sk_cgrp_data);
 	}
 
 	return sk;
diff --git a/net/netfilter/nft_meta.c b/net/netfilter/nft_meta.c
index 9dfaf4d55ee0b1bd92d8d7d2e64eebc42e8e5b04..1915cab7f32d34601498673cdbb16c42e5ef1ed7 100644
--- a/net/netfilter/nft_meta.c
+++ b/net/netfilter/nft_meta.c
@@ -174,7 +174,7 @@ void nft_meta_get_eval(const struct nft_expr *expr,
 		sk = skb_to_full_sk(skb);
 		if (!sk || !sk_fullsock(sk))
 			goto err;
-		*dest = sk->sk_classid;
+		*dest = sock_cgroup_classid(&sk->sk_cgrp_data);
 		break;
 #endif
 	default:
diff --git a/net/netfilter/xt_cgroup.c b/net/netfilter/xt_cgroup.c
index a1d126f2946305a10ccc04ce92e469b1255f60f9..54eaeb45ce996359ebd3c79adbe8f27d23ad7c1b 100644
--- a/net/netfilter/xt_cgroup.c
+++ b/net/netfilter/xt_cgroup.c
@@ -42,7 +42,8 @@ cgroup_mt(const struct sk_buff *skb, struct xt_action_param *par)
 	if (skb->sk == NULL || !sk_fullsock(skb->sk))
 		return false;
 
-	return (info->id == skb->sk->sk_classid) ^ info->invert;
+	return (info->id == sock_cgroup_classid(&skb->sk->sk_cgrp_data)) ^
+		info->invert;
 }
 
 static struct xt_match cgroup_mt_reg __read_mostly = {