diff --git a/include/linux/skbuff.h b/include/linux/skbuff.h
index 08d1c8e705405e44a9557bd40a3b9bfa4c90f639..9f7e01f2be831744b4b84cbccd13dccf5946660c 100644
--- a/include/linux/skbuff.h
+++ b/include/linux/skbuff.h
@@ -3449,6 +3449,7 @@ int skb_vlan_pop(struct sk_buff *skb);
 int skb_vlan_push(struct sk_buff *skb, __be16 vlan_proto, u16 vlan_tci);
 int skb_mpls_push(struct sk_buff *skb, __be32 mpls_lse, __be16 mpls_proto);
 int skb_mpls_pop(struct sk_buff *skb, __be16 next_proto);
+int skb_mpls_update_lse(struct sk_buff *skb, __be32 mpls_lse);
 struct sk_buff *pskb_extract(struct sk_buff *skb, int off, int to_copy,
 			     gfp_t gfp);
 
diff --git a/net/core/skbuff.c b/net/core/skbuff.c
index 8c00be4d89194222a40ef9d9ce1ec93fae7ace42..93443a01ab39ecf865426cef381d2c11c7967f1f 100644
--- a/net/core/skbuff.c
+++ b/net/core/skbuff.c
@@ -5531,6 +5531,39 @@ int skb_mpls_pop(struct sk_buff *skb, __be16 next_proto)
 }
 EXPORT_SYMBOL_GPL(skb_mpls_pop);
 
+/**
+ * skb_mpls_update_lse() - modify outermost MPLS header and update csum
+ *
+ * @skb: buffer
+ * @mpls_lse: new MPLS label stack entry to update to
+ *
+ * Expects skb->data at mac header.
+ *
+ * Returns 0 on success, -errno otherwise.
+ */
+int skb_mpls_update_lse(struct sk_buff *skb, __be32 mpls_lse)
+{
+	int err;
+
+	if (unlikely(!eth_p_mpls(skb->protocol)))
+		return -EINVAL;
+
+	err = skb_ensure_writable(skb, skb->mac_len + MPLS_HLEN);
+	if (unlikely(err))
+		return err;
+
+	if (skb->ip_summed == CHECKSUM_COMPLETE) {
+		__be32 diff[] = { ~mpls_hdr(skb)->label_stack_entry, mpls_lse };
+
+		skb->csum = csum_partial((char *)diff, sizeof(diff), skb->csum);
+	}
+
+	mpls_hdr(skb)->label_stack_entry = mpls_lse;
+
+	return 0;
+}
+EXPORT_SYMBOL_GPL(skb_mpls_update_lse);
+
 /**
  * alloc_skb_with_frags - allocate skb with page frags
  *
diff --git a/net/openvswitch/actions.c b/net/openvswitch/actions.c
index 62715bb8d6110bad9a8483d273ccc9161cdfa8fa..3572e11b6f21ed09ab476d39b1c86bcfb15c5fa2 100644
--- a/net/openvswitch/actions.c
+++ b/net/openvswitch/actions.c
@@ -193,19 +193,12 @@ static int set_mpls(struct sk_buff *skb, struct sw_flow_key *flow_key,
 	__be32 lse;
 	int err;
 
-	err = skb_ensure_writable(skb, skb->mac_len + MPLS_HLEN);
-	if (unlikely(err))
-		return err;
-
 	stack = mpls_hdr(skb);
 	lse = OVS_MASKED(stack->label_stack_entry, *mpls_lse, *mask);
-	if (skb->ip_summed == CHECKSUM_COMPLETE) {
-		__be32 diff[] = { ~(stack->label_stack_entry), lse };
-
-		skb->csum = csum_partial((char *)diff, sizeof(diff), skb->csum);
-	}
+	err = skb_mpls_update_lse(skb, lse);
+	if (err)
+		return err;
 
-	stack->label_stack_entry = lse;
 	flow_key->mpls.top_lse = lse;
 	return 0;
 }