diff --git a/fs/io-wq.c b/fs/io-wq.c
index 0ae9ecadf295ce0b25e9440a861bb917a5f1333a..e05f996d088f17437834bcb96f4987d257773ab0 100644
--- a/fs/io-wq.c
+++ b/fs/io-wq.c
@@ -488,6 +488,8 @@ static int io_wqe_worker(void *data)
 	set_task_comm(current, buf);
 
 	while (!test_bit(IO_WQ_BIT_EXIT, &wq->state)) {
+		long ret;
+
 		set_current_state(TASK_INTERRUPTIBLE);
 loop:
 		raw_spin_lock_irq(&wqe->lock);
@@ -498,7 +500,8 @@ loop:
 		__io_worker_idle(wqe, worker);
 		raw_spin_unlock_irq(&wqe->lock);
 		io_flush_signals();
-		if (schedule_timeout(WORKER_IDLE_TIMEOUT))
+		ret = schedule_timeout(WORKER_IDLE_TIMEOUT);
+		if (try_to_freeze() || ret)
 			continue;
 		if (fatal_signal_pending(current))
 			break;
@@ -709,6 +712,7 @@ static int io_wq_manager(void *data)
 		set_current_state(TASK_INTERRUPTIBLE);
 		io_wq_check_workers(wq);
 		schedule_timeout(HZ);
+		try_to_freeze();
 		if (fatal_signal_pending(current))
 			set_bit(IO_WQ_BIT_EXIT, &wq->state);
 	} while (!test_bit(IO_WQ_BIT_EXIT, &wq->state));
diff --git a/fs/io-wq.h b/fs/io-wq.h
index 1ac2f3248088e701c0067d45b2ae3909e619d562..80d590564ff93e00f1a457fffd7bcee97f73e837 100644
--- a/fs/io-wq.h
+++ b/fs/io-wq.h
@@ -2,7 +2,6 @@
 #define INTERNAL_IO_WQ_H
 
 #include <linux/refcount.h>
-#include <linux/io_uring.h>
 
 struct io_wq;
 
@@ -21,6 +20,15 @@ enum io_wq_cancel {
 	IO_WQ_CANCEL_NOTFOUND,	/* work not found */
 };
 
+struct io_wq_work_node {
+	struct io_wq_work_node *next;
+};
+
+struct io_wq_work_list {
+	struct io_wq_work_node *first;
+	struct io_wq_work_node *last;
+};
+
 static inline void wq_list_add_after(struct io_wq_work_node *node,
 				     struct io_wq_work_node *pos,
 				     struct io_wq_work_list *list)
