diff --git a/drivers/pci/controller/pci-hyperv.c b/drivers/pci/controller/pci-hyperv.c
index 6cc5036ac83cface8941f2e58817c98a4eb80084..f6325f1a89e878ed69591b3ee8f96b0bec852a60 100644
--- a/drivers/pci/controller/pci-hyperv.c
+++ b/drivers/pci/controller/pci-hyperv.c
@@ -1073,6 +1073,7 @@ static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 	struct pci_bus *pbus;
 	struct pci_dev *pdev;
 	struct cpumask *dest;
+	unsigned long flags;
 	struct compose_comp_ctxt comp;
 	struct tran_int_desc *int_desc;
 	struct {
@@ -1164,14 +1165,15 @@ static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 		 * the channel callback directly when channel->target_cpu is
 		 * the current CPU. When the higher level interrupt code
 		 * calls us with interrupt enabled, let's add the
-		 * local_bh_disable()/enable() to avoid race.
+		 * local_irq_save()/restore() to avoid race:
+		 * hv_pci_onchannelcallback() can also run in tasklet.
 		 */
-		local_bh_disable();
+		local_irq_save(flags);
 
 		if (hbus->hdev->channel->target_cpu == smp_processor_id())
 			hv_pci_onchannelcallback(hbus);
 
-		local_bh_enable();
+		local_irq_restore(flags);
 
 		if (hpdev->state == hv_pcichild_ejecting) {
 			dev_err_once(&hbus->hdev->device,