bpf: Fix verifier assumptions about socket->sk
authorAlexei Starovoitov <ast@kernel.org>
Sat, 27 Apr 2024 00:25:44 +0000 (17:25 -0700)
committerMartin KaFai Lau <martin.lau@kernel.org>
Mon, 29 Apr 2024 21:16:41 +0000 (14:16 -0700)
The verifier assumes that 'sk' field in 'struct socket' is valid
and non-NULL when 'socket' pointer itself is trusted and non-NULL.
That may not be the case when socket was just created and
passed to LSM socket_accept hook.
Fix this verifier assumption and adjust tests.

Reported-by: Liam Wisehart <liamwisehart@meta.com>
Acked-by: Kumar Kartikeya Dwivedi <memxor@gmail.com>
Fixes: 6fcd486b3a0a ("bpf: Refactor RCU enforcement in the verifier.")
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Link: https://lore.kernel.org/r/20240427002544.68803-1-alexei.starovoitov@gmail.com
Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
kernel/bpf/verifier.c
tools/testing/selftests/bpf/progs/bench_local_storage_create.c
tools/testing/selftests/bpf/progs/local_storage.c
tools/testing/selftests/bpf/progs/lsm_cgroup.c

index 87ff414899cf37fb93b06a8d9e1e1ceaabbe578d..5d42db05315e6aca85c403bc6c90c5b9007b1eff 100644 (file)
@@ -2368,6 +2368,8 @@ static void mark_btf_ld_reg(struct bpf_verifier_env *env,
        regs[regno].type = PTR_TO_BTF_ID | flag;
        regs[regno].btf = btf;
        regs[regno].btf_id = btf_id;
+       if (type_may_be_null(flag))
+               regs[regno].id = ++env->id_gen;
 }
 
 #define DEF_NOT_SUBREG (0)
