net: reintroduce missing rcu_assign_pointer() calls
[linux-2.6-block.git] / net / netfilter / nf_conntrack_netlink.c
index ef21b221f0363a8700fbbc3aadd316423c4f9d66..2a4834b83332afa2ebf87b37102cd2a9d6dcd8d1 100644 (file)
@@ -135,7 +135,7 @@ nla_put_failure:
 static inline int
 ctnetlink_dump_timeout(struct sk_buff *skb, const struct nf_conn *ct)
 {
-       long timeout = (ct->timeout.expires - jiffies) / HZ;
+       long timeout = ((long)ct->timeout.expires - (long)jiffies) / HZ;
 
        if (timeout < 0)
                timeout = 0;
@@ -203,25 +203,18 @@ nla_put_failure:
 }
 
 static int
-ctnetlink_dump_counters(struct sk_buff *skb, const struct nf_conn *ct,
-                       enum ip_conntrack_dir dir)
+dump_counters(struct sk_buff *skb, u64 pkts, u64 bytes,
+             enum ip_conntrack_dir dir)
 {
        enum ctattr_type type = dir ? CTA_COUNTERS_REPLY: CTA_COUNTERS_ORIG;
        struct nlattr *nest_count;
-       const struct nf_conn_counter *acct;
-
-       acct = nf_conn_acct_find(ct);
-       if (!acct)
-               return 0;
 
        nest_count = nla_nest_start(skb, type | NLA_F_NESTED);
        if (!nest_count)
                goto nla_put_failure;
 
-       NLA_PUT_BE64(skb, CTA_COUNTERS_PACKETS,
-                    cpu_to_be64(acct[dir].packets));
-       NLA_PUT_BE64(skb, CTA_COUNTERS_BYTES,
-                    cpu_to_be64(acct[dir].bytes));
+       NLA_PUT_BE64(skb, CTA_COUNTERS_PACKETS, cpu_to_be64(pkts));
+       NLA_PUT_BE64(skb, CTA_COUNTERS_BYTES, cpu_to_be64(bytes));
 
        nla_nest_end(skb, nest_count);
 
@@ -231,6 +224,27 @@ nla_put_failure:
        return -1;
 }
 
+static int
+ctnetlink_dump_counters(struct sk_buff *skb, const struct nf_conn *ct,
+                       enum ip_conntrack_dir dir, int type)
+{
+       struct nf_conn_counter *acct;
+       u64 pkts, bytes;
+
+       acct = nf_conn_acct_find(ct);
+       if (!acct)
+               return 0;
+
+       if (type == IPCTNL_MSG_CT_GET_CTRZERO) {
+               pkts = atomic64_xchg(&acct[dir].packets, 0);
+               bytes = atomic64_xchg(&acct[dir].bytes, 0);
+       } else {
+               pkts = atomic64_read(&acct[dir].packets);
+               bytes = atomic64_read(&acct[dir].bytes);
+       }
+       return dump_counters(skb, pkts, bytes, dir);
+}
+
 static int
 ctnetlink_dump_timestamp(struct sk_buff *skb, const struct nf_conn *ct)
 {
@@ -393,15 +407,15 @@ nla_put_failure:
 }
 
 static int
