diff --git a/arch/x86/include/asm/mem_encrypt.h b/arch/x86/include/asm/mem_encrypt.h
index 2b024741bce91a60b66474f2f95d4348acb3e914..c9459a4c3c680b754365cd8531b5c25f51b5fb54 100644
--- a/arch/x86/include/asm/mem_encrypt.h
+++ b/arch/x86/include/asm/mem_encrypt.h
@@ -42,6 +42,9 @@ void __init sme_early_init(void);
 void __init sme_encrypt_kernel(void);
 void __init sme_enable(struct boot_params *bp);
 
+int __init early_set_memory_decrypted(unsigned long vaddr, unsigned long size);
+int __init early_set_memory_encrypted(unsigned long vaddr, unsigned long size);
+
 /* Architecture __weak replacement functions */
 void __init mem_encrypt_init(void);
 
@@ -70,6 +73,11 @@ static inline void __init sme_enable(struct boot_params *bp) { }
 static inline bool sme_active(void) { return false; }
 static inline bool sev_active(void) { return false; }
 
+static inline int __init
+early_set_memory_decrypted(unsigned long vaddr, unsigned long size) { return 0; }
+static inline int __init
+early_set_memory_encrypted(unsigned long vaddr, unsigned long size) { return 0; }
+
 #endif	/* CONFIG_AMD_MEM_ENCRYPT */
 
 /*
diff --git a/arch/x86/mm/mem_encrypt.c b/arch/x86/mm/mem_encrypt.c
index d29b7831a0535ba9baf88e1d7fcd2b18bf6166b7..5049b8bad3e781085ba08cccd3eb6f379c322f05 100644
--- a/arch/x86/mm/mem_encrypt.c
+++ b/arch/x86/mm/mem_encrypt.c
@@ -30,6 +30,8 @@
 #include <asm/msr.h>
 #include <asm/cmdline.h>
 
+#include "mm_internal.h"
+
 static char sme_cmdline_arg[] __initdata = "mem_encrypt";
 static char sme_cmdline_on[]  __initdata = "on";
 static char sme_cmdline_off[] __initdata = "off";
@@ -260,6 +262,134 @@ static void sev_free(struct device *dev, size_t size, void *vaddr,
 	swiotlb_free_coherent(dev, size, vaddr, dma_handle);
 }
 
+static void __init __set_clr_pte_enc(pte_t *kpte, int level, bool enc)
+{
+	pgprot_t old_prot, new_prot;
+	unsigned long pfn, pa, size;
+	pte_t new_pte;
+
+	switch (level) {
+	case PG_LEVEL_4K:
+		pfn = pte_pfn(*kpte);
+		old_prot = pte_pgprot(*kpte);
+		break;
+	case PG_LEVEL_2M:
+		pfn = pmd_pfn(*(pmd_t *)kpte);
+		old_prot = pmd_pgprot(*(pmd_t *)kpte);
+		break;
+	case PG_LEVEL_1G:
+		pfn = pud_pfn(*(pud_t *)kpte);
+		old_prot = pud_pgprot(*(pud_t *)kpte);
+		break;
+	default:
+		return;
+	}
+
+	new_prot = old_prot;
+	if (enc)
+		pgprot_val(new_prot) |= _PAGE_ENC;
+	else
+		pgprot_val(new_prot) &= ~_PAGE_ENC;
+
+	/* If prot is same then do nothing. */
+	if (pgprot_val(old_prot) == pgprot_val(new_prot))
+		return;
+
+	pa = pfn << page_level_shift(level);
+	size = page_level_size(level);
+
+	/*
+	 * We are going to perform in-place en-/decryption and change the
+	 * physical page attribute from C=1 to C=0 or vice versa. Flush the
+	 * caches to ensure that data gets accessed with the correct C-bit.
+	 */
+	clflush_cache_range(__va(pa), size);
+
+	/* Encrypt/decrypt the contents in-place */
+	if (enc)
+		sme_early_encrypt(pa, size);
+	else
+		sme_early_decrypt(pa, size);
+
+	/* Change the page encryption mask. */
+	new_pte = pfn_pte(pfn, new_prot);
+	set_pte_atomic(kpte, new_pte);
+}
+
+static int __init early_set_memory_enc_dec(unsigned long vaddr,
+					   unsigned long size, bool enc)
+{
+	unsigned long vaddr_end, vaddr_next;
+	unsigned long psize, pmask;
+	int split_page_size_mask;
+	int level, ret;
+	pte_t *kpte;
+
+	vaddr_next = vaddr;
+	vaddr_end = vaddr + size;
+
+	for (; vaddr < vaddr_end; vaddr = vaddr_next) {
+		kpte = lookup_address(vaddr, &level);
+		if (!kpte || pte_none(*kpte)) {
+			ret = 1;
+			goto out;
+		}
+
+		if (level == PG_LEVEL_4K) {
+			__set_clr_pte_enc(kpte, level, enc);
+			vaddr_next = (vaddr & PAGE_MASK) + PAGE_SIZE;
+			continue;
+		}
+
+		psize = page_level_size(level);
+		pmask = page_level_mask(level);
+
+		/*
+		 * Check whether we can change the large page in one go.
+		 * We request a split when the address is not aligned and
+		 * the number of pages to set/clear encryption bit is smaller
+		 * than the number of pages in the large page.
+		 */
+		if (vaddr == (vaddr & pmask) &&
+		    ((vaddr_end - vaddr) >= psize)) {
+			__set_clr_pte_enc(kpte, level, enc);
+			vaddr_next = (vaddr & pmask) + psize;
+			continue;
+		}
+
+		/*
+		 * The virtual address is part of a larger page, create the next
+		 * level page table mapping (4K or 2M). If it is part of a 2M
+		 * page then we request a split of the large page into 4K
+		 * chunks. A 1GB large page is split into 2M pages, resp.
+		 */
+		if (level == PG_LEVEL_2M)
+			split_page_size_mask = 0;
+		else
+			split_page_size_mask = 1 << PG_LEVEL_2M;
+
+		kernel_physical_mapping_init(__pa(vaddr & pmask),
+					     __pa((vaddr_end & pmask) + psize),
+					     split_page_size_mask);
+	}
+
+	ret = 0;
+
+out:
+	__flush_tlb_all();
+	return ret;
+}
+
+int __init early_set_memory_decrypted(unsigned long vaddr, unsigned long size)
+{
+	return early_set_memory_enc_dec(vaddr, size, false);
+}
+
+int __init early_set_memory_encrypted(unsigned long vaddr, unsigned long size)
+{
+	return early_set_memory_enc_dec(vaddr, size, true);
+}
+
 /*
  * SME and SEV are very similar but they are not the same, so there are
  * times that the kernel will need to distinguish between SME and SEV. The