net: genetlink: push attrbuf allocation and parsing to a separate function
authorJiri Pirko <jiri@mellanox.com>
Sat, 5 Oct 2019 18:04:35 +0000 (20:04 +0200)
committerDavid S. Miller <davem@davemloft.net>
Sun, 6 Oct 2019 13:44:46 +0000 (15:44 +0200)
To be re-usable by dumpit as well, push the code that is taking care of
attrbuf allocation and parting from doit into separate function.
Introduce a helper to free the buffer too.

Check family->maxattr too before calling kfree() to be symmetrical with
the allocation check.

Signed-off-by: Jiri Pirko <jiri@mellanox.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/netlink/genetlink.c

index c785080e94012e08d3740f6d5ada22684d27b7bc..a98c94594508c30d28514f9c89bf636c38184750 100644 (file)
@@ -468,6 +468,45 @@ static void genl_dumpit_info_free(const struct genl_dumpit_info *info)
        kfree(info);
 }
 
+static struct nlattr **
+genl_family_rcv_msg_attrs_parse(const struct genl_family *family,
+                               struct nlmsghdr *nlh,
+                               struct netlink_ext_ack *extack,
+                               const struct genl_ops *ops,
+                               int hdrlen,
+                               enum genl_validate_flags no_strict_flag)
+{
+       enum netlink_validation validate = ops->validate & no_strict_flag ?
+                                          NL_VALIDATE_LIBERAL :
+                                          NL_VALIDATE_STRICT;
+       struct nlattr **attrbuf;
+       int err;
+
+       if (family->maxattr && family->parallel_ops) {
+               attrbuf = kmalloc_array(family->maxattr + 1,
+                                       sizeof(struct nlattr *), GFP_KERNEL);
+               if (!attrbuf)
+                       return ERR_PTR(-ENOMEM);
+       } else {
+               attrbuf = family->attrbuf;
+       }
+
+       err = __nlmsg_parse(nlh, hdrlen, attrbuf, family->maxattr,
+                           family->policy, validate, extack);
+       if (err && family->maxattr && family->parallel_ops) {
+               kfree(attrbuf);
+               return ERR_PTR(err);
+       }
+       return attrbuf;
+}
+
+static void genl_family_rcv_msg_attrs_free(const struct genl_family *family,
+                                          struct nlattr **attrbuf)
+{
+       if (family->maxattr && family->parallel_ops)
+               kfree(attrbuf);
+}
+
 static int genl_lock_start(struct netlink_callback *cb)
 {
        const struct genl_ops *ops = genl_dumpit_info(cb)->ops;
@@ -599,26 +638,11 @@ static int genl_family_rcv_msg_doit(const struct genl_family *family,
        if (!ops->doit)
                return -EOPNOTSUPP;
 
-       if (family->maxattr && family->parallel_ops) {
-               attrbuf = kmalloc_array(family->maxattr + 1,
-                                       sizeof(struct nlattr *),
-                                       GFP_KERNEL);
-               if (attrbuf == NULL)
-                       return -ENOMEM;
-       } else
-               attrbuf = family->attrbuf;
-
-       if (attrbuf) {
-               enum netlink_validation validate = NL_VALIDATE_STRICT;
-
-               if (ops->validate & GENL_DONT_VALIDATE_STRICT)
-                       validate = NL_VALIDATE_LIBERAL;
-
-               err = __nlmsg_parse(nlh, hdrlen, attrbuf, family->maxattr,
-                                   family->policy, validate, extack);
-               if (err < 0)
-                       goto out;
-       }
+       attrbuf = genl_family_rcv_msg_attrs_parse(family, nlh, extack,
+                                                 ops, hdrlen,
+                                                 GENL_DONT_VALIDATE_STRICT);
+       if (IS_ERR(attrbuf))
+               return PTR_ERR(attrbuf);
 
        info.snd_seq = nlh->nlmsg_seq;
        info.snd_portid = NETLINK_CB(skb).portid;
@@ -642,8 +666,7 @@ static int genl_family_rcv_msg_doit(const struct genl_family *family,
                family->post_doit(ops, skb, &info);
 
 out:
-       if (family->parallel_ops)
-               kfree(attrbuf);
+       genl_family_rcv_msg_attrs_free(family, attrbuf);
 
        return err;
 }