diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c
index 3ac3f8570ca1bc9c11ba1e6a4cd49bf170153941..a8f151b1b5fab52d061efdb3e49975076c93d3ac 100644
--- a/drivers/net/wireguard/device.c
+++ b/drivers/net/wireguard/device.c
@@ -45,17 +45,18 @@ static int wg_open(struct net_device *dev)
 	if (dev_v6)
 		dev_v6->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_NONE;
 
+	mutex_lock(&wg->device_update_lock);
 	ret = wg_socket_init(wg, wg->incoming_port);
 	if (ret < 0)
-		return ret;
-	mutex_lock(&wg->device_update_lock);
+		goto out;
 	list_for_each_entry(peer, &wg->peer_list, peer_list) {
 		wg_packet_send_staged_packets(peer);
 		if (peer->persistent_keepalive_interval)
 			wg_packet_send_keepalive(peer);
 	}
+out:
 	mutex_unlock(&wg->device_update_lock);
-	return 0;
+	return ret;
 }
 
 #ifdef CONFIG_PM_SLEEP
@@ -225,6 +226,7 @@ static void wg_destruct(struct net_device *dev)
 	list_del(&wg->device_list);
 	rtnl_unlock();
 	mutex_lock(&wg->device_update_lock);
+	rcu_assign_pointer(wg->creating_net, NULL);
 	wg->incoming_port = 0;
 	wg_socket_reinit(wg, NULL, NULL);
 	/* The final references are cleared in the below calls to destroy_workqueue. */
@@ -240,13 +242,11 @@ static void wg_destruct(struct net_device *dev)
 	skb_queue_purge(&wg->incoming_handshakes);
 	free_percpu(dev->tstats);
 	free_percpu(wg->incoming_handshakes_worker);
-	if (wg->have_creating_net_ref)
-		put_net(wg->creating_net);
 	kvfree(wg->index_hashtable);
 	kvfree(wg->peer_hashtable);
 	mutex_unlock(&wg->device_update_lock);
 
-	pr_debug("%s: Interface deleted\n", dev->name);
+	pr_debug("%s: Interface destroyed\n", dev->name);
 	free_netdev(dev);
 }
 
@@ -292,7 +292,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
 	struct wg_device *wg = netdev_priv(dev);
 	int ret = -ENOMEM;
 
-	wg->creating_net = src_net;
+	rcu_assign_pointer(wg->creating_net, src_net);
 	init_rwsem(&wg->static_identity.lock);
 	mutex_init(&wg->socket_update_lock);
 	mutex_init(&wg->device_update_lock);
@@ -393,30 +393,26 @@ static struct rtnl_link_ops link_ops __read_mostly = {
 	.newlink		= wg_newlink,
 };
 
