TCP will soon attach TIME_WAIT sockets to some ACK and RST.
Make sure sk_to_full_sk() detects this and does not return
a non full socket.
v3: also changed sk_const_to_full_sk()
Signed-off-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Reviewed-by: Martin KaFai Lau <martin.lau@kernel.org>
Reviewed-by: Brian Vazquez <brianvv@google.com>
Link: https://patch.msgid.link/20241010174817.1543642-2-edumazet@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
int __ret = 0; \
if (cgroup_bpf_enabled(CGROUP_INET_EGRESS) && sk) { \
typeof(sk) __sk = sk_to_full_sk(sk); \
- if (sk_fullsock(__sk) && __sk == skb_to_full_sk(skb) && \
+ if (__sk && __sk == skb_to_full_sk(skb) && \
cgroup_bpf_sock_enabled(__sk, CGROUP_INET_EGRESS)) \
__ret = __cgroup_bpf_run_filter_skb(__sk, skb, \
CGROUP_INET_EGRESS); \
static inline struct sock *sk_to_full_sk(struct sock *sk)
{
#ifdef CONFIG_INET
- if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
+ if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
sk = inet_reqsk(sk)->rsk_listener;
+ if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
+ sk = NULL;
#endif
return sk;
}
static inline const struct sock *sk_const_to_full_sk(const struct sock *sk)
{
#ifdef CONFIG_INET
- if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
+ if (sk && READ_ONCE(sk->sk_state) == TCP_NEW_SYN_RECV)
sk = ((const struct request_sock *)sk)->rsk_listener;
+ if (sk && READ_ONCE(sk->sk_state) == TCP_TIME_WAIT)
+ sk = NULL;
#endif
return sk;
}
/* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
* sock refcnt is decremented to prevent a request_sock leak.
*/
- if (!sk_fullsock(sk2))
- sk2 = NULL;
if (sk2 != sk) {
sock_gen_put(sk);
/* Ensure there is no need to bump sk2 refcnt */
/* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk
* sock refcnt is decremented to prevent a request_sock leak.
*/
- if (!sk_fullsock(sk2))
- sk2 = NULL;
if (sk2 != sk) {
sock_gen_put(sk);
/* Ensure there is no need to bump sk2 refcnt */
{
sk = sk_to_full_sk(sk);
- if (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
+ if (sk && sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
return (unsigned long)sk;
return (unsigned long)NULL;