diff --git a/fs/io_uring.c b/fs/io_uring.c
index 7cf96be691d8978fa40dffa98a14e56325b3f6f8..2a3542b487ff0bf5be1f3db22d8b2a32c5eecd47 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -258,12 +258,11 @@ enum {
 
 struct io_sq_data {
 	refcount_t		refs;
-	struct mutex		lock;
+	struct rw_semaphore	rw_lock;
 
 	/* ctx's that are using this sqd */
 	struct list_head	ctx_list;
 	struct list_head	ctx_new_list;
-	struct mutex		ctx_lock;
 
 	struct task_struct	*thread;
 	struct wait_queue_head	wait;
@@ -274,7 +273,6 @@ struct io_sq_data {
 
 	unsigned long		state;
 	struct completion	startup;
-	struct completion	parked;
 	struct completion	exited;
 };
 
@@ -6638,45 +6636,6 @@ static void io_sqd_init_new(struct io_sq_data *sqd)
 	io_sqd_update_thread_idle(sqd);
 }
 
-static bool io_sq_thread_should_stop(struct io_sq_data *sqd)
-{
-	return test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-}
-
-static bool io_sq_thread_should_park(struct io_sq_data *sqd)
-{
-	return test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-}
-
-static void io_sq_thread_parkme(struct io_sq_data *sqd)
-{
-	for (;;) {
-		/*
-		 * TASK_PARKED is a special state; we must serialize against
-		 * possible pending wakeups to avoid store-store collisions on
-		 * task->state.
-		 *
-		 * Such a collision might possibly result in the task state
-		 * changin from TASK_PARKED and us failing the
-		 * wait_task_inactive() in kthread_park().
-		 */
-		set_special_state(TASK_PARKED);
-		if (!test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state))
-			break;
-
-		/*
-		 * Thread is going to call schedule(), do not preempt it,
-		 * or the caller of kthread_park() may spend more time in
-		 * wait_task_inactive().
-		 */
-		preempt_disable();
-		complete(&sqd->parked);
-		schedule_preempt_disabled();
-		preempt_enable();
-	}
-	__set_current_state(TASK_RUNNING);
-}
-
 static int io_sq_thread(void *data)
 {
 	struct io_sq_data *sqd = data;
@@ -6697,17 +6656,16 @@ static int io_sq_thread(void *data)
 
 	wait_for_completion(&sqd->startup);
 
-	while (!io_sq_thread_should_stop(sqd)) {
+	down_read(&sqd->rw_lock);
+
+	while (!test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state)) {
 		int ret;
 		bool cap_entries, sqt_spin, needs_sched;
 
-		/*
-		 * Any changes to the sqd lists are synchronized through the
-		 * thread parking. This synchronizes the thread vs users,
-		 * the users are synchronized on the sqd->ctx_lock.
-		 */
-		if (io_sq_thread_should_park(sqd)) {
-			io_sq_thread_parkme(sqd);
+		if (test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state)) {
+			up_read(&sqd->rw_lock);
+			cond_resched();
+			down_read(&sqd->rw_lock);
 			continue;
 		}
 		if (unlikely(!list_empty(&sqd->ctx_new_list))) {
@@ -6752,12 +6710,14 @@ static int io_sq_thread(void *data)
 			}
 		}
 
-		if (needs_sched && !io_sq_thread_should_park(sqd)) {
+		if (needs_sched && !test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state)) {
 			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 				io_ring_set_wakeup_flag(ctx);
 
+			up_read(&sqd->rw_lock);
 			schedule();
 			try_to_freeze();
+			down_read(&sqd->rw_lock);
 			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 				io_ring_clear_wakeup_flag(ctx);
 		}
@@ -6768,28 +6728,16 @@ static int io_sq_thread(void *data)
 
 	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 		io_uring_cancel_sqpoll(ctx);
+	up_read(&sqd->rw_lock);
 
 	io_run_task_work();
 
-	/*
-	 * Ensure that we park properly if racing with someone trying to park
-	 * while we're exiting. If we fail to grab the lock, check park and
-	 * park if necessary. The ordering with the park bit and the lock
-	 * ensures that we catch this reliably.
-	 */
-	if (!mutex_trylock(&sqd->lock)) {
-		if (io_sq_thread_should_park(sqd))
-			io_sq_thread_parkme(sqd);
-		mutex_lock(&sqd->lock);
-	}
-
+	down_write(&sqd->rw_lock);
 	sqd->thread = NULL;
-	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
+	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 		io_ring_set_wakeup_flag(ctx);
-	}
-
+	up_write(&sqd->rw_lock);
 	complete(&sqd->exited);
