diff --git a/mm/memremap.c b/mm/memremap.c
index 73a206d0f64542e0f28a52da3b52ce3c7b4635ae..16b2fb482da11245092e732c858264ada92d2364 100644
--- a/mm/memremap.c
+++ b/mm/memremap.c
@@ -41,28 +41,24 @@ EXPORT_SYMBOL_GPL(memremap_compat_align);
 DEFINE_STATIC_KEY_FALSE(devmap_managed_key);
 EXPORT_SYMBOL(devmap_managed_key);
 
-static void devmap_managed_enable_put(void)
+static void devmap_managed_enable_put(struct dev_pagemap *pgmap)
 {
-	static_branch_dec(&devmap_managed_key);
+	if (pgmap->type == MEMORY_DEVICE_PRIVATE ||
+	    pgmap->type == MEMORY_DEVICE_FS_DAX)
+		static_branch_dec(&devmap_managed_key);
 }
 
-static int devmap_managed_enable_get(struct dev_pagemap *pgmap)
+static void devmap_managed_enable_get(struct dev_pagemap *pgmap)
 {
-	if (pgmap->type == MEMORY_DEVICE_PRIVATE &&
-	    (!pgmap->ops || !pgmap->ops->page_free)) {
-		WARN(1, "Missing page_free method\n");
-		return -EINVAL;
-	}
-
-	static_branch_inc(&devmap_managed_key);
-	return 0;
+	if (pgmap->type == MEMORY_DEVICE_PRIVATE ||
+	    pgmap->type == MEMORY_DEVICE_FS_DAX)
+		static_branch_inc(&devmap_managed_key);
 }
 #else
-static int devmap_managed_enable_get(struct dev_pagemap *pgmap)
+static void devmap_managed_enable_get(struct dev_pagemap *pgmap)
 {
-	return -EINVAL;
 }
-static void devmap_managed_enable_put(void)
+static void devmap_managed_enable_put(struct dev_pagemap *pgmap)
 {
 }
 #endif /* CONFIG_DEV_PAGEMAP_OPS */
@@ -169,7 +165,7 @@ void memunmap_pages(struct dev_pagemap *pgmap)
 		pageunmap_range(pgmap, i);
 
 	WARN_ONCE(pgmap->altmap.alloc, "failed to free all reserved pages\n");
-	devmap_managed_enable_put();
+	devmap_managed_enable_put(pgmap);
 }
 EXPORT_SYMBOL_GPL(memunmap_pages);
 
@@ -307,7 +303,6 @@ void *memremap_pages(struct dev_pagemap *pgmap, int nid)
 		.pgprot = PAGE_KERNEL,
 	};
 	const int nr_range = pgmap->nr_range;
-	bool need_devmap_managed = true;
 	int error, i;
 
 	if (WARN_ONCE(!nr_range, "nr_range must be specified\n"))
@@ -323,6 +318,10 @@ void *memremap_pages(struct dev_pagemap *pgmap, int nid)
 			WARN(1, "Missing migrate_to_ram method\n");
 			return ERR_PTR(-EINVAL);
 		}
+		if (!pgmap->ops->page_free) {
+			WARN(1, "Missing page_free method\n");
+			return ERR_PTR(-EINVAL);
+		}
 		if (!pgmap->owner) {
 			WARN(1, "Missing owner\n");
 			return ERR_PTR(-EINVAL);
@@ -336,11 +335,9 @@ void *memremap_pages(struct dev_pagemap *pgmap, int nid)
 		}
 		break;
 	case MEMORY_DEVICE_GENERIC:
-		need_devmap_managed = false;
 		break;
 	case MEMORY_DEVICE_PCI_P2PDMA:
 		params.pgprot = pgprot_noncached(params.pgprot);
-		need_devmap_managed = false;
 		break;
 	default:
 		WARN(1, "Invalid pgmap type %d\n", pgmap->type);
@@ -364,11 +361,7 @@ void *memremap_pages(struct dev_pagemap *pgmap, int nid)
 		}
 	}
 
-	if (need_devmap_managed) {
-		error = devmap_managed_enable_get(pgmap);
-		if (error)
-			return ERR_PTR(error);
-	}
+	devmap_managed_enable_get(pgmap);
 
 	/*
 	 * Clear the pgmap nr_range as it will be incremented for each