rds: tcp: use rds_destroy_pending() to synchronize netns/module teardown and rds...
authorSowmini Varadhan <sowmini.varadhan@oracle.com>
Sat, 3 Feb 2018 12:26:51 +0000 (04:26 -0800)
committerDavid S. Miller <davem@davemloft.net>
Thu, 8 Feb 2018 20:23:52 +0000 (15:23 -0500)
An rds_connection can get added during netns deletion between lines 528
and 529 of

  506 static void rds_tcp_kill_sock(struct net *net)
  :
  /* code to pull out all the rds_connections that should be destroyed */
  :
  528         spin_unlock_irq(&rds_tcp_conn_lock);
  529         list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node)
  530                 rds_conn_destroy(tc->t_cpath->cp_conn);

Such an rds_connection would miss out the rds_conn_destroy()
loop (that cancels all pending work) and (if it was scheduled
after netns deletion) could trigger the use-after-free.

A similar race-window exists for the module unload path
in rds_tcp_exit -> rds_tcp_destroy_conns

Concurrency with netns deletion (rds_tcp_kill_sock()) must be handled
by checking check_net() before enqueuing new work or adding new
connections.

Concurrency with module-unload is handled by maintaining a module
specific flag that is set at the start of the module exit function,
and must be checked before enqueuing new work or adding new connections.

