diff --git a/net/xdp/xdp_umem.c b/net/xdp/xdp_umem.c
index 20c91f02d3d8056614112a3adbf9a27fedd5629f..83de74ca729a39bc22cd83ac3f1cb2e7b184db0f 100644
--- a/net/xdp/xdp_umem.c
+++ b/net/xdp/xdp_umem.c
@@ -87,21 +87,20 @@ int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev,
 	struct netdev_bpf bpf;
 	int err = 0;
 
+	ASSERT_RTNL();
+
 	force_zc = flags & XDP_ZEROCOPY;
 	force_copy = flags & XDP_COPY;
 
 	if (force_zc && force_copy)
 		return -EINVAL;
 
-	rtnl_lock();
-	if (xdp_get_umem_from_qid(dev, queue_id)) {
-		err = -EBUSY;
-		goto out_rtnl_unlock;
-	}
+	if (xdp_get_umem_from_qid(dev, queue_id))
+		return -EBUSY;
 
 	err = xdp_reg_umem_at_qid(dev, umem, queue_id);
 	if (err)
-		goto out_rtnl_unlock;
+		return err;
 
 	umem->dev = dev;
 	umem->queue_id = queue_id;
@@ -110,7 +109,7 @@ int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev,
 
 	if (force_copy)
 		/* For copy-mode, we are done. */
-		goto out_rtnl_unlock;
+		return 0;
 
 	if (!dev->netdev_ops->ndo_bpf ||
 	    !dev->netdev_ops->ndo_xsk_async_xmit) {
@@ -125,7 +124,6 @@ int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev,
 	err = dev->netdev_ops->ndo_bpf(dev, &bpf);
 	if (err)
 		goto err_unreg_umem;
-	rtnl_unlock();
 
 	umem->zc = true;
 	return 0;
@@ -135,8 +133,6 @@ int xdp_umem_assign_dev(struct xdp_umem *umem, struct net_device *dev,
 		err = 0; /* fallback to copy mode */
 	if (err)
 		xdp_clear_umem_at_qid(dev, queue_id);
-out_rtnl_unlock:
-	rtnl_unlock();
 	return err;
 }
 
diff --git a/net/xdp/xsk.c b/net/xdp/xsk.c
index b994c32a664ab2db8c5eba564e1297211d8f929b..59b57d7086970b53b140bc500c881ab9cf1fb98b 100644
--- a/net/xdp/xsk.c
+++ b/net/xdp/xsk.c
@@ -430,6 +430,7 @@ static int xsk_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
 	if (flags & ~(XDP_SHARED_UMEM | XDP_COPY | XDP_ZEROCOPY))
 		return -EINVAL;
 
+	rtnl_lock();
 	mutex_lock(&xs->mutex);
 	if (xs->state != XSK_READY) {
 		err = -EBUSY;
@@ -515,6 +516,7 @@ static int xsk_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
 		xs->state = XSK_BOUND;
 out_release:
 	mutex_unlock(&xs->mutex);
+	rtnl_unlock();
 	return err;
 }