bpf: add access to sock fields and pkt data from sk_skb programs
[linux-block.git] / net / core / filter.c
index 8e136578488c24248fcd41b0e9f41781d4bcd09e..e9f8dcef6c57cf237ab5798da96f1230684d6021 100644 (file)
@@ -3278,8 +3278,16 @@ static const struct bpf_func_proto *
 static const struct bpf_func_proto *sk_skb_func_proto(enum bpf_func_id func_id)
 {
        switch (func_id) {
+       case BPF_FUNC_skb_store_bytes:
+               return &bpf_skb_store_bytes_proto;
        case BPF_FUNC_skb_load_bytes:
                return &bpf_skb_load_bytes_proto;
+       case BPF_FUNC_skb_pull_data:
+               return &bpf_skb_pull_data_proto;
+       case BPF_FUNC_skb_change_tail:
+               return &bpf_skb_change_tail_proto;
+       case BPF_FUNC_skb_change_head:
+               return &bpf_skb_change_head_proto;
        case BPF_FUNC_get_socket_cookie:
                return &bpf_get_socket_cookie_proto;
        case BPF_FUNC_get_socket_uid:
@@ -3343,6 +3351,10 @@ static bool bpf_skb_is_valid_access(int off, int size, enum bpf_access_type type
                if (off + size > offsetofend(struct __sk_buff, cb[4]))
                        return false;
                break;
+       case bpf_ctx_range_till(struct __sk_buff, remote_ip6[0], remote_ip6[3]):
+       case bpf_ctx_range_till(struct __sk_buff, local_ip6[0], local_ip6[3]):
+       case bpf_ctx_range_till(struct __sk_buff, remote_ip4, remote_ip4):
+       case bpf_ctx_range_till(struct __sk_buff, local_ip4, local_ip4):
        case bpf_ctx_range(struct __sk_buff, data):
        case bpf_ctx_range(struct __sk_buff, data_end):
                if (size != size_default)
@@ -3371,6 +3383,7 @@ static bool sk_filter_is_valid_access(int off, int size,
        case bpf_ctx_range(struct __sk_buff, tc_classid):
        case bpf_ctx_range(struct __sk_buff, data):
        case bpf_ctx_range(struct __sk_buff, data_end):
+       case bpf_ctx_range_till(struct __sk_buff, family, local_port):
                return false;
        }
 
@@ -3392,6 +3405,7 @@ static bool lwt_is_valid_access(int off, int size,
 {
        switch (off) {
        case bpf_ctx_range(struct __sk_buff, tc_classid):
+       case bpf_ctx_range_till(struct __sk_buff, family, local_port):
                return false;
        }
 
@@ -3505,6 +3519,8 @@ static bool tc_cls_act_is_valid_access(int off, int size,
        case bpf_ctx_range(struct __sk_buff, data_end):
                info->reg_type = PTR_TO_PACKET_END;
                break;
+       case bpf_ctx_range_till(struct __sk_buff, family, local_port):
+               return false;
        }
 
        return bpf_skb_is_valid_access(off, size, type, info);
@@ -3582,11 +3598,63 @@ static bool sock_ops_is_valid_access(int off, int size,
        return __is_valid_sock_ops_access(off, size);
 }
 
+static int sk_skb_prologue(struct bpf_insn *insn_buf, bool direct_write,
+                          const struct bpf_prog *prog)
+{
+       struct bpf_insn *insn = insn_buf;
+
+       if (!direct_write)
+               return 0;
+
+       /* if (!skb->cloned)
+        *       goto start;
+        *
+        * (Fast-path, otherwise approximation that we might be
+        *  a clone, do the rest in helper.)
+        */
+       *insn++ = BPF_LDX_MEM(BPF_B, BPF_REG_6, BPF_REG_1, CLONED_OFFSET());
+       *insn++ = BPF_ALU32_IMM(BPF_AND, BPF_REG_6, CLONED_MASK);
+       *insn++ = BPF_JMP_IMM(BPF_JEQ, BPF_REG_6, 0, 7);
+
+       /* ret = bpf_skb_pull_data(skb, 0); */
+       *insn++ = BPF_MOV64_REG(BPF_REG_6, BPF_REG_1);
+       *insn++ = BPF_ALU64_REG(BPF_XOR, BPF_REG_2, BPF_REG_2);
+       *insn++ = BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0,
+                              BPF_FUNC_skb_pull_data);
+       /* if (!ret)
+        *      goto restore;
+        * return SK_DROP;
+        */
+       *insn++ = BPF_JMP_IMM(BPF_JEQ, BPF_REG_0, 0, 2);
+       *insn++ = BPF_ALU32_IMM(BPF_MOV, BPF_REG_0, SK_DROP);
+       *insn++ = BPF_EXIT_INSN();
+
+       /* restore: */
+       *insn++ = BPF_MOV64_REG(BPF_REG_1, BPF_REG_6);
+       /* start: */
+       *insn++ = prog->insnsi[0];
+
+       return insn - insn_buf;
+}
+
 static bool sk_skb_is_valid_access(int off, int size,
                                   enum bpf_access_type type,
                                   struct bpf_insn_access_aux *info)
 {
+       if (type == BPF_WRITE) {
+               switch (off) {
+               case bpf_ctx_range(struct __sk_buff, mark):
+               case bpf_ctx_range(struct __sk_buff, tc_index):
+               case bpf_ctx_range(struct __sk_buff, priority):
+                       break;
+               default:
+                       return false;
+               }
+       }
+
        switch (off) {
+       case bpf_ctx_range(struct __sk_buff, tc_classid):
+               return false;
        case bpf_ctx_range(struct __sk_buff, data):
                info->reg_type = PTR_TO_PACKET;
                break;
@@ -3783,6 +3851,106 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type,
                *insn++ = BPF_MOV64_IMM(si->dst_reg, 0);
 #endif
                break;
+       case offsetof(struct __sk_buff, family):
+               BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_family) != 2);
+
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_buff, sk));
+               *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
+                                     bpf_target_off(struct sock_common,
+                                                    skc_family,
+                                                    2, target_size));
+               break;
+       case offsetof(struct __sk_buff, remote_ip4):
+               BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_daddr) != 4);
+
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_buff, sk));
+               *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
+                                     bpf_target_off(struct sock_common,
+                                                    skc_daddr,
+                                                    4, target_size));
+               break;
+       case offsetof(struct __sk_buff, local_ip4):
+               BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common,
+                                         skc_rcv_saddr) != 4);
+
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_buff, sk));
+               *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
+                                     bpf_target_off(struct sock_common,
+                                                    skc_rcv_saddr,
+                                                    4, target_size));
+               break;
+       case offsetof(struct __sk_buff, remote_ip6[0]) ...
+            offsetof(struct __sk_buff, remote_ip6[3]):
+#if IS_ENABLED(CONFIG_IPV6)
+               BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common,
+                                         skc_v6_daddr.s6_addr32[0]) != 4);
+
+               off = si->off;
+               off -= offsetof(struct __sk_buff, remote_ip6[0]);
+
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_buff, sk));
+               *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
+                                     offsetof(struct sock_common,
+                                              skc_v6_daddr.s6_addr32[0]) +
+                                     off);
+#else
+               *insn++ = BPF_MOV32_IMM(si->dst_reg, 0);
+#endif
+               break;
+       case offsetof(struct __sk_buff, local_ip6[0]) ...
+            offsetof(struct __sk_buff, local_ip6[3]):
+#if IS_ENABLED(CONFIG_IPV6)
+               BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common,
+                                         skc_v6_rcv_saddr.s6_addr32[0]) != 4);
+
+               off = si->off;
+               off -= offsetof(struct __sk_buff, local_ip6[0]);
+
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_buff, sk));
+               *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg,
+                                     offsetof(struct sock_common,
+                                              skc_v6_rcv_saddr.s6_addr32[0]) +
+                                     off);
+#else
+               *insn++ = BPF_MOV32_IMM(si->dst_reg, 0);
+#endif
+               break;
+
+       case offsetof(struct __sk_buff, remote_port):
+               BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_dport) != 2);
+
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_buff, sk));
+               *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
+                                     bpf_target_off(struct sock_common,
+                                                    skc_dport,
+                                                    2, target_size));
+#ifndef __BIG_ENDIAN_BITFIELD
+               *insn++ = BPF_ALU32_IMM(BPF_LSH, si->dst_reg, 16);
+#endif
+               break;
+
+       case offsetof(struct __sk_buff, local_port):
+               BUILD_BUG_ON(FIELD_SIZEOF(struct sock_common, skc_num) != 2);
+
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_buff, sk));
+               *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg,
+                                     bpf_target_off(struct sock_common,
+                                                    skc_num, 2, target_size));
+               break;
        }
 
        return insn - insn_buf;
@@ -4071,6 +4239,7 @@ const struct bpf_verifier_ops sk_skb_prog_ops = {
        .get_func_proto         = sk_skb_func_proto,
        .is_valid_access        = sk_skb_is_valid_access,
        .convert_ctx_access     = bpf_convert_ctx_access,
+       .gen_prologue           = sk_skb_prologue,
 };
 
 int sk_detach_filter(struct sock *sk)