Merge tag 'apparmor-pr-2018-09-06' of git://git.kernel.org/pub/scm/linux/kernel/git...
[linux-2.6-block.git] / kernel / bpf / sockmap.c
index 98e621a29e8e6953ec9dec5b4cb6f8559dd750d3..488ef9663c01f3b4d2cc44b9ca88863ced2860e7 100644 (file)
@@ -236,7 +236,7 @@ static int bpf_tcp_init(struct sock *sk)
 }
 
 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
-static int free_start_sg(struct sock *sk, struct sk_msg_buff *md);
+static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge);
 
 static void bpf_tcp_release(struct sock *sk)
 {
@@ -248,7 +248,7 @@ static void bpf_tcp_release(struct sock *sk)
                goto out;
 
        if (psock->cork) {
-               free_start_sg(psock->sock, psock->cork);
+               free_start_sg(psock->sock, psock->cork, true);
                kfree(psock->cork);
                psock->cork = NULL;
        }
@@ -330,14 +330,14 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
        close_fun = psock->save_close;
 
        if (psock->cork) {
-               free_start_sg(psock->sock, psock->cork);
+               free_start_sg(psock->sock, psock->cork, true);
                kfree(psock->cork);
                psock->cork = NULL;
        }
 
        list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
                list_del(&md->list);
-               free_start_sg(psock->sock, md);
+               free_start_sg(psock->sock, md, true);
                kfree(md);
        }
 
@@ -369,7 +369,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
                        /* If another thread deleted this object skip deletion.
                         * The refcnt on psock may or may not be zero.
                         */
-                       if (l) {
+                       if (l && l == link) {
                                hlist_del_rcu(&link->hash_node);
                                smap_release_sock(psock, link->sk);
                                free_htab_elem(htab, link);
@@ -570,14 +570,16 @@ static void free_bytes_sg(struct sock *sk, int bytes,
        md->sg_start = i;
 }
 
-static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
+static int free_sg(struct sock *sk, int start,
+                  struct sk_msg_buff *md, bool charge)
 {
        struct scatterlist *sg = md->sg_data;
        int i = start, free = 0;
 
        while (sg[i].length) {
                free += sg[i].length;
-               sk_mem_uncharge(sk, sg[i].length);
+               if (charge)
+                       sk_mem_uncharge(sk, sg[i].length);
                if (!md->skb)
                        put_page(sg_page(&sg[i]));
                sg[i].length = 0;
@@ -594,9 +596,9 @@ static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
        return free;
 }
 
-static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
+static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge)
 {
-       int free = free_sg(sk, md->sg_start, md);
+       int free = free_sg(sk, md->sg_start, md, charge);
 
        md->sg_start = md->sg_end;
        return free;
@@ -604,7 +606,7 @@ static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
 
 static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
 {
-       return free_sg(sk, md->sg_curr, md);
+       return free_sg(sk, md->sg_curr, md, true);
 }
 
 static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
@@ -718,7 +720,7 @@ static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
                list_add_tail(&r->list, &psock->ingress);
                sk->sk_data_ready(sk);
        } else {
-               free_start_sg(sk, r);
+               free_start_sg(sk, r, true);
                kfree(r);
        }
 
@@ -752,14 +754,10 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
                release_sock(sk);
        }
        smap_release_sock(psock, sk);
-       if (unlikely(err))
-               goto out;
-       return 0;
+       return err;
 out_rcu:
        rcu_read_unlock();
-out:
-       free_bytes_sg(NULL, send, md, false);
-       return err;
+       return 0;
 }
 
 static inline void bpf_md_init(struct smap_psock *psock)
@@ -822,7 +820,7 @@ more_data:
        case __SK_PASS:
                err = bpf_tcp_push(sk, send, m, flags, true);
                if (unlikely(err)) {
-                       *copied -= free_start_sg(sk, m);
+                       *copied -= free_start_sg(sk, m, true);
                        break;
                }
 
