diff --git a/drivers/block/xen-blkback/xenbus.c b/drivers/block/xen-blkback/xenbus.c
index d6a6adfd5159d5c6f2587fa10fce6f9d9a5e1a0b..4c5d99f8781361867344f3f49ddc9f825bcc285a 100644
--- a/drivers/block/xen-blkback/xenbus.c
+++ b/drivers/block/xen-blkback/xenbus.c
@@ -190,6 +190,9 @@ static int xen_blkif_map(struct xen_blkif_ring *ring, grant_ref_t *gref,
 {
 	int err;
 	struct xen_blkif *blkif = ring->blkif;
+	const struct blkif_common_sring *sring_common;
+	RING_IDX rsp_prod, req_prod;
+	unsigned int size;
 
 	/* Already connected through? */
 	if (ring->irq)
@@ -200,46 +203,62 @@ static int xen_blkif_map(struct xen_blkif_ring *ring, grant_ref_t *gref,
 	if (err < 0)
 		return err;
 
+	sring_common = (struct blkif_common_sring *)ring->blk_ring;
+	rsp_prod = READ_ONCE(sring_common->rsp_prod);
+	req_prod = READ_ONCE(sring_common->req_prod);
+
 	switch (blkif->blk_protocol) {
 	case BLKIF_PROTOCOL_NATIVE:
 	{
-		struct blkif_sring *sring;
-		sring = (struct blkif_sring *)ring->blk_ring;
-		BACK_RING_INIT(&ring->blk_rings.native, sring,
-			       XEN_PAGE_SIZE * nr_grefs);
+		struct blkif_sring *sring_native =
+			(struct blkif_sring *)ring->blk_ring;
+
+		BACK_RING_ATTACH(&ring->blk_rings.native, sring_native,
+				 rsp_prod, XEN_PAGE_SIZE * nr_grefs);
+		size = __RING_SIZE(sring_native, XEN_PAGE_SIZE * nr_grefs);
 		break;
 	}
 	case BLKIF_PROTOCOL_X86_32:
 	{
-		struct blkif_x86_32_sring *sring_x86_32;
-		sring_x86_32 = (struct blkif_x86_32_sring *)ring->blk_ring;
-		BACK_RING_INIT(&ring->blk_rings.x86_32, sring_x86_32,
-			       XEN_PAGE_SIZE * nr_grefs);
+		struct blkif_x86_32_sring *sring_x86_32 =
+			(struct blkif_x86_32_sring *)ring->blk_ring;
+
+		BACK_RING_ATTACH(&ring->blk_rings.x86_32, sring_x86_32,
+				 rsp_prod, XEN_PAGE_SIZE * nr_grefs);
+		size = __RING_SIZE(sring_x86_32, XEN_PAGE_SIZE * nr_grefs);
 		break;
 	}
 	case BLKIF_PROTOCOL_X86_64:
 	{
-		struct blkif_x86_64_sring *sring_x86_64;
-		sring_x86_64 = (struct blkif_x86_64_sring *)ring->blk_ring;
-		BACK_RING_INIT(&ring->blk_rings.x86_64, sring_x86_64,
-			       XEN_PAGE_SIZE * nr_grefs);
+		struct blkif_x86_64_sring *sring_x86_64 =
+			(struct blkif_x86_64_sring *)ring->blk_ring;
+
+		BACK_RING_ATTACH(&ring->blk_rings.x86_64, sring_x86_64,
+				 rsp_prod, XEN_PAGE_SIZE * nr_grefs);
+		size = __RING_SIZE(sring_x86_64, XEN_PAGE_SIZE * nr_grefs);
 		break;
 	}
 	default:
 		BUG();
 	}
 
+	err = -EIO;
+	if (req_prod - rsp_prod > size)
+		goto fail;
+
 	err = bind_interdomain_evtchn_to_irqhandler(blkif->domid, evtchn,
 						    xen_blkif_be_int, 0,
 						    "blkif-backend", ring);
-	if (err < 0) {
-		xenbus_unmap_ring_vfree(blkif->be->dev, ring->blk_ring);
-		ring->blk_rings.common.sring = NULL;
-		return err;
-	}
+	if (err < 0)
+		goto fail;
 	ring->irq = err;
 
 	return 0;
+
+fail:
+	xenbus_unmap_ring_vfree(blkif->be->dev, ring->blk_ring);
+	ring->blk_rings.common.sring = NULL;
+	return err;
 }
 
 static int xen_blkif_disconnect(struct xen_blkif *blkif)
@@ -1131,7 +1150,8 @@ static struct xenbus_driver xen_blkbk_driver = {
 	.ids  = xen_blkbk_ids,
 	.probe = xen_blkbk_probe,
 	.remove = xen_blkbk_remove,
-	.otherend_changed = frontend_changed
+	.otherend_changed = frontend_changed,
+	.allow_rebind = true,
 };
 
 int xen_blkif_xenbus_init(void)
diff --git a/drivers/block/xen-blkfront.c b/drivers/block/xen-blkfront.c
index a74d03913822df88989b9b0198da99dea490eef1..c02be06c529950ee89a08bc985cab6531f432342 100644
--- a/drivers/block/xen-blkfront.c
+++ b/drivers/block/xen-blkfront.c
@@ -1113,8 +1113,8 @@ static int xlvbd_alloc_gendisk(blkif_sector_t capacity,
 	if (!VDEV_IS_EXTENDED(info->vdevice)) {
 		err = xen_translate_vdev(info->vdevice, &minor, &offset);
 		if (err)
-			return err;		
- 		nr_parts = PARTS_PER_DISK;
+			return err;
+		nr_parts = PARTS_PER_DISK;
 	} else {
 		minor = BLKIF_MINOR_EXT(info->vdevice);
 		nr_parts = PARTS_PER_EXT_DISK;
diff --git a/drivers/xen/grant-table.c b/drivers/xen/grant-table.c
index 49b381e104efaf64469c75e35668e10efac4ba4d..7b36b51cdb9f978657ec1e6fa00c5d6ee2b13b6b 100644
--- a/drivers/xen/grant-table.c
+++ b/drivers/xen/grant-table.c
@@ -664,7 +664,6 @@ static int grow_gnttab_list(unsigned int more_frames)
 	unsigned int nr_glist_frames, new_nr_glist_frames;
 	unsigned int grefs_per_frame;
 
-	BUG_ON(gnttab_interface == NULL);
 	grefs_per_frame = gnttab_interface->grefs_per_grant_frame;
 
 	new_nr_grant_frames = nr_grant_frames + more_frames;
@@ -1160,7 +1159,6 @@ EXPORT_SYMBOL_GPL(gnttab_unmap_refs_sync);
 
 static unsigned int nr_status_frames(unsigned int nr_grant_frames)
 {
-	BUG_ON(gnttab_interface == NULL);
 	return gnttab_frames(nr_grant_frames, SPP);
 }
 
@@ -1388,7 +1386,6 @@ static int gnttab_expand(unsigned int req_entries)
 	int rc;
 	unsigned int cur, extra;
 
-	BUG_ON(gnttab_interface == NULL);
 	cur = nr_grant_frames;
 	extra = ((req_entries + gnttab_interface->grefs_per_grant_frame - 1) /
 		 gnttab_interface->grefs_per_grant_frame);
@@ -1423,7 +1420,6 @@ int gnttab_init(void)
 	/* Determine the maximum number of frames required for the
 	 * grant reference free list on the current hypervisor.
 	 */
-	BUG_ON(gnttab_interface == NULL);
 	max_nr_glist_frames = (max_nr_grant_frames *
 			       gnttab_interface->grefs_per_grant_frame / RPP);
 
diff --git a/drivers/xen/xenbus/xenbus.h b/drivers/xen/xenbus/xenbus.h
index d75a2385b37c773773beec66b16541b8cb0acfe1..5f5b8a7d5b80b998425dfcd5cc900d5ba4855c2e 100644
--- a/drivers/xen/xenbus/xenbus.h
+++ b/drivers/xen/xenbus/xenbus.h
@@ -116,8 +116,6 @@ int xenbus_probe_devices(struct xen_bus_type *bus);
 
 void xenbus_dev_changed(const char *node, struct xen_bus_type *bus);
 
-void xenbus_dev_shutdown(struct device *_dev);
-
 int xenbus_dev_suspend(struct device *dev);
 int xenbus_dev_resume(struct device *dev);
 int xenbus_dev_cancel(struct device *dev);
diff --git a/drivers/xen/xenbus/xenbus_probe.c b/drivers/xen/xenbus/xenbus_probe.c
index c21be6e9d38a6c91ae75fb4e08a28588c1164df6..378486b79f96aeec8e308f5630f9ffc168ae16a6 100644
--- a/drivers/xen/xenbus/xenbus_probe.c
+++ b/drivers/xen/xenbus/xenbus_probe.c
@@ -255,7 +255,6 @@ fail_put:
 	module_put(drv->driver.owner);
 fail:
 	xenbus_dev_error(dev, err, "xenbus_dev_probe on %s", dev->nodename);
-	xenbus_switch_state(dev, XenbusStateClosed);
 	return err;
 }
 EXPORT_SYMBOL_GPL(xenbus_dev_probe);
@@ -276,34 +275,20 @@ int xenbus_dev_remove(struct device *_dev)
 
 	free_otherend_details(dev);
 
-	xenbus_switch_state(dev, XenbusStateClosed);
+	/*
+	 * If the toolstack has forced the device state to closing then set
+	 * the state to closed now to allow it to be cleaned up.
+	 * Similarly, if the driver does not support re-bind, set the
+	 * closed.
+	 */
+	if (!drv->allow_rebind ||
+	    xenbus_read_driver_state(dev->nodename) == XenbusStateClosing)
+		xenbus_switch_state(dev, XenbusStateClosed);
+
 	return 0;
 }
 EXPORT_SYMBOL_GPL(xenbus_dev_remove);
 
-void xenbus_dev_shutdown(struct device *_dev)
-{
-	struct xenbus_device *dev = to_xenbus_device(_dev);
-	unsigned long timeout = 5*HZ;
-
-	DPRINTK("%s", dev->nodename);
-
-	get_device(&dev->dev);
-	if (dev->state != XenbusStateConnected) {
-		pr_info("%s: %s: %s != Connected, skipping\n",
-			__func__, dev->nodename, xenbus_strstate(dev->state));
-		goto out;
-	}
-	xenbus_switch_state(dev, XenbusStateClosing);
-	timeout = wait_for_completion_timeout(&dev->down, timeout);
-	if (!timeout)
-		pr_info("%s: %s timeout closing device\n",
-			__func__, dev->nodename);
- out:
-	put_device(&dev->dev);
-}
-EXPORT_SYMBOL_GPL(xenbus_dev_shutdown);
-
 int xenbus_register_driver_common(struct xenbus_driver *drv,
 				  struct xen_bus_type *bus,
 				  struct module *owner, const char *mod_name)
diff --git a/drivers/xen/xenbus/xenbus_probe_backend.c b/drivers/xen/xenbus/xenbus_probe_backend.c
index b0bed4faf44cc85a918a4fdb8a3929b846c13ec8..14876faff3b03ed33c7855db863d9d904e94a329 100644
--- a/drivers/xen/xenbus/xenbus_probe_backend.c
+++ b/drivers/xen/xenbus/xenbus_probe_backend.c
@@ -198,7 +198,6 @@ static struct xen_bus_type xenbus_backend = {
 		.uevent		= xenbus_uevent_backend,
 		.probe		= xenbus_dev_probe,
 		.remove		= xenbus_dev_remove,
-		.shutdown	= xenbus_dev_shutdown,
 		.dev_groups	= xenbus_dev_groups,
 	},
 };
diff --git a/drivers/xen/xenbus/xenbus_probe_frontend.c b/drivers/xen/xenbus/xenbus_probe_frontend.c
index a7d90a719cea6727259dad81af433bf92779aed4..8a1650bbe18ffc426acc88b65397b99f1c714d63 100644
--- a/drivers/xen/xenbus/xenbus_probe_frontend.c
+++ b/drivers/xen/xenbus/xenbus_probe_frontend.c
@@ -126,6 +126,28 @@ static int xenbus_frontend_dev_probe(struct device *dev)
 	return xenbus_dev_probe(dev);
 }
 
