Merge tag 'armsoc-soc' of git://git.kernel.org/pub/scm/linux/kernel/git/arm/arm-soc
[linux-2.6-block.git] / net / openvswitch / conntrack.c
index 02fc343feb665f1404703a4b2815752014e362fa..284aca2a252df5cf6fabb0afd931a0c7c3d53217 100644 (file)
 #include <linux/tcp.h>
 #include <linux/udp.h>
 #include <linux/sctp.h>
+#include <linux/static_key.h>
 #include <net/ip.h>
+#include <net/genetlink.h>
 #include <net/netfilter/nf_conntrack_core.h>
+#include <net/netfilter/nf_conntrack_count.h>
 #include <net/netfilter/nf_conntrack_helper.h>
 #include <net/netfilter/nf_conntrack_labels.h>
 #include <net/netfilter/nf_conntrack_seqadj.h>
@@ -76,6 +79,31 @@ struct ovs_conntrack_info {
 #endif
 };
 
+#if    IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT)
+#define OVS_CT_LIMIT_UNLIMITED 0
+#define OVS_CT_LIMIT_DEFAULT OVS_CT_LIMIT_UNLIMITED
+#define CT_LIMIT_HASH_BUCKETS 512
+static DEFINE_STATIC_KEY_FALSE(ovs_ct_limit_enabled);
+
+struct ovs_ct_limit {
+       /* Elements in ovs_ct_limit_info->limits hash table */
+       struct hlist_node hlist_node;
+       struct rcu_head rcu;
+       u16 zone;
+       u32 limit;
+};
+
+struct ovs_ct_limit_info {
+       u32 default_limit;
+       struct hlist_head *limits;
+       struct nf_conncount_data *data;
+};
+
+static const struct nla_policy ct_limit_policy[OVS_CT_LIMIT_ATTR_MAX + 1] = {
+       [OVS_CT_LIMIT_ATTR_ZONE_LIMIT] = { .type = NLA_NESTED, },
+};
+#endif
+
 static bool labels_nonzero(const struct ovs_key_ct_labels *labels);
 
 static void __ovs_ct_free_action(struct ovs_conntrack_info *ct_info);
@@ -1036,6 +1064,89 @@ static bool labels_nonzero(const struct ovs_key_ct_labels *labels)
        return false;
 }
 
