tcp: annotate sk->sk_rcvbuf lockless reads
[linux-2.6-block.git] / net / core / sock.c
index 6d08553f885cdb44ed3e53c2231abb7f48bf309c..8c8f61e70141583afe52420b58fea4bcce3a74f0 100644 (file)
@@ -522,7 +522,7 @@ int __sk_receive_skb(struct sock *sk, struct sk_buff *skb,
                rc = sk_backlog_rcv(sk, skb);
 
                mutex_release(&sk->sk_lock.dep_map, 1, _RET_IP_);
-       } else if (sk_add_backlog(sk, skb, sk->sk_rcvbuf)) {
+       } else if (sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf))) {
                bh_unlock_sock(sk);
                atomic_inc(&sk->sk_drops);
                goto discard_and_relse;
@@ -831,7 +831,8 @@ set_rcvbuf:
                 * returning the value we actually used in getsockopt
                 * is the most desirable behavior.
                 */
-               sk->sk_rcvbuf = max_t(int, val * 2, SOCK_MIN_RCVBUF);
+               WRITE_ONCE(sk->sk_rcvbuf,
+                          max_t(int, val * 2, SOCK_MIN_RCVBUF));
                break;
 
        case SO_RCVBUFFORCE:
@@ -974,7 +975,7 @@ set_rcvbuf:
                if (sock->ops->set_rcvlowat)
                        ret = sock->ops->set_rcvlowat(sk, val);
                else
-                       sk->sk_rcvlowat = val ? : 1;
+                       WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);
                break;
 
        case SO_RCVTIMEO_OLD:
@@ -1700,8 +1701,6 @@ static void __sk_destruct(struct rcu_head *head)
                sk_filter_uncharge(sk, filter);
                RCU_INIT_POINTER(sk->sk_filter, NULL);
        }
-       if (rcu_access_pointer(sk->sk_reuseport_cb))
-               reuseport_detach_sock(sk);
 
        sock_disable_timestamp(sk, SK_FLAGS_TIMESTAMP);
 
@@ -1728,7 +1727,14 @@ static void __sk_destruct(struct rcu_head *head)
 
 void sk_destruct(struct sock *sk)
 {
-       if (sock_flag(sk, SOCK_RCU_FREE))
+       bool use_call_rcu = sock_flag(sk, SOCK_RCU_FREE);
+
+       if (rcu_access_pointer(sk->sk_reuseport_cb)) {
+               reuseport_detach_sock(sk);
+               use_call_rcu = true;
+       }
+
+       if (use_call_rcu)
                call_rcu(&sk->sk_rcu, __sk_destruct);
        else
                __sk_destruct(&sk->sk_rcu);
@@ -1851,9 +1857,12 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
                        goto out;
                }
                RCU_INIT_POINTER(newsk->sk_reuseport_cb, NULL);
-#ifdef CONFIG_BPF_SYSCALL
-               RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL);
-#endif
+
+               if (bpf_sk_storage_clone(sk, newsk)) {
+                       sk_free_unlock_clone(newsk);
+                       newsk = NULL;
+                       goto out;
+               }
 
                newsk->sk_err      = 0;
                newsk->sk_err_soft = 0;
@@ -2326,8 +2335,8 @@ static void sk_leave_memory_pressure(struct sock *sk)
        } else {
                unsigned long *memory_pressure = sk->sk_prot->memory_pressure;
 
-               if (memory_pressure && *memory_pressure)
-                       *memory_pressure = 0;
+               if (memory_pressure && READ_ONCE(*memory_pressure))
+                       WRITE_ONCE(*memory_pressure, 0);
        }
 }
 
@@ -3196,13 +3205,13 @@ void sk_get_meminfo(const struct sock *sk, u32 *mem)
        memset(mem, 0, sizeof(*mem) * SK_MEMINFO_VARS);
 
        mem[SK_MEMINFO_RMEM_ALLOC] = sk_rmem_alloc_get(sk);