+static void xenbus_frontend_dev_shutdown(struct device *_dev)
+{
+	struct xenbus_device *dev = to_xenbus_device(_dev);
+	unsigned long timeout = 5*HZ;
+
+	DPRINTK("%s", dev->nodename);
+
+	get_device(&dev->dev);
+	if (dev->state != XenbusStateConnected) {
+		pr_info("%s: %s: %s != Connected, skipping\n",
+			__func__, dev->nodename, xenbus_strstate(dev->state));
+		goto out;
+	}
+	xenbus_switch_state(dev, XenbusStateClosing);
+	timeout = wait_for_completion_timeout(&dev->down, timeout);
+	if (!timeout)
+		pr_info("%s: %s timeout closing device\n",
+			__func__, dev->nodename);
+ out:
+	put_device(&dev->dev);
+}
+
 static const struct dev_pm_ops xenbus_pm_ops = {
 	.suspend	= xenbus_dev_suspend,
 	.resume		= xenbus_frontend_dev_resume,
@@ -146,7 +168,7 @@ static struct xen_bus_type xenbus_frontend = {
 		.uevent		= xenbus_uevent_frontend,
 		.probe		= xenbus_frontend_dev_probe,
 		.remove		= xenbus_dev_remove,
-		.shutdown	= xenbus_dev_shutdown,
+		.shutdown	= xenbus_frontend_dev_shutdown,
 		.dev_groups	= xenbus_dev_groups,
 
 		.pm		= &xenbus_pm_ops,
diff --git a/include/xen/interface/io/ring.h b/include/xen/interface/io/ring.h
index 3f40501fc60b1d9f95a7455ed0d29f97d8bd9571..2af7a1cd665893d9ff509730c73e0e0852a6881a 100644
--- a/include/xen/interface/io/ring.h
+++ b/include/xen/interface/io/ring.h
@@ -125,35 +125,24 @@ struct __name##_back_ring {						\
     memset((_s)->pad, 0, sizeof((_s)->pad));				\
 } while(0)
 
-#define FRONT_RING_INIT(_r, _s, __size) do {				\
-    (_r)->req_prod_pvt = 0;						\
-    (_r)->rsp_cons = 0;							\
+#define FRONT_RING_ATTACH(_r, _s, _i, __size) do {			\
+    (_r)->req_prod_pvt = (_i);						\
+    (_r)->rsp_cons = (_i);						\
     (_r)->nr_ents = __RING_SIZE(_s, __size);				\
     (_r)->sring = (_s);							\
 } while (0)
 
