bpf: sk_msg, sock{map|hash} redirect through ULP
[linux-2.6-block.git] / net / ipv4 / tcp_bpf.c
index 3b45fe530f91e2e1aa697888e11a78cf7e9d211e..1bb7321a256d09da590c623817d7087e77f43575 100644 (file)
@@ -8,6 +8,7 @@
 #include <linux/wait.h>
 
 #include <net/inet_common.h>
+#include <net/tls.h>
 
 static bool tcp_bpf_stream_read(const struct sock *sk)
 {
@@ -198,7 +199,7 @@ static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
                msg->sg.start = i;
                msg->sg.size -= apply_bytes;
                sk_psock_queue_msg(psock, tmp);
-               sk->sk_data_ready(sk);
+               sk_psock_data_ready(sk, psock);
        } else {
                sk_msg_free(sk, tmp);
                kfree(tmp);
@@ -218,6 +219,8 @@ static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
        u32 off;
 
        while (1) {
+               bool has_tx_ulp;
+
                sge = sk_msg_elem(msg, msg->sg.start);
                size = (apply && apply_bytes < sge->length) ?
                        apply_bytes : sge->length;
@@ -226,7 +229,15 @@ static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
 
                tcp_rate_check_app_limited(sk);
 retry:
-               ret = do_tcp_sendpages(sk, page, off, size, flags);
+               has_tx_ulp = tls_sw_has_ctx_tx(sk);
+               if (has_tx_ulp) {
+                       flags |= MSG_SENDPAGE_NOPOLICY;
+                       ret = kernel_sendpage_locked(sk,
+                                                    page, off, size, flags);
+               } else {
+                       ret = do_tcp_sendpages(sk, page, off, size, flags);
+               }
+
                if (ret <= 0)
                        return ret;
                if (apply)
@@ -289,12 +300,23 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
 {
        bool cork = false, enospc = msg->sg.start == msg->sg.end;
        struct sock *sk_redir;
-       u32 tosend;
+       u32 tosend, delta = 0;
        int ret;
 
 more_data:
-       if (psock->eval == __SK_NONE)
+       if (psock->eval == __SK_NONE) {
+               /* Track delta in msg size to add/subtract it on SK_DROP from
+                * returned to user copied size. This ensures user doesn't
+                * get a positive return code with msg_cut_data and SK_DROP
+                * verdict.
+                */
+               delta = msg->sg.size;
                psock->eval = sk_psock_msg_verdict(sk, psock, msg);
+               if (msg->sg.size < delta)
+                       delta -= msg->sg.size;
+               else
+                       delta = 0;
+       }
 
        if (msg->cork_bytes &&
            msg->cork_bytes > msg->sg.size && !enospc) {
@@ -350,7 +372,7 @@ more_data:
        default:
                sk_msg_free_partial(sk, msg, tosend);
                sk_msg_apply_bytes(psock, tosend);
-               *copied -= tosend;
+               *copied -= (tosend + delta);
                return -EACCES;
        }