diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c
index fc011e13213b242daff25e84b75dae9a79290f12..c9d756b7ee9eb83a2865e2c5c8b6f31691a2f454 100644
--- a/drivers/vfio/pci/vfio_pci.c
+++ b/drivers/vfio/pci/vfio_pci.c
@@ -37,6 +37,8 @@ module_param_named(nointxmask, nointxmask, bool, S_IRUGO | S_IWUSR);
 MODULE_PARM_DESC(nointxmask,
 		  "Disable support for PCI 2.3 style INTx masking.  If this resolves problems for specific devices, report lspci -vvvxxx to linux-pci@vger.kernel.org so the device can be fixed automatically via the broken_intx_masking flag.");
 
+static DEFINE_MUTEX(driver_lock);
+
 static int vfio_pci_enable(struct vfio_pci_device *vdev)
 {
 	struct pci_dev *pdev = vdev->pdev;
@@ -163,23 +165,29 @@ static void vfio_pci_release(void *device_data)
 {
 	struct vfio_pci_device *vdev = device_data;
 
-	if (atomic_dec_and_test(&vdev->refcnt)) {
+	mutex_lock(&driver_lock);
+
+	if (!(--vdev->refcnt)) {
 		vfio_spapr_pci_eeh_release(vdev->pdev);
 		vfio_pci_disable(vdev);
 	}
 
+	mutex_unlock(&driver_lock);
+
 	module_put(THIS_MODULE);
 }
 
 static int vfio_pci_open(void *device_data)
 {
 	struct vfio_pci_device *vdev = device_data;
-	int ret;
+	int ret = 0;
 
 	if (!try_module_get(THIS_MODULE))
 		return -ENODEV;
 
-	if (atomic_inc_return(&vdev->refcnt) == 1) {
+	mutex_lock(&driver_lock);
+
+	if (!vdev->refcnt) {
 		ret = vfio_pci_enable(vdev);
 		if (ret)
 			goto error;
@@ -190,10 +198,11 @@ static int vfio_pci_open(void *device_data)
 			goto error;
 		}
 	}
-
-	return 0;
+	vdev->refcnt++;
 error:
-	module_put(THIS_MODULE);
+	mutex_unlock(&driver_lock);
+	if (ret)
+		module_put(THIS_MODULE);
 	return ret;
 }
 
@@ -849,7 +858,6 @@ static int vfio_pci_probe(struct pci_dev *pdev, const struct pci_device_id *id)
 	vdev->irq_type = VFIO_PCI_NUM_IRQS;
 	mutex_init(&vdev->igate);
 	spin_lock_init(&vdev->irqlock);
-	atomic_set(&vdev->refcnt, 0);
 
 	ret = vfio_add_group_dev(&pdev->dev, &vfio_pci_ops, vdev);
 	if (ret) {
@@ -864,12 +872,15 @@ static void vfio_pci_remove(struct pci_dev *pdev)
 {
 	struct vfio_pci_device *vdev;
 
+	mutex_lock(&driver_lock);
+
 	vdev = vfio_del_group_dev(&pdev->dev);
-	if (!vdev)
-		return;
+	if (vdev) {
+		iommu_group_put(pdev->dev.iommu_group);
+		kfree(vdev);
+	}
 
-	iommu_group_put(pdev->dev.iommu_group);
-	kfree(vdev);
+	mutex_unlock(&driver_lock);
 }
 
 static pci_ers_result_t vfio_pci_aer_err_detected(struct pci_dev *pdev,
diff --git a/drivers/vfio/pci/vfio_pci_private.h b/drivers/vfio/pci/vfio_pci_private.h
index 9c6d5d0f3b02db8cfb1acdd418b4a5abb44e12e5..31e7a30196ab463cf7d4214eae7d5b279694ab06 100644
--- a/drivers/vfio/pci/vfio_pci_private.h
+++ b/drivers/vfio/pci/vfio_pci_private.h
@@ -55,7 +55,7 @@ struct vfio_pci_device {
 	bool			bardirty;
 	bool			has_vga;
 	struct pci_saved_state	*pci_saved_state;
-	atomic_t		refcnt;
+	int			refcnt;
 	struct eventfd_ctx	*err_trigger;
 };