-       mem[SK_MEMINFO_RCVBUF] = sk->sk_rcvbuf;
+       mem[SK_MEMINFO_RCVBUF] = READ_ONCE(sk->sk_rcvbuf);
        mem[SK_MEMINFO_WMEM_ALLOC] = sk_wmem_alloc_get(sk);
        mem[SK_MEMINFO_SNDBUF] = sk->sk_sndbuf;
        mem[SK_MEMINFO_FWD_ALLOC] = sk->sk_forward_alloc;
        mem[SK_MEMINFO_WMEM_QUEUED] = sk->sk_wmem_queued;
        mem[SK_MEMINFO_OPTMEM] = atomic_read(&sk->sk_omem_alloc);
-       mem[SK_MEMINFO_BACKLOG] = sk->sk_backlog.len;
+       mem[SK_MEMINFO_BACKLOG] = READ_ONCE(sk->sk_backlog.len);
        mem[SK_MEMINFO_DROPS] = atomic_read(&sk->sk_drops);
 }
 
@@ -3287,16 +3296,17 @@ static __init int net_inuse_init(void)
 
 core_initcall(net_inuse_init);
 
-static void assign_proto_idx(struct proto *prot)
+static int assign_proto_idx(struct proto *prot)
 {
        prot->inuse_idx = find_first_zero_bit(proto_inuse_idx, PROTO_INUSE_NR);
 
        if (unlikely(prot->inuse_idx == PROTO_INUSE_NR - 1)) {
                pr_err("PROTO_INUSE_NR exhausted\n");
-               return;
+               return -ENOSPC;
        }
 
        set_bit(prot->inuse_idx, proto_inuse_idx);
+       return 0;
 }
 
 static void release_proto_idx(struct proto *prot)
@@ -3305,8 +3315,9 @@ static void release_proto_idx(struct proto *prot)
                clear_bit(prot->inuse_idx, proto_inuse_idx);
 }
 #else
-static inline void assign_proto_idx(struct proto *prot)
+static inline int assign_proto_idx(struct proto *prot)
 {
+       return 0;
 }
 
 static inline void release_proto_idx(struct proto *prot)
@@ -3355,6 +3366,8 @@ static int req_prot_init(const struct proto *prot)
 
 int proto_register(struct proto *prot, int alloc_slab)
 {
+       int ret = -ENOBUFS;
+
        if (alloc_slab) {
                prot->slab = kmem_cache_create_usercopy(prot->name,
                                        prot->obj_size, 0,
@@ -3391,20 +3404,27 @@ int proto_register(struct proto *prot, int alloc_slab)
        }
 
        mutex_lock(&proto_list_mutex);
+       ret = assign_proto_idx(prot);
+       if (ret) {
+               mutex_unlock(&proto_list_mutex);
+               goto out_free_timewait_sock_slab_name;
+       }
        list_add(&prot->node, &proto_list);
-       assign_proto_idx(prot);
        mutex_unlock(&proto_list_mutex);
-       return 0;
+       return ret;
 
 out_free_timewait_sock_slab_name:
-       kfree(prot->twsk_prot->twsk_slab_name);
+       if (alloc_slab && prot->twsk_prot)
+               kfree(prot->twsk_prot->twsk_slab_name);
 out_free_request_sock_slab:
-       req_prot_cleanup(prot->rsk_prot);
+       if (alloc_slab) {
+               req_prot_cleanup(prot->rsk_prot);
 
-       kmem_cache_destroy(prot->slab);
-       prot->slab = NULL;
+               kmem_cache_destroy(prot->slab);
+               prot->slab = NULL;
+       }
 out:
-       return -ENOBUFS;
+       return ret;
 }
 EXPORT_SYMBOL(proto_register);
 
@@ -3478,7 +3498,7 @@ static long sock_prot_memory_allocated(struct proto *proto)
        return proto->memory_allocated != NULL ? proto_memory_allocated(proto) : -1L;
 }
 
-static char *sock_prot_memory_pressure(struct proto *proto)
+static const char *sock_prot_memory_pressure(struct proto *proto)
 {
        return proto->memory_pressure != NULL ?
        proto_memory_pressure(proto) ? "yes" : "no" : "NI";