diff --git a/fs/io_uring.c b/fs/io_uring.c
index a4bce17af506bccbf4aebcb828da8b3761d3d7c3..ea2d3e120555e9fbf57c3da3f895e71d3d69dc61 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -258,7 +258,8 @@ enum {
 
 struct io_sq_data {
 	refcount_t		refs;
-	struct rw_semaphore	rw_lock;
+	atomic_t		park_pending;
+	struct mutex		lock;
 
 	/* ctx's that are using this sqd */
 	struct list_head	ctx_list;
@@ -273,6 +274,7 @@ struct io_sq_data {
 
 	unsigned long		state;
 	struct completion	exited;
+	struct callback_head	*park_task_work;
 };
 
 #define IO_IOPOLL_BATCH			8
@@ -402,7 +404,7 @@ struct io_ring_ctx {
 	struct socket		*ring_sock;
 #endif
 
-	struct idr		io_buffer_idr;
+	struct xarray		io_buffers;
 
 	struct xarray		personalities;
 	u32			pers_next;
@@ -454,6 +456,22 @@ struct io_ring_ctx {
 	struct list_head		tctx_list;
 };
 
+struct io_uring_task {
+	/* submission side */
+	struct xarray		xa;
+	struct wait_queue_head	wait;
+	const struct io_ring_ctx *last;
+	struct io_wq		*io_wq;
+	struct percpu_counter	inflight;
+	atomic_t		in_idle;
+	bool			sqpoll;
+
+	spinlock_t		task_lock;
+	struct io_wq_work_list	task_list;
+	unsigned long		task_state;
+	struct callback_head	task_work;
+};
+
 /*
  * First field must be the file pointer in all the
  * iocb unions! See also 'struct kiocb' in <linux/fs.h>
@@ -1135,7 +1153,7 @@ static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
 	init_waitqueue_head(&ctx->cq_wait);
 	INIT_LIST_HEAD(&ctx->cq_overflow_list);
 	init_completion(&ctx->ref_comp);
-	idr_init(&ctx->io_buffer_idr);
+	xa_init_flags(&ctx->io_buffers, XA_FLAGS_ALLOC1);
 	xa_init_flags(&ctx->personalities, XA_FLAGS_ALLOC1);
 	mutex_init(&ctx->uring_lock);
 	init_waitqueue_head(&ctx->wait);
@@ -1550,14 +1568,17 @@ static void io_req_complete_post(struct io_kiocb *req, long res,
 		io_put_task(req->task, 1);
 		list_add(&req->compl.list, &cs->locked_free_list);
 		cs->locked_free_nr++;
-	} else
-		req = NULL;
+	} else {
+		if (!percpu_ref_tryget(&ctx->refs))
+			req = NULL;
+	}
 	io_commit_cqring(ctx);
 	spin_unlock_irqrestore(&ctx->completion_lock, flags);
-	io_cqring_ev_posted(ctx);
 
-	if (req)
+	if (req) {
+		io_cqring_ev_posted(ctx);
 		percpu_ref_put(&ctx->refs);
+	}
 }
 
 static void io_req_complete_state(struct io_kiocb *req, long res,
@@ -1925,17 +1946,44 @@ static int io_req_task_work_add(struct io_kiocb *req)
 	return ret;
 }
 
-static void io_req_task_work_add_fallback(struct io_kiocb *req,
-					  task_work_func_t cb)
+static bool io_run_task_work_head(struct callback_head **work_head)
+{
+	struct callback_head *work, *next;
+	bool executed = false;
+
+	do {
+		work = xchg(work_head, NULL);
+		if (!work)
+			break;
+
+		do {
+			next = work->next;
+			work->func(work);
+			work = next;
+			cond_resched();
+		} while (work);
+		executed = true;
+	} while (1);
+
+	return executed;
+}
+
+static void io_task_work_add_head(struct callback_head **work_head,
+				  struct callback_head *task_work)
 {
-	struct io_ring_ctx *ctx = req->ctx;
 	struct callback_head *head;
 
-	init_task_work(&req->task_work, cb);
 	do {
-		head = READ_ONCE(ctx->exit_task_work);
-		req->task_work.next = head;
-	} while (cmpxchg(&ctx->exit_task_work, head, &req->task_work) != head);
+		head = READ_ONCE(*work_head);
+		task_work->next = head;
+	} while (cmpxchg(work_head, head, task_work) != head);
+}
+
+static void io_req_task_work_add_fallback(struct io_kiocb *req,
+					  task_work_func_t cb)
+{
+	init_task_work(&req->task_work, cb);
+	io_task_work_add_head(&req->ctx->exit_task_work, &req->task_work);
 }
 
 static void __io_req_task_cancel(struct io_kiocb *req, int error)
@@ -2843,7 +2891,7 @@ static struct io_buffer *io_buffer_select(struct io_kiocb *req, size_t *len,
 
 	lockdep_assert_held(&req->ctx->uring_lock);
 
-	head = idr_find(&req->ctx->io_buffer_idr, bgid);
+	head = xa_load(&req->ctx->io_buffers, bgid);
 	if (head) {
 		if (!list_empty(&head->list)) {
 			kbuf = list_last_entry(&head->list, struct io_buffer,
@@ -2851,7 +2899,7 @@ static struct io_buffer *io_buffer_select(struct io_kiocb *req, size_t *len,
 			list_del(&kbuf->list);
 		} else {
 			kbuf = head;
-			idr_remove(&req->ctx->io_buffer_idr, bgid);
+			xa_erase(&req->ctx->io_buffers, bgid);
 		}
 		if (*len > kbuf->len)
 			*len = kbuf->len;
@@ -3892,7 +3940,7 @@ static int __io_remove_buffers(struct io_ring_ctx *ctx, struct io_buffer *buf,
 	}
 	i++;
 	kfree(buf);
-	idr_remove(&ctx->io_buffer_idr, bgid);
+	xa_erase(&ctx->io_buffers, bgid);
 
 	return i;
 }
@@ -3910,7 +3958,7 @@ static int io_remove_buffers(struct io_kiocb *req, unsigned int issue_flags)
 	lockdep_assert_held(&ctx->uring_lock);
 
 	ret = -ENOENT;
-	head = idr_find(&ctx->io_buffer_idr, p->bgid);
+	head = xa_load(&ctx->io_buffers, p->bgid);
 	if (head)
 		ret = __io_remove_buffers(ctx, head, p->bgid, p->nbufs);
 	if (ret < 0)
@@ -3993,21 +4041,14 @@ static int io_provide_buffers(struct io_kiocb *req, unsigned int issue_flags)
 
 	lockdep_assert_held(&ctx->uring_lock);
 
-	list = head = idr_find(&ctx->io_buffer_idr, p->bgid);
+	list = head = xa_load(&ctx->io_buffers, p->bgid);
 
 	ret = io_add_buffers(p, &head);
-	if (ret < 0)
-		goto out;
-
-	if (!list) {
-		ret = idr_alloc(&ctx->io_buffer_idr, head, p->bgid, p->bgid + 1,
-					GFP_KERNEL);
-		if (ret < 0) {
+	if (ret >= 0 && !list) {
+		ret = xa_insert(&ctx->io_buffers, p->bgid, head, GFP_KERNEL);
+		if (ret < 0)
 			__io_remove_buffers(ctx, head, p->bgid, -1U);
-			goto out;
-		}
 	}
-out:
 	if (ret < 0)
 		req_set_fail_links(req);
 
@@ -4359,7 +4400,7 @@ static int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
 		kmsg = &iomsg;
 	}
 
-	flags = req->sr_msg.msg_flags;
+	flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
 	if (flags & MSG_DONTWAIT)
 		req->flags |= REQ_F_NOWAIT;
 	else if (issue_flags & IO_URING_F_NONBLOCK)
@@ -4403,7 +4444,7 @@ static int io_send(struct io_kiocb *req, unsigned int issue_flags)
 	msg.msg_controllen = 0;
 	msg.msg_namelen = 0;
 
-	flags = req->sr_msg.msg_flags;
+	flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
 	if (flags & MSG_DONTWAIT)
 		req->flags |= REQ_F_NOWAIT;
 	else if (issue_flags & IO_URING_F_NONBLOCK)
@@ -4593,7 +4634,7 @@ static int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 				1, req->sr_msg.len);
 	}
 
-	flags = req->sr_msg.msg_flags;
+	flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
 	if (flags & MSG_DONTWAIT)
 		req->flags |= REQ_F_NOWAIT;
 	else if (force_nonblock)
@@ -4652,7 +4693,7 @@ static int io_recv(struct io_kiocb *req, unsigned int issue_flags)
 	msg.msg_iocb = NULL;
 	msg.msg_flags = 0;
 
-	flags = req->sr_msg.msg_flags;
+	flags = req->sr_msg.msg_flags | MSG_NOSIGNAL;
 	if (flags & MSG_DONTWAIT)
 		req->flags |= REQ_F_NOWAIT;
 	else if (force_nonblock)
@@ -6204,7 +6245,6 @@ static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer)
 	spin_unlock_irqrestore(&ctx->completion_lock, flags);
 
 	if (prev) {
-		req_set_fail_links(prev);
 		io_async_find_and_cancel(ctx, req, prev->user_data, -ETIME);
 		io_put_req_deferred(prev, 1);
 	} else {
@@ -6694,17 +6734,17 @@ static int io_sq_thread(void *data)
 		set_cpus_allowed_ptr(current, cpu_online_mask);
 	current->flags |= PF_NO_SETAFFINITY;
 
-	down_read(&sqd->rw_lock);
-
+	mutex_lock(&sqd->lock);
 	while (!test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state)) {
 		int ret;
 		bool cap_entries, sqt_spin, needs_sched;
 
 		if (test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state)) {
-			up_read(&sqd->rw_lock);
+			mutex_unlock(&sqd->lock);
 			cond_resched();
-			down_read(&sqd->rw_lock);
+			mutex_lock(&sqd->lock);
 			io_run_task_work();
+			io_run_task_work_head(&sqd->park_task_work);
 			timeout = jiffies + sqd->sq_thread_idle;
 			continue;
 		}
@@ -6750,32 +6790,28 @@ static int io_sq_thread(void *data)
 			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 				io_ring_set_wakeup_flag(ctx);
 
-			up_read(&sqd->rw_lock);
+			mutex_unlock(&sqd->lock);
 			schedule();
-			down_read(&sqd->rw_lock);
+			try_to_freeze();
+			mutex_lock(&sqd->lock);
 			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 				io_ring_clear_wakeup_flag(ctx);
 		}
 
 		finish_wait(&sqd->wait, &wait);
+		io_run_task_work_head(&sqd->park_task_work);
 		timeout = jiffies + sqd->sq_thread_idle;
 	}
-	up_read(&sqd->rw_lock);
-	down_write(&sqd->rw_lock);
-	/*
-	 * someone may have parked and added a cancellation task_work, run
-	 * it first because we don't want it in io_uring_cancel_sqpoll()
-	 */
-	io_run_task_work();
 
 	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 		io_uring_cancel_sqpoll(ctx);
 	sqd->thread = NULL;
 	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 		io_ring_set_wakeup_flag(ctx);
-	up_write(&sqd->rw_lock);
+	mutex_unlock(&sqd->lock);
 
 	io_run_task_work();
+	io_run_task_work_head(&sqd->park_task_work);
 	complete(&sqd->exited);
 	do_exit(0);
 }
@@ -7075,23 +7111,28 @@ static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 }
 
 static void io_sq_thread_unpark(struct io_sq_data *sqd)
-	__releases(&sqd->rw_lock)
+	__releases(&sqd->lock)
 {
 	WARN_ON_ONCE(sqd->thread == current);
 
+	/*
+	 * Do the dance but not conditional clear_bit() because it'd race with
+	 * other threads incrementing park_pending and setting the bit.
+	 */
 	clear_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-	up_write(&sqd->rw_lock);
+	if (atomic_dec_return(&sqd->park_pending))
+		set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
+	mutex_unlock(&sqd->lock);
 }
 
 static void io_sq_thread_park(struct io_sq_data *sqd)
-	__acquires(&sqd->rw_lock)
+	__acquires(&sqd->lock)
 {
 	WARN_ON_ONCE(sqd->thread == current);
 
+	atomic_inc(&sqd->park_pending);
 	set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-	down_write(&sqd->rw_lock);
-	/* set again for consistency, in case concurrent parks are happening */
-	set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
+	mutex_lock(&sqd->lock);
 	if (sqd->thread)
 		wake_up_process(sqd->thread);
 }
@@ -7100,17 +7141,19 @@ static void io_sq_thread_stop(struct io_sq_data *sqd)
 {
 	WARN_ON_ONCE(sqd->thread == current);
 
-	down_write(&sqd->rw_lock);
+	mutex_lock(&sqd->lock);
 	set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
 	if (sqd->thread)
 		wake_up_process(sqd->thread);
-	up_write(&sqd->rw_lock);
+	mutex_unlock(&sqd->lock);
 	wait_for_completion(&sqd->exited);
 }
 
 static void io_put_sq_data(struct io_sq_data *sqd)
 {
 	if (refcount_dec_and_test(&sqd->refs)) {
+		WARN_ON_ONCE(atomic_read(&sqd->park_pending));
+
 		io_sq_thread_stop(sqd);
 		kfree(sqd);
 	}
@@ -7184,9 +7227,10 @@ static struct io_sq_data *io_get_sq_data(struct io_uring_params *p,
 	if (!sqd)
 		return ERR_PTR(-ENOMEM);
 
+	atomic_set(&sqd->park_pending, 0);
 	refcount_set(&sqd->refs, 1);
 	INIT_LIST_HEAD(&sqd->ctx_list);
-	init_rwsem(&sqd->rw_lock);
+	mutex_init(&sqd->lock);
 	init_waitqueue_head(&sqd->wait);
 	init_completion(&sqd->exited);
 	return sqd;
@@ -7866,22 +7910,17 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 
 		ret = 0;
 		io_sq_thread_park(sqd);
+		list_add(&ctx->sqd_list, &sqd->ctx_list);
+		io_sqd_update_thread_idle(sqd);
 		/* don't attach to a dying SQPOLL thread, would be racy */
-		if (attached && !sqd->thread) {
+		if (attached && !sqd->thread)
 			ret = -ENXIO;
-		} else {
-			list_add(&ctx->sqd_list, &sqd->ctx_list);
-			io_sqd_update_thread_idle(sqd);
-		}
 		io_sq_thread_unpark(sqd);
 
-		if (ret < 0) {
-			io_put_sq_data(sqd);
-			ctx->sq_data = NULL;
-			return ret;
-		} else if (attached) {
+		if (ret < 0)
+			goto err;
+		if (attached)
 			return 0;
-		}
 
 		if (p->flags & IORING_SETUP_SQ_AFF) {
 			int cpu = p->sq_thread_cpu;
@@ -8332,19 +8371,13 @@ static int io_eventfd_unregister(struct io_ring_ctx *ctx)
 	return -ENXIO;
 }
 
-static int __io_destroy_buffers(int id, void *p, void *data)
-{
-	struct io_ring_ctx *ctx = data;
-	struct io_buffer *buf = p;
-
-	__io_remove_buffers(ctx, buf, id, -1U);
-	return 0;
-}
-
 static void io_destroy_buffers(struct io_ring_ctx *ctx)
 {
-	idr_for_each(&ctx->io_buffer_idr, __io_destroy_buffers, ctx);
-	idr_destroy(&ctx->io_buffer_idr);
+	struct io_buffer *buf;
+	unsigned long index;
+
+	xa_for_each(&ctx->io_buffers, index, buf)
+		__io_remove_buffers(ctx, buf, index, -1U);
 }
 
 static void io_req_cache_free(struct list_head *list, struct task_struct *tsk)
@@ -8386,11 +8419,13 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
 {
 	/*
 	 * Some may use context even when all refs and requests have been put,
-	 * and they are free to do so while still holding uring_lock, see
-	 * __io_req_task_submit(). Wait for them to finish.
+	 * and they are free to do so while still holding uring_lock or
+	 * completion_lock, see __io_req_task_submit(). Wait for them to finish.
 	 */
 	mutex_lock(&ctx->uring_lock);
 	mutex_unlock(&ctx->uring_lock);
+	spin_lock_irq(&ctx->completion_lock);
+	spin_unlock_irq(&ctx->completion_lock);
 
 	io_sq_thread_finish(ctx);
 	io_sqe_buffers_unregister(ctx);
@@ -8478,26 +8513,9 @@ static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
 	return -EINVAL;
 }
 
-static bool io_run_ctx_fallback(struct io_ring_ctx *ctx)
+static inline bool io_run_ctx_fallback(struct io_ring_ctx *ctx)
 {
-	struct callback_head *work, *next;
-	bool executed = false;
-
-	do {
-		work = xchg(&ctx->exit_task_work, NULL);
-		if (!work)
-			break;
-
-		do {
-			next = work->next;
-			work->func(work);
-			work = next;
-			cond_resched();
-		} while (work);
-		executed = true;
-	} while (1);
-
-	return executed;
+	return io_run_task_work_head(&ctx->exit_task_work);
 }
 
 struct io_tctx_exit {
@@ -8580,6 +8598,14 @@ static void io_ring_ctx_wait_and_kill(struct io_ring_ctx *ctx)
 		io_unregister_personality(ctx, index);
 	mutex_unlock(&ctx->uring_lock);
 
+	/* prevent SQPOLL from submitting new requests */
+	if (ctx->sq_data) {
+		io_sq_thread_park(ctx->sq_data);
+		list_del_init(&ctx->sqd_list);
+		io_sqd_update_thread_idle(ctx->sq_data);
+		io_sq_thread_unpark(ctx->sq_data);
+	}
+
 	io_kill_timeouts(ctx, NULL, NULL);
 	io_poll_remove_all(ctx, NULL, NULL);
 
@@ -8879,7 +8905,7 @@ static void io_sqpoll_cancel_sync(struct io_ring_ctx *ctx)
 	if (task) {
 		init_completion(&work.completion);
 		init_task_work(&work.task_work, io_sqpoll_cancel_cb);
-		WARN_ON_ONCE(task_work_add(task, &work.task_work, TWA_SIGNAL));
+		io_task_work_add_head(&sqd->park_task_work, &work.task_work);
 		wake_up_process(task);
 	}
 	io_sq_thread_unpark(sqd);
diff --git a/include/linux/io_uring.h b/include/linux/io_uring.h
index 9761a0ec9f95c35b615c19a7aa14d2f5999b7d1c..79cde9906be0486c702dc09a464c193ddf22f6b8 100644
--- a/include/linux/io_uring.h
+++ b/include/linux/io_uring.h
@@ -5,31 +5,6 @@
 #include <linux/sched.h>
 #include <linux/xarray.h>
 
-struct io_wq_work_node {
-	struct io_wq_work_node *next;
-};
-
-struct io_wq_work_list {
-	struct io_wq_work_node *first;
-	struct io_wq_work_node *last;
-};
-
-struct io_uring_task {
-	/* submission side */
-	struct xarray		xa;
-	struct wait_queue_head	wait;
-	void			*last;
-	void			*io_wq;
-	struct percpu_counter	inflight;
-	atomic_t		in_idle;
-	bool			sqpoll;
-
-	spinlock_t		task_lock;
-	struct io_wq_work_list	task_list;
-	unsigned long		task_state;
-	struct callback_head	task_work;
-};
-
 #if defined(CONFIG_IO_URING)
 struct sock *io_uring_get_socket(struct file *file);
 void __io_uring_task_cancel(void);
diff --git a/kernel/fork.c b/kernel/fork.c
index 0acc8ed1076b7a0e92499dad63065fa458efc5cc..54cc905e5fe095d0081ec943ac2a16e487f2235b 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -2444,7 +2444,6 @@ struct task_struct *create_io_thread(int (*fn)(void *), void *arg, int node)
 	if (!IS_ERR(tsk)) {
 		sigfillset(&tsk->blocked);
 		sigdelsetmask(&tsk->blocked, sigmask(SIGKILL));
-		tsk->flags |= PF_NOFREEZE;
 	}
 	return tsk;
 }
diff --git a/kernel/freezer.c b/kernel/freezer.c
index dc520f01f99ddc053366ab5fdbaffd126f519c24..1a2d57d1327cd6d467d8ce86117bc2174652068e 100644
--- a/kernel/freezer.c
+++ b/kernel/freezer.c
@@ -134,7 +134,7 @@ bool freeze_task(struct task_struct *p)
 		return false;
 	}
 
-	if (!(p->flags & PF_KTHREAD))
+	if (!(p->flags & (PF_KTHREAD | PF_IO_WORKER)))
 		fake_signal_wake_up(p);
 	else
 		wake_up_state(p, TASK_INTERRUPTIBLE);