-	mutex_unlock(&sqd->lock);
 	do_exit(0);
 }
 
@@ -7088,44 +7036,40 @@ static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 }
 
 static void io_sq_thread_unpark(struct io_sq_data *sqd)
-	__releases(&sqd->lock)
+	__releases(&sqd->rw_lock)
 {
 	if (sqd->thread == current)
 		return;
 	clear_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-	if (sqd->thread)
-		wake_up_state(sqd->thread, TASK_PARKED);
-	mutex_unlock(&sqd->lock);
+	up_write(&sqd->rw_lock);
 }
 
 static void io_sq_thread_park(struct io_sq_data *sqd)
-	__acquires(&sqd->lock)
+	__acquires(&sqd->rw_lock)
 {
 	if (sqd->thread == current)
 		return;
 	set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-	mutex_lock(&sqd->lock);
-	if (sqd->thread) {
+	down_write(&sqd->rw_lock);
+	/* set again for consistency, in case concurrent parks are happening */
+	set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
+	if (sqd->thread)
 		wake_up_process(sqd->thread);
-		wait_for_completion(&sqd->parked);
-	}
 }
 
 static void io_sq_thread_stop(struct io_sq_data *sqd)
 {
 	if (test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state))
 		return;
-	mutex_lock(&sqd->lock);
-	if (sqd->thread) {
-		set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-		WARN_ON_ONCE(test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state));
-		wake_up_process(sqd->thread);
-		mutex_unlock(&sqd->lock);
-		wait_for_completion(&sqd->exited);
-		WARN_ON_ONCE(sqd->thread);
-	} else {
-		mutex_unlock(&sqd->lock);
+	down_write(&sqd->rw_lock);
+	if (!sqd->thread) {
+		up_write(&sqd->rw_lock);
+		return;
 	}
+	set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
+	wake_up_process(sqd->thread);
+	up_write(&sqd->rw_lock);
+	wait_for_completion(&sqd->exited);
 }
 
 static void io_put_sq_data(struct io_sq_data *sqd)
@@ -7142,18 +7086,13 @@ static void io_sq_thread_finish(struct io_ring_ctx *ctx)
 
 	if (sqd) {
 		complete(&sqd->startup);
-		if (sqd->thread) {
+		if (sqd->thread)
 			wait_for_completion(&ctx->sq_thread_comp);
-			io_sq_thread_park(sqd);
-		}
 
-		mutex_lock(&sqd->ctx_lock);
+		io_sq_thread_park(sqd);
 		list_del(&ctx->sqd_list);
 		io_sqd_update_thread_idle(sqd);
-		mutex_unlock(&sqd->ctx_lock);
-
-		if (sqd->thread)
-			io_sq_thread_unpark(sqd);
+		io_sq_thread_unpark(sqd);
 
 		io_put_sq_data(sqd);
 		ctx->sq_data = NULL;
@@ -7202,11 +7141,9 @@ static struct io_sq_data *io_get_sq_data(struct io_uring_params *p)
 	refcount_set(&sqd->refs, 1);
 	INIT_LIST_HEAD(&sqd->ctx_list);
 	INIT_LIST_HEAD(&sqd->ctx_new_list);
-	mutex_init(&sqd->ctx_lock);
-	mutex_init(&sqd->lock);
+	init_rwsem(&sqd->rw_lock);
 	init_waitqueue_head(&sqd->wait);
 	init_completion(&sqd->startup);
-	init_completion(&sqd->parked);
 	init_completion(&sqd->exited);
 	return sqd;
 }
@@ -7880,9 +7817,7 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 		ctx->sq_creds = get_current_cred();
 		ctx->sq_data = sqd;
 		io_sq_thread_park(sqd);
-		mutex_lock(&sqd->ctx_lock);
 		list_add(&ctx->sqd_list, &sqd->ctx_new_list);
-		mutex_unlock(&sqd->ctx_lock);
 		io_sq_thread_unpark(sqd);
 
 		ctx->sq_thread_idle = msecs_to_jiffies(p->sq_thread_idle);