diff --git a/drivers/net/ethernet/mellanox/mlx5/core/esw/sample.c b/drivers/net/ethernet/mellanox/mlx5/core/esw/sample.c
index 37e33670bb2484706ce7827bbb47fb85984c910a..79a0e49b279960f8dadca9f821703b579eb516a2 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/esw/sample.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/esw/sample.c
@@ -1,6 +1,8 @@
 // SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB
 /* Copyright (c) 2021 Mellanox Technologies. */
 
+#include <linux/skbuff.h>
+#include <net/psample.h>
 #include "esw/sample.h"
 #include "eswitch.h"
 #include "en_tc.h"
@@ -12,6 +14,8 @@ struct mlx5_esw_psample {
 	struct mlx5_flow_handle *termtbl_rule;
 	DECLARE_HASHTABLE(hashtbl, 8);
 	struct mutex ht_lock; /* protect hashtbl */
+	DECLARE_HASHTABLE(restore_hashtbl, 8);
+	struct mutex restore_lock; /* protect restore_hashtbl */
 };
 
 struct mlx5_sampler {
@@ -25,6 +29,15 @@ struct mlx5_sampler {
 
 struct mlx5_sample_flow {
 	struct mlx5_sampler *sampler;
+	struct mlx5_sample_restore *restore;
+};
+
+struct mlx5_sample_restore {
+	struct hlist_node hlist;
+	struct mlx5_modify_hdr *modify_hdr;
+	struct mlx5_flow_handle *rule;
+	u32 obj_id;
+	int count;
 };
 
 static int
@@ -192,6 +205,99 @@ sampler_put(struct mlx5_esw_psample *esw_psample, struct mlx5_sampler *sampler)
 	mutex_unlock(&esw_psample->ht_lock);
 }
 
+static struct mlx5_modify_hdr *
+sample_metadata_rule_get(struct mlx5_core_dev *mdev, u32 obj_id)
+{
+	struct mlx5e_tc_mod_hdr_acts mod_acts = {};
+	struct mlx5_modify_hdr *modify_hdr;
+	int err;
+
+	err = mlx5e_tc_match_to_reg_set(mdev, &mod_acts, MLX5_FLOW_NAMESPACE_FDB,
+					CHAIN_TO_REG, obj_id);
+	if (err)
+		goto err_set_regc0;
+
+	modify_hdr = mlx5_modify_header_alloc(mdev, MLX5_FLOW_NAMESPACE_FDB,
+					      mod_acts.num_actions,
+					      mod_acts.actions);
+	if (IS_ERR(modify_hdr)) {
+		err = PTR_ERR(modify_hdr);
+		goto err_modify_hdr;
+	}
+
+	dealloc_mod_hdr_actions(&mod_acts);
+	return modify_hdr;
+
+err_modify_hdr:
+	dealloc_mod_hdr_actions(&mod_acts);
+err_set_regc0:
+	return ERR_PTR(err);
+}
+
+static struct mlx5_sample_restore *
+sample_restore_get(struct mlx5_esw_psample *esw_psample, u32 obj_id)
+{
+	struct mlx5_core_dev *mdev = esw_psample->priv->mdev;
+	struct mlx5_eswitch *esw = mdev->priv.eswitch;
+	struct mlx5_sample_restore *restore;
+	struct mlx5_modify_hdr *modify_hdr;
+	int err;
+
+	mutex_lock(&esw_psample->restore_lock);
+	hash_for_each_possible(esw_psample->restore_hashtbl, restore, hlist, obj_id)
+		if (restore->obj_id == obj_id)
+			goto add_ref;
+
+	restore = kzalloc(sizeof(*restore), GFP_KERNEL);
+	if (!restore) {
+		err = -ENOMEM;
+		goto err_alloc;
+	}
+	restore->obj_id = obj_id;
+
+	modify_hdr = sample_metadata_rule_get(mdev, obj_id);
+	if (IS_ERR(modify_hdr)) {
+		err = PTR_ERR(modify_hdr);
+		goto err_modify_hdr;
+	}
+	restore->modify_hdr = modify_hdr;
+
+	restore->rule = esw_add_restore_rule(esw, obj_id);
+	if (IS_ERR(restore->rule)) {
+		err = PTR_ERR(restore->rule);
+		goto err_restore;
+	}
+
+	hash_add(esw_psample->restore_hashtbl, &restore->hlist, obj_id);
+add_ref:
+	restore->count++;
+	mutex_unlock(&esw_psample->restore_lock);
+	return restore;
+
+err_restore:
+	mlx5_modify_header_dealloc(mdev, restore->modify_hdr);
+err_modify_hdr:
+	kfree(restore);
+err_alloc:
+	mutex_unlock(&esw_psample->restore_lock);
+	return ERR_PTR(err);
+}
+
+static void
+sample_restore_put(struct mlx5_esw_psample *esw_psample, struct mlx5_sample_restore *restore)
+{
+	mutex_lock(&esw_psample->restore_lock);
+	if (--restore->count == 0)
+		hash_del(&restore->hlist);
+	mutex_unlock(&esw_psample->restore_lock);
+
+	if (!restore->count) {
+		mlx5_del_flow_rules(restore->rule);
+		mlx5_modify_header_dealloc(esw_psample->priv->mdev, restore->modify_hdr);
+		kfree(restore);
+	}
+}
+
 struct mlx5_esw_psample *
 mlx5_esw_sample_init(struct mlx5e_priv *priv)
 {
@@ -207,6 +313,7 @@ mlx5_esw_sample_init(struct mlx5e_priv *priv)
 		goto err_termtbl;
 
 	mutex_init(&esw_psample->ht_lock);
+	mutex_init(&esw_psample->restore_lock);
 
 	return esw_psample;
 
@@ -221,6 +328,7 @@ mlx5_esw_sample_cleanup(struct mlx5_esw_psample *esw_psample)
 	if (IS_ERR_OR_NULL(esw_psample))
 		return;
 
+	mutex_destroy(&esw_psample->restore_lock);
 	mutex_destroy(&esw_psample->ht_lock);
 	sampler_termtbl_destroy(esw_psample);
 	kfree(esw_psample);