-ctnetlink_fill_info(struct sk_buff *skb, u32 pid, u32 seq,
-                   int event, struct nf_conn *ct)
+ctnetlink_fill_info(struct sk_buff *skb, u32 pid, u32 seq, u32 type,
+                   struct nf_conn *ct)
 {
        struct nlmsghdr *nlh;
        struct nfgenmsg *nfmsg;
        struct nlattr *nest_parms;
-       unsigned int flags = pid ? NLM_F_MULTI : 0;
+       unsigned int flags = pid ? NLM_F_MULTI : 0, event;
 
-       event |= NFNL_SUBSYS_CTNETLINK << 8;
+       event = (NFNL_SUBSYS_CTNETLINK << 8 | IPCTNL_MSG_CT_NEW);
        nlh = nlmsg_put(skb, pid, seq, event, sizeof(*nfmsg), flags);
        if (nlh == NULL)
                goto nlmsg_failure;
@@ -430,8 +444,8 @@ ctnetlink_fill_info(struct sk_buff *skb, u32 pid, u32 seq,
 
        if (ctnetlink_dump_status(skb, ct) < 0 ||
            ctnetlink_dump_timeout(skb, ct) < 0 ||
-           ctnetlink_dump_counters(skb, ct, IP_CT_DIR_ORIGINAL) < 0 ||
-           ctnetlink_dump_counters(skb, ct, IP_CT_DIR_REPLY) < 0 ||
+           ctnetlink_dump_counters(skb, ct, IP_CT_DIR_ORIGINAL, type) < 0 ||
+           ctnetlink_dump_counters(skb, ct, IP_CT_DIR_REPLY, type) < 0 ||
            ctnetlink_dump_timestamp(skb, ct) < 0 ||
            ctnetlink_dump_protoinfo(skb, ct) < 0 ||
            ctnetlink_dump_helpinfo(skb, ct) < 0 ||
@@ -612,8 +626,10 @@ ctnetlink_conntrack_event(unsigned int events, struct nf_ct_event *item)
                goto nla_put_failure;
 
        if (events & (1 << IPCT_DESTROY)) {
-               if (ctnetlink_dump_counters(skb, ct, IP_CT_DIR_ORIGINAL) < 0 ||
-                   ctnetlink_dump_counters(skb, ct, IP_CT_DIR_REPLY) < 0 ||
+               if (ctnetlink_dump_counters(skb, ct,
+                                           IP_CT_DIR_ORIGINAL, type) < 0 ||
+                   ctnetlink_dump_counters(skb, ct,
+                                           IP_CT_DIR_REPLY, type) < 0 ||
                    ctnetlink_dump_timestamp(skb, ct) < 0)
                        goto nla_put_failure;
        } else {
@@ -709,20 +725,13 @@ restart:
                        }
                        if (ctnetlink_fill_info(skb, NETLINK_CB(cb->skb).pid,
                                                cb->nlh->nlmsg_seq,
-                                               IPCTNL_MSG_CT_NEW, ct) < 0) {
+                                               NFNL_MSG_TYPE(
+                                                       cb->nlh->nlmsg_type),
+                                               ct) < 0) {
                                nf_conntrack_get(&ct->ct_general);
                                cb->args[1] = (unsigned long)ct;
                                goto out;
                        }
-
-                       if (NFNL_MSG_TYPE(cb->nlh->nlmsg_type) ==
-                                               IPCTNL_MSG_CT_GET_CTRZERO) {
-                               struct nf_conn_counter *acct;
-
-                               acct = nf_conn_acct_find(ct);
-                               if (acct)
-                                       memset(acct, 0, sizeof(struct nf_conn_counter[IP_CT_DIR_MAX]));
-                       }
                }
                if (cb->args[1]) {
                        cb->args[1] = 0;
@@ -1001,7 +1010,7 @@ ctnetlink_get_conntrack(struct sock *ctnl, struct sk_buff *skb,
 
        rcu_read_lock();
        err = ctnetlink_fill_info(skb2, NETLINK_CB(skb).pid, nlh->nlmsg_seq,
-                                 IPCTNL_MSG_CT_NEW, ct);
+                                 NFNL_MSG_TYPE(nlh->nlmsg_type), ct);
        rcu_read_unlock();
        nf_ct_put(ct);
        if (err <= 0)
@@ -1087,14 +1096,14 @@ ctnetlink_change_nat(struct nf_conn *ct, const struct nlattr * const cda[])
 
        if (cda[CTA_NAT_DST]) {
                ret = ctnetlink_parse_nat_setup(ct,
-                                               IP_NAT_MANIP_DST,
+                                               NF_NAT_MANIP_DST,
                                                cda[CTA_NAT_DST]);
                if (ret < 0)
                        return ret;
        }
        if (cda[CTA_NAT_SRC]) {
                ret = ctnetlink_parse_nat_setup(ct,
-                                               IP_NAT_MANIP_SRC,
+                                               NF_NAT_MANIP_SRC,
                                                cda[CTA_NAT_SRC]);
                if (ret < 0)
                        return ret;
@@ -1163,7 +1172,7 @@ ctnetlink_change_helper(struct nf_conn *ct, const struct nlattr * const cda[])
                return -EOPNOTSUPP;
        }
 
-       RCU_INIT_POINTER(help->helper, helper);
+       rcu_assign_pointer(help->helper, helper);
 
        return 0;
 }
@@ -1358,12 +1367,15 @@ ctnetlink_create_conntrack(struct net *net, u16 zone,
                                                    nf_ct_protonum(ct));
                if (helper == NULL) {
                        rcu_read_unlock();
+                       spin_unlock_bh(&nf_conntrack_lock);
 #ifdef CONFIG_MODULES
                        if (request_module("nfct-helper-%s", helpname) < 0) {
+                               spin_lock_bh(&nf_conntrack_lock);
                                err = -EOPNOTSUPP;
                                goto err1;
                        }
 
+                       spin_lock_bh(&nf_conntrack_lock);
                        rcu_read_lock();
                        helper = __nf_conntrack_helper_find(helpname,
                                                            nf_ct_l3num(ct),
@@ -1638,7 +1650,7 @@ ctnetlink_exp_dump_expect(struct sk_buff *skb,
                          const struct nf_conntrack_expect *exp)
 {
        struct nf_conn *master = exp->master;
-       long timeout = (exp->timeout.expires - jiffies) / HZ;
+       long timeout = ((long)exp->timeout.expires - (long)jiffies) / HZ;
        struct nf_conn_help *help;
 
        if (timeout < 0)
@@ -1847,7 +1859,9 @@ ctnetlink_get_expect(struct sock *ctnl, struct sk_buff *skb,
        if (err < 0)
                return err;
 
-       if (cda[CTA_EXPECT_MASTER])
+       if (cda[CTA_EXPECT_TUPLE])
+               err = ctnetlink_parse_tuple(cda, &tuple, CTA_EXPECT_TUPLE, u3);
+       else if (cda[CTA_EXPECT_MASTER])
                err = ctnetlink_parse_tuple(cda, &tuple, CTA_EXPECT_MASTER, u3);
        else
                return -EINVAL;
@@ -1869,25 +1883,30 @@ ctnetlink_get_expect(struct sock *ctnl, struct sk_buff *skb,
 
        err = -ENOMEM;
        skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
-       if (skb2 == NULL)
+       if (skb2 == NULL) {
+               nf_ct_expect_put(exp);
                goto out;
+       }
 
        rcu_read_lock();
        err = ctnetlink_exp_fill_info(skb2, NETLINK_CB(skb).pid,
                                      nlh->nlmsg_seq, IPCTNL_MSG_EXP_NEW, exp);
        rcu_read_unlock();
+       nf_ct_expect_put(exp);
        if (err <= 0)
                goto free;
 
-       nf_ct_expect_put(exp);
+       err = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
+       if (err < 0)
+               goto out;
 
-       return netlink_unicast(ctnl, skb2, NETLINK_CB(skb).pid, MSG_DONTWAIT);
+       return 0;
 
 free:
        kfree_skb(skb2);
 out:
-       nf_ct_expect_put(exp);
-       return err;
+       /* this avoids a loop in nfnetlink. */
+       return err == -EAGAIN ? -ENOBUFS : err;
 }
 
 static int
@@ -2023,6 +2042,10 @@ ctnetlink_create_expect(struct net *net, u16 zone,
        }
        help = nfct_help(ct);
        if (!help) {
+               err = -EOPNOTSUPP;
+               goto out;
+       }
+       if (test_bit(IPS_USERSPACE_HELPER_BIT, &ct->status)) {
                if (!cda[CTA_EXPECT_TIMEOUT]) {
                        err = -EINVAL;
                        goto out;
@@ -2247,7 +2270,6 @@ static void __exit ctnetlink_exit(void)
 {
        pr_info("ctnetlink: unregistering from nfnetlink.\n");
 
-       nf_ct_remove_userspace_expectations();
        unregister_pernet_subsys(&ctnetlink_net_ops);
        nfnetlink_subsys_unregister(&ctnl_exp_subsys);
        nfnetlink_subsys_unregister(&ctnl_subsys);