+#if    IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT)
+static struct hlist_head *ct_limit_hash_bucket(
+       const struct ovs_ct_limit_info *info, u16 zone)
+{
+       return &info->limits[zone & (CT_LIMIT_HASH_BUCKETS - 1)];
+}
+
+/* Call with ovs_mutex */
+static void ct_limit_set(const struct ovs_ct_limit_info *info,
+                        struct ovs_ct_limit *new_ct_limit)
+{
+       struct ovs_ct_limit *ct_limit;
+       struct hlist_head *head;
+
+       head = ct_limit_hash_bucket(info, new_ct_limit->zone);
+       hlist_for_each_entry_rcu(ct_limit, head, hlist_node) {
+               if (ct_limit->zone == new_ct_limit->zone) {
+                       hlist_replace_rcu(&ct_limit->hlist_node,
+                                         &new_ct_limit->hlist_node);
+                       kfree_rcu(ct_limit, rcu);
+                       return;
+               }
+       }
+
+       hlist_add_head_rcu(&new_ct_limit->hlist_node, head);
+}
+
+/* Call with ovs_mutex */
+static void ct_limit_del(const struct ovs_ct_limit_info *info, u16 zone)
+{
+       struct ovs_ct_limit *ct_limit;
+       struct hlist_head *head;
+       struct hlist_node *n;
+
+       head = ct_limit_hash_bucket(info, zone);
+       hlist_for_each_entry_safe(ct_limit, n, head, hlist_node) {
+               if (ct_limit->zone == zone) {
+                       hlist_del_rcu(&ct_limit->hlist_node);
+                       kfree_rcu(ct_limit, rcu);
+                       return;
+               }
+       }
+}
+
+/* Call with RCU read lock */
+static u32 ct_limit_get(const struct ovs_ct_limit_info *info, u16 zone)
+{
+       struct ovs_ct_limit *ct_limit;
+       struct hlist_head *head;
+
+       head = ct_limit_hash_bucket(info, zone);
+       hlist_for_each_entry_rcu(ct_limit, head, hlist_node) {
+               if (ct_limit->zone == zone)
+                       return ct_limit->limit;
+       }
+
+       return info->default_limit;
+}
+
+static int ovs_ct_check_limit(struct net *net,
+                             const struct ovs_conntrack_info *info,
+                             const struct nf_conntrack_tuple *tuple)
+{
+       struct ovs_net *ovs_net = net_generic(net, ovs_net_id);
+       const struct ovs_ct_limit_info *ct_limit_info = ovs_net->ct_limit_info;
+       u32 per_zone_limit, connections;
+       u32 conncount_key;
+
+       conncount_key = info->zone.id;
+
+       per_zone_limit = ct_limit_get(ct_limit_info, info->zone.id);
+       if (per_zone_limit == OVS_CT_LIMIT_UNLIMITED)
+               return 0;
+
+       connections = nf_conncount_count(net, ct_limit_info->data,
+                                        &conncount_key, tuple, &info->zone);
+       if (connections > per_zone_limit)
+               return -ENOMEM;
+
+       return 0;
+}
+#endif
+
 /* Lookup connection and confirm if unconfirmed. */
 static int ovs_ct_commit(struct net *net, struct sw_flow_key *key,
                         const struct ovs_conntrack_info *info,
@@ -1054,6 +1165,21 @@ static int ovs_ct_commit(struct net *net, struct sw_flow_key *key,
        if (!ct)
                return 0;
 
+#if    IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT)
+       if (static_branch_unlikely(&ovs_ct_limit_enabled)) {
+               if (!nf_ct_is_confirmed(ct)) {
+                       err = ovs_ct_check_limit(net, info,
+                               &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple);
+                       if (err) {
+                               net_warn_ratelimited("openvswitch: zone: %u "
+                                       "execeeds conntrack limit\n",
+                                       info->zone.id);
+                               return err;
+                       }
+               }
+       }
+#endif
+
        /* Set the conntrack event mask if given.  NEW and DELETE events have
         * their own groups, but the NFNLGRP_CONNTRACK_UPDATE group listener
         * typically would receive many kinds of updates.  Setting the event
@@ -1655,7 +1781,420 @@ static void __ovs_ct_free_action(struct ovs_conntrack_info *ct_info)
                nf_ct_tmpl_free(ct_info->ct);
 }
 
-void ovs_ct_init(struct net *net)
+#if    IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT)
+static int ovs_ct_limit_init(struct net *net, struct ovs_net *ovs_net)
+{
+       int i, err;
+
+       ovs_net->ct_limit_info = kmalloc(sizeof(*ovs_net->ct_limit_info),
+                                        GFP_KERNEL);
+       if (!ovs_net->ct_limit_info)
+               return -ENOMEM;
+
+       ovs_net->ct_limit_info->default_limit = OVS_CT_LIMIT_DEFAULT;
+       ovs_net->ct_limit_info->limits =
+               kmalloc_array(CT_LIMIT_HASH_BUCKETS, sizeof(struct hlist_head),
+                             GFP_KERNEL);
+       if (!ovs_net->ct_limit_info->limits) {
+               kfree(ovs_net->ct_limit_info);
+               return -ENOMEM;
+       }
+
+       for (i = 0; i < CT_LIMIT_HASH_BUCKETS; i++)
+               INIT_HLIST_HEAD(&ovs_net->ct_limit_info->limits[i]);
+
+       ovs_net->ct_limit_info->data =
+               nf_conncount_init(net, NFPROTO_INET, sizeof(u32));
+
+       if (IS_ERR(ovs_net->ct_limit_info->data)) {
+               err = PTR_ERR(ovs_net->ct_limit_info->data);
+               kfree(ovs_net->ct_limit_info->limits);
+               kfree(ovs_net->ct_limit_info);
+               pr_err("openvswitch: failed to init nf_conncount %d\n", err);
+               return err;
+       }
+       return 0;
+}
+
+static void ovs_ct_limit_exit(struct net *net, struct ovs_net *ovs_net)
+{
+       const struct ovs_ct_limit_info *info = ovs_net->ct_limit_info;
+       int i;
+
+       nf_conncount_destroy(net, NFPROTO_INET, info->data);
+       for (i = 0; i < CT_LIMIT_HASH_BUCKETS; ++i) {
+               struct hlist_head *head = &info->limits[i];
+               struct ovs_ct_limit *ct_limit;
+
+               hlist_for_each_entry_rcu(ct_limit, head, hlist_node)
+                       kfree_rcu(ct_limit, rcu);
+       }
+       kfree(ovs_net->ct_limit_info->limits);
+       kfree(ovs_net->ct_limit_info);
+}
+
+static struct sk_buff *
+ovs_ct_limit_cmd_reply_start(struct genl_info *info, u8 cmd,
+                            struct ovs_header **ovs_reply_header)
+{
+       struct ovs_header *ovs_header = info->userhdr;
+       struct sk_buff *skb;
+
+       skb = genlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
+       if (!skb)
+               return ERR_PTR(-ENOMEM);
+
+       *ovs_reply_header = genlmsg_put(skb, info->snd_portid,
+                                       info->snd_seq,
+                                       &dp_ct_limit_genl_family, 0, cmd);
+
+       if (!*ovs_reply_header) {
+               nlmsg_free(skb);
+               return ERR_PTR(-EMSGSIZE);
+       }
+       (*ovs_reply_header)->dp_ifindex = ovs_header->dp_ifindex;
+
+       return skb;
+}
+
+static bool check_zone_id(int zone_id, u16 *pzone)
+{
+       if (zone_id >= 0 && zone_id <= 65535) {
+               *pzone = (u16)zone_id;
+               return true;
+       }
+       return false;
+}
+
+static int ovs_ct_limit_set_zone_limit(struct nlattr *nla_zone_limit,
+                                      struct ovs_ct_limit_info *info)
+{
+       struct ovs_zone_limit *zone_limit;
+       int rem;
+       u16 zone;
+
+       rem = NLA_ALIGN(nla_len(nla_zone_limit));
+       zone_limit = (struct ovs_zone_limit *)nla_data(nla_zone_limit);
+
+       while (rem >= sizeof(*zone_limit)) {
+               if (unlikely(zone_limit->zone_id ==
+                               OVS_ZONE_LIMIT_DEFAULT_ZONE)) {
+                       ovs_lock();
+                       info->default_limit = zone_limit->limit;
+                       ovs_unlock();
+               } else if (unlikely(!check_zone_id(
+                               zone_limit->zone_id, &zone))) {
+                       OVS_NLERR(true, "zone id is out of range");
+               } else {
+                       struct ovs_ct_limit *ct_limit;
+
+                       ct_limit = kmalloc(sizeof(*ct_limit), GFP_KERNEL);
+                       if (!ct_limit)
+                               return -ENOMEM;
+
+                       ct_limit->zone = zone;
+                       ct_limit->limit = zone_limit->limit;
+
+                       ovs_lock();
+                       ct_limit_set(info, ct_limit);
+                       ovs_unlock();
+               }
+               rem -= NLA_ALIGN(sizeof(*zone_limit));
+               zone_limit = (struct ovs_zone_limit *)((u8 *)zone_limit +
+                               NLA_ALIGN(sizeof(*zone_limit)));
+       }
+
+       if (rem)
+               OVS_NLERR(true, "set zone limit has %d unknown bytes", rem);
+
+       return 0;
+}
+
+static int ovs_ct_limit_del_zone_limit(struct nlattr *nla_zone_limit,
+                                      struct ovs_ct_limit_info *info)
+{
+       struct ovs_zone_limit *zone_limit;
+       int rem;
+       u16 zone;
+
+       rem = NLA_ALIGN(nla_len(nla_zone_limit));
+       zone_limit = (struct ovs_zone_limit *)nla_data(nla_zone_limit);
+
+       while (rem >= sizeof(*zone_limit)) {
+               if (unlikely(zone_limit->zone_id ==
+                               OVS_ZONE_LIMIT_DEFAULT_ZONE)) {
+                       ovs_lock();
+                       info->default_limit = OVS_CT_LIMIT_DEFAULT;
+                       ovs_unlock();
+               } else if (unlikely(!check_zone_id(
+                               zone_limit->zone_id, &zone))) {
+                       OVS_NLERR(true, "zone id is out of range");
+               } else {
+                       ovs_lock();
+                       ct_limit_del(info, zone);
+                       ovs_unlock();
+               }
+               rem -= NLA_ALIGN(sizeof(*zone_limit));
+               zone_limit = (struct ovs_zone_limit *)((u8 *)zone_limit +
+                               NLA_ALIGN(sizeof(*zone_limit)));
+       }
+
+       if (rem)
+               OVS_NLERR(true, "del zone limit has %d unknown bytes", rem);
+
+       return 0;
+}
+
+static int ovs_ct_limit_get_default_limit(struct ovs_ct_limit_info *info,
+                                         struct sk_buff *reply)
+{
+       struct ovs_zone_limit zone_limit;
+       int err;
+
+       zone_limit.zone_id = OVS_ZONE_LIMIT_DEFAULT_ZONE;
+       zone_limit.limit = info->default_limit;
+       err = nla_put_nohdr(reply, sizeof(zone_limit), &zone_limit);
+       if (err)
+               return err;
+
+       return 0;
+}
+
+static int __ovs_ct_limit_get_zone_limit(struct net *net,
+                                        struct nf_conncount_data *data,
+                                        u16 zone_id, u32 limit,
+                                        struct sk_buff *reply)
+{
+       struct nf_conntrack_zone ct_zone;
+       struct ovs_zone_limit zone_limit;
+       u32 conncount_key = zone_id;
+
+       zone_limit.zone_id = zone_id;
+       zone_limit.limit = limit;
+       nf_ct_zone_init(&ct_zone, zone_id, NF_CT_DEFAULT_ZONE_DIR, 0);
+
+       zone_limit.count = nf_conncount_count(net, data, &conncount_key, NULL,
+                                             &ct_zone);
+       return nla_put_nohdr(reply, sizeof(zone_limit), &zone_limit);
+}
+
+static int ovs_ct_limit_get_zone_limit(struct net *net,
+                                      struct nlattr *nla_zone_limit,
+                                      struct ovs_ct_limit_info *info,
+                                      struct sk_buff *reply)
+{
+       struct ovs_zone_limit *zone_limit;
+       int rem, err;
+       u32 limit;
+       u16 zone;
+
+       rem = NLA_ALIGN(nla_len(nla_zone_limit));
+       zone_limit = (struct ovs_zone_limit *)nla_data(nla_zone_limit);
+
+       while (rem >= sizeof(*zone_limit)) {
+               if (unlikely(zone_limit->zone_id ==
+                               OVS_ZONE_LIMIT_DEFAULT_ZONE)) {
+                       err = ovs_ct_limit_get_default_limit(info, reply);
+                       if (err)
+                               return err;
+               } else if (unlikely(!check_zone_id(zone_limit->zone_id,
+                                                       &zone))) {
+                       OVS_NLERR(true, "zone id is out of range");
+               } else {
+                       rcu_read_lock();
+                       limit = ct_limit_get(info, zone);
+                       rcu_read_unlock();
+
+                       err = __ovs_ct_limit_get_zone_limit(
+                               net, info->data, zone, limit, reply);
+                       if (err)
+                               return err;
+               }
+               rem -= NLA_ALIGN(sizeof(*zone_limit));
+               zone_limit = (struct ovs_zone_limit *)((u8 *)zone_limit +
+                               NLA_ALIGN(sizeof(*zone_limit)));
+       }
+
+       if (rem)
+               OVS_NLERR(true, "get zone limit has %d unknown bytes", rem);
+
+       return 0;
+}
+
+static int ovs_ct_limit_get_all_zone_limit(struct net *net,
+                                          struct ovs_ct_limit_info *info,
+                                          struct sk_buff *reply)
+{
+       struct ovs_ct_limit *ct_limit;
+       struct hlist_head *head;
+       int i, err = 0;
+
+       err = ovs_ct_limit_get_default_limit(info, reply);
+       if (err)
+               return err;
+
+       rcu_read_lock();
+       for (i = 0; i < CT_LIMIT_HASH_BUCKETS; ++i) {
+               head = &info->limits[i];
+               hlist_for_each_entry_rcu(ct_limit, head, hlist_node) {
+                       err = __ovs_ct_limit_get_zone_limit(net, info->data,
+                               ct_limit->zone, ct_limit->limit, reply);
+                       if (err)
+                               goto exit_err;
+               }
+       }
+
+exit_err:
+       rcu_read_unlock();
+       return err;
+}
+
+static int ovs_ct_limit_cmd_set(struct sk_buff *skb, struct genl_info *info)
+{
+       struct nlattr **a = info->attrs;
+       struct sk_buff *reply;
+       struct ovs_header *ovs_reply_header;
+       struct ovs_net *ovs_net = net_generic(sock_net(skb->sk), ovs_net_id);
+       struct ovs_ct_limit_info *ct_limit_info = ovs_net->ct_limit_info;
+       int err;
+
+       reply = ovs_ct_limit_cmd_reply_start(info, OVS_CT_LIMIT_CMD_SET,
+                                            &ovs_reply_header);
+       if (IS_ERR(reply))
+               return PTR_ERR(reply);
+
+       if (!a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT]) {
+               err = -EINVAL;
+               goto exit_err;
+       }
+
+       err = ovs_ct_limit_set_zone_limit(a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT],
+                                         ct_limit_info);
+       if (err)
+               goto exit_err;
+
+       static_branch_enable(&ovs_ct_limit_enabled);
+
+       genlmsg_end(reply, ovs_reply_header);
+       return genlmsg_reply(reply, info);
+
+exit_err:
+       nlmsg_free(reply);
+       return err;
+}
+
+static int ovs_ct_limit_cmd_del(struct sk_buff *skb, struct genl_info *info)
+{
+       struct nlattr **a = info->attrs;
+       struct sk_buff *reply;
+       struct ovs_header *ovs_reply_header;
+       struct ovs_net *ovs_net = net_generic(sock_net(skb->sk), ovs_net_id);
+       struct ovs_ct_limit_info *ct_limit_info = ovs_net->ct_limit_info;
+       int err;
+
+       reply = ovs_ct_limit_cmd_reply_start(info, OVS_CT_LIMIT_CMD_DEL,
+                                            &ovs_reply_header);
+       if (IS_ERR(reply))
+               return PTR_ERR(reply);
+
+       if (!a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT]) {
+               err = -EINVAL;
+               goto exit_err;
+       }
+
+       err = ovs_ct_limit_del_zone_limit(a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT],
+                                         ct_limit_info);
+       if (err)
+               goto exit_err;
+
+       genlmsg_end(reply, ovs_reply_header);
+       return genlmsg_reply(reply, info);
+
+exit_err:
+       nlmsg_free(reply);
+       return err;
+}
+
+static int ovs_ct_limit_cmd_get(struct sk_buff *skb, struct genl_info *info)
+{
+       struct nlattr **a = info->attrs;
+       struct nlattr *nla_reply;
+       struct sk_buff *reply;
+       struct ovs_header *ovs_reply_header;
+       struct net *net = sock_net(skb->sk);
+       struct ovs_net *ovs_net = net_generic(net, ovs_net_id);
+       struct ovs_ct_limit_info *ct_limit_info = ovs_net->ct_limit_info;
+       int err;
+
+       reply = ovs_ct_limit_cmd_reply_start(info, OVS_CT_LIMIT_CMD_GET,
+                                            &ovs_reply_header);
+       if (IS_ERR(reply))
+               return PTR_ERR(reply);
+
+       nla_reply = nla_nest_start(reply, OVS_CT_LIMIT_ATTR_ZONE_LIMIT);
+
+       if (a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT]) {
+               err = ovs_ct_limit_get_zone_limit(
+                       net, a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT], ct_limit_info,
+                       reply);
+               if (err)
+                       goto exit_err;
+       } else {
+               err = ovs_ct_limit_get_all_zone_limit(net, ct_limit_info,
+                                                     reply);
+               if (err)
+                       goto exit_err;
+       }
+
+       nla_nest_end(reply, nla_reply);
+       genlmsg_end(reply, ovs_reply_header);
+       return genlmsg_reply(reply, info);
+
+exit_err:
+       nlmsg_free(reply);
+       return err;
+}
+
+static struct genl_ops ct_limit_genl_ops[] = {
+       { .cmd = OVS_CT_LIMIT_CMD_SET,
+               .flags = GENL_ADMIN_PERM, /* Requires CAP_NET_ADMIN
+                                          * privilege. */
+               .policy = ct_limit_policy,
+               .doit = ovs_ct_limit_cmd_set,
+       },
+       { .cmd = OVS_CT_LIMIT_CMD_DEL,
+               .flags = GENL_ADMIN_PERM, /* Requires CAP_NET_ADMIN
+                                          * privilege. */
+               .policy = ct_limit_policy,
+               .doit = ovs_ct_limit_cmd_del,
+       },
+       { .cmd = OVS_CT_LIMIT_CMD_GET,
+               .flags = 0,               /* OK for unprivileged users. */
+               .policy = ct_limit_policy,
+               .doit = ovs_ct_limit_cmd_get,
+       },
+};
+
+static const struct genl_multicast_group ovs_ct_limit_multicast_group = {
+       .name = OVS_CT_LIMIT_MCGROUP,
+};
+
+struct genl_family dp_ct_limit_genl_family __ro_after_init = {
+       .hdrsize = sizeof(struct ovs_header),
+       .name = OVS_CT_LIMIT_FAMILY,
+       .version = OVS_CT_LIMIT_VERSION,
+       .maxattr = OVS_CT_LIMIT_ATTR_MAX,
+       .netnsok = true,
+       .parallel_ops = true,
+       .ops = ct_limit_genl_ops,
+       .n_ops = ARRAY_SIZE(ct_limit_genl_ops),
+       .mcgrps = &ovs_ct_limit_multicast_group,
+       .n_mcgrps = 1,
+       .module = THIS_MODULE,
+};
+#endif
+
+int ovs_ct_init(struct net *net)
 {
        unsigned int n_bits = sizeof(struct ovs_key_ct_labels) * BITS_PER_BYTE;
        struct ovs_net *ovs_net = net_generic(net, ovs_net_id);
@@ -1666,12 +2205,22 @@ void ovs_ct_init(struct net *net)
        } else {
                ovs_net->xt_label = true;
        }
+
+#if    IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT)
+       return ovs_ct_limit_init(net, ovs_net);
+#else
+       return 0;
+#endif
 }
 
 void ovs_ct_exit(struct net *net)
 {
        struct ovs_net *ovs_net = net_generic(net, ovs_net_id);
 
+#if    IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT)
+       ovs_ct_limit_exit(net, ovs_net);
+#endif
+
        if (ovs_net->xt_label)
                nf_connlabels_put(net);
 }