diff --git a/fs/io-wq.c b/fs/io-wq.c
index 0b65a912b0369f0427d6ca784dab78a14ee73c8d..47c5f3aeb46007eae5a031e825a45a699ace2770 100644
--- a/fs/io-wq.c
+++ b/fs/io-wq.c
@@ -903,13 +903,15 @@ void io_wq_cancel_all(struct io_wq *wq)
 struct io_cb_cancel_data {
 	work_cancel_fn *fn;
 	void *data;
+	int nr_running;
+	int nr_pending;
+	bool cancel_all;
 };
 
 static bool io_wq_worker_cancel(struct io_worker *worker, void *data)
 {
 	struct io_cb_cancel_data *match = data;
 	unsigned long flags;
-	bool ret = false;
 
 	/*
 	 * Hold the lock to avoid ->cur_work going out of scope, caller
@@ -920,74 +922,90 @@ static bool io_wq_worker_cancel(struct io_worker *worker, void *data)
 	    !(worker->cur_work->flags & IO_WQ_WORK_NO_CANCEL) &&
 	    match->fn(worker->cur_work, match->data)) {
 		send_sig(SIGINT, worker->task, 1);
-		ret = true;
+		match->nr_running++;
 	}
 	spin_unlock_irqrestore(&worker->lock, flags);
 
-	return ret;
+	return match->nr_running && !match->cancel_all;
 }
 
-static enum io_wq_cancel io_wqe_cancel_work(struct io_wqe *wqe,
-					    struct io_cb_cancel_data *match)
+static void io_wqe_cancel_pending_work(struct io_wqe *wqe,
+				       struct io_cb_cancel_data *match)
 {
 	struct io_wq_work_node *node, *prev;
 	struct io_wq_work *work;
 	unsigned long flags;
-	bool found = false;
 
-	/*
-	 * First check pending list, if we're lucky we can just remove it
-	 * from there. CANCEL_OK means that the work is returned as-new,
-	 * no completion will be posted for it.
-	 */
+retry:
 	spin_lock_irqsave(&wqe->lock, flags);
 	wq_list_for_each(node, prev, &wqe->work_list) {
 		work = container_of(node, struct io_wq_work, list);
+		if (!match->fn(work, match->data))
+			continue;
 
-		if (match->fn(work, match->data)) {
-			wq_list_del(&wqe->work_list, node, prev);
-			found = true;
-			break;
-		}
-	}
-	spin_unlock_irqrestore(&wqe->lock, flags);
-
-	if (found) {
+		wq_list_del(&wqe->work_list, node, prev);
+		spin_unlock_irqrestore(&wqe->lock, flags);
 		io_run_cancel(work, wqe);
-		return IO_WQ_CANCEL_OK;
+		match->nr_pending++;
+		if (!match->cancel_all)
+			return;
+
+		/* not safe to continue after unlock */
+		goto retry;
 	}
+	spin_unlock_irqrestore(&wqe->lock, flags);
+}
 
