diff --git a/arch/x86/kernel/tboot.c b/arch/x86/kernel/tboot.c
index ae64f98ec2ab6a213202bb15b16d34abc6780806..4c09ba1102047c66db21e340c6a3f4bcf528c7b1 100644
--- a/arch/x86/kernel/tboot.c
+++ b/arch/x86/kernel/tboot.c
@@ -93,6 +93,7 @@ static struct mm_struct tboot_mm = {
 	.pgd            = swapper_pg_dir,
 	.mm_users       = ATOMIC_INIT(2),
 	.mm_count       = ATOMIC_INIT(1),
+	.write_protect_seq = SEQCNT_ZERO(tboot_mm.write_protect_seq),
 	MMAP_LOCK_INITIALIZER(init_mm)
 	.page_table_lock =  __SPIN_LOCK_UNLOCKED(init_mm.page_table_lock),
 	.mmlist         = LIST_HEAD_INIT(init_mm.mmlist),
diff --git a/drivers/firmware/efi/efi.c b/drivers/firmware/efi/efi.c
index 6c6eec044a978a03821dbb0f0b3e2a6ed1814aa4..df3f9bcab581c430df52347c2eba41950811d07e 100644
--- a/drivers/firmware/efi/efi.c
+++ b/drivers/firmware/efi/efi.c
@@ -57,6 +57,7 @@ struct mm_struct efi_mm = {
 	.mm_rb			= RB_ROOT,
 	.mm_users		= ATOMIC_INIT(2),
 	.mm_count		= ATOMIC_INIT(1),
+	.write_protect_seq      = SEQCNT_ZERO(efi_mm.write_protect_seq),
 	MMAP_LOCK_INITIALIZER(efi_mm)
 	.page_table_lock	= __SPIN_LOCK_UNLOCKED(efi_mm.page_table_lock),
 	.mmlist			= LIST_HEAD_INIT(efi_mm.mmlist),
diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index 5a9238f6caad97b40be7acb40f896b99d2f6f4ea..915f4f100383b5a36d6f64ecb58dddf392fa064c 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -14,6 +14,7 @@
 #include <linux/uprobes.h>
 #include <linux/page-flags-layout.h>
 #include <linux/workqueue.h>
+#include <linux/seqlock.h>
 
 #include <asm/mmu.h>
 
@@ -446,6 +447,13 @@ struct mm_struct {
 		 */
 		atomic_t has_pinned;
 
+		/**
+		 * @write_protect_seq: Locked when any thread is write
+		 * protecting pages mapped by this mm to enforce a later COW,
+		 * for instance during page table copying for fork().
+		 */
+		seqcount_t write_protect_seq;
+
 #ifdef CONFIG_MMU
 		atomic_long_t pgtables_bytes;	/* PTE page table pages */
 #endif
diff --git a/kernel/fork.c b/kernel/fork.c
index 6d266388d3804ce75e2983cce13f91fb20d53c5b..dc55f68a6ee36dfd730e757d1ff172e43a3227be 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -1007,6 +1007,7 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p,
 	mm->vmacache_seqnum = 0;
 	atomic_set(&mm->mm_users, 1);
 	atomic_set(&mm->mm_count, 1);
+	seqcount_init(&mm->write_protect_seq);
 	mmap_init_lock(mm);
 	INIT_LIST_HEAD(&mm->mmlist);
 	mm->core_state = NULL;
diff --git a/mm/gup.c b/mm/gup.c
index c7e24301860abb7e7e62266130f1045e189d1488..9c6a2f5001c5c2c9c5bdeb29e033a773b9f3fbba 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -2684,11 +2684,18 @@ static unsigned long lockless_pages_from_mm(unsigned long start,
 {
 	unsigned long flags;
 	int nr_pinned = 0;
+	unsigned seq;
 
 	if (!IS_ENABLED(CONFIG_HAVE_FAST_GUP) ||
 	    !gup_fast_permitted(start, end))
 		return 0;
 
+	if (gup_flags & FOLL_PIN) {
+		seq = raw_read_seqcount(&current->mm->write_protect_seq);
+		if (seq & 1)
+			return 0;
+	}
+
 	/*
 	 * Disable interrupts. The nested form is used, in order to allow full,
 	 * general purpose use of this routine.
@@ -2703,6 +2710,17 @@ static unsigned long lockless_pages_from_mm(unsigned long start,
 	local_irq_save(flags);
 	gup_pgd_range(start, end, gup_flags, pages, &nr_pinned);
 	local_irq_restore(flags);
+
+	/*
+	 * When pinning pages for DMA there could be a concurrent write protect
+	 * from fork() via copy_page_range(), in this case always fail fast GUP.
+	 */
+	if (gup_flags & FOLL_PIN) {
+		if (read_seqcount_retry(&current->mm->write_protect_seq, seq)) {
+			unpin_user_pages(pages, nr_pinned);
+			return 0;
+		}
+	}
 	return nr_pinned;
 }
 
diff --git a/mm/init-mm.c b/mm/init-mm.c
index 3a613c85f9ede21e197ec7cf00d249b4085d4df4..153162669f806289dd2a0129a19edc954fb37567 100644
--- a/mm/init-mm.c
+++ b/mm/init-mm.c
@@ -31,6 +31,7 @@ struct mm_struct init_mm = {
 	.pgd		= swapper_pg_dir,
 	.mm_users	= ATOMIC_INIT(2),
 	.mm_count	= ATOMIC_INIT(1),
+	.write_protect_seq = SEQCNT_ZERO(init_mm.write_protect_seq),
 	MMAP_LOCK_INITIALIZER(init_mm)
 	.page_table_lock =  __SPIN_LOCK_UNLOCKED(init_mm.page_table_lock),
 	.arg_lock	=  __SPIN_LOCK_UNLOCKED(init_mm.arg_lock),
diff --git a/mm/memory.c b/mm/memory.c
index c48f8df6e502683432097861bc31b010e08f263d..50632c4366b8ab5d66dfd162847bf7ecfc9c6d97 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -1171,6 +1171,15 @@ copy_page_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma)
 		mmu_notifier_range_init(&range, MMU_NOTIFY_PROTECTION_PAGE,
 					0, src_vma, src_mm, addr, end);
 		mmu_notifier_invalidate_range_start(&range);
+		/*
+		 * Disabling preemption is not needed for the write side, as
+		 * the read side doesn't spin, but goes to the mmap_lock.
+		 *
+		 * Use the raw variant of the seqcount_t write API to avoid
+		 * lockdep complaining about preemptibility.
+		 */
+		mmap_assert_write_locked(src_mm);
+		raw_write_seqcount_begin(&src_mm->write_protect_seq);
 	}
 
 	ret = 0;
@@ -1187,8 +1196,10 @@ copy_page_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma)
 		}
 	} while (dst_pgd++, src_pgd++, addr = next, addr != end);
 
-	if (is_cow)
+	if (is_cow) {
+		raw_write_seqcount_end(&src_mm->write_protect_seq);
 		mmu_notifier_invalidate_range_end(&range);
+	}
 	return ret;
 }