diff --git a/net/dsa/dsa2.c b/net/dsa/dsa2.c
index 8b68dc2f570764f215dd8bac6bf7aff8742d15db..d3f1a760746382f38a656191c22d533a2c29fa58 100644
--- a/net/dsa/dsa2.c
+++ b/net/dsa/dsa2.c
@@ -32,10 +32,9 @@ static struct dsa_switch_tree *dsa_get_dst(unsigned int index)
 	struct dsa_switch_tree *dst;
 
 	list_for_each_entry(dst, &dsa_switch_trees, list)
-		if (dst->index == index) {
-			kref_get(&dst->refcount);
+		if (dst->index == index)
 			return dst;
-		}
+
 	return NULL;
 }
 
@@ -48,11 +47,6 @@ static void dsa_free_dst(struct kref *ref)
 	kfree(dst);
 }
 
-static void dsa_put_dst(struct dsa_switch_tree *dst)
-{
-	kref_put(&dst->refcount, dsa_free_dst);
-}
-
 static struct dsa_switch_tree *dsa_add_dst(unsigned int index)
 {
 	struct dsa_switch_tree *dst;
@@ -63,7 +57,10 @@ static struct dsa_switch_tree *dsa_add_dst(unsigned int index)
 	dst->index = index;
 	INIT_LIST_HEAD(&dst->list);
 	list_add_tail(&dsa_switch_trees, &dst->list);
+
+	/* Initialize the reference counter to the number of switches, not 1 */
 	kref_init(&dst->refcount);
+	refcount_set(&dst->refcount.refcount, 0);
 
 	return dst;
 }
@@ -739,10 +736,8 @@ static int _dsa_register_switch(struct dsa_switch *ds)
 			return -ENOMEM;
 	}
 
-	if (dst->ds[index]) {
-		err = -EBUSY;
-		goto out;
-	}
+	if (dst->ds[index])
+		return -EBUSY;
 
 	ds->dst = dst;
 	ds->index = index;
@@ -758,11 +753,9 @@ static int _dsa_register_switch(struct dsa_switch *ds)
 	if (err < 0)
 		goto out_del_dst;
 
-	if (err == 1) {
-		/* Not all switches registered yet */
-		err = 0;
-		goto out;
-	}
+	/* Not all switches registered yet */
+	if (err == 1)
+		return 0;
 
 	if (dst->applied) {
 		pr_info("DSA: Disjoint trees?\n");
@@ -779,13 +772,10 @@ static int _dsa_register_switch(struct dsa_switch *ds)
 		goto out_del_dst;
 	}
 
-	dsa_put_dst(dst);
 	return 0;
 
 out_del_dst:
 	dsa_dst_del_ds(dst, ds, ds->index);
-out:
-	dsa_put_dst(dst);
 
 	return err;
 }