diff --git a/arch/sparc64/mm/init.c b/arch/sparc64/mm/init.c
index 9bbd0bf64af0f44fabfc1f2857256f71b7556a5f..a63939347b3d4ed4d90da6603770a1ad59585b21 100644
--- a/arch/sparc64/mm/init.c
+++ b/arch/sparc64/mm/init.c
@@ -639,9 +639,10 @@ void get_new_mmu_context(struct mm_struct *mm)
 {
 	unsigned long ctx, new_ctx;
 	unsigned long orig_pgsz_bits;
+	unsigned long flags;
 	int new_version;
 
-	spin_lock(&ctx_alloc_lock);
+	spin_lock_irqsave(&ctx_alloc_lock, flags);
 	orig_pgsz_bits = (mm->context.sparc64_ctx_val & CTX_PGSZ_MASK);
 	ctx = (tlb_context_cache + 1) & CTX_NR_MASK;
 	new_ctx = find_next_zero_bit(mmu_context_bmap, 1 << CTX_NR_BITS, ctx);
@@ -677,7 +678,7 @@ void get_new_mmu_context(struct mm_struct *mm)
 out:
 	tlb_context_cache = new_ctx;
 	mm->context.sparc64_ctx_val = new_ctx | orig_pgsz_bits;
-	spin_unlock(&ctx_alloc_lock);
+	spin_unlock_irqrestore(&ctx_alloc_lock, flags);
 
 	if (unlikely(new_version))
 		smp_new_mmu_context_version();
diff --git a/arch/sparc64/mm/tsb.c b/arch/sparc64/mm/tsb.c
index 534ac2819892f96daebef6fd4d46b86b7ddc8c4f..f36799b7152ce47470c93125127bb3a000281492 100644
--- a/arch/sparc64/mm/tsb.c
+++ b/arch/sparc64/mm/tsb.c
@@ -354,6 +354,7 @@ void tsb_grow(struct mm_struct *mm, unsigned long rss, gfp_t gfp_flags)
 
 int init_new_context(struct task_struct *tsk, struct mm_struct *mm)
 {
+	spin_lock_init(&mm->context.lock);
 
 	mm->context.sparc64_ctx_val = 0UL;
 
diff --git a/include/asm-sparc64/mmu_context.h b/include/asm-sparc64/mmu_context.h
index 4be40c58e3c15dc9592cc69302c1b07e5acf9637..ca36ea96f64bc8694baed2ae40feb1557bff6545 100644
--- a/include/asm-sparc64/mmu_context.h
+++ b/include/asm-sparc64/mmu_context.h
@@ -67,14 +67,14 @@ extern void __flush_tlb_mm(unsigned long, unsigned long);
 /* Switch the current MM context.  Interrupts are disabled.  */
 static inline void switch_mm(struct mm_struct *old_mm, struct mm_struct *mm, struct task_struct *tsk)
 {
-	unsigned long ctx_valid;
+	unsigned long ctx_valid, flags;
 	int cpu;
 
-	spin_lock(&mm->context.lock);
+	spin_lock_irqsave(&mm->context.lock, flags);
 	ctx_valid = CTX_VALID(mm->context);
 	if (!ctx_valid)
 		get_new_mmu_context(mm);
-	spin_unlock(&mm->context.lock);
+	spin_unlock_irqrestore(&mm->context.lock, flags);
 
 	if (!ctx_valid || (old_mm != mm)) {
 		load_secondary_context(mm);