ipv6: call dst_hold_safe() properly
[linux-block.git] / net / ipv6 / route.c
index 524a76b5206e2e5742aab554010d13b9edb4cc9e..c52c5190888186c6d7c1937583ecd4b4ce8ce2c4 100644 (file)
@@ -354,7 +354,7 @@ static struct rt6_info *__ip6_dst_alloc(struct net *net,
                                        int flags)
 {
        struct rt6_info *rt = dst_alloc(&net->ipv6.ip6_dst_ops, dev,
-                                       0, DST_OBSOLETE_FORCE_CHK, flags);
+                                       1, DST_OBSOLETE_FORCE_CHK, flags);
 
        if (rt)
                rt6_info_init(rt);
@@ -381,7 +381,9 @@ struct rt6_info *ip6_dst_alloc(struct net *net,
                                *p =  NULL;
                        }
                } else {
-                       dst_destroy((struct dst_entry *)rt);
+                       dst_release(&rt->dst);
+                       if (!(flags & DST_NOCACHE))
+                               dst_destroy((struct dst_entry *)rt);
                        return NULL;
                }
        }
@@ -932,9 +934,9 @@ struct rt6_info *rt6_lookup(struct net *net, const struct in6_addr *daddr,
 EXPORT_SYMBOL(rt6_lookup);
 
 /* ip6_ins_rt is called with FREE table->tb6_lock.
  It takes new route entry, the addition fails by any reason the
-   route is freed. In any case, if caller does not hold it, it may
  be destroyed.
* It takes new route entry, the addition fails by any reason the
+ * route is released.
* Caller must hold dst before calling it.
  */
 
 static int __ip6_ins_rt(struct rt6_info *rt, struct nl_info *info,
@@ -957,6 +959,8 @@ int ip6_ins_rt(struct rt6_info *rt)
        struct nl_info info = { .nl_net = dev_net(rt->dst.dev), };
        struct mx6_config mxc = { .mx = NULL, };
 
+       /* Hold dst to account for the reference from the fib6 tree */
+       dst_hold(&rt->dst);
        return __ip6_ins_rt(rt, &info, &mxc, NULL);
 }
 
@@ -1049,6 +1053,7 @@ static struct rt6_info *rt6_make_pcpu_route(struct rt6_info *rt)
                prev = cmpxchg(p, NULL, pcpu_rt);
                if (prev) {
                        /* If someone did it before us, return prev instead */
+                       dst_release(&pcpu_rt->dst);
                        dst_destroy(&pcpu_rt->dst);
                        pcpu_rt = prev;
                }
@@ -1059,6 +1064,7 @@ static struct rt6_info *rt6_make_pcpu_route(struct rt6_info *rt)
                 * since rt is going away anyway.  The next
                 * dst_check() will trigger a re-lookup.
                 */
+               dst_release(&pcpu_rt->dst);
                dst_destroy(&pcpu_rt->dst);
                pcpu_rt = rt;
        }
@@ -1129,12 +1135,15 @@ redo_rt6_select:
                uncached_rt = ip6_rt_cache_alloc(rt, &fl6->daddr, NULL);
                dst_release(&rt->dst);
 
-               if (uncached_rt)
+               if (uncached_rt) {
+                       /* Uncached_rt's refcnt is taken during ip6_rt_cache_alloc()
+                        * No need for another dst_hold()
+                        */
                        rt6_uncached_list_add(uncached_rt);
-               else
+               } else {
                        uncached_rt = net->ipv6.ip6_null_entry;
-
-               dst_hold(&uncached_rt->dst);
+                       dst_hold(&uncached_rt->dst);
+               }
 
                trace_fib6_table_lookup(net, uncached_rt, table->tb6_id, fl6);
                return uncached_rt;