-static int wg_netdevice_notification(struct notifier_block *nb,
-				     unsigned long action, void *data)
+static void wg_netns_pre_exit(struct net *net)
 {
-	struct net_device *dev = ((struct netdev_notifier_info *)data)->dev;
-	struct wg_device *wg = netdev_priv(dev);
-
-	ASSERT_RTNL();
-
-	if (action != NETDEV_REGISTER || dev->netdev_ops != &netdev_ops)
-		return 0;
+	struct wg_device *wg;
 
-	if (dev_net(dev) == wg->creating_net && wg->have_creating_net_ref) {
-		put_net(wg->creating_net);
-		wg->have_creating_net_ref = false;
-	} else if (dev_net(dev) != wg->creating_net &&
-		   !wg->have_creating_net_ref) {
-		wg->have_creating_net_ref = true;
-		get_net(wg->creating_net);
+	rtnl_lock();
+	list_for_each_entry(wg, &device_list, device_list) {
+		if (rcu_access_pointer(wg->creating_net) == net) {
+			pr_debug("%s: Creating namespace exiting\n", wg->dev->name);
+			netif_carrier_off(wg->dev);
+			mutex_lock(&wg->device_update_lock);
+			rcu_assign_pointer(wg->creating_net, NULL);
+			wg_socket_reinit(wg, NULL, NULL);
+			mutex_unlock(&wg->device_update_lock);
+		}
 	}
-	return 0;
+	rtnl_unlock();
 }
 
-static struct notifier_block netdevice_notifier = {
-	.notifier_call = wg_netdevice_notification
+static struct pernet_operations pernet_ops = {
+	.pre_exit = wg_netns_pre_exit
 };
 
 int __init wg_device_init(void)
@@ -429,18 +425,18 @@ int __init wg_device_init(void)
 		return ret;
 #endif
 
-	ret = register_netdevice_notifier(&netdevice_notifier);
+	ret = register_pernet_device(&pernet_ops);
 	if (ret)
 		goto error_pm;
 
 	ret = rtnl_link_register(&link_ops);
 	if (ret)
-		goto error_netdevice;
+		goto error_pernet;
 
 	return 0;
 
-error_netdevice:
-	unregister_netdevice_notifier(&netdevice_notifier);
+error_pernet:
+	unregister_pernet_device(&pernet_ops);
 error_pm:
 #ifdef CONFIG_PM_SLEEP
 	unregister_pm_notifier(&pm_notifier);
@@ -451,7 +447,7 @@ int __init wg_device_init(void)
 void wg_device_uninit(void)
 {
 	rtnl_link_unregister(&link_ops);
-	unregister_netdevice_notifier(&netdevice_notifier);
+	unregister_pernet_device(&pernet_ops);
 #ifdef CONFIG_PM_SLEEP
 	unregister_pm_notifier(&pm_notifier);
 #endif
diff --git a/drivers/net/wireguard/device.h b/drivers/net/wireguard/device.h
index b15a8be9d8169688d07ffb295b9f0fa2afe4722f..4d0144e169478f1a3c6982effe8deed2146d40d7 100644
--- a/drivers/net/wireguard/device.h
+++ b/drivers/net/wireguard/device.h
@@ -40,7 +40,7 @@ struct wg_device {
 	struct net_device *dev;
 	struct crypt_queue encrypt_queue, decrypt_queue;
 	struct sock __rcu *sock4, *sock6;
-	struct net *creating_net;
+	struct net __rcu *creating_net;
 	struct noise_static_identity static_identity;
 	struct workqueue_struct *handshake_receive_wq, *handshake_send_wq;
 	struct workqueue_struct *packet_crypt_wq;
@@ -56,7 +56,6 @@ struct wg_device {
 	unsigned int num_peers, device_update_gen;
 	u32 fwmark;
 	u16 incoming_port;
-	bool have_creating_net_ref;
 };
 
 int wg_device_init(void);
diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c
index 802099c8828a6b1aed3a476b2144e88175d857ec..20a4f3c0a0a1992ad638175192dd301191fee8dd 100644
--- a/drivers/net/wireguard/netlink.c
+++ b/drivers/net/wireguard/netlink.c
@@ -511,11 +511,15 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info)
 	if (flags & ~__WGDEVICE_F_ALL)
 		goto out;
 
-	ret = -EPERM;
-	if ((info->attrs[WGDEVICE_A_LISTEN_PORT] ||
-	     info->attrs[WGDEVICE_A_FWMARK]) &&
-	    !ns_capable(wg->creating_net->user_ns, CAP_NET_ADMIN))
-		goto out;
+	if (info->attrs[WGDEVICE_A_LISTEN_PORT] || info->attrs[WGDEVICE_A_FWMARK]) {
+		struct net *net;
+		rcu_read_lock();
+		net = rcu_dereference(wg->creating_net);
+		ret = !net || !ns_capable(net->user_ns, CAP_NET_ADMIN) ? -EPERM : 0;
+		rcu_read_unlock();
+		if (ret)
+			goto out;
+	}
 
 	++wg->device_update_gen;
 
diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c
index f9018027fc133db48e7d20972690c7e52f4942f5..c33e2c81635fa1aaa49ac3b0a917f26e01ac6e2c 100644
--- a/drivers/net/wireguard/socket.c
+++ b/drivers/net/wireguard/socket.c
@@ -347,6 +347,7 @@ static void set_sock_opts(struct socket *sock)
 
 int wg_socket_init(struct wg_device *wg, u16 port)
 {
+	struct net *net;
 	int ret;
 	struct udp_tunnel_sock_cfg cfg = {
 		.sk_user_data = wg,
@@ -371,37 +372,47 @@ int wg_socket_init(struct wg_device *wg, u16 port)
 	};
 #endif
 
+	rcu_read_lock();
+	net = rcu_dereference(wg->creating_net);
+	net = net ? maybe_get_net(net) : NULL;
+	rcu_read_unlock();
+	if (unlikely(!net))
+		return -ENONET;
+
 #if IS_ENABLED(CONFIG_IPV6)
 retry:
 #endif
 
-	ret = udp_sock_create(wg->creating_net, &port4, &new4);
+	ret = udp_sock_create(net, &port4, &new4);
 	if (ret < 0) {
 		pr_err("%s: Could not create IPv4 socket\n", wg->dev->name);
-		return ret;
+		goto out;
 	}
 	set_sock_opts(new4);
-	setup_udp_tunnel_sock(wg->creating_net, new4, &cfg);
+	setup_udp_tunnel_sock(net, new4, &cfg);
 
 #if IS_ENABLED(CONFIG_IPV6)
 	if (ipv6_mod_enabled()) {
 		port6.local_udp_port = inet_sk(new4->sk)->inet_sport;
-		ret = udp_sock_create(wg->creating_net, &port6, &new6);
+		ret = udp_sock_create(net, &port6, &new6);
 		if (ret < 0) {
 			udp_tunnel_sock_release(new4);
 			if (ret == -EADDRINUSE && !port && retries++ < 100)
 				goto retry;
 			pr_err("%s: Could not create IPv6 socket\n",
 			       wg->dev->name);
-			return ret;
+			goto out;
 		}
 		set_sock_opts(new6);
-		setup_udp_tunnel_sock(wg->creating_net, new6, &cfg);
+		setup_udp_tunnel_sock(net, new6, &cfg);
 	}
 #endif
 
 	wg_socket_reinit(wg, new4->sk, new6 ? new6->sk : NULL);
-	return 0;
+	ret = 0;
+out:
+	put_net(net);
+	return ret;
 }
 
 void wg_socket_reinit(struct wg_device *wg, struct sock *new4,
diff --git a/tools/testing/selftests/wireguard/netns.sh b/tools/testing/selftests/wireguard/netns.sh
index 17a1f53ceba01aa7536fd6f57029bbb6b2918ea6..d77f4829f1e0702fa63f4019683101df9fac9e60 100755
--- a/tools/testing/selftests/wireguard/netns.sh
+++ b/tools/testing/selftests/wireguard/netns.sh
@@ -587,9 +587,20 @@ ip0 link set wg0 up
 kill $ncat_pid
 ip0 link del wg0
 
+# Ensure there aren't circular reference loops
+ip1 link add wg1 type wireguard
+ip2 link add wg2 type wireguard
+ip1 link set wg1 netns $netns2
+ip2 link set wg2 netns $netns1
+pp ip netns delete $netns1
+pp ip netns delete $netns2
+pp ip netns add $netns1
+pp ip netns add $netns2
+
+sleep 2 # Wait for cleanup and grace periods
 declare -A objects
 while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
-	[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
+	[[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ ?[0-9]*)\ .*(created|destroyed).* ]] || continue
 	objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
 done < /dev/kmsg
 alldeleted=1