Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net-next
[linux-2.6-block.git] / net / ipv6 / udp.c
index 56030d45823aa0949ab51ff56c28e60849e07f12..e2ecfb137297b931f05eb4d511f8e85cc7633336 100644 (file)
@@ -129,7 +129,7 @@ static void udp_v6_rehash(struct sock *sk)
 static int compute_score(struct sock *sk, struct net *net,
                         const struct in6_addr *saddr, __be16 sport,
                         const struct in6_addr *daddr, unsigned short hnum,
-                        int dif, bool exact_dif)
+                        int dif, int sdif, bool exact_dif)
 {
        int score;
        struct inet_sock *inet;
@@ -161,9 +161,13 @@ static int compute_score(struct sock *sk, struct net *net,
        }
 
        if (sk->sk_bound_dev_if || exact_dif) {
-               if (sk->sk_bound_dev_if != dif)
+               bool dev_match = (sk->sk_bound_dev_if == dif ||
+                                 sk->sk_bound_dev_if == sdif);
+
+               if (exact_dif && !dev_match)
                        return -1;
-               score++;
+               if (sk->sk_bound_dev_if && dev_match)
+                       score++;
        }
 
        if (sk->sk_incoming_cpu == raw_smp_processor_id())
@@ -175,9 +179,9 @@ static int compute_score(struct sock *sk, struct net *net,
 /* called with rcu_read_lock() */
 static struct sock *udp6_lib_lookup2(struct net *net,
                const struct in6_addr *saddr, __be16 sport,
-               const struct in6_addr *daddr, unsigned int hnum, int dif,
-               bool exact_dif, struct udp_hslot *hslot2,
-               struct sk_buff *skb)
+               const struct in6_addr *daddr, unsigned int hnum,
+               int dif, int sdif, bool exact_dif,
+               struct udp_hslot *hslot2, struct sk_buff *skb)
 {
        struct sock *sk, *result;
        int score, badness, matches = 0, reuseport = 0;
@@ -187,7 +191,7 @@ static struct sock *udp6_lib_lookup2(struct net *net,
        badness = -1;
        udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
                score = compute_score(sk, net, saddr, sport,
-                                     daddr, hnum, dif, exact_dif);
+                                     daddr, hnum, dif, sdif, exact_dif);
                if (score > badness) {
                        reuseport = sk->sk_reuseport;
                        if (reuseport) {
@@ -214,10 +218,10 @@ static struct sock *udp6_lib_lookup2(struct net *net,
 
 /* rcu_read_lock() must be held */
 struct sock *__udp6_lib_lookup(struct net *net,
-                                     const struct in6_addr *saddr, __be16 sport,
-                                     const struct in6_addr *daddr, __be16 dport,
-                                     int dif, struct udp_table *udptable,
-                                     struct sk_buff *skb)
+                              const struct in6_addr *saddr, __be16 sport,
+                              const struct in6_addr *daddr, __be16 dport,
+                              int dif, int sdif, struct udp_table *udptable,
+                              struct sk_buff *skb)
 {
        struct sock *sk, *result;
        unsigned short hnum = ntohs(dport);
@@ -235,7 +239,7 @@ struct sock *__udp6_lib_lookup(struct net *net,
                        goto begin;
 
                result = udp6_lib_lookup2(net, saddr, sport,
-                                         daddr, hnum, dif, exact_dif,
+                                         daddr, hnum, dif, sdif, exact_dif,
                                          hslot2, skb);
                if (!result) {
                        unsigned int old_slot2 = slot2;
@@ -250,7 +254,7 @@ struct sock *__udp6_lib_lookup(struct net *net,
                                goto begin;
 
                        result = udp6_lib_lookup2(net, saddr, sport,
-                                                 daddr, hnum, dif,
+                                                 daddr, hnum, dif, sdif,
                                                  exact_dif, hslot2,
                                                  skb);
                }
@@ -261,7 +265,7 @@ begin:
        badness = -1;
        sk_for_each_rcu(sk, &hslot->head) {
                score = compute_score(sk, net, saddr, sport, daddr, hnum, dif,
-                                     exact_dif);
+                                     sdif, exact_dif);
                if (score > badness) {
                        reuseport = sk->sk_reuseport;
                        if (reuseport) {
@@ -294,7 +298,7 @@ static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb,
 
        return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport,
                                 &iph->daddr, dport, inet6_iif(skb),
-                                udptable, skb);
+                                inet6_sdif(skb), udptable, skb);
 }
 
 struct sock *udp6_lib_lookup_skb(struct sk_buff *skb,
@@ -304,7 +308,7 @@ struct sock *udp6_lib_lookup_skb(struct sk_buff *skb,
 
        return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport,
                                 &iph->daddr, dport, inet6_iif(skb),
-                                &udp_table, skb);
+                                inet6_sdif(skb), &udp_table, skb);
 }
 EXPORT_SYMBOL_GPL(udp6_lib_lookup_skb);
 
@@ -320,7 +324,7 @@ struct sock *udp6_lib_lookup(struct net *net, const struct in6_addr *saddr, __be
        struct sock *sk;
 
        sk =  __udp6_lib_lookup(net, saddr, sport, daddr, dport,
-                               dif, &udp_table, NULL);
+                               dif, 0, &udp_table, NULL);
        if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
                sk = NULL;
        return sk;
@@ -502,7 +506,7 @@ void __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
        struct net *net = dev_net(skb->dev);
 
        sk = __udp6_lib_lookup(net, daddr, uh->dest, saddr, uh->source,
-                              inet6_iif(skb), udptable, skb);
+                              inet6_iif(skb), 0, udptable, skb);
        if (!sk) {
                __ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
                                  ICMP6_MIB_INERRORS);
@@ -902,7 +906,7 @@ discard:
 static struct sock *__udp6_lib_demux_lookup(struct net *net,
                        __be16 loc_port, const struct in6_addr *loc_addr,
                        __be16 rmt_port, const struct in6_addr *rmt_addr,
-                       int dif)
+                       int dif, int sdif)
 {
        unsigned short hnum = ntohs(loc_port);
        unsigned int hash2 = udp6_portaddr_hash(net, loc_addr, hnum);
@@ -913,7 +917,7 @@ static struct sock *__udp6_lib_demux_lookup(struct net *net,
 
        udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
                if (sk->sk_state == TCP_ESTABLISHED &&
-                   INET6_MATCH(sk, net, rmt_addr, loc_addr, ports, dif))
+                   INET6_MATCH(sk, net, rmt_addr, loc_addr, ports, dif, sdif))
                        return sk;
                /* Only check first socket in chain */
                break;
@@ -928,6 +932,7 @@ static void udp_v6_early_demux(struct sk_buff *skb)
        struct sock *sk;
        struct dst_entry *dst;
        int dif = skb->dev->ifindex;
+       int sdif = inet6_sdif(skb);
 
        if (!pskb_may_pull(skb, skb_transport_offset(skb) +
            sizeof(struct udphdr)))
@@ -939,7 +944,7 @@ static void udp_v6_early_demux(struct sk_buff *skb)
                sk = __udp6_lib_demux_lookup(net, uh->dest,
                                             &ipv6_hdr(skb)->daddr,
                                             uh->source, &ipv6_hdr(skb)->saddr,
-                                            dif);
+                                            dif, sdif);
        else
                return;
 
@@ -1475,6 +1480,9 @@ int compat_udpv6_getsockopt(struct sock *sk, int level, int optname,
 }
 #endif
 
+/* thinking of making this const? Don't.
+ * early_demux can change based on sysctl.
+ */
 static struct inet6_protocol udpv6_protocol = {
        .early_demux    =       udp_v6_early_demux,
        .early_demux_handler =  udp_v6_early_demux,