-	/*
-	 * Now check if a free (going busy) or busy worker has the work
-	 * currently running. If we find it there, we'll return CANCEL_RUNNING
-	 * as an indication that we attempt to signal cancellation. The
-	 * completion will run normally in this case.
-	 */
+static void io_wqe_cancel_running_work(struct io_wqe *wqe,
+				       struct io_cb_cancel_data *match)
+{
 	rcu_read_lock();
-	found = io_wq_for_each_worker(wqe, io_wq_worker_cancel, match);
+	io_wq_for_each_worker(wqe, io_wq_worker_cancel, match);
 	rcu_read_unlock();
-	return found ? IO_WQ_CANCEL_RUNNING : IO_WQ_CANCEL_NOTFOUND;
 }
 
 enum io_wq_cancel io_wq_cancel_cb(struct io_wq *wq, work_cancel_fn *cancel,
-				  void *data)
+				  void *data, bool cancel_all)
 {
 	struct io_cb_cancel_data match = {
-		.fn	= cancel,
-		.data	= data,
+		.fn		= cancel,
+		.data		= data,
+		.cancel_all	= cancel_all,
 	};
-	enum io_wq_cancel ret = IO_WQ_CANCEL_NOTFOUND;
 	int node;
 
+	/*
+	 * First check pending list, if we're lucky we can just remove it
+	 * from there. CANCEL_OK means that the work is returned as-new,
+	 * no completion will be posted for it.
+	 */
 	for_each_node(node) {
 		struct io_wqe *wqe = wq->wqes[node];
 
-		ret = io_wqe_cancel_work(wqe, &match);
-		if (ret != IO_WQ_CANCEL_NOTFOUND)
-			break;
+		io_wqe_cancel_pending_work(wqe, &match);
+		if (match.nr_pending && !match.cancel_all)
+			return IO_WQ_CANCEL_OK;
 	}
 
-	return ret;
+	/*
+	 * Now check if a free (going busy) or busy worker has the work
+	 * currently running. If we find it there, we'll return CANCEL_RUNNING
+	 * as an indication that we attempt to signal cancellation. The
+	 * completion will run normally in this case.
+	 */
+	for_each_node(node) {
+		struct io_wqe *wqe = wq->wqes[node];
+
+		io_wqe_cancel_running_work(wqe, &match);
+		if (match.nr_running && !match.cancel_all)
+			return IO_WQ_CANCEL_RUNNING;
+	}
+
+	if (match.nr_running)
+		return IO_WQ_CANCEL_RUNNING;
+	if (match.nr_pending)
+		return IO_WQ_CANCEL_OK;
+	return IO_WQ_CANCEL_NOTFOUND;
 }
 
 static bool io_wq_io_cb_cancel_data(struct io_wq_work *work, void *data)
@@ -997,21 +1015,7 @@ static bool io_wq_io_cb_cancel_data(struct io_wq_work *work, void *data)
 
 enum io_wq_cancel io_wq_cancel_work(struct io_wq *wq, struct io_wq_work *cwork)
 {
-	return io_wq_cancel_cb(wq, io_wq_io_cb_cancel_data, (void *)cwork);
-}
-
-static bool io_wq_pid_match(struct io_wq_work *work, void *data)
-{
-	pid_t pid = (pid_t) (unsigned long) data;
-
-	return work->task_pid == pid;
-}
-
-enum io_wq_cancel io_wq_cancel_pid(struct io_wq *wq, pid_t pid)
-{
-	void *data = (void *) (unsigned long) pid;
-
-	return io_wq_cancel_cb(wq, io_wq_pid_match, data);
+	return io_wq_cancel_cb(wq, io_wq_io_cb_cancel_data, (void *)cwork, false);
 }
 
 struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
diff --git a/fs/io-wq.h b/fs/io-wq.h
index 8e138fa88b9fcc848e7c82bad29f5678f70d08fa..071f1a9978002675c48294ad518a725c52c0bd05 100644
--- a/fs/io-wq.h
+++ b/fs/io-wq.h
@@ -90,7 +90,6 @@ struct io_wq_work {
 	const struct cred *creds;
 	struct fs_struct *fs;
 	unsigned flags;
-	pid_t task_pid;
 };
 
 static inline struct io_wq_work *wq_next_work(struct io_wq_work *work)
@@ -125,12 +124,11 @@ static inline bool io_wq_is_hashed(struct io_wq_work *work)
 
 void io_wq_cancel_all(struct io_wq *wq);
 enum io_wq_cancel io_wq_cancel_work(struct io_wq *wq, struct io_wq_work *cwork);
