net: sched: introduce reference counting for tcf_proto
authorVlad Buslov <vladbu@mellanox.com>
Mon, 11 Feb 2019 08:55:39 +0000 (10:55 +0200)
committerDavid S. Miller <davem@davemloft.net>
Tue, 12 Feb 2019 18:41:32 +0000 (13:41 -0500)
In order to remove dependency on rtnl lock and allow concurrent tcf_proto
modification, extend tcf_proto with reference counter. Implement helper
get/put functions for tcf proto and use them to modify cls API to always
take reference to tcf_proto while using it. Only release reference to
parent chain after releasing last reference to tp.

Signed-off-by: Vlad Buslov <vladbu@mellanox.com>
Acked-by: Jiri Pirko <jiri@mellanox.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/sch_generic.h
net/sched/cls_api.c

index 85993d7efee6eb38dae70c269a7b9558617faa57..4372c08fc4d90a9bb4f62f59ff5d669eec2beb1f 100644 (file)
@@ -322,6 +322,7 @@ struct tcf_proto {
        void                    *data;
        const struct tcf_proto_ops      *ops;
        struct tcf_chain        *chain;
+       refcount_t              refcnt;
        struct rcu_head         rcu;
 };
 
index 3fce30ae9a9b3b333d0574cf9d6c57cb1fffc410..37c05b96898f03f1f2a42e56b414c8f290a724c3 100644 (file)
@@ -180,6 +180,7 @@ static struct tcf_proto *tcf_proto_create(const char *kind, u32 protocol,
        tp->protocol = protocol;
        tp->prio = prio;
        tp->chain = chain;
+       refcount_set(&tp->refcnt, 1);
 
        err = tp->ops->init(tp);
        if (err) {
@@ -193,14 +194,29 @@ errout:
        return ERR_PTR(err);
 }
 
+static void tcf_proto_get(struct tcf_proto *tp)
+{
+       refcount_inc(&tp->refcnt);
+}
+
+static void tcf_chain_put(struct tcf_chain *chain);
+
 static void tcf_proto_destroy(struct tcf_proto *tp,
                              struct netlink_ext_ack *extack)
 {
        tp->ops->destroy(tp, extack);
+       tcf_chain_put(tp->chain);
        module_put(tp->ops->owner);
        kfree_rcu(tp, rcu);
 }
 
+static void tcf_proto_put(struct tcf_proto *tp,
+                         struct netlink_ext_ack *extack)
+{
+       if (refcount_dec_and_test(&tp->refcnt))
+               tcf_proto_destroy(tp, extack);
+}
+
 #define ASSERT_BLOCK_LOCKED(block)                                     \
        lockdep_assert_held(&(block)->lock)
 
@@ -445,18 +461,18 @@ static void tcf_chain_put_explicitly_created(struct tcf_chain *chain)
 
 static void tcf_chain_flush(struct tcf_chain *chain)
 {
-       struct tcf_proto *tp;
+       struct tcf_proto *tp, *tp_next;
 
        mutex_lock(&chain->filter_chain_lock);
        tp = tcf_chain_dereference(chain->filter_chain, chain);
+       RCU_INIT_POINTER(chain->filter_chain, NULL);
        tcf_chain0_head_change(chain, NULL);
        mutex_unlock(&chain->filter_chain_lock);
 
        while (tp) {
-               RCU_INIT_POINTER(chain->filter_chain, tp->next);
-               tcf_proto_destroy(tp, NULL);
-               tp = rtnl_dereference(chain->filter_chain);
-               tcf_chain_put(chain);
+               tp_next = rcu_dereference_protected(tp->next, 1);
+               tcf_proto_put(tp, NULL);
+               tp = tp_next;
        }
 }
 
@@ -1500,9 +1516,9 @@ static void tcf_chain_tp_insert(struct tcf_chain *chain,
 {
        if (*chain_info->pprev == chain->filter_chain)
                tcf_chain0_head_change(chain, tp);
+       tcf_proto_get(tp);
        RCU_INIT_POINTER(tp->next, tcf_chain_tp_prev(chain, chain_info));
        rcu_assign_pointer(*chain_info->pprev, tp);
-       tcf_chain_hold(chain);
 }
 
 static void tcf_chain_tp_remove(struct tcf_chain *chain,
@@ -1514,7 +1530,6 @@ static void tcf_chain_tp_remove(struct tcf_chain *chain,
        if (tp == chain->filter_chain)
                tcf_chain0_head_change(chain, next);
        RCU_INIT_POINTER(*chain_info->pprev, next);
-       tcf_chain_put(chain);
 }
 
 static struct tcf_proto *tcf_chain_tp_find(struct tcf_chain *chain,
@@ -1541,7 +1556,12 @@ static struct tcf_proto *tcf_chain_tp_find(struct tcf_chain *chain,
                }
        }
        chain_info->pprev = pprev;
-       chain_info->next = tp ? tp->next : NULL;
+       if (tp) {
+               chain_info->next = tp->next;
+               tcf_proto_get(tp);
+       } else {
+               chain_info->next = NULL;
+       }
        return tp;
 }
 
@@ -1699,6 +1719,7 @@ replay:
        prio = TC_H_MAJ(t->tcm_info);
        prio_allocate = false;
        parent = t->tcm_parent;
+       tp = NULL;
        cl = 0;
 
        if (prio == 0) {
@@ -1816,6 +1837,12 @@ replay:
 errout:
        if (chain)
                tcf_chain_put(chain);
+       if (chain) {
+               if (tp && !IS_ERR(tp))
+                       tcf_proto_put(tp, NULL);
+               if (!tp_created)
+                       tcf_chain_put(chain);
+       }
        tcf_block_release(q, block);
        if (err == -EAGAIN)
                /* Replay the request. */
@@ -1946,8 +1973,11 @@ static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n,
        }
 
 errout:
-       if (chain)
+       if (chain) {
+               if (tp && !IS_ERR(tp))
+                       tcf_proto_put(tp, NULL);
                tcf_chain_put(chain);
+       }
        tcf_block_release(q, block);
        return err;
 
@@ -2038,8 +2068,11 @@ static int tc_get_tfilter(struct sk_buff *skb, struct nlmsghdr *n,
        }
 
 errout:
-       if (chain)
+       if (chain) {
+               if (tp && !IS_ERR(tp))
+                       tcf_proto_put(tp, NULL);
                tcf_chain_put(chain);
+       }
        tcf_block_release(q, block);
        return err;
 }