diff --git a/fs/io_uring.c b/fs/io_uring.c
index a9d094f7060f439548a02bf4463d6892b461924c..58dd104811062f28b6001bde9e78ea6f88989575 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -2186,10 +2186,10 @@ static bool __tctx_task_work(struct io_uring_task *tctx)
 	if (wq_list_empty(&tctx->task_list))
 		return false;
 
-	spin_lock(&tctx->task_lock);
+	spin_lock_irq(&tctx->task_lock);
 	list = tctx->task_list;
 	INIT_WQ_LIST(&tctx->task_list);
-	spin_unlock(&tctx->task_lock);
+	spin_unlock_irq(&tctx->task_lock);
 
 	node = list.first;
 	while (node) {
@@ -2236,13 +2236,14 @@ static int io_task_work_add(struct task_struct *tsk, struct io_kiocb *req,
 {
 	struct io_uring_task *tctx = tsk->io_uring;
 	struct io_wq_work_node *node, *prev;
+	unsigned long flags;
 	int ret;
 
 	WARN_ON_ONCE(!tctx);
 
-	spin_lock(&tctx->task_lock);
+	spin_lock_irqsave(&tctx->task_lock, flags);
 	wq_list_add_tail(&req->io_task_work.node, &tctx->task_list);
-	spin_unlock(&tctx->task_lock);
+	spin_unlock_irqrestore(&tctx->task_lock, flags);
 
 	/* task_work already pending, we're done */
 	if (test_bit(0, &tctx->task_state) ||
@@ -2257,7 +2258,7 @@ static int io_task_work_add(struct task_struct *tsk, struct io_kiocb *req,
 	 * in the list, it got run and we're fine.
 	 */
 	ret = 0;
-	spin_lock(&tctx->task_lock);
+	spin_lock_irqsave(&tctx->task_lock, flags);
 	wq_list_for_each(node, prev, &tctx->task_list) {
 		if (&req->io_task_work.node == node) {
 			wq_list_del(&tctx->task_list, node, prev);
@@ -2265,7 +2266,7 @@ static int io_task_work_add(struct task_struct *tsk, struct io_kiocb *req,
 			break;
 		}
 	}
-	spin_unlock(&tctx->task_lock);
+	spin_unlock_irqrestore(&tctx->task_lock, flags);
 	clear_bit(0, &tctx->task_state);
 	return ret;
 }