diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index 5c212bf29640d8bee4f069e6ec720e41587b34ba..3c082451ab1a00da8577978a61276c0341930e3d 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -404,6 +404,7 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
 {
 	unsigned long pfn = 0;
 	long ret, pinned = 0, lock_acct = 0;
+	bool rsvd;
 	dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
 
 	/* This code path is only user initiated */
@@ -414,23 +415,14 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
 	if (ret)
 		return ret;
 
-	if (is_invalid_reserved_pfn(*pfn_base)) {
-		struct vm_area_struct *vma;
-
-		down_read(&current->mm->mmap_sem);
-		vma = find_vma_intersection(current->mm, vaddr, vaddr + 1);
-		pinned = min_t(long, npage, vma_pages(vma));
-		up_read(&current->mm->mmap_sem);
-		return pinned;
-	}
-
 	pinned++;
+	rsvd = is_invalid_reserved_pfn(*pfn_base);
 
 	/*
 	 * Reserved pages aren't counted against the user, externally pinned
 	 * pages are already counted against the user.
 	 */
-	if (!vfio_find_vpfn(dma, iova)) {
+	if (!rsvd && !vfio_find_vpfn(dma, iova)) {
 		if (!lock_cap && current->mm->locked_vm + 1 > limit) {
 			put_pfn(*pfn_base, dma->prot);
 			pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
@@ -450,12 +442,13 @@ static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
 		if (ret)
 			break;
 
-		if (pfn != *pfn_base + pinned) {
+		if (pfn != *pfn_base + pinned ||
+		    rsvd != is_invalid_reserved_pfn(pfn)) {
 			put_pfn(pfn, dma->prot);
 			break;
 		}
 
-		if (!vfio_find_vpfn(dma, iova)) {
+		if (!rsvd && !vfio_find_vpfn(dma, iova)) {
 			if (!lock_cap &&
 			    current->mm->locked_vm + lock_acct + 1 > limit) {
 				put_pfn(pfn, dma->prot);
@@ -473,8 +466,10 @@ out:
 
 unpin_out:
 	if (ret) {
-		for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
-			put_pfn(pfn, dma->prot);
+		if (!rsvd) {
+			for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
+				put_pfn(pfn, dma->prot);
+		}
 
 		return ret;
 	}