This commit refactors existing RDS_DESTROY_PENDING checks added by
commit 3db6e0d172c9 ("rds: use RCU to synchronize work-enqueue with
connection teardown") and consolidates all the concurrency checks
listed above into the function rds_destroy_pending().

Signed-off-by: Sowmini Varadhan <sowmini.varadhan@oracle.com>
Acked-by: Santosh Shilimkar <santosh.shilimkar@oracle.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/rds/cong.c
net/rds/connection.c
net/rds/ib.c
net/rds/ib_cm.c
net/rds/rds.h
net/rds/send.c
net/rds/tcp.c
net/rds/tcp_connect.c
net/rds/tcp_recv.c
net/rds/tcp_send.c
net/rds/threads.c

index 8d19fd25dce36db3eb7c5b4252da325cc0b74913..63da9d2f142d19016e1094f7f4136781da8d6037 100644 (file)
@@ -223,7 +223,7 @@ void rds_cong_queue_updates(struct rds_cong_map *map)
 
                rcu_read_lock();
                if (!test_and_set_bit(0, &conn->c_map_queued) &&
-                   !test_bit(RDS_DESTROY_PENDING, &cp->cp_flags)) {
+                   !rds_destroy_pending(cp->cp_conn)) {
                        rds_stats_inc(s_cong_update_queued);
                        /* We cannot inline the call to rds_send_xmit() here
                         * for two reasons (both pertaining to a TCP transport):
index b10c0ef36d8d458d808d054670dd2833b0e6a322..94e190febfddd0670c0a16b35501acda16df548d 100644 (file)
@@ -220,8 +220,13 @@ static struct rds_connection *__rds_conn_create(struct net *net,
                                     is_outgoing);
                conn->c_path[i].cp_index = i;
        }
-       ret = trans->conn_alloc(conn, gfp);
+       rcu_read_lock();
+       if (rds_destroy_pending(conn))
+               ret = -ENETDOWN;
+       else
+               ret = trans->conn_alloc(conn, gfp);
        if (ret) {
+               rcu_read_unlock();
                kfree(conn->c_path);
                kmem_cache_free(rds_conn_slab, conn);
                conn = ERR_PTR(ret);
@@ -283,6 +288,7 @@ static struct rds_connection *__rds_conn_create(struct net *net,
                }
        }
        spin_unlock_irqrestore(&rds_conn_lock, flags);
+       rcu_read_unlock();
 
 out:
        return conn;
@@ -382,13 +388,10 @@ static void rds_conn_path_destroy(struct rds_conn_path *cp)
 {
        struct rds_message *rm, *rtmp;
 
-       set_bit(RDS_DESTROY_PENDING, &cp->cp_flags);
-
        if (!cp->cp_transport_data)
                return;
 
        /* make sure lingering queued work won't try to ref the conn */
-       synchronize_rcu();
        cancel_delayed_work_sync(&cp->cp_send_w);
        cancel_delayed_work_sync(&cp->cp_recv_w);
 
@@ -691,7 +694,7 @@ void rds_conn_path_drop(struct rds_conn_path *cp, bool destroy)
        atomic_set(&cp->cp_state, RDS_CONN_ERROR);
 
        rcu_read_lock();
-       if (!destroy && test_bit(RDS_DESTROY_PENDING, &cp->cp_flags)) {
+       if (!destroy && rds_destroy_pending(cp->cp_conn)) {
                rcu_read_unlock();
                return;
        }
@@ -714,7 +717,7 @@ EXPORT_SYMBOL_GPL(rds_conn_drop);
 void rds_conn_path_connect_if_down(struct rds_conn_path *cp)
 {
        rcu_read_lock();
-       if (test_bit(RDS_DESTROY_PENDING, &cp->cp_flags)) {
+       if (rds_destroy_pending(cp->cp_conn)) {
                rcu_read_unlock();
                return;
        }
index ff0c98096af1c62817bb0f5c2aa757d9288f93f6..50a88f3e7e393401db0b143982571f6ad50cd999 100644 (file)
@@ -48,6 +48,7 @@
 static unsigned int rds_ib_mr_1m_pool_size = RDS_MR_1M_POOL_SIZE;
 static unsigned int rds_ib_mr_8k_pool_size = RDS_MR_8K_POOL_SIZE;
 unsigned int rds_ib_retry_count = RDS_IB_DEFAULT_RETRY_COUNT;
+static atomic_t rds_ib_unloading;
 
 module_param(rds_ib_mr_1m_pool_size, int, 0444);
 MODULE_PARM_DESC(rds_ib_mr_1m_pool_size, " Max number of 1M mr per HCA");
@@ -378,8 +379,23 @@ static void rds_ib_unregister_client(void)
        flush_workqueue(rds_wq);
 }
 
+static void rds_ib_set_unloading(void)
+{
+       atomic_set(&rds_ib_unloading, 1);
+}
+
+static bool rds_ib_is_unloading(struct rds_connection *conn)
+{
+       struct rds_conn_path *cp = &conn->c_path[0];
+
+       return (test_bit(RDS_DESTROY_PENDING, &cp->cp_flags) ||
+               atomic_read(&rds_ib_unloading) != 0);
+}
+
 void rds_ib_exit(void)
 {
+       rds_ib_set_unloading();
+       synchronize_rcu();
        rds_info_deregister_func(RDS_INFO_IB_CONNECTIONS, rds_ib_ic_info);
        rds_ib_unregister_client();
        rds_ib_destroy_nodev_conns();
@@ -413,6 +429,7 @@ struct rds_transport rds_ib_transport = {
        .flush_mrs              = rds_ib_flush_mrs,
        .t_owner                = THIS_MODULE,
        .t_name                 = "infiniband",
+       .t_unloading            = rds_ib_is_unloading,
        .t_type                 = RDS_TRANS_IB
 };
 
index 80fb6f63e768d3461c47533615c875526bb8bab9..eea1d8611b205d771c04cdb12c7c35dc2db403ff 100644 (file)
@@ -117,6 +117,7 @@ void rds_ib_cm_connect_complete(struct rds_connection *conn, struct rdma_cm_even
                          &conn->c_laddr, &conn->c_faddr,
                          RDS_PROTOCOL_MAJOR(conn->c_version),
                          RDS_PROTOCOL_MINOR(conn->c_version));
+               set_bit(RDS_DESTROY_PENDING, &conn->c_path[0].cp_flags);
                rds_conn_destroy(conn);
                return;
        } else {
index 374ae83b60d48e5b7974aa6c5990daa6e353cd57..7301b9b01890ed170114550c28e8b0433cf3da7b 100644 (file)
@@ -518,6 +518,7 @@ struct rds_transport {
        void (*sync_mr)(void *trans_private, int direction);
        void (*free_mr)(void *trans_private, int invalidate);
        void (*flush_mrs)(void);
+       bool (*t_unloading)(struct rds_connection *conn);
 };
 
 struct rds_sock {
@@ -862,6 +863,12 @@ static inline void rds_mr_put(struct rds_mr *mr)
                __rds_put_mr_final(mr);
 }
 
+static inline bool rds_destroy_pending(struct rds_connection *conn)
+{
+       return !check_net(rds_conn_net(conn)) ||
+              (conn->c_trans->t_unloading && conn->c_trans->t_unloading(conn));
+}
+
 /* stats.c */
 DECLARE_PER_CPU_SHARED_ALIGNED(struct rds_statistics, rds_stats);
 #define rds_stats_inc_which(which, member) do {                \
index d3e32d1f3c7d61fe910e2dfa59016e7456d4252c..b1b0022b8370bf497cf4255412c33a5b033467c9 100644 (file)
@@ -162,7 +162,7 @@ restart:
                goto out;
        }
 
-       if (test_bit(RDS_DESTROY_PENDING, &cp->cp_flags)) {
+       if (rds_destroy_pending(cp->cp_conn)) {
                release_in_xmit(cp);
                ret = -ENETUNREACH; /* dont requeue send work */
                goto out;
@@ -444,7 +444,7 @@ over_batch:
                        if (batch_count < send_batch_count)
                                goto restart;
                        rcu_read_lock();
-                       if (test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+                       if (rds_destroy_pending(cp->cp_conn))
                                ret = -ENETUNREACH;
                        else
                                queue_delayed_work(rds_wq, &cp->cp_send_w, 1);
@@ -1162,7 +1162,7 @@ int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len)
        else
                cpath = &conn->c_path[0];
 
-       if (test_bit(RDS_DESTROY_PENDING, &cpath->cp_flags)) {
+       if (rds_destroy_pending(conn)) {
                ret = -EAGAIN;
                goto out;
        }
@@ -1209,7 +1209,7 @@ int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len)
        if (ret == -ENOMEM || ret == -EAGAIN) {
                ret = 0;
                rcu_read_lock();
-               if (test_bit(RDS_DESTROY_PENDING, &cpath->cp_flags))
+               if (rds_destroy_pending(cpath->cp_conn))
                        ret = -ENETUNREACH;
                else
                        queue_delayed_work(rds_wq, &cpath->cp_send_w, 1);
@@ -1295,7 +1295,7 @@ rds_send_probe(struct rds_conn_path *cp, __be16 sport,
 
        /* schedule the send work on rds_wq */
        rcu_read_lock();
-       if (!test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+       if (!rds_destroy_pending(cp->cp_conn))
                queue_delayed_work(rds_wq, &cp->cp_send_w, 1);
        rcu_read_unlock();
 
index 9920d2f84eff8e358702ec69f010f0394b0f63fa..44c4652721af23a0f2d6a3486495310755515350 100644 (file)
@@ -49,6 +49,7 @@ static unsigned int rds_tcp_tc_count;
 /* Track rds_tcp_connection structs so they can be cleaned up */
 static DEFINE_SPINLOCK(rds_tcp_conn_lock);
 static LIST_HEAD(rds_tcp_conn_list);
+static atomic_t rds_tcp_unloading = ATOMIC_INIT(0);
 
 static struct kmem_cache *rds_tcp_conn_slab;
 
@@ -274,14 +275,13 @@ static int rds_tcp_laddr_check(struct net *net, __be32 addr)
 static void rds_tcp_conn_free(void *arg)
 {
        struct rds_tcp_connection *tc = arg;
-       unsigned long flags;
 
        rdsdebug("freeing tc %p\n", tc);
 
-       spin_lock_irqsave(&rds_tcp_conn_lock, flags);
+       spin_lock_bh(&rds_tcp_conn_lock);
        if (!tc->t_tcp_node_detached)
                list_del(&tc->t_tcp_node);
-       spin_unlock_irqrestore(&rds_tcp_conn_lock, flags);
+       spin_unlock_bh(&rds_tcp_conn_lock);
 
        kmem_cache_free(rds_tcp_conn_slab, tc);
 }
@@ -296,7 +296,7 @@ static int rds_tcp_conn_alloc(struct rds_connection *conn, gfp_t gfp)
                tc = kmem_cache_alloc(rds_tcp_conn_slab, gfp);
                if (!tc) {
                        ret = -ENOMEM;
-                       break;
+                       goto fail;
                }
                mutex_init(&tc->t_conn_path_lock);
                tc->t_sock = NULL;
@@ -306,14 +306,19 @@ static int rds_tcp_conn_alloc(struct rds_connection *conn, gfp_t gfp)
 
                conn->c_path[i].cp_transport_data = tc;
                tc->t_cpath = &conn->c_path[i];
+               tc->t_tcp_node_detached = true;
 
-               spin_lock_irq(&rds_tcp_conn_lock);
-               tc->t_tcp_node_detached = false;
-               list_add_tail(&tc->t_tcp_node, &rds_tcp_conn_list);
-               spin_unlock_irq(&rds_tcp_conn_lock);
                rdsdebug("rds_conn_path [%d] tc %p\n", i,
                         conn->c_path[i].cp_transport_data);
        }
+       spin_lock_bh(&rds_tcp_conn_lock);
+       for (i = 0; i < RDS_MPATH_WORKERS; i++) {
+               tc = conn->c_path[i].cp_transport_data;
+               tc->t_tcp_node_detached = false;
+               list_add_tail(&tc->t_tcp_node, &rds_tcp_conn_list);
+       }
+       spin_unlock_bh(&rds_tcp_conn_lock);
+fail:
        if (ret) {
                for (j = 0; j < i; j++)
                        rds_tcp_conn_free(conn->c_path[j].cp_transport_data);
@@ -332,6 +337,16 @@ static bool list_has_conn(struct list_head *list, struct rds_connection *conn)
        return false;
 }
 
+static void rds_tcp_set_unloading(void)
+{
+       atomic_set(&rds_tcp_unloading, 1);
+}
+
+static bool rds_tcp_is_unloading(struct rds_connection *conn)
+{
+       return atomic_read(&rds_tcp_unloading) != 0;
+}
+
 static void rds_tcp_destroy_conns(void)
 {
        struct rds_tcp_connection *tc, *_tc;
@@ -370,6 +385,7 @@ struct rds_transport rds_tcp_transport = {
        .t_type                 = RDS_TRANS_TCP,
        .t_prefer_loopback      = 1,
        .t_mp_capable           = 1,
+       .t_unloading            = rds_tcp_is_unloading,
 };
 
 static unsigned int rds_tcp_netid;
@@ -513,7 +529,7 @@ static void rds_tcp_kill_sock(struct net *net)
 
        rtn->rds_tcp_listen_sock = NULL;
        rds_tcp_listen_stop(lsock, &rtn->rds_tcp_accept_w);
-       spin_lock_irq(&rds_tcp_conn_lock);
+       spin_lock_bh(&rds_tcp_conn_lock);
        list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) {
                struct net *c_net = read_pnet(&tc->t_cpath->cp_conn->c_net);
 
@@ -526,7 +542,7 @@ static void rds_tcp_kill_sock(struct net *net)
                        tc->t_tcp_node_detached = true;
                }
        }
-       spin_unlock_irq(&rds_tcp_conn_lock);
+       spin_unlock_bh(&rds_tcp_conn_lock);
        list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node)
                rds_conn_destroy(tc->t_cpath->cp_conn);
 }
@@ -574,7 +590,7 @@ static void rds_tcp_sysctl_reset(struct net *net)
 {
        struct rds_tcp_connection *tc, *_tc;
 
-       spin_lock_irq(&rds_tcp_conn_lock);
+       spin_lock_bh(&rds_tcp_conn_lock);
        list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) {
                struct net *c_net = read_pnet(&tc->t_cpath->cp_conn->c_net);
 
@@ -584,7 +600,7 @@ static void rds_tcp_sysctl_reset(struct net *net)
                /* reconnect with new parameters */
                rds_conn_path_drop(tc->t_cpath, false);
        }
-       spin_unlock_irq(&rds_tcp_conn_lock);
+       spin_unlock_bh(&rds_tcp_conn_lock);
 }
 
 static int rds_tcp_skbuf_handler(struct ctl_table *ctl, int write,
@@ -607,6 +623,8 @@ static int rds_tcp_skbuf_handler(struct ctl_table *ctl, int write,
 
 static void rds_tcp_exit(void)
 {
+       rds_tcp_set_unloading();
+       synchronize_rcu();
        rds_info_deregister_func(RDS_INFO_TCP_SOCKETS, rds_tcp_tc_info);
        unregister_pernet_subsys(&rds_tcp_net_ops);
        if (unregister_netdevice_notifier(&rds_tcp_dev_notifier))
index 534c67aeb20f8a9babf3e133e18db5af14be4867..d999e707564579f0b81a8667946b348566b2695b 100644 (file)
@@ -170,7 +170,7 @@ void rds_tcp_conn_path_shutdown(struct rds_conn_path *cp)
                 cp->cp_conn, tc, sock);
 
        if (sock) {
-               if (test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+               if (rds_destroy_pending(cp->cp_conn))
                        rds_tcp_set_linger(sock);
                sock->ops->shutdown(sock, RCV_SHUTDOWN | SEND_SHUTDOWN);
                lock_sock(sock->sk);
index dd707b9e73e57dae52e8568b60d3e3f7000765af..b9fbd2ee74efe1c4f75cb499f00ce92f8be5a331 100644 (file)
@@ -323,7 +323,7 @@ void rds_tcp_data_ready(struct sock *sk)
 
        if (rds_tcp_read_sock(cp, GFP_ATOMIC) == -ENOMEM) {
                rcu_read_lock();
-               if (!test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+               if (!rds_destroy_pending(cp->cp_conn))
                        queue_delayed_work(rds_wq, &cp->cp_recv_w, 0);
                rcu_read_unlock();
        }
index 16f65744d9844b995ffe6323bd1bf9fe3d2932d8..7df869d37afd4c27e519b227b57bb306cf30ef35 100644 (file)
@@ -204,7 +204,7 @@ void rds_tcp_write_space(struct sock *sk)
 
        rcu_read_lock();
        if ((refcount_read(&sk->sk_wmem_alloc) << 1) <= sk->sk_sndbuf &&
-           !test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+           !rds_destroy_pending(cp->cp_conn))
                queue_delayed_work(rds_wq, &cp->cp_send_w, 0);
        rcu_read_unlock();
 
index eb76db1360b00b3fe3e7ff96f588c3b1a9cac46d..c52861d77a596ca49ad40c3c7c2efe902bc135f4 100644 (file)
@@ -88,7 +88,7 @@ void rds_connect_path_complete(struct rds_conn_path *cp, int curr)
        cp->cp_reconnect_jiffies = 0;
        set_bit(0, &cp->cp_conn->c_map_queued);
        rcu_read_lock();
-       if (!test_bit(RDS_DESTROY_PENDING, &cp->cp_flags)) {
+       if (!rds_destroy_pending(cp->cp_conn)) {
                queue_delayed_work(rds_wq, &cp->cp_send_w, 0);
                queue_delayed_work(rds_wq, &cp->cp_recv_w, 0);
        }
@@ -138,7 +138,7 @@ void rds_queue_reconnect(struct rds_conn_path *cp)
        if (cp->cp_reconnect_jiffies == 0) {
                cp->cp_reconnect_jiffies = rds_sysctl_reconnect_min_jiffies;
                rcu_read_lock();
-               if (!test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+               if (!rds_destroy_pending(cp->cp_conn))
                        queue_delayed_work(rds_wq, &cp->cp_conn_w, 0);
                rcu_read_unlock();
                return;
@@ -149,7 +149,7 @@ void rds_queue_reconnect(struct rds_conn_path *cp)
                 rand % cp->cp_reconnect_jiffies, cp->cp_reconnect_jiffies,
                 conn, &conn->c_laddr, &conn->c_faddr);
        rcu_read_lock();
-       if (!test_bit(RDS_DESTROY_PENDING, &cp->cp_flags))
+       if (!rds_destroy_pending(cp->cp_conn))
                queue_delayed_work(rds_wq, &cp->cp_conn_w,
                                   rand % cp->cp_reconnect_jiffies);
        rcu_read_unlock();