diff --git a/include/linux/kvm_host.h b/include/linux/kvm_host.h
index 665a260c7e09948fa3b5e3516efca2890edef840..72cbf08d45fbc878d5198eda807b12dbdffa6268 100644
--- a/include/linux/kvm_host.h
+++ b/include/linux/kvm_host.h
@@ -596,6 +596,7 @@ void kvm_free_irq_source_id(struct kvm *kvm, int irq_source_id);
 
 #ifdef CONFIG_IOMMU_API
 int kvm_iommu_map_pages(struct kvm *kvm, struct kvm_memory_slot *slot);
+void kvm_iommu_unmap_pages(struct kvm *kvm, struct kvm_memory_slot *slot);
 int kvm_iommu_map_guest(struct kvm *kvm);
 int kvm_iommu_unmap_guest(struct kvm *kvm);
 int kvm_assign_device(struct kvm *kvm,
@@ -609,6 +610,11 @@ static inline int kvm_iommu_map_pages(struct kvm *kvm,
 	return 0;
 }
 
+static inline void kvm_iommu_unmap_pages(struct kvm *kvm,
+					 struct kvm_memory_slot *slot)
+{
+}
+
 static inline int kvm_iommu_map_guest(struct kvm *kvm)
 {
 	return -ENODEV;
diff --git a/virt/kvm/iommu.c b/virt/kvm/iommu.c
index a457d2138f49ac8ed860507de8c9068566368889..fec1723de9b4eb400b7c59c25bc97ab8e373df92 100644
--- a/virt/kvm/iommu.c
+++ b/virt/kvm/iommu.c
@@ -310,6 +310,11 @@ static void kvm_iommu_put_pages(struct kvm *kvm,
 	}
 }
 
+void kvm_iommu_unmap_pages(struct kvm *kvm, struct kvm_memory_slot *slot)
+{
+	kvm_iommu_put_pages(kvm, slot->base_gfn, slot->npages);
+}
+
 static int kvm_iommu_unmap_memslots(struct kvm *kvm)
 {
 	int idx;
@@ -320,7 +325,7 @@ static int kvm_iommu_unmap_memslots(struct kvm *kvm)
 	slots = kvm_memslots(kvm);
 
 	kvm_for_each_memslot(memslot, slots)
-		kvm_iommu_put_pages(kvm, memslot->base_gfn, memslot->npages);
+		kvm_iommu_unmap_pages(kvm, memslot);
 
 	srcu_read_unlock(&kvm->srcu, idx);
 
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 42b73930a6de6b2210708aca55903f87a9bab459..9739b533ca2e6954f75c98fc1bb307e641aedf94 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -808,12 +808,13 @@ int __kvm_set_memory_region(struct kvm *kvm,
 	if (r)
 		goto out_free;
 
-	/* map the pages in iommu page table */
+	/* map/unmap the pages in iommu page table */
 	if (npages) {
 		r = kvm_iommu_map_pages(kvm, &new);
 		if (r)
 			goto out_free;
-	}
+	} else
+		kvm_iommu_unmap_pages(kvm, &old);
 
 	r = -ENOMEM;
 	slots = kmemdup(kvm->memslots, sizeof(struct kvm_memslots),