diff --git a/fs/io_uring.c b/fs/io_uring.c
index 70ceb8ed5950033c678b541dbbd62985c9f2e695..5c2de43b99f53f084d486b91004ee8871576606f 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -1929,17 +1929,44 @@ static int io_req_task_work_add(struct io_kiocb *req)
 	return ret;
 }
 
-static void io_req_task_work_add_fallback(struct io_kiocb *req,
-					  task_work_func_t cb)
+static bool io_run_task_work_head(struct callback_head **work_head)
+{
+	struct callback_head *work, *next;
+	bool executed = false;
+
+	do {
+		work = xchg(work_head, NULL);
+		if (!work)
+			break;
+
+		do {
+			next = work->next;
+			work->func(work);
+			work = next;
+			cond_resched();
+		} while (work);
+		executed = true;
+	} while (1);
+
+	return executed;
+}
+
+static void io_task_work_add_head(struct callback_head **work_head,
+				  struct callback_head *task_work)
 {
-	struct io_ring_ctx *ctx = req->ctx;
 	struct callback_head *head;
 
-	init_task_work(&req->task_work, cb);
 	do {
-		head = READ_ONCE(ctx->exit_task_work);
-		req->task_work.next = head;
-	} while (cmpxchg(&ctx->exit_task_work, head, &req->task_work) != head);
+		head = READ_ONCE(*work_head);
+		task_work->next = head;
+	} while (cmpxchg(work_head, head, task_work) != head);
+}
+
+static void io_req_task_work_add_fallback(struct io_kiocb *req,
+					  task_work_func_t cb)
+{
+	init_task_work(&req->task_work, cb);
+	io_task_work_add_head(&req->ctx->exit_task_work, &req->task_work);
 }
 
 static void __io_req_task_cancel(struct io_kiocb *req, int error)
@@ -8471,26 +8498,9 @@ static int io_unregister_personality(struct io_ring_ctx *ctx, unsigned id)
 	return -EINVAL;
 }
 
-static bool io_run_ctx_fallback(struct io_ring_ctx *ctx)
+static inline bool io_run_ctx_fallback(struct io_ring_ctx *ctx)
 {
-	struct callback_head *work, *next;
-	bool executed = false;
-
-	do {
-		work = xchg(&ctx->exit_task_work, NULL);
-		if (!work)
-			break;
-
-		do {
-			next = work->next;
-			work->func(work);
-			work = next;
-			cond_resched();
-		} while (work);
-		executed = true;
-	} while (1);
-
-	return executed;
+	return io_run_task_work_head(&ctx->exit_task_work);
 }
 
 struct io_tctx_exit {