@@ -1245,9 +1254,12 @@ EXPORT_SYMBOL_GPL(ip6_route_output_flags);
 struct dst_entry *ip6_blackhole_route(struct net *net, struct dst_entry *dst_orig)
 {
        struct rt6_info *rt, *ort = (struct rt6_info *) dst_orig;
+       struct net_device *loopback_dev = net->loopback_dev;
        struct dst_entry *new = NULL;
 
-       rt = dst_alloc(&ip6_dst_blackhole_ops, ort->dst.dev, 1, DST_OBSOLETE_NONE, 0);
+
+       rt = dst_alloc(&ip6_dst_blackhole_ops, loopback_dev, 1,
+                      DST_OBSOLETE_NONE, 0);
        if (rt) {
                rt6_info_init(rt);
 
@@ -1257,10 +1269,8 @@ struct dst_entry *ip6_blackhole_route(struct net *net, struct dst_entry *dst_ori
                new->output = dst_discard_out;
 
                dst_copy_metrics(new, &ort->dst);
-               rt->rt6i_idev = ort->rt6i_idev;
-               if (rt->rt6i_idev)
-                       in6_dev_hold(rt->rt6i_idev);
 
+               rt->rt6i_idev = in6_dev_get(loopback_dev);
                rt->rt6i_gateway = ort->rt6i_gateway;
                rt->rt6i_flags = ort->rt6i_flags & ~RTF_PCPU;
                rt->rt6i_metric = 0;
@@ -1356,8 +1366,8 @@ static void ip6_link_failure(struct sk_buff *skb)
        rt = (struct rt6_info *) skb_dst(skb);
        if (rt) {
                if (rt->rt6i_flags & RTF_CACHE) {
-                       dst_hold(&rt->dst);
-                       ip6_del_rt(rt);
+                       if (dst_hold_safe(&rt->dst))
+                               ip6_del_rt(rt);
                } else if (rt->rt6i_node && (rt->rt6i_flags & RTF_DEFAULT)) {
                        rt->rt6i_node->fn_sernum = -1;
                }
@@ -1421,6 +1431,10 @@ static void __ip6_rt_update_pmtu(struct dst_entry *dst, const struct sock *sk,
                         * invalidate the sk->sk_dst_cache.
                         */
                        ip6_ins_rt(nrt6);
+                       /* Release the reference taken in
+                        * ip6_rt_cache_alloc()
+                        */
+                       dst_release(&nrt6->dst);
                }
        }
 }
@@ -1672,7 +1686,6 @@ struct dst_entry *icmp6_dst_alloc(struct net_device *dev,
 
        rt->dst.flags |= DST_HOST;
        rt->dst.output  = ip6_output;
-       atomic_set(&rt->dst.__refcnt, 1);
        rt->rt6i_gateway  = fl6->daddr;
        rt->rt6i_dst.addr = fl6->daddr;
        rt->rt6i_dst.plen = 128;
@@ -1939,7 +1952,7 @@ static struct rt6_info *ip6_route_info_create(struct fib6_config *cfg,
 
                err = lwtunnel_build_state(cfg->fc_encap_type,
                                           cfg->fc_encap, AF_INET6, cfg,
-                                          &lwtstate);
+                                          &lwtstate, extack);
                if (err)
                        goto out;
                rt->dst.lwtstate = lwtstate_get(lwtstate);
@@ -2129,8 +2142,10 @@ out:
                dev_put(dev);
        if (idev)
                in6_dev_put(idev);
-       if (rt)
+       if (rt) {
+               dst_release(&rt->dst);
                dst_free(&rt->dst);
+       }
 
        return ERR_PTR(err);
 }
@@ -2159,8 +2174,10 @@ int ip6_route_add(struct fib6_config *cfg,
 
        return err;
 out:
-       if (rt)
+       if (rt) {
+               dst_release(&rt->dst);
                dst_free(&rt->dst);
+       }
 
        return err;
 }
@@ -2397,7 +2414,7 @@ static void rt6_do_redirect(struct dst_entry *dst, struct sock *sk, struct sk_bu
        nrt->rt6i_gateway = *(struct in6_addr *)neigh->primary_key;
 
        if (ip6_ins_rt(nrt))
-               goto out;
+               goto out_release;
 
        netevent.old = &rt->dst;
        netevent.new = &nrt->dst;
@@ -2410,6 +2427,12 @@ static void rt6_do_redirect(struct dst_entry *dst, struct sock *sk, struct sk_bu
                ip6_del_rt(rt);
        }
 
+out_release:
+       /* Release the reference taken in
+        * ip6_rt_cache_alloc()
+        */
+       dst_release(&nrt->dst);
+
 out:
        neigh_release(neigh);
 }
@@ -2759,8 +2782,6 @@ struct rt6_info *addrconf_dst_alloc(struct inet6_dev *idev,
        rt->rt6i_table = fib6_get_table(net, tb_id);
        rt->dst.flags |= DST_NOCACHE;
 
-       atomic_set(&rt->dst.__refcnt, 1);
-
        return rt;
 }
 
@@ -2832,6 +2853,7 @@ static int fib6_ifdown(struct rt6_info *rt, void *arg)
        if ((rt->dst.dev == dev || !dev) &&
            rt != adn->net->ipv6.ip6_null_entry &&
            (rt->rt6i_nsiblings == 0 ||
+            (dev && netdev_unregistering(dev)) ||
             !rt->rt6i_idev->cnf.ignore_routes_with_linkdown))
                return -1;
 
@@ -3184,6 +3206,7 @@ static int ip6_route_multipath_add(struct fib6_config *cfg,
 
                err = ip6_route_info_append(&rt6_nh_list, rt, &r_cfg);
                if (err) {
+                       dst_release(&rt->dst);
                        dst_free(&rt->dst);
                        goto cleanup;
                }
@@ -3247,8 +3270,10 @@ add_errout:
 
 cleanup:
        list_for_each_entry_safe(nh, nh_safe, &rt6_nh_list, next) {
-               if (nh->rt6_info)
+               if (nh->rt6_info) {
+                       dst_release(&nh->rt6_info->dst);
                        dst_free(&nh->rt6_info->dst);
+               }
                kfree(nh->mxc.mx);
                list_del(&nh->next);
                kfree(nh);