diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index e9ba01336d4e72e4e56778be4b0f7e652683054d..f1542d6558c2c64fec4804a4bf07e2e1817a7280 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -1369,8 +1369,7 @@ void memcg_kmem_put_cache(struct kmem_cache *cachep);
 #ifdef CONFIG_MEMCG_KMEM
 int __memcg_kmem_charge(struct page *page, gfp_t gfp, int order);
 void __memcg_kmem_uncharge(struct page *page, int order);
-int __memcg_kmem_charge_memcg(struct page *page, gfp_t gfp, int order,
-			      struct mem_cgroup *memcg);
+int __memcg_kmem_charge_memcg(struct mem_cgroup *memcg, gfp_t gfp, int order);
 void __memcg_kmem_uncharge_memcg(struct mem_cgroup *memcg,
 				 unsigned int nr_pages);
 
@@ -1407,11 +1406,11 @@ static inline void memcg_kmem_uncharge(struct page *page, int order)
 		__memcg_kmem_uncharge(page, order);
 }
 
-static inline int memcg_kmem_charge_memcg(struct page *page, gfp_t gfp,
-					  int order, struct mem_cgroup *memcg)
+static inline int memcg_kmem_charge_memcg(struct mem_cgroup *memcg, gfp_t gfp,
+					  int order)
 {
 	if (memcg_kmem_enabled())
-		return __memcg_kmem_charge_memcg(page, gfp, order, memcg);
+		return __memcg_kmem_charge_memcg(memcg, gfp, order);
 	return 0;
 }
 
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index c1aa24a57e55713d327353d7b95ae8f542469bcb..896b6ebef6a26094621c422f2cc44ac1ef3b4359 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -2882,15 +2882,13 @@ void memcg_kmem_put_cache(struct kmem_cache *cachep)
 
 /**
  * __memcg_kmem_charge_memcg: charge a kmem page
- * @page: page to charge
+ * @memcg: memory cgroup to charge
  * @gfp: reclaim mode
  * @order: allocation order
- * @memcg: memory cgroup to charge
  *
  * Returns 0 on success, an error code on failure.
  */
-int __memcg_kmem_charge_memcg(struct page *page, gfp_t gfp, int order,
-			    struct mem_cgroup *memcg)
+int __memcg_kmem_charge_memcg(struct mem_cgroup *memcg, gfp_t gfp, int order)
 {
 	unsigned int nr_pages = 1 << order;
 	struct page_counter *counter;
@@ -2936,7 +2934,7 @@ int __memcg_kmem_charge(struct page *page, gfp_t gfp, int order)
 
 	memcg = get_mem_cgroup_from_current();
 	if (!mem_cgroup_is_root(memcg)) {
-		ret = __memcg_kmem_charge_memcg(page, gfp, order, memcg);
+		ret = __memcg_kmem_charge_memcg(memcg, gfp, order);
 		if (!ret) {
 			page->mem_cgroup = memcg;
 			__SetPageKmemcg(page);
diff --git a/mm/slab.h b/mm/slab.h
index 7e94700aa78c6e4fd9f26dcfa559688a7920fc3f..c4c93e991250d5fc1eed9f8655a7b577596c13f7 100644
--- a/mm/slab.h
+++ b/mm/slab.h
@@ -365,7 +365,7 @@ static __always_inline int memcg_charge_slab(struct page *page,
 		return 0;
 	}
 
-	ret = memcg_kmem_charge_memcg(page, gfp, order, memcg);
+	ret = memcg_kmem_charge_memcg(memcg, gfp, order);
 	if (ret)
 		goto out;