bpf: add BPF_CGROUP_SOCK_OPS callback that is executed on every RTT
[linux-2.6-block.git] / net / ipv4 / tcp_input.c
index 08a477e74cf3267b725294c66b46fdad12bd2b72..c21e8a22fb3bb39d06eb3ee7eb4cfae5066b6f48 100644 (file)
@@ -119,7 +119,7 @@ void clean_acked_data_enable(struct inet_connection_sock *icsk,
                             void (*cad)(struct sock *sk, u32 ack_seq))
 {
        icsk->icsk_clean_acked = cad;
-       static_branch_inc(&clean_acked_data_enabled.key);
+       static_branch_deferred_inc(&clean_acked_data_enabled);
 }
 EXPORT_SYMBOL_GPL(clean_acked_data_enable);
 
@@ -778,6 +778,8 @@ static void tcp_rtt_estimator(struct sock *sk, long mrtt_us)
                                tp->rttvar_us -= (tp->rttvar_us - tp->mdev_max_us) >> 2;
                        tp->rtt_seq = tp->snd_nxt;
                        tp->mdev_max_us = tcp_rto_min_us(sk);
+
+                       tcp_bpf_rtt(sk);
                }
        } else {
                /* no previous measure. */
@@ -786,6 +788,8 @@ static void tcp_rtt_estimator(struct sock *sk, long mrtt_us)
                tp->rttvar_us = max(tp->mdev_us, tcp_rto_min_us(sk));
                tp->mdev_max_us = tp->rttvar_us;
                tp->rtt_seq = tp->snd_nxt;
+
+               tcp_bpf_rtt(sk);
        }
        tp->srtt_us = max(1U, srtt);
 }
@@ -1302,7 +1306,7 @@ static bool tcp_shifted_skb(struct sock *sk, struct sk_buff *prev,
        TCP_SKB_CB(skb)->seq += shifted;
 
        tcp_skb_pcount_add(prev, pcount);
-       BUG_ON(tcp_skb_pcount(skb) < pcount);
+       WARN_ON_ONCE(tcp_skb_pcount(skb) < pcount);
        tcp_skb_pcount_add(skb, -pcount);
 
        /* When we're adding to gso_segs == 1, gso_size will be zero,
@@ -1368,6 +1372,21 @@ static int skb_can_shift(const struct sk_buff *skb)
        return !skb_headlen(skb) && skb_is_nonlinear(skb);
 }
 
+int tcp_skb_shift(struct sk_buff *to, struct sk_buff *from,
+                 int pcount, int shiftlen)
+{
+       /* TCP min gso_size is 8 bytes (TCP_MIN_GSO_SIZE)
+        * Since TCP_SKB_CB(skb)->tcp_gso_segs is 16 bits, we need
+        * to make sure not storing more than 65535 * 8 bytes per skb,
+        * even if current MSS is bigger.
+        */
+       if (unlikely(to->len + shiftlen >= 65535 * TCP_MIN_GSO_SIZE))
+               return 0;
+       if (unlikely(tcp_skb_pcount(to) + pcount > 65535))
+               return 0;
+       return skb_shift(to, from, shiftlen);
+}
+
 /* Try collapsing SACK blocks spanning across multiple skbs to a single
  * skb.
  */
@@ -1473,7 +1492,7 @@ static struct sk_buff *tcp_shift_skb_data(struct sock *sk, struct sk_buff *skb,
        if (!after(TCP_SKB_CB(skb)->seq + len, tp->snd_una))
                goto fallback;
 
-       if (!skb_shift(prev, skb, len))
+       if (!tcp_skb_shift(prev, skb, pcount, len))
                goto fallback;
        if (!tcp_shifted_skb(sk, prev, skb, state, pcount, len, mss, dup_sack))
                goto out;
@@ -1491,11 +1510,10 @@ static struct sk_buff *tcp_shift_skb_data(struct sock *sk, struct sk_buff *skb,
                goto out;
 
        len = skb->len;
-       if (skb_shift(prev, skb, len)) {
-               pcount += tcp_skb_pcount(skb);
-               tcp_shifted_skb(sk, prev, skb, state, tcp_skb_pcount(skb),
+       pcount = tcp_skb_pcount(skb);
+       if (tcp_skb_shift(prev, skb, pcount, len))
+               tcp_shifted_skb(sk, prev, skb, state, pcount,
                                len, mss, 0);
-       }
 
 out:
        return prev;
@@ -2648,7 +2666,7 @@ static void tcp_process_loss(struct sock *sk, int flag, int num_dupack,
        struct tcp_sock *tp = tcp_sk(sk);
        bool recovered = !before(tp->snd_una, tp->high_seq);
 
-       if ((flag & FLAG_SND_UNA_ADVANCED) &&
+       if ((flag & FLAG_SND_UNA_ADVANCED || tp->fastopen_rsk) &&
            tcp_try_undo_loss(sk, false))
                return;