diff --git a/drivers/vfio/pci/vfio_pci_private.h b/drivers/vfio/pci/vfio_pci_private.h
index 86a02aff8735ffa2a7223821971e76d47157ccd5..61ca8ab165dc16f930296c7fd20263cfe9680b17 100644
--- a/drivers/vfio/pci/vfio_pci_private.h
+++ b/drivers/vfio/pci/vfio_pci_private.h
@@ -33,12 +33,14 @@
 
 struct vfio_pci_ioeventfd {
 	struct list_head	next;
+	struct vfio_pci_device	*vdev;
 	struct virqfd		*virqfd;
 	void __iomem		*addr;
 	uint64_t		data;
 	loff_t			pos;
 	int			bar;
 	int			count;
+	bool			test_mem;
 };
 
 struct vfio_pci_irq_ctx {
diff --git a/drivers/vfio/pci/vfio_pci_rdwr.c b/drivers/vfio/pci/vfio_pci_rdwr.c
index 916b184df3a5b52637c00779a7a883cf0fbcdf77..9e353c484ace2676b5fa51f0c07f0b50208cd04e 100644
--- a/drivers/vfio/pci/vfio_pci_rdwr.c
+++ b/drivers/vfio/pci/vfio_pci_rdwr.c
@@ -37,17 +37,70 @@
 #define vfio_ioread8	ioread8
 #define vfio_iowrite8	iowrite8
 
+#define VFIO_IOWRITE(size) \
+static int vfio_pci_iowrite##size(struct vfio_pci_device *vdev,		\
+			bool test_mem, u##size val, void __iomem *io)	\
+{									\
+	if (test_mem) {							\
+		down_read(&vdev->memory_lock);				\
+		if (!__vfio_pci_memory_enabled(vdev)) {			\
+			up_read(&vdev->memory_lock);			\
+			return -EIO;					\
+		}							\
+	}								\
+									\
+	vfio_iowrite##size(val, io);					\
+									\
+	if (test_mem)							\
+		up_read(&vdev->memory_lock);				\
+									\
+	return 0;							\
+}
+
+VFIO_IOWRITE(8)
+VFIO_IOWRITE(16)
+VFIO_IOWRITE(32)
+#ifdef iowrite64
+VFIO_IOWRITE(64)
+#endif
+
+#define VFIO_IOREAD(size) \
+static int vfio_pci_ioread##size(struct vfio_pci_device *vdev,		\
+			bool test_mem, u##size *val, void __iomem *io)	\
+{									\
+	if (test_mem) {							\
+		down_read(&vdev->memory_lock);				\
+		if (!__vfio_pci_memory_enabled(vdev)) {			\
+			up_read(&vdev->memory_lock);			\
+			return -EIO;					\
+		}							\
+	}								\
+									\
+	*val = vfio_ioread##size(io);					\
+									\
+	if (test_mem)							\
+		up_read(&vdev->memory_lock);				\
+									\
+	return 0;							\
+}
+
+VFIO_IOREAD(8)
+VFIO_IOREAD(16)
+VFIO_IOREAD(32)
+
 /*
  * Read or write from an __iomem region (MMIO or I/O port) with an excluded
  * range which is inaccessible.  The excluded range drops writes and fills
  * reads with -1.  This is intended for handling MSI-X vector tables and
  * leftover space for ROM BARs.
  */
-static ssize_t do_io_rw(void __iomem *io, char __user *buf,
+static ssize_t do_io_rw(struct vfio_pci_device *vdev, bool test_mem,
+			void __iomem *io, char __user *buf,
 			loff_t off, size_t count, size_t x_start,
 			size_t x_end, bool iswrite)
 {
 	ssize_t done = 0;
+	int ret;
 
 	while (count) {
 		size_t fillable, filled;
@@ -66,9 +119,15 @@ static ssize_t do_io_rw(void __iomem *io, char __user *buf,
 				if (copy_from_user(&val, buf, 4))
 					return -EFAULT;
 
-				vfio_iowrite32(val, io + off);
+				ret = vfio_pci_iowrite32(vdev, test_mem,
+							 val, io + off);
+				if (ret)
+					return ret;
 			} else {
-				val = vfio_ioread32(io + off);
+				ret = vfio_pci_ioread32(vdev, test_mem,
+							&val, io + off);
+				if (ret)
+					return ret;
 
 				if (copy_to_user(buf, &val, 4))
 					return -EFAULT;
@@ -82,9 +141,15 @@ static ssize_t do_io_rw(void __iomem *io, char __user *buf,
 				if (copy_from_user(&val, buf, 2))
 					return -EFAULT;
 
-				vfio_iowrite16(val, io + off);
+				ret = vfio_pci_iowrite16(vdev, test_mem,
+							 val, io + off);
+				if (ret)
+					return ret;
 			} else {
-				val = vfio_ioread16(io + off);
+				ret = vfio_pci_ioread16(vdev, test_mem,
+							&val, io + off);
+				if (ret)
+					return ret;
 
 				if (copy_to_user(buf, &val, 2))
 					return -EFAULT;
@@ -98,9 +163,15 @@ static ssize_t do_io_rw(void __iomem *io, char __user *buf,
 				if (copy_from_user(&val, buf, 1))
 					return -EFAULT;
 
-				vfio_iowrite8(val, io + off);
+				ret = vfio_pci_iowrite8(vdev, test_mem,
+							val, io + off);
+				if (ret)
+					return ret;
 			} else {
-				val = vfio_ioread8(io + off);
+				ret = vfio_pci_ioread8(vdev, test_mem,
+						       &val, io + off);
+				if (ret)
+					return ret;
 
 				if (copy_to_user(buf, &val, 1))
 					return -EFAULT;
@@ -178,14 +249,6 @@ ssize_t vfio_pci_bar_rw(struct vfio_pci_device *vdev, char __user *buf,
 
 	count = min(count, (size_t)(end - pos));
 
-	if (res->flags & IORESOURCE_MEM) {
-		down_read(&vdev->memory_lock);
-		if (!__vfio_pci_memory_enabled(vdev)) {
-			up_read(&vdev->memory_lock);
-			return -EIO;
-		}
-	}
-
 	if (bar == PCI_ROM_RESOURCE) {
 		/*
 		 * The ROM can fill less space than the BAR, so we start the
@@ -213,7 +276,8 @@ ssize_t vfio_pci_bar_rw(struct vfio_pci_device *vdev, char __user *buf,
 		x_end = vdev->msix_offset + vdev->msix_size;
 	}
 
-	done = do_io_rw(io, buf, pos, count, x_start, x_end, iswrite);
+	done = do_io_rw(vdev, res->flags & IORESOURCE_MEM, io, buf, pos,
+			count, x_start, x_end, iswrite);
 
 	if (done >= 0)
 		*ppos += done;
@@ -221,9 +285,6 @@ ssize_t vfio_pci_bar_rw(struct vfio_pci_device *vdev, char __user *buf,
 	if (bar == PCI_ROM_RESOURCE)
 		pci_unmap_rom(pdev, io);
 out:
-	if (res->flags & IORESOURCE_MEM)
-		up_read(&vdev->memory_lock);
-
 	return done;
 }
 
@@ -278,7 +339,12 @@ ssize_t vfio_pci_vga_rw(struct vfio_pci_device *vdev, char __user *buf,
 		return ret;
 	}
 
-	done = do_io_rw(iomem, buf, off, count, 0, 0, iswrite);
+	/*
+	 * VGA MMIO is a legacy, non-BAR resource that hopefully allows
+	 * probing, so we don't currently worry about access in relation
+	 * to the memory enable bit in the command register.
+	 */
+	done = do_io_rw(vdev, false, iomem, buf, off, count, 0, 0, iswrite);
 
 	vga_put(vdev->pdev, rsrc);
 
@@ -296,17 +362,21 @@ static int vfio_pci_ioeventfd_handler(void *opaque, void *unused)
 
 	switch (ioeventfd->count) {
 	case 1:
-		vfio_iowrite8(ioeventfd->data, ioeventfd->addr);
+		vfio_pci_iowrite8(ioeventfd->vdev, ioeventfd->test_mem,
+				  ioeventfd->data, ioeventfd->addr);
 		break;
 	case 2:
-		vfio_iowrite16(ioeventfd->data, ioeventfd->addr);
+		vfio_pci_iowrite16(ioeventfd->vdev, ioeventfd->test_mem,
+				   ioeventfd->data, ioeventfd->addr);
 		break;
 	case 4:
-		vfio_iowrite32(ioeventfd->data, ioeventfd->addr);
+		vfio_pci_iowrite32(ioeventfd->vdev, ioeventfd->test_mem,
+				   ioeventfd->data, ioeventfd->addr);
 		break;
 #ifdef iowrite64
 	case 8:
-		vfio_iowrite64(ioeventfd->data, ioeventfd->addr);
+		vfio_pci_iowrite64(ioeventfd->vdev, ioeventfd->test_mem,
+				   ioeventfd->data, ioeventfd->addr);
 		break;
 #endif
 	}
@@ -378,11 +448,13 @@ long vfio_pci_ioeventfd(struct vfio_pci_device *vdev, loff_t offset,
 		goto out_unlock;
 	}
 
+	ioeventfd->vdev = vdev;
 	ioeventfd->addr = vdev->barmap[bar] + pos;
 	ioeventfd->data = data;
 	ioeventfd->pos = pos;
 	ioeventfd->bar = bar;
 	ioeventfd->count = count;
+	ioeventfd->test_mem = vdev->pdev->resource[bar].flags & IORESOURCE_MEM;
 
 	ret = vfio_virqfd_enable(ioeventfd, vfio_pci_ioeventfd_handler,
 				 NULL, NULL, &ioeventfd->virqfd, fd);
diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index 6990fc711a80b4c5ff75ae918740d50307550410..c992973cc2d5fa864ab3f38bf544b61418c84a53 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -1424,13 +1424,16 @@ static int vfio_bus_type(struct device *dev, void *data)
 static int vfio_iommu_replay(struct vfio_iommu *iommu,
 			     struct vfio_domain *domain)
 {
-	struct vfio_domain *d;
+	struct vfio_domain *d = NULL;
 	struct rb_node *n;
 	unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
 	int ret;
 
 	/* Arbitrarily pick the first domain in the list for lookups */
-	d = list_first_entry(&iommu->domain_list, struct vfio_domain, next);
+	if (!list_empty(&iommu->domain_list))
+		d = list_first_entry(&iommu->domain_list,
+				     struct vfio_domain, next);
+
 	n = rb_first(&iommu->dma_list);
 
 	for (; n; n = rb_next(n)) {
@@ -1448,6 +1451,11 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
 				phys_addr_t p;
 				dma_addr_t i;
 
+				if (WARN_ON(!d)) { /* mapped w/o a domain?! */
+					ret = -EINVAL;
+					goto unwind;
+				}
+
 				phys = iommu_iova_to_phys(d->domain, iova);
 
 				if (WARN_ON(!phys)) {
@@ -1477,7 +1485,7 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
 				if (npage <= 0) {
 					WARN_ON(!npage);
 					ret = (int)npage;
-					return ret;
+					goto unwind;
 				}
 
 				phys = pfn << PAGE_SHIFT;
@@ -1486,14 +1494,67 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
 
 			ret = iommu_map(domain->domain, iova, phys,
 					size, dma->prot | domain->prot);
-			if (ret)
-				return ret;
+			if (ret) {
+				if (!dma->iommu_mapped)
+					vfio_unpin_pages_remote(dma, iova,
+							phys >> PAGE_SHIFT,
+							size >> PAGE_SHIFT,
+							true);
+				goto unwind;
+			}
 
 			iova += size;
 		}
+	}
+
+	/* All dmas are now mapped, defer to second tree walk for unwind */
+	for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
+		struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
+
 		dma->iommu_mapped = true;
 	}
+
 	return 0;
+
+unwind:
+	for (; n; n = rb_prev(n)) {
+		struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
+		dma_addr_t iova;
+
+		if (dma->iommu_mapped) {
+			iommu_unmap(domain->domain, dma->iova, dma->size);
+			continue;
+		}
+
+		iova = dma->iova;
+		while (iova < dma->iova + dma->size) {
+			phys_addr_t phys, p;
+			size_t size;
+			dma_addr_t i;
+
+			phys = iommu_iova_to_phys(domain->domain, iova);
+			if (!phys) {
+				iova += PAGE_SIZE;
+				continue;
+			}
+
+			size = PAGE_SIZE;
+			p = phys + size;
+			i = iova + size;
+			while (i < dma->iova + dma->size &&
+			       p == iommu_iova_to_phys(domain->domain, i)) {
+				size += PAGE_SIZE;
+				p += PAGE_SIZE;
+				i += PAGE_SIZE;
+			}
+
+			iommu_unmap(domain->domain, iova, size);
+			vfio_unpin_pages_remote(dma, iova, phys >> PAGE_SHIFT,
+						size >> PAGE_SHIFT, true);
+		}
+	}
+
+	return ret;
 }
 
 /*