diff --git a/fs/dax.c b/fs/dax.c
index 7ae6df7ea1d2d04962ef4554a6a2fa1efb977006..bcfb14bfc1e49eb36d507c53ccd1c3f4f23135f9 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -569,8 +569,20 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
 	if (!buffer_size_valid(&bh) || bh.b_size < PMD_SIZE)
 		goto fallback;
 
+	sector = bh.b_blocknr << (blkbits - 9);
+
 	if (buffer_unwritten(&bh) || buffer_new(&bh)) {
 		int i;
+
+		length = bdev_direct_access(bh.b_bdev, sector, &kaddr, &pfn,
+						bh.b_size);
+		if (length < 0) {
+			result = VM_FAULT_SIGBUS;
+			goto out;
+		}
+		if ((length < PMD_SIZE) || (pfn & PG_PMD_COLOUR))
+			goto fallback;
+
 		for (i = 0; i < PTRS_PER_PMD; i++)
 			clear_pmem(kaddr + i * PAGE_SIZE, PAGE_SIZE);
 		wmb_pmem();
@@ -623,7 +635,6 @@ int __dax_pmd_fault(struct vm_area_struct *vma, unsigned long address,
 		result = VM_FAULT_NOPAGE;
 		spin_unlock(ptl);
 	} else {
-		sector = bh.b_blocknr << (blkbits - 9);
 		length = bdev_direct_access(bh.b_bdev, sector, &kaddr, &pfn,
 						bh.b_size);
 		if (length < 0) {