sctp: fix an use-after-free issue in sctp_sock_dump
[linux-block.git] / net / sctp / socket.c
index 1b00a1e09b93e4106a38b4f6d45df7175e27598b..d4730ada7f3233367be7a0e3bb10e286a25602c8 100644 (file)
@@ -4658,29 +4658,39 @@ int sctp_transport_lookup_process(int (*cb)(struct sctp_transport *, void *),
 EXPORT_SYMBOL_GPL(sctp_transport_lookup_process);
 
 int sctp_for_each_transport(int (*cb)(struct sctp_transport *, void *),
-                           struct net *net, int pos, void *p) {
+                           int (*cb_done)(struct sctp_transport *, void *),
+                           struct net *net, int *pos, void *p) {
        struct rhashtable_iter hti;
-       void *obj;
-       int err;
-
-       err = sctp_transport_walk_start(&hti);
-       if (err)
-               return err;
+       struct sctp_transport *tsp;
+       int ret;
 
-       obj = sctp_transport_get_idx(net, &hti, pos + 1);
-       for (; !IS_ERR_OR_NULL(obj); obj = sctp_transport_get_next(net, &hti)) {
-               struct sctp_transport *transport = obj;
+again:
+       ret = sctp_transport_walk_start(&hti);
+       if (ret)
+               return ret;
 
-               if (!sctp_transport_hold(transport))
+       tsp = sctp_transport_get_idx(net, &hti, *pos + 1);
+       for (; !IS_ERR_OR_NULL(tsp); tsp = sctp_transport_get_next(net, &hti)) {
+               if (!sctp_transport_hold(tsp))
                        continue;
-               err = cb(transport, p);
-               sctp_transport_put(transport);
-               if (err)
+               ret = cb(tsp, p);
+               if (ret)
                        break;
+               (*pos)++;
+               sctp_transport_put(tsp);
        }
        sctp_transport_walk_stop(&hti);
 
-       return err;
+       if (ret) {
+               if (cb_done && !cb_done(tsp, p)) {
+                       (*pos)++;
+                       sctp_transport_put(tsp);
+                       goto again;
+               }
+               sctp_transport_put(tsp);
+       }
+
+       return ret;
 }
 EXPORT_SYMBOL_GPL(sctp_for_each_transport);