@@ -5400,8 +5402,6 @@ static int check_map_kptr_access(struct bpf_verifier_env *env, u32 regno,
                 */
                mark_btf_ld_reg(env, cur_regs(env), value_regno, PTR_TO_BTF_ID, kptr_field->kptr.btf,
                                kptr_field->kptr.btf_id, btf_ld_kptr_type(env, kptr_field));
-               /* For mark_ptr_or_null_reg */
-               val_reg->id = ++env->id_gen;
        } else if (class == BPF_STX) {
                val_reg = reg_state(env, value_regno);
                if (!register_is_null(val_reg) &&
@@ -5719,7 +5719,8 @@ static bool is_trusted_reg(const struct bpf_reg_state *reg)
                return true;
 
        /* Types listed in the reg2btf_ids are always trusted */
-       if (reg2btf_ids[base_type(reg->type)])
+       if (reg2btf_ids[base_type(reg->type)] &&
+           !bpf_type_has_unsafe_modifiers(reg->type))
                return true;
 
        /* If a register is not referenced, it is trusted if it has the
@@ -6339,6 +6340,7 @@ static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val,
 #define BTF_TYPE_SAFE_RCU(__type)  __PASTE(__type, __safe_rcu)
 #define BTF_TYPE_SAFE_RCU_OR_NULL(__type)  __PASTE(__type, __safe_rcu_or_null)
 #define BTF_TYPE_SAFE_TRUSTED(__type)  __PASTE(__type, __safe_trusted)
+#define BTF_TYPE_SAFE_TRUSTED_OR_NULL(__type)  __PASTE(__type, __safe_trusted_or_null)
 
 /*
  * Allow list few fields as RCU trusted or full trusted.
@@ -6402,7 +6404,7 @@ BTF_TYPE_SAFE_TRUSTED(struct dentry) {
        struct inode *d_inode;
 };
 
-BTF_TYPE_SAFE_TRUSTED(struct socket) {
+BTF_TYPE_SAFE_TRUSTED_OR_NULL(struct socket) {
        struct sock *sk;
 };
 
@@ -6437,11 +6439,20 @@ static bool type_is_trusted(struct bpf_verifier_env *env,
        BTF_TYPE_EMIT(BTF_TYPE_SAFE_TRUSTED(struct linux_binprm));
        BTF_TYPE_EMIT(BTF_TYPE_SAFE_TRUSTED(struct file));
        BTF_TYPE_EMIT(BTF_TYPE_SAFE_TRUSTED(struct dentry));
-       BTF_TYPE_EMIT(BTF_TYPE_SAFE_TRUSTED(struct socket));
 
        return btf_nested_type_is_trusted(&env->log, reg, field_name, btf_id, "__safe_trusted");
 }
 
+static bool type_is_trusted_or_null(struct bpf_verifier_env *env,
+                                   struct bpf_reg_state *reg,
+                                   const char *field_name, u32 btf_id)
+{
+       BTF_TYPE_EMIT(BTF_TYPE_SAFE_TRUSTED_OR_NULL(struct socket));
+
+       return btf_nested_type_is_trusted(&env->log, reg, field_name, btf_id,
+                                         "__safe_trusted_or_null");
+}
+
 static int check_ptr_to_btf_access(struct bpf_verifier_env *env,
                                   struct bpf_reg_state *regs,
                                   int regno, int off, int size,
@@ -6550,6 +6561,8 @@ static int check_ptr_to_btf_access(struct bpf_verifier_env *env,
                 */
                if (type_is_trusted(env, reg, field_name, btf_id)) {
                        flag |= PTR_TRUSTED;
+               } else if (type_is_trusted_or_null(env, reg, field_name, btf_id)) {
+                       flag |= PTR_TRUSTED | PTR_MAYBE_NULL;
                } else if (in_rcu_cs(env) && !type_may_be_null(reg->type)) {
                        if (type_is_rcu(env, reg, field_name, btf_id)) {
                                /* ignore __rcu tag and mark it MEM_RCU */
index e4bfbba6c19360b9e8f5462e65cfd2e56454f800..c8ec0d0368e4a15a72c2b166a73ab90615424c97 100644 (file)
@@ -61,14 +61,15 @@ SEC("lsm.s/socket_post_create")
 int BPF_PROG(socket_post_create, struct socket *sock, int family, int type,
             int protocol, int kern)
 {
+       struct sock *sk = sock->sk;
        struct storage *stg;
        __u32 pid;
 
        pid = bpf_get_current_pid_tgid() >> 32;
-       if (pid != bench_pid)
+       if (pid != bench_pid || !sk)
                return 0;
 
-       stg = bpf_sk_storage_get(&sk_storage_map, sock->sk, NULL,
+       stg = bpf_sk_storage_get(&sk_storage_map, sk, NULL,
                                 BPF_LOCAL_STORAGE_GET_F_CREATE);
 
        if (stg)
index e5e3a8b8dd075845e3c62bf8bcd6ec3027647595..637e75df2e1463e2e403ec5f8ffe7f6570ef5efd 100644 (file)
@@ -140,11 +140,12 @@ int BPF_PROG(socket_bind, struct socket *sock, struct sockaddr *address,
 {
        __u32 pid = bpf_get_current_pid_tgid() >> 32;
        struct local_storage *storage;
+       struct sock *sk = sock->sk;
 
-       if (pid != monitored_pid)
+       if (pid != monitored_pid || !sk)
                return 0;
 
-       storage = bpf_sk_storage_get(&sk_storage_map, sock->sk, 0, 0);
+       storage = bpf_sk_storage_get(&sk_storage_map, sk, 0, 0);
        if (!storage)
                return 0;
 
@@ -155,24 +156,24 @@ int BPF_PROG(socket_bind, struct socket *sock, struct sockaddr *address,
        /* This tests that we can associate multiple elements
         * with the local storage.
         */
-       storage = bpf_sk_storage_get(&sk_storage_map2, sock->sk, 0,
+       storage = bpf_sk_storage_get(&sk_storage_map2, sk, 0,
                                     BPF_LOCAL_STORAGE_GET_F_CREATE);
        if (!storage)
                return 0;
 
-       if (bpf_sk_storage_delete(&sk_storage_map2, sock->sk))
+       if (bpf_sk_storage_delete(&sk_storage_map2, sk))
                return 0;
 
-       storage = bpf_sk_storage_get(&sk_storage_map2, sock->sk, 0,
+       storage = bpf_sk_storage_get(&sk_storage_map2, sk, 0,
                                     BPF_LOCAL_STORAGE_GET_F_CREATE);
        if (!storage)
                return 0;
 
-       if (bpf_sk_storage_delete(&sk_storage_map, sock->sk))
+       if (bpf_sk_storage_delete(&sk_storage_map, sk))
                return 0;
 
        /* Ensure that the sk_storage_map is disconnected from the storage. */
-       if (!sock->sk->sk_bpf_storage || sock->sk->sk_bpf_storage->smap)
+       if (!sk->sk_bpf_storage || sk->sk_bpf_storage->smap)
                return 0;
 
        sk_storage_result = 0;
@@ -185,11 +186,12 @@ int BPF_PROG(socket_post_create, struct socket *sock, int family, int type,
 {
        __u32 pid = bpf_get_current_pid_tgid() >> 32;
        struct local_storage *storage;
+       struct sock *sk = sock->sk;
 
-       if (pid != monitored_pid)
+       if (pid != monitored_pid || !sk)
                return 0;
 
-       storage = bpf_sk_storage_get(&sk_storage_map, sock->sk, 0,
+       storage = bpf_sk_storage_get(&sk_storage_map, sk, 0,
                                     BPF_LOCAL_STORAGE_GET_F_CREATE);
        if (!storage)
                return 0;
index 02c11d16b692aba1486b678498e97b7c390a4a8a..d7598538aa2dad1565727ffcfe4ed66ec6f8fc63 100644 (file)
@@ -103,11 +103,15 @@ static __always_inline int real_bind(struct socket *sock,
                                     int addrlen)
 {
        struct sockaddr_ll sa = {};
+       struct sock *sk = sock->sk;
 
-       if (sock->sk->__sk_common.skc_family != AF_PACKET)
+       if (!sk)
+               return 1;
+
+       if (sk->__sk_common.skc_family != AF_PACKET)
                return 1;
 
-       if (sock->sk->sk_kern_sock)
+       if (sk->sk_kern_sock)
                return 1;
 
        bpf_probe_read_kernel(&sa, sizeof(sa), address);