diff --git a/arch/arm64/include/asm/proc-fns.h b/arch/arm64/include/asm/proc-fns.h
index 9a8fd84f8fb2b67fe4bc2d845e7f6b584cbf0799..941c375616e2072d3b9fbf6442426c3491f00c01 100644
--- a/arch/arm64/include/asm/proc-fns.h
+++ b/arch/arm64/include/asm/proc-fns.h
@@ -39,7 +39,11 @@ extern u64 cpu_do_resume(phys_addr_t ptr, u64 idmap_ttbr);
 
 #include <asm/memory.h>
 
-#define cpu_switch_mm(pgd,mm) cpu_do_switch_mm(virt_to_phys(pgd),mm)
+#define cpu_switch_mm(pgd,mm)				\
+do {							\
+	BUG_ON(pgd == swapper_pg_dir);			\
+	cpu_do_switch_mm(virt_to_phys(pgd),mm);		\
+} while (0)
 
 #define cpu_get_pgd()					\
 ({							\
diff --git a/arch/arm64/kernel/efi.c b/arch/arm64/kernel/efi.c
index 2b8d70164428010911cff2b66791c5d6c4b1dad4..ab21e0d58278825e5dfc0199cd18b853f07232c1 100644
--- a/arch/arm64/kernel/efi.c
+++ b/arch/arm64/kernel/efi.c
@@ -337,7 +337,11 @@ core_initcall(arm64_dmi_init);
 
 static void efi_set_pgd(struct mm_struct *mm)
 {
-	cpu_switch_mm(mm->pgd, mm);
+	if (mm == &init_mm)
+		cpu_set_reserved_ttbr0();
+	else
+		cpu_switch_mm(mm->pgd, mm);
+
 	flush_tlb_all();
 	if (icache_is_aivivt())
 		__flush_icache_all();