diff --git a/fs/io_uring.c b/fs/io_uring.c
index 5c2de43b99f53f084d486b91004ee8871576606f..5f4e312111ea730b167d9cd987137ba09e435236 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -274,6 +274,7 @@ struct io_sq_data {
 
 	unsigned long		state;
 	struct completion	exited;
+	struct callback_head	*park_task_work;
 };
 
 #define IO_IOPOLL_BATCH			8
@@ -6727,6 +6728,7 @@ static int io_sq_thread(void *data)
 			cond_resched();
 			mutex_lock(&sqd->lock);
 			io_run_task_work();
+			io_run_task_work_head(&sqd->park_task_work);
 			timeout = jiffies + sqd->sq_thread_idle;
 			continue;
 		}
@@ -6781,6 +6783,7 @@ static int io_sq_thread(void *data)
 		}
 
 		finish_wait(&sqd->wait, &wait);
+		io_run_task_work_head(&sqd->park_task_work);
 		timeout = jiffies + sqd->sq_thread_idle;
 	}
 
@@ -6792,6 +6795,7 @@ static int io_sq_thread(void *data)
 	mutex_unlock(&sqd->lock);
 
 	io_run_task_work();
+	io_run_task_work_head(&sqd->park_task_work);
 	complete(&sqd->exited);
 	do_exit(0);
 }
@@ -8890,7 +8894,7 @@ static void io_sqpoll_cancel_sync(struct io_ring_ctx *ctx)
 	if (task) {
 		init_completion(&work.completion);
 		init_task_work(&work.task_work, io_sqpoll_cancel_cb);
-		WARN_ON_ONCE(task_work_add(task, &work.task_work, TWA_SIGNAL));
+		io_task_work_add_head(&sqd->park_task_work, &work.task_work);
 		wake_up_process(task);
 	}
 	io_sq_thread_unpark(sqd);