mptcp: Avoid acquiring PM lock for subflow priority changes
[linux-block.git] / net / mptcp / subflow.c
index 8841e8cd9ad8dbde4f9853e0dd16d6ace244f09e..63e8892ec807d33a2dd0070e87902fd6a2289837 100644 (file)
@@ -843,7 +843,8 @@ enum mapping_status {
        MAPPING_INVALID,
        MAPPING_EMPTY,
        MAPPING_DATA_FIN,
-       MAPPING_DUMMY
+       MAPPING_DUMMY,
+       MAPPING_BAD_CSUM
 };
 
 static void dbg_bad_map(struct mptcp_subflow_context *subflow, u32 ssn)
@@ -958,11 +959,7 @@ static enum mapping_status validate_data_csum(struct sock *ssk, struct sk_buff *
                                 subflow->map_data_csum);
        if (unlikely(csum)) {
                MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_DATACSUMERR);
-               if (subflow->mp_join || subflow->valid_csum_seen) {
-                       subflow->send_mp_fail = 1;
-                       MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_MPFAILTX);
-               }
-               return subflow->mp_join ? MAPPING_INVALID : MAPPING_DUMMY;
+               return MAPPING_BAD_CSUM;
        }
 
        subflow->valid_csum_seen = 1;
@@ -974,7 +971,6 @@ static enum mapping_status get_mapping_status(struct sock *ssk,
 {
        struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
        bool csum_reqd = READ_ONCE(msk->csum_enabled);
-       struct sock *sk = (struct sock *)msk;
        struct mptcp_ext *mpext;
        struct sk_buff *skb;
        u16 data_len;
@@ -1016,9 +1012,6 @@ static enum mapping_status get_mapping_status(struct sock *ssk,
                pr_debug("infinite mapping received");
                MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_INFINITEMAPRX);
                subflow->map_data_len = 0;
-               if (!sock_flag(ssk, SOCK_DEAD))
-                       sk_stop_timer(sk, &sk->sk_timer);
-
                return MAPPING_INVALID;
        }
 
@@ -1165,6 +1158,33 @@ static bool subflow_can_fallback(struct mptcp_subflow_context *subflow)
                return !subflow->fully_established;
 }
 