-#define BACK_RING_INIT(_r, _s, __size) do {				\
-    (_r)->rsp_prod_pvt = 0;						\
-    (_r)->req_cons = 0;							\
-    (_r)->nr_ents = __RING_SIZE(_s, __size);				\
-    (_r)->sring = (_s);							\
-} while (0)
+#define FRONT_RING_INIT(_r, _s, __size) FRONT_RING_ATTACH(_r, _s, 0, __size)
 
-/* Initialize to existing shared indexes -- for recovery */
-#define FRONT_RING_ATTACH(_r, _s, __size) do {				\
-    (_r)->sring = (_s);							\
-    (_r)->req_prod_pvt = (_s)->req_prod;				\
-    (_r)->rsp_cons = (_s)->rsp_prod;					\
+#define BACK_RING_ATTACH(_r, _s, _i, __size) do {			\
+    (_r)->rsp_prod_pvt = (_i);						\
+    (_r)->req_cons = (_i);						\
     (_r)->nr_ents = __RING_SIZE(_s, __size);				\
-} while (0)
-
-#define BACK_RING_ATTACH(_r, _s, __size) do {				\
     (_r)->sring = (_s);							\
-    (_r)->rsp_prod_pvt = (_s)->rsp_prod;				\
-    (_r)->req_cons = (_s)->req_prod;					\
-    (_r)->nr_ents = __RING_SIZE(_s, __size);				\
 } while (0)
 
+#define BACK_RING_INIT(_r, _s, __size) BACK_RING_ATTACH(_r, _s, 0, __size)
+
 /* How big is this ring? */
 #define RING_SIZE(_r)							\
     ((_r)->nr_ents)
diff --git a/include/xen/xenbus.h b/include/xen/xenbus.h
index 869c816d5f8c3097b09298a9d086e7f75fff540e..24228a102141e3981911bba20712ca53dcb2b219 100644
--- a/include/xen/xenbus.h
+++ b/include/xen/xenbus.h
@@ -93,6 +93,7 @@ struct xenbus_device_id
 struct xenbus_driver {
 	const char *name;       /* defaults to ids[0].devicetype */
 	const struct xenbus_device_id *ids;
+	bool allow_rebind; /* avoid setting xenstore closed during remove */
 	int (*probe)(struct xenbus_device *dev,
 		     const struct xenbus_device_id *id);
 	void (*otherend_changed)(struct xenbus_device *dev,