diff --git a/net/sched/cls_flower.c b/net/sched/cls_flower.c
index d964e60c730eafb79a8a9437325e14c8d6288e40..eacaaf803914cdc77372d9e661d8d1bb63096be3 100644
--- a/net/sched/cls_flower.c
+++ b/net/sched/cls_flower.c
@@ -61,16 +61,18 @@ struct fl_flow_mask_range {
 struct fl_flow_mask {
 	struct fl_flow_key key;
 	struct fl_flow_mask_range range;
-	struct rcu_head	rcu;
+	struct rhash_head ht_node;
+	struct rhashtable ht;
+	struct rhashtable_params filter_ht_params;
+	struct flow_dissector dissector;
+	struct list_head filters;
+	struct rcu_head rcu;
+	struct list_head list;
 };
 
 struct cls_fl_head {
 	struct rhashtable ht;
-	struct fl_flow_mask mask;
-	struct flow_dissector dissector;
-	bool mask_assigned;
-	struct list_head filters;
-	struct rhashtable_params ht_params;
+	struct list_head masks;
 	union {
 		struct work_struct work;
 		struct rcu_head	rcu;
@@ -79,6 +81,7 @@ struct cls_fl_head {
 };
 
 struct cls_fl_filter {
+	struct fl_flow_mask *mask;
 	struct rhash_head ht_node;
 	struct fl_flow_key mkey;
 	struct tcf_exts exts;
@@ -94,6 +97,13 @@ struct cls_fl_filter {
 	struct net_device *hw_dev;
 };
 
+static const struct rhashtable_params mask_ht_params = {
+	.key_offset = offsetof(struct fl_flow_mask, key),
+	.key_len = sizeof(struct fl_flow_key),
+	.head_offset = offsetof(struct fl_flow_mask, ht_node),
+	.automatic_shrinking = true,
+};
+
 static unsigned short int fl_mask_range(const struct fl_flow_mask *mask)
 {
 	return mask->range.end - mask->range.start;
@@ -103,13 +113,19 @@ static void fl_mask_update_range(struct fl_flow_mask *mask)
 {
 	const u8 *bytes = (const u8 *) &mask->key;
 	size_t size = sizeof(mask->key);
-	size_t i, first = 0, last = size - 1;
+	size_t i, first = 0, last;
 
-	for (i = 0; i < sizeof(mask->key); i++) {
+	for (i = 0; i < size; i++) {
+		if (bytes[i]) {
+			first = i;
+			break;
+		}
+	}
+	last = first;
+	for (i = size - 1; i != first; i--) {
 		if (bytes[i]) {
-			if (!first && i)
-				first = i;
 			last = i;
+			break;
 		}
 	}
 	mask->range.start = rounddown(first, sizeof(long));
@@ -140,12 +156,11 @@ static void fl_clear_masked_range(struct fl_flow_key *key,
 	memset(fl_key_get_start(key, mask), 0, fl_mask_range(mask));
 }
 
-static struct cls_fl_filter *fl_lookup(struct cls_fl_head *head,
+static struct cls_fl_filter *fl_lookup(struct fl_flow_mask *mask,
 				       struct fl_flow_key *mkey)
 {
-	return rhashtable_lookup_fast(&head->ht,
-				      fl_key_get_start(mkey, &head->mask),
-				      head->ht_params);
+	return rhashtable_lookup_fast(&mask->ht, fl_key_get_start(mkey, mask),
+				      mask->filter_ht_params);
 }
 
 static int fl_classify(struct sk_buff *skb, const struct tcf_proto *tp,
@@ -153,28 +168,28 @@ static int fl_classify(struct sk_buff *skb, const struct tcf_proto *tp,
 {
 	struct cls_fl_head *head = rcu_dereference_bh(tp->root);
 	struct cls_fl_filter *f;
+	struct fl_flow_mask *mask;
 	struct fl_flow_key skb_key;
 	struct fl_flow_key skb_mkey;
 
-	if (!atomic_read(&head->ht.nelems))
-		return -1;
-
-	fl_clear_masked_range(&skb_key, &head->mask);
+	list_for_each_entry_rcu(mask, &head->masks, list) {
+		fl_clear_masked_range(&skb_key, mask);
 
-	skb_key.indev_ifindex = skb->skb_iif;
-	/* skb_flow_dissect() does not set n_proto in case an unknown protocol,
-	 * so do it rather here.
-	 */
-	skb_key.basic.n_proto = skb->protocol;
-	skb_flow_dissect_tunnel_info(skb, &head->dissector, &skb_key);
-	skb_flow_dissect(skb, &head->dissector, &skb_key, 0);
+		skb_key.indev_ifindex = skb->skb_iif;
+		/* skb_flow_dissect() does not set n_proto in case an unknown
+		 * protocol, so do it rather here.
+		 */
+		skb_key.basic.n_proto = skb->protocol;
+		skb_flow_dissect_tunnel_info(skb, &mask->dissector, &skb_key);
+		skb_flow_dissect(skb, &mask->dissector, &skb_key, 0);
 
-	fl_set_masked_key(&skb_mkey, &skb_key, &head->mask);
+		fl_set_masked_key(&skb_mkey, &skb_key, mask);
 
-	f = fl_lookup(head, &skb_mkey);
-	if (f && !tc_skip_sw(f->flags)) {
-		*res = f->res;
-		return tcf_exts_exec(skb, &f->exts, res);
+		f = fl_lookup(mask, &skb_mkey);
+		if (f && !tc_skip_sw(f->flags)) {
+			*res = f->res;
+			return tcf_exts_exec(skb, &f->exts, res);
+		}
 	}
 	return -1;
 }
@@ -187,11 +202,28 @@ static int fl_init(struct tcf_proto *tp)
 	if (!head)
 		return -ENOBUFS;
 
-	INIT_LIST_HEAD_RCU(&head->filters);
+	INIT_LIST_HEAD_RCU(&head->masks);
 	rcu_assign_pointer(tp->root, head);
 	idr_init(&head->handle_idr);
 
-	return 0;
+	return rhashtable_init(&head->ht, &mask_ht_params);
+}
+
+static bool fl_mask_put(struct cls_fl_head *head, struct fl_flow_mask *mask,
+			bool async)
+{
+	if (!list_empty(&mask->filters))
+		return false;
+
+	rhashtable_remove_fast(&head->ht, &mask->ht_node, mask_ht_params);
+	rhashtable_destroy(&mask->ht);
+	list_del_rcu(&mask->list);
+	if (async)
+		kfree_rcu(mask, rcu);
+	else
+		kfree(mask);
+
+	return true;
 }
 
 static void __fl_destroy_filter(struct cls_fl_filter *f)
@@ -234,8 +266,6 @@ static void fl_hw_destroy_filter(struct tcf_proto *tp, struct cls_fl_filter *f,
 }
 
 static int fl_hw_replace_filter(struct tcf_proto *tp,
-				struct flow_dissector *dissector,
-				struct fl_flow_key *mask,
 				struct cls_fl_filter *f,
 				struct netlink_ext_ack *extack)
 {
@@ -247,8 +277,8 @@ static int fl_hw_replace_filter(struct tcf_proto *tp,
 	tc_cls_common_offload_init(&cls_flower.common, tp, f->flags, extack);
 	cls_flower.command = TC_CLSFLOWER_REPLACE;
 	cls_flower.cookie = (unsigned long) f;
-	cls_flower.dissector = dissector;
-	cls_flower.mask = mask;
+	cls_flower.dissector = &f->mask->dissector;
+	cls_flower.mask = &f->mask->key;
 	cls_flower.key = &f->mkey;
 	cls_flower.exts = &f->exts;
 	cls_flower.classid = f->res.classid;
@@ -283,28 +313,31 @@ static void fl_hw_update_stats(struct tcf_proto *tp, struct cls_fl_filter *f)
 			 &cls_flower, false);
 }
 
-static void __fl_delete(struct tcf_proto *tp, struct cls_fl_filter *f,
+static bool __fl_delete(struct tcf_proto *tp, struct cls_fl_filter *f,
 			struct netlink_ext_ack *extack)
 {
 	struct cls_fl_head *head = rtnl_dereference(tp->root);
+	bool async = tcf_exts_get_net(&f->exts);
+	bool last;
 
 	idr_remove(&head->handle_idr, f->handle);
 	list_del_rcu(&f->list);
+	last = fl_mask_put(head, f->mask, async);
 	if (!tc_skip_hw(f->flags))
 		fl_hw_destroy_filter(tp, f, extack);
 	tcf_unbind_filter(tp, &f->res);
-	if (tcf_exts_get_net(&f->exts))
+	if (async)
 		call_rcu(&f->rcu, fl_destroy_filter);
 	else
 		__fl_destroy_filter(f);
+
+	return last;
 }
 
 static void fl_destroy_sleepable(struct work_struct *work)
 {
 	struct cls_fl_head *head = container_of(work, struct cls_fl_head,
 						work);
-	if (head->mask_assigned)
-		rhashtable_destroy(&head->ht);
 	kfree(head);
 	module_put(THIS_MODULE);
 }
@@ -320,10 +353,15 @@ static void fl_destroy_rcu(struct rcu_head *rcu)
 static void fl_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack)
 {
 	struct cls_fl_head *head = rtnl_dereference(tp->root);
+	struct fl_flow_mask *mask, *next_mask;
 	struct cls_fl_filter *f, *next;
 
-	list_for_each_entry_safe(f, next, &head->filters, list)
-		__fl_delete(tp, f, extack);
+	list_for_each_entry_safe(mask, next_mask, &head->masks, list) {
+		list_for_each_entry_safe(f, next, &mask->filters, list) {
+			if (__fl_delete(tp, f, extack))
+				break;
+		}
+	}
 	idr_destroy(&head->handle_idr);
 
 	__module_get(THIS_MODULE);
@@ -715,14 +753,14 @@ static int fl_set_key(struct net *net, struct nlattr **tb,
 	return ret;
 }
 
-static bool fl_mask_eq(struct fl_flow_mask *mask1,
-		       struct fl_flow_mask *mask2)
+static void fl_mask_copy(struct fl_flow_mask *dst,
+			 struct fl_flow_mask *src)
 {
-	const long *lmask1 = fl_key_get_start(&mask1->key, mask1);
-	const long *lmask2 = fl_key_get_start(&mask2->key, mask2);
+	const void *psrc = fl_key_get_start(&src->key, src);
+	void *pdst = fl_key_get_start(&dst->key, src);
 
-	return !memcmp(&mask1->range, &mask2->range, sizeof(mask1->range)) &&
-	       !memcmp(lmask1, lmask2, fl_mask_range(mask1));
+	memcpy(pdst, psrc, fl_mask_range(src));
+	dst->range = src->range;
 }
 
 static const struct rhashtable_params fl_ht_params = {
@@ -731,14 +769,13 @@ static const struct rhashtable_params fl_ht_params = {
 	.automatic_shrinking = true,
 };
 
-static int fl_init_hashtable(struct cls_fl_head *head,
-			     struct fl_flow_mask *mask)
+static int fl_init_mask_hashtable(struct fl_flow_mask *mask)
 {
-	head->ht_params = fl_ht_params;
-	head->ht_params.key_len = fl_mask_range(mask);
-	head->ht_params.key_offset += mask->range.start;
+	mask->filter_ht_params = fl_ht_params;
+	mask->filter_ht_params.key_len = fl_mask_range(mask);
+	mask->filter_ht_params.key_offset += mask->range.start;
 
-	return rhashtable_init(&head->ht, &head->ht_params);
+	return rhashtable_init(&mask->ht, &mask->filter_ht_params);
 }
 
 #define FL_KEY_MEMBER_OFFSET(member) offsetof(struct fl_flow_key, member)
@@ -761,8 +798,7 @@ static int fl_init_hashtable(struct cls_fl_head *head,
 			FL_KEY_SET(keys, cnt, id, member);			\
 	} while(0);
 
-static void fl_init_dissector(struct cls_fl_head *head,
-			      struct fl_flow_mask *mask)
+static void fl_init_dissector(struct fl_flow_mask *mask)
 {
 	struct flow_dissector_key keys[FLOW_DISSECTOR_KEY_MAX];
 	size_t cnt = 0;
@@ -802,31 +838,66 @@ static void fl_init_dissector(struct cls_fl_head *head,
 	FL_KEY_SET_IF_MASKED(&mask->key, keys, cnt,
 			     FLOW_DISSECTOR_KEY_ENC_PORTS, enc_tp);
 
-	skb_flow_dissector_init(&head->dissector, keys, cnt);
+	skb_flow_dissector_init(&mask->dissector, keys, cnt);
+}
+
+static struct fl_flow_mask *fl_create_new_mask(struct cls_fl_head *head,
+					       struct fl_flow_mask *mask)
+{
+	struct fl_flow_mask *newmask;
+	int err;
+
+	newmask = kzalloc(sizeof(*newmask), GFP_KERNEL);
+	if (!newmask)
+		return ERR_PTR(-ENOMEM);
+
+	fl_mask_copy(newmask, mask);
+
+	err = fl_init_mask_hashtable(newmask);
+	if (err)
+		goto errout_free;
+
+	fl_init_dissector(newmask);
+
+	INIT_LIST_HEAD_RCU(&newmask->filters);
+
+	err = rhashtable_insert_fast(&head->ht, &newmask->ht_node,
+				     mask_ht_params);
+	if (err)
+		goto errout_destroy;
+
+	list_add_tail_rcu(&newmask->list, &head->masks);
+
+	return newmask;
+
+errout_destroy:
+	rhashtable_destroy(&newmask->ht);
+errout_free:
+	kfree(newmask);
+
+	return ERR_PTR(err);
 }
 
 static int fl_check_assign_mask(struct cls_fl_head *head,
+				struct cls_fl_filter *fnew,
+				struct cls_fl_filter *fold,
 				struct fl_flow_mask *mask)
 {
-	int err;
+	struct fl_flow_mask *newmask;
 
-	if (head->mask_assigned) {
-		if (!fl_mask_eq(&head->mask, mask))
+	fnew->mask = rhashtable_lookup_fast(&head->ht, mask, mask_ht_params);
+	if (!fnew->mask) {
+		if (fold)
 			return -EINVAL;
-		else
-			return 0;
-	}
 
-	/* Mask is not assigned yet. So assign it and init hashtable
-	 * according to that.
-	 */
-	err = fl_init_hashtable(head, mask);
-	if (err)
-		return err;
-	memcpy(&head->mask, mask, sizeof(head->mask));
-	head->mask_assigned = true;
+		newmask = fl_create_new_mask(head, mask);
+		if (IS_ERR(newmask))
+			return PTR_ERR(newmask);
 
-	fl_init_dissector(head, mask);
+		fnew->mask = newmask;
+	} else if (fold && fold->mask == fnew->mask) {
+		return -EINVAL;
+	}
 
 	return 0;
 }
@@ -924,30 +995,26 @@ static int fl_change(struct net *net, struct sk_buff *in_skb,
 	if (err)
 		goto errout_idr;
 
-	err = fl_check_assign_mask(head, &mask);
+	err = fl_check_assign_mask(head, fnew, fold, &mask);
 	if (err)
 		goto errout_idr;
 
 	if (!tc_skip_sw(fnew->flags)) {
-		if (!fold && fl_lookup(head, &fnew->mkey)) {
+		if (!fold && fl_lookup(fnew->mask, &fnew->mkey)) {
 			err = -EEXIST;
-			goto errout_idr;
+			goto errout_mask;
 		}
 
-		err = rhashtable_insert_fast(&head->ht, &fnew->ht_node,
-					     head->ht_params);
+		err = rhashtable_insert_fast(&fnew->mask->ht, &fnew->ht_node,
+					     fnew->mask->filter_ht_params);
 		if (err)
-			goto errout_idr;
+			goto errout_mask;
 	}
 
 	if (!tc_skip_hw(fnew->flags)) {
-		err = fl_hw_replace_filter(tp,
-					   &head->dissector,
-					   &mask.key,
-					   fnew,
-					   extack);
+		err = fl_hw_replace_filter(tp, fnew, extack);
 		if (err)
-			goto errout_idr;
+			goto errout_mask;
 	}
 
 	if (!tc_in_hw(fnew->flags))
@@ -955,8 +1022,9 @@ static int fl_change(struct net *net, struct sk_buff *in_skb,
 
 	if (fold) {
 		if (!tc_skip_sw(fold->flags))
-			rhashtable_remove_fast(&head->ht, &fold->ht_node,
-					       head->ht_params);
+			rhashtable_remove_fast(&fold->mask->ht,
+					       &fold->ht_node,
+					       fold->mask->filter_ht_params);
 		if (!tc_skip_hw(fold->flags))
 			fl_hw_destroy_filter(tp, fold, NULL);
 	}
@@ -970,12 +1038,15 @@ static int fl_change(struct net *net, struct sk_buff *in_skb,
 		tcf_exts_get_net(&fold->exts);
 		call_rcu(&fold->rcu, fl_destroy_filter);
 	} else {
-		list_add_tail_rcu(&fnew->list, &head->filters);
+		list_add_tail_rcu(&fnew->list, &fnew->mask->filters);
 	}
 
 	kfree(tb);
 	return 0;
 
+errout_mask:
+	fl_mask_put(head, fnew->mask, false);
+
 errout_idr:
 	if (fnew->handle)
 		idr_remove(&head->handle_idr, fnew->handle);
@@ -994,10 +1065,10 @@ static int fl_delete(struct tcf_proto *tp, void *arg, bool *last,
 	struct cls_fl_filter *f = arg;
 
 	if (!tc_skip_sw(f->flags))
-		rhashtable_remove_fast(&head->ht, &f->ht_node,
-				       head->ht_params);
+		rhashtable_remove_fast(&f->mask->ht, &f->ht_node,
+				       f->mask->filter_ht_params);
 	__fl_delete(tp, f, extack);
-	*last = list_empty(&head->filters);
+	*last = list_empty(&head->masks);
 	return 0;
 }
 
@@ -1005,16 +1076,19 @@ static void fl_walk(struct tcf_proto *tp, struct tcf_walker *arg)
 {
 	struct cls_fl_head *head = rtnl_dereference(tp->root);
 	struct cls_fl_filter *f;
-
-	list_for_each_entry_rcu(f, &head->filters, list) {
-		if (arg->count < arg->skip)
-			goto skip;
-		if (arg->fn(tp, f, arg) < 0) {
-			arg->stop = 1;
-			break;
-		}
+	struct fl_flow_mask *mask;
+
+	list_for_each_entry_rcu(mask, &head->masks, list) {
+		list_for_each_entry_rcu(f, &mask->filters, list) {
+			if (arg->count < arg->skip)
+				goto skip;
+			if (arg->fn(tp, f, arg) < 0) {
+				arg->stop = 1;
+				break;
+			}
 skip:
-		arg->count++;
+			arg->count++;
+		}
 	}
 }
 
@@ -1150,7 +1224,6 @@ static int fl_dump_key_flags(struct sk_buff *skb, u32 flags_key, u32 flags_mask)
 static int fl_dump(struct net *net, struct tcf_proto *tp, void *fh,
 		   struct sk_buff *skb, struct tcmsg *t)
 {
-	struct cls_fl_head *head = rtnl_dereference(tp->root);
 	struct cls_fl_filter *f = fh;
 	struct nlattr *nest;
 	struct fl_flow_key *key, *mask;
@@ -1169,7 +1242,7 @@ static int fl_dump(struct net *net, struct tcf_proto *tp, void *fh,
 		goto nla_put_failure;
 
 	key = &f->key;
-	mask = &head->mask.key;
+	mask = &f->mask->key;
 
 	if (mask->indev_ifindex) {
 		struct net_device *dev;