net: make sock diag per-namespace
[linux-2.6-block.git] / net / ipv4 / inet_diag.c
index 38064a285cca9dabaad6164ecb96b880c72241ab..570e61f9611fe9f62bf3513afc9a5abb365544b7 100644 (file)
@@ -272,16 +272,17 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *in_s
        int err;
        struct sock *sk;
        struct sk_buff *rep;
+       struct net *net = sock_net(in_skb->sk);
 
        err = -EINVAL;
        if (req->sdiag_family == AF_INET) {
-               sk = inet_lookup(&init_net, hashinfo, req->id.idiag_dst[0],
+               sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0],
                                 req->id.idiag_dport, req->id.idiag_src[0],
                                 req->id.idiag_sport, req->id.idiag_if);
        }
 #if IS_ENABLED(CONFIG_IPV6)
        else if (req->sdiag_family == AF_INET6) {
-               sk = inet6_lookup(&init_net, hashinfo,
+               sk = inet6_lookup(net, hashinfo,
                                  (struct in6_addr *)req->id.idiag_dst,
                                  req->id.idiag_dport,
                                  (struct in6_addr *)req->id.idiag_src,
@@ -317,7 +318,7 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *in_s
                nlmsg_free(rep);
                goto out;
        }
-       err = netlink_unicast(sock_diag_nlsk, rep, NETLINK_CB(in_skb).pid,
+       err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).pid,
                              MSG_DONTWAIT);
        if (err > 0)
                err = 0;
@@ -724,6 +725,7 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
 {
        int i, num;
        int s_i, s_num;
+       struct net *net = sock_net(skb->sk);
 
        s_i = cb->args[1];
        s_num = num = cb->args[2];
@@ -743,6 +745,9 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
                        sk_nulls_for_each(sk, node, &ilb->head) {
                                struct inet_sock *inet = inet_sk(sk);
 
+                               if (!net_eq(sock_net(sk), net))
+                                       continue;
+
                                if (num < s_num) {
                                        num++;
                                        continue;
@@ -813,6 +818,8 @@ skip_listen_ht:
                sk_nulls_for_each(sk, node, &head->chain) {
                        struct inet_sock *inet = inet_sk(sk);
 
+                       if (!net_eq(sock_net(sk), net))
+                               continue;
                        if (num < s_num)
                                goto next_normal;
                        if (!(r->idiag_states & (1 << sk->sk_state)))
@@ -839,6 +846,8 @@ next_normal:
 
                        inet_twsk_for_each(tw, node,
                                    &head->twchain) {
+                               if (!net_eq(twsk_net(tw), net))
+                                       continue;
 
                                if (num < s_num)
                                        goto next_dying;
@@ -943,6 +952,7 @@ static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
 static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
 {
        int hdrlen = sizeof(struct inet_diag_req);
+       struct net *net = sock_net(skb->sk);
 
        if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
            nlmsg_len(nlh) < hdrlen)
@@ -963,7 +973,7 @@ static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
                        struct netlink_dump_control c = {
                                .dump = inet_diag_dump_compat,
                        };
-                       return netlink_dump_start(sock_diag_nlsk, skb, nlh, &c);
+                       return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
                }
        }
 
@@ -973,6 +983,7 @@ static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
 static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
 {
        int hdrlen = sizeof(struct inet_diag_req_v2);
+       struct net *net = sock_net(skb->sk);
 
        if (nlmsg_len(h) < hdrlen)
                return -EINVAL;
@@ -991,7 +1002,7 @@ static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)
                        struct netlink_dump_control c = {
                                .dump = inet_diag_dump,
                        };
-                       return netlink_dump_start(sock_diag_nlsk, skb, h, &c);
+                       return netlink_dump_start(net->diag_nlsk, skb, h, &c);
                }
        }