diff --git a/include/linux/mm.h b/include/linux/mm.h
index 91c08f6f0dc96dbb7474d3349f62b5d3f723fe80..80001de019ba33d86b90b9922b39722270cb0449 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -905,6 +905,27 @@ static inline void set_page_links(struct page *page, enum zone_type zone,
 #endif
 }
 
+#ifdef CONFIG_MEMCG
+static inline struct mem_cgroup *page_memcg(struct page *page)
+{
+	return page->mem_cgroup;
+}
+
+static inline void set_page_memcg(struct page *page, struct mem_cgroup *memcg)
+{
+	page->mem_cgroup = memcg;
+}
+#else
+static inline struct mem_cgroup *page_memcg(struct page *page)
+{
+	return NULL;
+}
+
+static inline void set_page_memcg(struct page *page, struct mem_cgroup *memcg)
+{
+}
+#endif
+
 /*
  * Some inline functions in vmstat.h depend on page_zone()
  */
diff --git a/mm/migrate.c b/mm/migrate.c
index 7452a00bbb50c134b529c1d024dfc53fcfca093b..842ecd7aaf7fa6ac1371f6137dc155c91851505c 100644
--- a/mm/migrate.c
+++ b/mm/migrate.c
@@ -740,6 +740,15 @@ static int move_to_new_page(struct page *newpage, struct page *page,
 	if (PageSwapBacked(page))
 		SetPageSwapBacked(newpage);
 
+	/*
+	 * Indirectly called below, migrate_page_copy() copies PG_dirty and thus
+	 * needs newpage's memcg set to transfer memcg dirty page accounting.
+	 * So perform memcg migration in two steps:
+	 * 1. set newpage->mem_cgroup (here)
+	 * 2. clear page->mem_cgroup (below)
+	 */
+	set_page_memcg(newpage, page_memcg(page));
+
 	mapping = page_mapping(page);
 	if (!mapping)
 		rc = migrate_page(mapping, newpage, page, mode);
@@ -756,9 +765,10 @@ static int move_to_new_page(struct page *newpage, struct page *page,
 		rc = fallback_migrate_page(mapping, newpage, page, mode);
 
 	if (rc != MIGRATEPAGE_SUCCESS) {
+		set_page_memcg(newpage, NULL);
 		newpage->mapping = NULL;
 	} else {
-		mem_cgroup_migrate(page, newpage, false);
+		set_page_memcg(page, NULL);
 		if (page_was_mapped)
 			remove_migration_ptes(page, newpage);
 		page->mapping = NULL;