-enum io_wq_cancel io_wq_cancel_pid(struct io_wq *wq, pid_t pid);
 
 typedef bool (work_cancel_fn)(struct io_wq_work *, void *);
 
 enum io_wq_cancel io_wq_cancel_cb(struct io_wq *wq, work_cancel_fn *cancel,
-					void *data);
+					void *data, bool cancel_all);
 
 struct task_struct *io_wq_get_task(struct io_wq *wq);
 
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 155f3d830ddb36bf7894d606bcb6ed58fafa0676..a78201b9617965acf762c1480c36e9dfb3d71eba 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -541,6 +541,7 @@ enum {
 	REQ_F_NO_FILE_TABLE_BIT,
 	REQ_F_QUEUE_TIMEOUT_BIT,
 	REQ_F_WORK_INITIALIZED_BIT,
+	REQ_F_TASK_PINNED_BIT,
 
 	/* not a real bit, just to check we're not overflowing the space */
 	__REQ_F_LAST_BIT,
@@ -598,6 +599,8 @@ enum {
 	REQ_F_QUEUE_TIMEOUT	= BIT(REQ_F_QUEUE_TIMEOUT_BIT),
 	/* io_wq_work is initialized */
 	REQ_F_WORK_INITIALIZED	= BIT(REQ_F_WORK_INITIALIZED_BIT),
+	/* req->task is refcounted */
+	REQ_F_TASK_PINNED	= BIT(REQ_F_TASK_PINNED_BIT),
 };
 
 struct async_poll {
@@ -910,6 +913,21 @@ struct sock *io_uring_get_socket(struct file *file)
 }
 EXPORT_SYMBOL(io_uring_get_socket);
 
+static void io_get_req_task(struct io_kiocb *req)
+{
+	if (req->flags & REQ_F_TASK_PINNED)
+		return;
+	get_task_struct(req->task);
+	req->flags |= REQ_F_TASK_PINNED;
+}
+
+/* not idempotent -- it doesn't clear REQ_F_TASK_PINNED */
+static void __io_put_req_task(struct io_kiocb *req)
+{
+	if (req->flags & REQ_F_TASK_PINNED)
+		put_task_struct(req->task);
+}
+
 static void io_file_put_work(struct work_struct *work);
 
 /*
@@ -1045,8 +1063,6 @@ static inline void io_req_work_grab_env(struct io_kiocb *req,
 		}
 		spin_unlock(&current->fs->lock);
 	}
-	if (!req->work.task_pid)
-		req->work.task_pid = task_pid_vnr(current);
 }
 
 static inline void io_req_work_drop_env(struct io_kiocb *req)
@@ -1087,6 +1103,7 @@ static inline void io_prep_async_work(struct io_kiocb *req,
 			req->work.flags |= IO_WQ_WORK_UNBOUND;
 	}
 
+	io_req_init_async(req);
 	io_req_work_grab_env(req, def);
 
 	*link = io_prep_linked_timeout(req);
@@ -1398,9 +1415,7 @@ static void __io_req_aux_free(struct io_kiocb *req)
 	kfree(req->io);
 	if (req->file)
 		io_put_file(req, req->file, (req->flags & REQ_F_FIXED_FILE));
-	if (req->task)
-		put_task_struct(req->task);
-
+	__io_put_req_task(req);
 	io_req_work_drop_env(req);
 }
 
@@ -1727,6 +1742,18 @@ static int io_put_kbuf(struct io_kiocb *req)
 	return cflags;
 }
 
+static void io_iopoll_queue(struct list_head *again)
+{
+	struct io_kiocb *req;
+
+	do {
+		req = list_first_entry(again, struct io_kiocb, list);
+		list_del(&req->list);
+		refcount_inc(&req->refs);
+		io_queue_async_work(req);
+	} while (!list_empty(again));
+}
+
 /*
  * Find and free completed poll iocbs
  */
@@ -1735,12 +1762,21 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
 {
 	struct req_batch rb;
 	struct io_kiocb *req;
+	LIST_HEAD(again);
+
+	/* order with ->result store in io_complete_rw_iopoll() */
+	smp_rmb();
 
 	rb.to_free = rb.need_iter = 0;
 	while (!list_empty(done)) {
 		int cflags = 0;
 
 		req = list_first_entry(done, struct io_kiocb, list);
+		if (READ_ONCE(req->result) == -EAGAIN) {
+			req->iopoll_completed = 0;
+			list_move_tail(&req->list, &again);
+			continue;
+		}
 		list_del(&req->list);
 
 		if (req->flags & REQ_F_BUFFER_SELECTED)
@@ -1758,18 +1794,9 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
 	if (ctx->flags & IORING_SETUP_SQPOLL)
 		io_cqring_ev_posted(ctx);
 	io_free_req_many(ctx, &rb);
-}
 
-static void io_iopoll_queue(struct list_head *again)
-{
-	struct io_kiocb *req;
-
-	do {
-		req = list_first_entry(again, struct io_kiocb, list);
-		list_del(&req->list);
-		refcount_inc(&req->refs);
-		io_queue_async_work(req);
-	} while (!list_empty(again));
+	if (!list_empty(&again))
+		io_iopoll_queue(&again);
 }
 
 static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
@@ -1777,7 +1804,6 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
 {
 	struct io_kiocb *req, *tmp;
 	LIST_HEAD(done);
-	LIST_HEAD(again);
 	bool spin;
 	int ret;
 
@@ -1803,13 +1829,6 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
 		if (!list_empty(&done))
 			break;
 
-		if (req->result == -EAGAIN) {
-			list_move_tail(&req->list, &again);
-			continue;
-		}
-		if (!list_empty(&again))
-			break;
-
 		ret = kiocb->ki_filp->f_op->iopoll(kiocb, spin);
 		if (ret < 0)
 			break;
@@ -1822,9 +1841,6 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
 	if (!list_empty(&done))
 		io_iopoll_complete(ctx, nr_events, &done);
 
-	if (!list_empty(&again))
-		io_iopoll_queue(&again);
-
 	return ret;
 }
 
@@ -1973,11 +1989,15 @@ static void io_complete_rw_iopoll(struct kiocb *kiocb, long res, long res2)
 	if (kiocb->ki_flags & IOCB_WRITE)
 		kiocb_end_write(req);
 
-	if (res != req->result)
+	if (res != -EAGAIN && res != req->result)
 		req_set_fail_links(req);
-	req->result = res;
-	if (res != -EAGAIN)
+
+	WRITE_ONCE(req->result, res);
+	/* order with io_poll_complete() checking ->result */
+	if (res != -EAGAIN) {
+		smp_wmb();
 		WRITE_ONCE(req->iopoll_completed, 1);
+	}
 }
 
 /*
@@ -2650,8 +2670,8 @@ static int io_read(struct io_kiocb *req, bool force_nonblock)
 		}
 	}
 out_free:
-	kfree(iovec);
-	req->flags &= ~REQ_F_NEED_CLEANUP;
+	if (!(req->flags & REQ_F_NEED_CLEANUP))
+		kfree(iovec);
 	return ret;
 }
 
@@ -2773,8 +2793,8 @@ static int io_write(struct io_kiocb *req, bool force_nonblock)
 		}
 	}
 out_free:
-	req->flags &= ~REQ_F_NEED_CLEANUP;
-	kfree(iovec);
+	if (!(req->flags & REQ_F_NEED_CLEANUP))
+		kfree(iovec);
 	return ret;
 }
 
@@ -4236,6 +4256,28 @@ static void io_async_queue_proc(struct file *file, struct wait_queue_head *head,
 	__io_queue_proc(&pt->req->apoll->poll, pt, head);
 }
 
+static void io_sq_thread_drop_mm(struct io_ring_ctx *ctx)
+{
+	struct mm_struct *mm = current->mm;
+
+	if (mm) {
+		kthread_unuse_mm(mm);
+		mmput(mm);
+	}
+}
+
+static int io_sq_thread_acquire_mm(struct io_ring_ctx *ctx,
+				   struct io_kiocb *req)
+{
+	if (io_op_defs[req->opcode].needs_mm && !current->mm) {
+		if (unlikely(!mmget_not_zero(ctx->sqo_mm)))
+			return -EFAULT;
+		kthread_use_mm(ctx->sqo_mm);
+	}
+
+	return 0;
+}
+
 static void io_async_task_func(struct callback_head *cb)
 {
 	struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
@@ -4270,11 +4312,16 @@ static void io_async_task_func(struct callback_head *cb)
 
 	if (!canceled) {
 		__set_current_state(TASK_RUNNING);
+		if (io_sq_thread_acquire_mm(ctx, req)) {
+			io_cqring_add_event(req, -EFAULT);
+			goto end_req;
+		}
 		mutex_lock(&ctx->uring_lock);
 		__io_queue_sqe(req, NULL);
 		mutex_unlock(&ctx->uring_lock);
 	} else {
 		io_cqring_ev_posted(ctx);
+end_req:
 		req_set_fail_links(req);
 		io_double_put_req(req);
 	}
@@ -4366,8 +4413,7 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
 		memcpy(&apoll->work, &req->work, sizeof(req->work));
 	had_io = req->io != NULL;
 
-	get_task_struct(current);
-	req->task = current;
+	io_get_req_task(req);
 	req->apoll = apoll;
 	INIT_HLIST_NODE(&req->hash_node);
 
@@ -4555,8 +4601,7 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe
 	events = READ_ONCE(sqe->poll_events);
 	poll->events = demangle_poll(events) | EPOLLERR | EPOLLHUP;
 
-	get_task_struct(current);
-	req->task = current;
+	io_get_req_task(req);
 	return 0;
 }
 
@@ -4772,7 +4817,7 @@ static int io_async_cancel_one(struct io_ring_ctx *ctx, void *sqe_addr)
 	enum io_wq_cancel cancel_ret;
 	int ret = 0;
 
-	cancel_ret = io_wq_cancel_cb(ctx->io_wq, io_cancel_cb, sqe_addr);
+	cancel_ret = io_wq_cancel_cb(ctx->io_wq, io_cancel_cb, sqe_addr, false);
 	switch (cancel_ret) {
 	case IO_WQ_CANCEL_OK:
 		ret = 0;
@@ -5817,17 +5862,14 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
 	req->flags = 0;
 	/* one is dropped after submission, the other at completion */
 	refcount_set(&req->refs, 2);
-	req->task = NULL;
+	req->task = current;
 	req->result = 0;
 
 	if (unlikely(req->opcode >= IORING_OP_LAST))
 		return -EINVAL;
 
-	if (io_op_defs[req->opcode].needs_mm && !current->mm) {
-		if (unlikely(!mmget_not_zero(ctx->sqo_mm)))
-			return -EFAULT;
-		kthread_use_mm(ctx->sqo_mm);
-	}
+	if (unlikely(io_sq_thread_acquire_mm(ctx, req)))
+		return -EFAULT;
 
 	sqe_flags = READ_ONCE(sqe->flags);
 	/* enforce forwards compatibility on users */
@@ -5936,16 +5978,6 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
 	return submitted;
 }
 
