diff --git a/arch/x86/mm/hugetlbpage.c b/arch/x86/mm/hugetlbpage.c
index f6679a7fb8ca492482f18421d78f438cc9cfba2b..b91e48512425f6f210e9406cbfc666395bad6ad1 100644
--- a/arch/x86/mm/hugetlbpage.c
+++ b/arch/x86/mm/hugetlbpage.c
@@ -56,9 +56,16 @@ static int vma_shareable(struct vm_area_struct *vma, unsigned long addr)
 }
 
 /*
- * search for a shareable pmd page for hugetlb.
+ * Search for a shareable pmd page for hugetlb. In any case calls pmd_alloc()
+ * and returns the corresponding pte. While this is not necessary for the
+ * !shared pmd case because we can allocate the pmd later as well, it makes the
+ * code much cleaner. pmd allocation is essential for the shared case because
+ * pud has to be populated inside the same i_mmap_mutex section - otherwise
+ * racing tasks could either miss the sharing (see huge_pte_offset) or select a
+ * bad pmd for sharing.
  */
-static void huge_pmd_share(struct mm_struct *mm, unsigned long addr, pud_t *pud)
+static pte_t *
+huge_pmd_share(struct mm_struct *mm, unsigned long addr, pud_t *pud)
 {
 	struct vm_area_struct *vma = find_vma(mm, addr);
 	struct address_space *mapping = vma->vm_file->f_mapping;
@@ -68,9 +75,10 @@ static void huge_pmd_share(struct mm_struct *mm, unsigned long addr, pud_t *pud)
 	struct vm_area_struct *svma;
 	unsigned long saddr;
 	pte_t *spte = NULL;
+	pte_t *pte;
 
 	if (!vma_shareable(vma, addr))
-		return;
+		return (pte_t *)pmd_alloc(mm, pud, addr);
 
 	mutex_lock(&mapping->i_mmap_mutex);
 	vma_prio_tree_foreach(svma, &iter, &mapping->i_mmap, idx, idx) {
@@ -97,7 +105,9 @@ static void huge_pmd_share(struct mm_struct *mm, unsigned long addr, pud_t *pud)
 		put_page(virt_to_page(spte));
 	spin_unlock(&mm->page_table_lock);
 out:
+	pte = (pte_t *)pmd_alloc(mm, pud, addr);
 	mutex_unlock(&mapping->i_mmap_mutex);
+	return pte;
 }
 
 /*
@@ -142,8 +152,9 @@ pte_t *huge_pte_alloc(struct mm_struct *mm,
 		} else {
 			BUG_ON(sz != PMD_SIZE);
 			if (pud_none(*pud))
-				huge_pmd_share(mm, addr, pud);
-			pte = (pte_t *) pmd_alloc(mm, pud, addr);
+				pte = huge_pmd_share(mm, addr, pud);
+			else
+				pte = (pte_t *)pmd_alloc(mm, pud, addr);
 		}
 	}
 	BUG_ON(pte && !pte_none(*pte) && !pte_huge(*pte));