+static void mptcp_subflow_fail(struct mptcp_sock *msk, struct sock *ssk)
+{
+       struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
+       unsigned long fail_tout;
+
+       /* greceful failure can happen only on the MPC subflow */
+       if (WARN_ON_ONCE(ssk != READ_ONCE(msk->first)))
+               return;
+
+       /* since the close timeout take precedence on the fail one,
+        * no need to start the latter when the first is already set
+        */
+       if (sock_flag((struct sock *)msk, SOCK_DEAD))
+               return;
+
+       /* we don't need extreme accuracy here, use a zero fail_tout as special
+        * value meaning no fail timeout at all;
+        */
+       fail_tout = jiffies + TCP_RTO_MAX;
+       if (!fail_tout)
+               fail_tout = 1;
+       WRITE_ONCE(subflow->fail_tout, fail_tout);
+       tcp_send_ack(ssk);
+
+       mptcp_reset_timeout(msk, subflow->fail_tout);
+}
+
 static bool subflow_check_data_avail(struct sock *ssk)
 {
        struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
@@ -1184,10 +1204,8 @@ static bool subflow_check_data_avail(struct sock *ssk)
 
                status = get_mapping_status(ssk, msk);
                trace_subflow_check_data_avail(status, skb_peek(&ssk->sk_receive_queue));
-               if (unlikely(status == MAPPING_INVALID))
-                       goto fallback;
-
-               if (unlikely(status == MAPPING_DUMMY))
+               if (unlikely(status == MAPPING_INVALID || status == MAPPING_DUMMY ||
+                            status == MAPPING_BAD_CSUM))
                        goto fallback;
 
                if (status != MAPPING_OK)
@@ -1229,22 +1247,17 @@ no_data:
 fallback:
        if (!__mptcp_check_fallback(msk)) {
                /* RFC 8684 section 3.7. */
-               if (subflow->send_mp_fail) {
+               if (status == MAPPING_BAD_CSUM &&
+                   (subflow->mp_join || subflow->valid_csum_seen)) {
+                       subflow->send_mp_fail = 1;
+
                        if (!READ_ONCE(msk->allow_infinite_fallback)) {
-                               ssk->sk_err = EBADMSG;
-                               tcp_set_state(ssk, TCP_CLOSE);
                                subflow->reset_transient = 0;
                                subflow->reset_reason = MPTCP_RST_EMIDDLEBOX;
-                               tcp_send_active_reset(ssk, GFP_ATOMIC);
-                               while ((skb = skb_peek(&ssk->sk_receive_queue)))
-                                       sk_eat_skb(ssk, skb);
-                       } else if (!sock_flag(ssk, SOCK_DEAD)) {
-                               WRITE_ONCE(subflow->mp_fail_response_expect, true);
-                               sk_reset_timer((struct sock *)msk,
-                                              &((struct sock *)msk)->sk_timer,
-                                              jiffies + TCP_RTO_MAX);
+                               goto reset;
                        }
-                       WRITE_ONCE(subflow->data_avail, MPTCP_SUBFLOW_NODATA);
+                       mptcp_subflow_fail(msk, ssk);
+                       WRITE_ONCE(subflow->data_avail, MPTCP_SUBFLOW_DATA_AVAIL);
                        return true;
                }
 
@@ -1252,16 +1265,20 @@ fallback:
                        /* fatal protocol error, close the socket.
                         * subflow_error_report() will introduce the appropriate barriers
                         */
-                       ssk->sk_err = EBADMSG;
-                       tcp_set_state(ssk, TCP_CLOSE);
                        subflow->reset_transient = 0;
                        subflow->reset_reason = MPTCP_RST_EMPTCP;
+
+reset:
+                       ssk->sk_err = EBADMSG;
+                       tcp_set_state(ssk, TCP_CLOSE);
+                       while ((skb = skb_peek(&ssk->sk_receive_queue)))
+                               sk_eat_skb(ssk, skb);
                        tcp_send_active_reset(ssk, GFP_ATOMIC);
                        WRITE_ONCE(subflow->data_avail, MPTCP_SUBFLOW_NODATA);
                        return false;
                }
 
-               __mptcp_do_fallback(msk);
+               mptcp_do_fallback(ssk);
        }
 
        skb = skb_peek(&ssk->sk_receive_queue);
@@ -1706,6 +1723,58 @@ static void subflow_state_change(struct sock *sk)
        }
 }
 
+void mptcp_subflow_queue_clean(struct sock *listener_ssk)
+{
+       struct request_sock_queue *queue = &inet_csk(listener_ssk)->icsk_accept_queue;
+       struct mptcp_sock *msk, *next, *head = NULL;
+       struct request_sock *req;
+
+       /* build a list of all unaccepted mptcp sockets */
+       spin_lock_bh(&queue->rskq_lock);
+       for (req = queue->rskq_accept_head; req; req = req->dl_next) {
+               struct mptcp_subflow_context *subflow;
+               struct sock *ssk = req->sk;
+               struct mptcp_sock *msk;
+
+               if (!sk_is_mptcp(ssk))
+                       continue;
+
+               subflow = mptcp_subflow_ctx(ssk);
+               if (!subflow || !subflow->conn)
+                       continue;
+
+               /* skip if already in list */
+               msk = mptcp_sk(subflow->conn);
+               if (msk->dl_next || msk == head)
+                       continue;
+
+               msk->dl_next = head;
+               head = msk;
+       }
+       spin_unlock_bh(&queue->rskq_lock);
+       if (!head)
+               return;
+
+       /* can't acquire the msk socket lock under the subflow one,
+        * or will cause ABBA deadlock
+        */
+       release_sock(listener_ssk);
+
+       for (msk = head; msk; msk = next) {
+               struct sock *sk = (struct sock *)msk;
+               bool slow;
+
+               slow = lock_sock_fast_nested(sk);
+               next = msk->dl_next;
+               msk->first = NULL;
+               msk->dl_next = NULL;
+               unlock_sock_fast(sk, slow);
+       }
+
+       /* we are still under the listener msk socket lock */
+       lock_sock_nested(listener_ssk, SINGLE_DEPTH_NESTING);
+}
+
 static int subflow_ulp_init(struct sock *sk)
 {
        struct inet_connection_sock *icsk = inet_csk(sk);