-static inline void io_sq_thread_drop_mm(struct io_ring_ctx *ctx)
-{
-	struct mm_struct *mm = current->mm;
-
-	if (mm) {
-		kthread_unuse_mm(mm);
-		mmput(mm);
-	}
-}
-
 static int io_sq_thread(void *data)
 {
 	struct io_ring_ctx *ctx = data;
@@ -7331,7 +7363,17 @@ static void io_ring_exit_work(struct work_struct *work)
 	if (ctx->rings)
 		io_cqring_overflow_flush(ctx, true);
 
-	wait_for_completion(&ctx->ref_comp);
+	/*
+	 * If we're doing polled IO and end up having requests being
+	 * submitted async (out-of-line), then completions can come in while
+	 * we're waiting for refs to drop. We need to reap these manually,
+	 * as nobody else will be looking for them.
+	 */
+	while (!wait_for_completion_timeout(&ctx->ref_comp, HZ/20)) {
+		io_iopoll_reap_events(ctx);
+		if (ctx->rings)
+			io_cqring_overflow_flush(ctx, true);
+	}
 	io_ring_ctx_free(ctx);
 }
 
@@ -7365,9 +7407,22 @@ static int io_uring_release(struct inode *inode, struct file *file)
 	return 0;
 }
 
+static bool io_wq_files_match(struct io_wq_work *work, void *data)
+{
+	struct files_struct *files = data;
+
+	return work->files == files;
+}
+
 static void io_uring_cancel_files(struct io_ring_ctx *ctx,
 				  struct files_struct *files)
 {
+	if (list_empty_careful(&ctx->inflight_list))
+		return;
+
+	/* cancel all at once, should be faster than doing it one by one*/
+	io_wq_cancel_cb(ctx->io_wq, io_wq_files_match, files, true);
+
 	while (!list_empty_careful(&ctx->inflight_list)) {
 		struct io_kiocb *cancel_req = NULL, *req;
 		DEFINE_WAIT(wait);
@@ -7423,6 +7478,14 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
 	}
 }
 
+static bool io_cancel_task_cb(struct io_wq_work *work, void *data)
+{
+	struct io_kiocb *req = container_of(work, struct io_kiocb, work);
+	struct task_struct *task = data;
+
+	return req->task == task;
+}
+
 static int io_uring_flush(struct file *file, void *data)
 {
 	struct io_ring_ctx *ctx = file->private_data;
@@ -7433,7 +7496,7 @@ static int io_uring_flush(struct file *file, void *data)
 	 * If the task is going away, cancel work it may have pending
 	 */
 	if (fatal_signal_pending(current) || (current->flags & PF_EXITING))
-		io_wq_cancel_pid(ctx->io_wq, task_pid_vnr(current));
+		io_wq_cancel_cb(ctx->io_wq, io_cancel_task_cb, current, true);
 
 	return 0;
 }