net: add TIME_WAIT logic to sk_to_full_sk()
authorEric Dumazet <edumazet@google.com>
Thu, 10 Oct 2024 17:48:13 +0000 (17:48 +0000)
committerJakub Kicinski <kuba@kernel.org>
Tue, 15 Oct 2024 00:39:36 +0000 (17:39 -0700)
TCP will soon attach TIME_WAIT sockets to some ACK and RST.

Make sure sk_to_full_sk() detects this and does not return
a non full socket.

v3: also changed sk_const_to_full_sk()

Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Reviewed-by: Martin KaFai Lau <martin.lau@kernel.org>
Reviewed-by: Brian Vazquez <brianvv@google.com>
Link: https://patch.msgid.link/20241010174817.1543642-2-edumazet@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
include/linux/bpf-cgroup.h
include/net/inet_sock.h
net/core/filter.c

index ce91d9b2acb9f8991150ceead4475b130bead438..f0f219271daf4afea2666c4d09fd4d1a8091f844 100644 (file)
@@ -209,7 +209,7 @@ static inline bool cgroup_bpf_sock_enabled(struct sock *sk,
        int __ret = 0;                                                         \
        if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) {                    \
                typeof(sk) __sk = sk_to_full_sk(sk);                           \
-               if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) &&        \
+               if (__sk && __sk == skb_to_full_sk(skb) &&             \
                    cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS))         \
                        __ret = __cgroup_bpf_run_filter_skb(__sk, skb,         \
                                                      CGROUP_INET_EGRESS); \
index f01dd273bea69d2eaf7a1d28274d7f980942b78a..56d8bc5593d3dfffd5f94cf7c6383948881917df 100644 (file)
@@ -321,8 +321,10 @@ static inline unsigned long inet_cmsg_flags(const struct inet_sock *inet)
 static inline struct sock *sk_to_full_sk(struct sock *sk)
 {
 #ifdef CONFIG_INET
-       if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
+       if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
                sk = inet_reqsk(sk)->rsk_listener;
+       if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
+               sk = NULL;
 #endif
        return sk;
 }
@@ -331,8 +333,10 @@ static inline struct sock *sk_to_full_sk(struct sock *sk)
 static inline const struct sock *sk_const_to_full_sk(const struct sock *sk)
 {
 #ifdef CONFIG_INET
-       if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
+       if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
                sk = ((const struct request_sock *)sk)->rsk_listener;
+       if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
+               sk = NULL;
 #endif
        return sk;
 }
index bd0d08bf76bb8de39ca2ca89cda99a97c9b0a034..202c1d386e19599e9fc6e0a0d4a95986ba6d0ea8 100644 (file)
@@ -6778,8 +6778,6 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
                /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
                 * sock refcnt is decremented to prevent a request_sock leak.
                 */
-               if (!sk_fullsock(sk2))
-                       sk2 = NULL;
                if (sk2 != sk) {
                        sock_gen_put(sk);
                        /* Ensure there is no need to bump sk2 refcnt */
@@ -6826,8 +6824,6 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
                /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
                 * sock refcnt is decremented to prevent a request_sock leak.
                 */
-               if (!sk_fullsock(sk2))
-                       sk2 = NULL;
                if (sk2 != sk) {
                        sock_gen_put(sk);
                        /* Ensure there is no need to bump sk2 refcnt */
@@ -7276,7 +7272,7 @@ BPF_CALL_1(bpf_get_listener_sock, struct sock *, sk)
 {
        sk = sk_to_full_sk(sk);
 
-       if (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
+       if (sk && sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
                return (unsigned long)sk;
 
        return (unsigned long)NULL;