net: netfilter: Deduplicate code in bpf_{xdp,skb}_ct_lookup
authorKumar Kartikeya Dwivedi <memxor@gmail.com>
Thu, 21 Jul 2022 13:42:38 +0000 (15:42 +0200)
committerAlexei Starovoitov <ast@kernel.org>
Fri, 22 Jul 2022 04:03:16 +0000 (21:03 -0700)
Move common checks inside the common function, and maintain the only
difference the two being how to obtain the struct net * from ctx.
No functional change intended.

Signed-off-by: Kumar Kartikeya Dwivedi <memxor@gmail.com>
Link: https://lore.kernel.org/r/20220721134245.2450-7-memxor@gmail.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
net/netfilter/nf_conntrack_bpf.c

index cf2096f65d0e690575ff2bb3930e7b16b12f458a..16304869264f1b00571c2fb352637de6d4d284e6 100644 (file)
@@ -57,16 +57,19 @@ enum {
 
 static struct nf_conn *__bpf_nf_ct_lookup(struct net *net,
                                          struct bpf_sock_tuple *bpf_tuple,
-                                         u32 tuple_len, u8 protonum,
-                                         s32 netns_id, u8 *dir)
+                                         u32 tuple_len, struct bpf_ct_opts *opts,
+                                         u32 opts_len)
 {
        struct nf_conntrack_tuple_hash *hash;
        struct nf_conntrack_tuple tuple;
        struct nf_conn *ct;
 
-       if (unlikely(protonum != IPPROTO_TCP && protonum != IPPROTO_UDP))
+       if (!opts || !bpf_tuple || opts->reserved[0] || opts->reserved[1] ||
+           opts_len != NF_BPF_CT_OPTS_SZ)
+               return ERR_PTR(-EINVAL);
+       if (unlikely(opts->l4proto != IPPROTO_TCP && opts->l4proto != IPPROTO_UDP))
                return ERR_PTR(-EPROTO);
-       if (unlikely(netns_id < BPF_F_CURRENT_NETNS))
+       if (unlikely(opts->netns_id < BPF_F_CURRENT_NETNS))
                return ERR_PTR(-EINVAL);
 
        memset(&tuple, 0, sizeof(tuple));
@@ -89,23 +92,22 @@ static struct nf_conn *__bpf_nf_ct_lookup(struct net *net,
                return ERR_PTR(-EAFNOSUPPORT);
        }
 
-       tuple.dst.protonum = protonum;
+       tuple.dst.protonum = opts->l4proto;
 
-       if (netns_id >= 0) {
-               net = get_net_ns_by_id(net, netns_id);
+       if (opts->netns_id >= 0) {
+               net = get_net_ns_by_id(net, opts->netns_id);
                if (unlikely(!net))
                        return ERR_PTR(-ENONET);
        }
 
        hash = nf_conntrack_find_get(net, &nf_ct_zone_dflt, &tuple);
-       if (netns_id >= 0)
+       if (opts->netns_id >= 0)
                put_net(net);
        if (!hash)
                return ERR_PTR(-ENOENT);
 
        ct = nf_ct_tuplehash_to_ctrack(hash);
-       if (dir)
-               *dir = NF_CT_DIRECTION(hash);
+       opts->dir = NF_CT_DIRECTION(hash);
 
        return ct;
 }
@@ -138,20 +140,11 @@ bpf_xdp_ct_lookup(struct xdp_md *xdp_ctx, struct bpf_sock_tuple *bpf_tuple,
        struct net *caller_net;
        struct nf_conn *nfct;
 
-       BUILD_BUG_ON(sizeof(struct bpf_ct_opts) != NF_BPF_CT_OPTS_SZ);
-
-       if (!opts)
-               return NULL;
-       if (!bpf_tuple || opts->reserved[0] || opts->reserved[1] ||
-           opts__sz != NF_BPF_CT_OPTS_SZ) {
-               opts->error = -EINVAL;
-               return NULL;
-       }
        caller_net = dev_net(ctx->rxq->dev);
-       nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts->l4proto,
-                                 opts->netns_id, &opts->dir);
+       nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts, opts__sz);
        if (IS_ERR(nfct)) {
-               opts->error = PTR_ERR(nfct);
+               if (opts)
+                       opts->error = PTR_ERR(nfct);
                return NULL;
        }
        return nfct;
@@ -181,20 +174,11 @@ bpf_skb_ct_lookup(struct __sk_buff *skb_ctx, struct bpf_sock_tuple *bpf_tuple,
        struct net *caller_net;
        struct nf_conn *nfct;
 
-       BUILD_BUG_ON(sizeof(struct bpf_ct_opts) != NF_BPF_CT_OPTS_SZ);
-
-       if (!opts)
-               return NULL;
-       if (!bpf_tuple || opts->reserved[0] || opts->reserved[1] ||
-           opts__sz != NF_BPF_CT_OPTS_SZ) {
-               opts->error = -EINVAL;
-               return NULL;
-       }
        caller_net = skb->dev ? dev_net(skb->dev) : sock_net(skb->sk);
-       nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts->l4proto,
-                                 opts->netns_id, &opts->dir);
+       nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts, opts__sz);
        if (IS_ERR(nfct)) {
-               opts->error = PTR_ERR(nfct);
+               if (opts)
+                       opts->error = PTR_ERR(nfct);
                return NULL;
        }
        return nfct;