diff --git a/arch/s390/include/asm/pgtable.h b/arch/s390/include/asm/pgtable.h
index 0f0de30e3e3fccbf2259e59176ecffab97a6e1e2..ac01463038f1e9b8236a32d50d0a37024b882a19 100644
--- a/arch/s390/include/asm/pgtable.h
+++ b/arch/s390/include/asm/pgtable.h
@@ -646,7 +646,7 @@ static inline pgste_t pgste_update_all(pte_t *ptep, pgste_t pgste)
 	unsigned long address, bits;
 	unsigned char skey;
 
-	if (!pte_present(*ptep))
+	if (pte_val(*ptep) & _PAGE_INVALID)
 		return pgste;
 	address = pte_val(*ptep) & PAGE_MASK;
 	skey = page_get_storage_key(address);
@@ -680,7 +680,7 @@ static inline pgste_t pgste_update_young(pte_t *ptep, pgste_t pgste)
 #ifdef CONFIG_PGSTE
 	int young;
 
-	if (!pte_present(*ptep))
+	if (pte_val(*ptep) & _PAGE_INVALID)
 		return pgste;
 	/* Get referenced bit from storage key */
 	young = page_reset_referenced(pte_val(*ptep) & PAGE_MASK);
@@ -706,7 +706,7 @@ static inline void pgste_set_key(pte_t *ptep, pgste_t pgste, pte_t entry)
 	unsigned long address;
 	unsigned long okey, nkey;
 
-	if (!pte_present(entry))
+	if (pte_val(entry) & _PAGE_INVALID)
 		return;
 	address = pte_val(entry) & PAGE_MASK;
 	okey = nkey = page_get_storage_key(address);
@@ -1098,6 +1098,9 @@ static inline pte_t ptep_modify_prot_start(struct mm_struct *mm,
 	pte = *ptep;
 	if (!mm_exclusive(mm))
 		__ptep_ipte(address, ptep);
+
+	if (mm_has_pgste(mm))
+		pgste = pgste_update_all(&pte, pgste);
 	return pte;
 }
 
@@ -1105,9 +1108,13 @@ static inline void ptep_modify_prot_commit(struct mm_struct *mm,
 					   unsigned long address,
 					   pte_t *ptep, pte_t pte)
 {
+	pgste_t pgste;
+
 	if (mm_has_pgste(mm)) {
+		pgste = *(pgste_t *)(ptep + PTRS_PER_PTE);
+		pgste_set_key(ptep, pgste, pte);
 		pgste_set_pte(ptep, pte);
-		pgste_set_unlock(ptep, *(pgste_t *)(ptep + PTRS_PER_PTE));
+		pgste_set_unlock(ptep, pgste);
 	} else
 		*ptep = pte;
 }