@@ -845,16 +843,17 @@ more_data:
                lock_sock(sk);
 
                if (unlikely(err < 0)) {
-                       free_start_sg(sk, m);
+                       int free = free_start_sg(sk, m, false);
+
                        psock->sg_size = 0;
                        if (!cork)
-                               *copied -= send;
+                               *copied -= free;
                } else {
                        psock->sg_size -= send;
                }
 
                if (cork) {
-                       free_start_sg(sk, m);
+                       free_start_sg(sk, m, true);
                        psock->sg_size = 0;
                        kfree(m);
                        m = NULL;
@@ -912,6 +911,8 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 
        if (unlikely(flags & MSG_ERRQUEUE))
                return inet_recv_error(sk, msg, len, addr_len);
+       if (!skb_queue_empty(&sk->sk_receive_queue))
+               return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 
        rcu_read_lock();
        psock = smap_psock_sk(sk);
@@ -922,9 +923,6 @@ static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
                goto out;
        rcu_read_unlock();
 
-       if (!skb_queue_empty(&sk->sk_receive_queue))
-               return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
-
        lock_sock(sk);
 bytes_ready:
        while (copied != len) {
@@ -1122,7 +1120,7 @@ wait_for_memory:
                err = sk_stream_wait_memory(sk, &timeo);
                if (err) {
                        if (m && m != psock->cork)
-                               free_start_sg(sk, m);
+                               free_start_sg(sk, m, true);
                        goto out_err;
                }
        }
@@ -1427,12 +1425,15 @@ out:
 static void smap_write_space(struct sock *sk)
 {
        struct smap_psock *psock;
+       void (*write_space)(struct sock *sk);
 
        rcu_read_lock();
        psock = smap_psock_sk(sk);
        if (likely(psock && test_bit(SMAP_TX_RUNNING, &psock->state)))
                schedule_work(&psock->tx_work);
+       write_space = psock->save_write_space;
        rcu_read_unlock();
+       write_space(sk);
 }
 
 static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
@@ -1461,10 +1462,16 @@ static void smap_destroy_psock(struct rcu_head *rcu)
        schedule_work(&psock->gc_work);
 }
 
+static bool psock_is_smap_sk(struct sock *sk)
+{
+       return inet_csk(sk)->icsk_ulp_ops == &bpf_tcp_ulp_ops;
+}
+
 static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
 {
        if (refcount_dec_and_test(&psock->refcnt)) {
-               tcp_cleanup_ulp(sock);
+               if (psock_is_smap_sk(sock))
+                       tcp_cleanup_ulp(sock);
                write_lock_bh(&sock->sk_callback_lock);
                smap_stop_sock(psock, sock);
                write_unlock_bh(&sock->sk_callback_lock);
@@ -1578,13 +1585,13 @@ static void smap_gc_work(struct work_struct *w)
                bpf_prog_put(psock->bpf_tx_msg);
 
        if (psock->cork) {
-               free_start_sg(psock->sock, psock->cork);
+               free_start_sg(psock->sock, psock->cork, true);
                kfree(psock->cork);
        }
 
        list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
                list_del(&md->list);
-               free_start_sg(psock->sock, md);
+               free_start_sg(psock->sock, md, true);
                kfree(md);
        }
 
@@ -1891,6 +1898,10 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
         * doesn't update user data.
         */
        if (psock) {
+               if (!psock_is_smap_sk(sock)) {
+                       err = -EBUSY;
+                       goto out_progs;
+               }
                if (READ_ONCE(psock->bpf_parse) && parse) {
                        err = -EBUSY;
                        goto out_progs;
@@ -2140,7 +2151,9 @@ static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
                return ERR_PTR(-EPERM);
 
        /* check sanity of attributes */
-       if (attr->max_entries == 0 || attr->value_size != 4 ||
+       if (attr->max_entries == 0 ||
+           attr->key_size == 0 ||
+           attr->value_size != 4 ||
            attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
                return ERR_PTR(-EINVAL);
 
@@ -2267,8 +2280,10 @@ static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
        }
        l_new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
                             htab->map.numa_node);
-       if (!l_new)
+       if (!l_new) {
+               atomic_dec(&htab->count);
                return ERR_PTR(-ENOMEM);
+       }
 
        memcpy(l_new->key, key, key_size);
        l_new->sk = sk;