virtio/vsock: replace virtio_vsock_pkt with sk_buff
[linux-2.6-block.git] / drivers / vhost / vsock.c
index a2b3743723639cbbe8088c7c9ab420cdfb46a363..1f3b89c885cca9af78c8bc1031cbd87f25228373 100644 (file)
@@ -51,8 +51,7 @@ struct vhost_vsock {
        struct hlist_node hash;
 
        struct vhost_work send_pkt_work;
-       spinlock_t send_pkt_list_lock;
-       struct list_head send_pkt_list; /* host->guest pending packets */
+       struct sk_buff_head send_pkt_queue; /* host->guest pending packets */
 
        atomic_t queued_replies;
 
@@ -108,40 +107,31 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
        vhost_disable_notify(&vsock->dev, vq);
 
        do {
-               struct virtio_vsock_pkt *pkt;
+               struct virtio_vsock_hdr *hdr;
+               size_t iov_len, payload_len;
                struct iov_iter iov_iter;
+               u32 flags_to_restore = 0;
+               struct sk_buff *skb;
                unsigned out, in;
                size_t nbytes;
-               size_t iov_len, payload_len;
                int head;
-               u32 flags_to_restore = 0;
 
-               spin_lock_bh(&vsock->send_pkt_list_lock);
-               if (list_empty(&vsock->send_pkt_list)) {
-                       spin_unlock_bh(&vsock->send_pkt_list_lock);
+               skb = virtio_vsock_skb_dequeue(&vsock->send_pkt_queue);
+
+               if (!skb) {
                        vhost_enable_notify(&vsock->dev, vq);
                        break;
                }
 
-               pkt = list_first_entry(&vsock->send_pkt_list,
-                                      struct virtio_vsock_pkt, list);
-               list_del_init(&pkt->list);
-               spin_unlock_bh(&vsock->send_pkt_list_lock);
-
                head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
                                         &out, &in, NULL, NULL);
                if (head < 0) {
-                       spin_lock_bh(&vsock->send_pkt_list_lock);
-                       list_add(&pkt->list, &vsock->send_pkt_list);
-                       spin_unlock_bh(&vsock->send_pkt_list_lock);
+                       virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb);
                        break;
                }
 
                if (head == vq->num) {
-                       spin_lock_bh(&vsock->send_pkt_list_lock);
-                       list_add(&pkt->list, &vsock->send_pkt_list);
-                       spin_unlock_bh(&vsock->send_pkt_list_lock);
-
+                       virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb);
                        /* We cannot finish yet if more buffers snuck in while
                         * re-enabling notify.
                         */
@@ -153,26 +143,27 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
                }
 
                if (out) {
-                       virtio_transport_free_pkt(pkt);
+                       kfree_skb(skb);
                        vq_err(vq, "Expected 0 output buffers, got %u\n", out);
                        break;
                }
 
                iov_len = iov_length(&vq->iov[out], in);
-               if (iov_len < sizeof(pkt->hdr)) {
-                       virtio_transport_free_pkt(pkt);
+               if (iov_len < sizeof(*hdr)) {
+                       kfree_skb(skb);
                        vq_err(vq, "Buffer len [%zu] too small\n", iov_len);
                        break;
                }
 
                iov_iter_init(&iov_iter, ITER_DEST, &vq->iov[out], in, iov_len);
-               payload_len = pkt->len - pkt->off;
+               payload_len = skb->len;
+               hdr = virtio_vsock_hdr(skb);
 
                /* If the packet is greater than the space available in the
                 * buffer, we split it using multiple buffers.
                 */
-               if (payload_len > iov_len - sizeof(pkt->hdr)) {
-                       payload_len = iov_len - sizeof(pkt->hdr);
+               if (payload_len > iov_len - sizeof(*hdr)) {
+                       payload_len = iov_len - sizeof(*hdr);
 
                        /* As we are copying pieces of large packet's buffer to
                         * small rx buffers, headers of packets in rx queue are
@@ -185,31 +176,30 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
                         * bits set. After initialized header will be copied to
                         * rx buffer, these required bits will be restored.
                         */
-                       if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOM) {
-                               pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
+                       if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
+                               hdr->flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
                                flags_to_restore |= VIRTIO_VSOCK_SEQ_EOM;
 
-                               if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) {
-                                       pkt->hdr.flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
+                               if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR) {
+                                       hdr->flags &= ~cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
                                        flags_to_restore |= VIRTIO_VSOCK_SEQ_EOR;
                                }
                        }
                }
 
                /* Set the correct length in the header */
-               pkt->hdr.len = cpu_to_le32(payload_len);
+               hdr->len = cpu_to_le32(payload_len);
 
-               nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
-               if (nbytes != sizeof(pkt->hdr)) {
-                       virtio_transport_free_pkt(pkt);
+               nbytes = copy_to_iter(hdr, sizeof(*hdr), &iov_iter);
+               if (nbytes != sizeof(*hdr)) {
+                       kfree_skb(skb);
                        vq_err(vq, "Faulted on copying pkt hdr\n");
                        break;
                }
 
-               nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len,
-                                     &iov_iter);
+               nbytes = copy_to_iter(skb->data, payload_len, &iov_iter);
                if (nbytes != payload_len) {
-                       virtio_transport_free_pkt(pkt);
+                       kfree_skb(skb);
                        vq_err(vq, "Faulted on copying pkt buf\n");
                        break;
                }
@@ -217,31 +207,28 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
                /* Deliver to monitoring devices all packets that we
                 * will transmit.
                 */
-               virtio_transport_deliver_tap_pkt(pkt);
+               virtio_transport_deliver_tap_pkt(skb);
 
-               vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len);
+               vhost_add_used(vq, head, sizeof(*hdr) + payload_len);
                added = true;
 
-               pkt->off += payload_len;
+               skb_pull(skb, payload_len);
                total_len += payload_len;
 
                /* If we didn't send all the payload we can requeue the packet
                 * to send it with the next available buffer.
                 */
-               if (pkt->off < pkt->len) {
-                       pkt->hdr.flags |= cpu_to_le32(flags_to_restore);
+               if (skb->len > 0) {
+                       hdr->flags |= cpu_to_le32(flags_to_restore);
 
-                       /* We are queueing the same virtio_vsock_pkt to handle
+                       /* We are queueing the same skb to handle
                         * the remaining bytes, and we want to deliver it
                         * to monitoring devices in the next iteration.
                         */
-                       pkt->tap_delivered = false;
-
-                       spin_lock_bh(&vsock->send_pkt_list_lock);
-                       list_add(&pkt->list, &vsock->send_pkt_list);
-                       spin_unlock_bh(&vsock->send_pkt_list_lock);
+                       virtio_vsock_skb_clear_tap_delivered(skb);
+                       virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb);
                } else {
-                       if (pkt->reply) {
+                       if (virtio_vsock_skb_reply(skb)) {
                                int val;
 
                                val = atomic_dec_return(&vsock->queued_replies);
@@ -253,7 +240,7 @@ vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
                                        restart_tx = true;
                        }
 
-                       virtio_transport_free_pkt(pkt);
+                       consume_skb(skb);
                }
        } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
        if (added)
@@ -278,28 +265,26 @@ static void vhost_transport_send_pkt_work(struct vhost_work *work)
 }
 
 static int
-vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
+vhost_transport_send_pkt(struct sk_buff *skb)
 {
+       struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
        struct vhost_vsock *vsock;
-       int len = pkt->len;
+       int len = skb->len;
 
        rcu_read_lock();
 
        /* Find the vhost_vsock according to guest context id  */
-       vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
+       vsock = vhost_vsock_get(le64_to_cpu(hdr->dst_cid));
        if (!vsock) {
                rcu_read_unlock();
-               virtio_transport_free_pkt(pkt);
+               kfree_skb(skb);
                return -ENODEV;
        }
 
-       if (pkt->reply)
+       if (virtio_vsock_skb_reply(skb))
                atomic_inc(&vsock->queued_replies);
 
-       spin_lock_bh(&vsock->send_pkt_list_lock);
-       list_add_tail(&pkt->list, &vsock->send_pkt_list);
-       spin_unlock_bh(&vsock->send_pkt_list_lock);
-
+       virtio_vsock_skb_queue_tail(&vsock->send_pkt_queue, skb);
        vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
 
        rcu_read_unlock();
@@ -310,10 +295,8 @@ static int
 vhost_transport_cancel_pkt(struct vsock_sock *vsk)
 {
        struct vhost_vsock *vsock;
-       struct virtio_vsock_pkt *pkt, *n;
        int cnt = 0;
        int ret = -ENODEV;
-       LIST_HEAD(freeme);
 
        rcu_read_lock();
 
@@ -322,20 +305,7 @@ vhost_transport_cancel_pkt(struct vsock_sock *vsk)
        if (!vsock)
                goto out;
 
-       spin_lock_bh(&vsock->send_pkt_list_lock);
-       list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
-               if (pkt->vsk != vsk)
-                       continue;
-               list_move(&pkt->list, &freeme);
-       }
-       spin_unlock_bh(&vsock->send_pkt_list_lock);
-
-       list_for_each_entry_safe(pkt, n, &freeme, list) {
-               if (pkt->reply)
-                       cnt++;
-               list_del(&pkt->list);
-               virtio_transport_free_pkt(pkt);
-       }
+       cnt = virtio_transport_purge_skbs(vsk, &vsock->send_pkt_queue);
 
        if (cnt) {
                struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
@@ -352,12 +322,14 @@ out:
        return ret;
 }
 
-static struct virtio_vsock_pkt *
-vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
+static struct sk_buff *
+vhost_vsock_alloc_skb(struct vhost_virtqueue *vq,
                      unsigned int out, unsigned int in)
 {
-       struct virtio_vsock_pkt *pkt;
+       struct virtio_vsock_hdr *hdr;
        struct iov_iter iov_iter;
+       struct sk_buff *skb;
+       size_t payload_len;
        size_t nbytes;
        size_t len;
 
@@ -366,50 +338,48 @@ vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
                return NULL;
        }
 
-       pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
-       if (!pkt)
+       len = iov_length(vq->iov, out);
+
+       /* len contains both payload and hdr */
+       skb = virtio_vsock_alloc_skb(len, GFP_KERNEL);
+       if (!skb)
                return NULL;
 
-       len = iov_length(vq->iov, out);
        iov_iter_init(&iov_iter, ITER_SOURCE, vq->iov, out, len);
 
-       nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
-       if (nbytes != sizeof(pkt->hdr)) {
+       hdr = virtio_vsock_hdr(skb);
+       nbytes = copy_from_iter(hdr, sizeof(*hdr), &iov_iter);
+       if (nbytes != sizeof(*hdr)) {
                vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
-                      sizeof(pkt->hdr), nbytes);
-               kfree(pkt);
+                      sizeof(*hdr), nbytes);
+               kfree_skb(skb);
                return NULL;
        }
 
-       pkt->len = le32_to_cpu(pkt->hdr.len);
+       payload_len = le32_to_cpu(hdr->len);
 
        /* No payload */
-       if (!pkt->len)
-               return pkt;
+       if (!payload_len)
+               return skb;
 
-       /* The pkt is too big */
-       if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
-               kfree(pkt);
+       /* The pkt is too big or the length in the header is invalid */
+       if (payload_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE ||
+           payload_len + sizeof(*hdr) > len) {
+               kfree_skb(skb);
                return NULL;
        }
 
-       pkt->buf = kvmalloc(pkt->len, GFP_KERNEL);
-       if (!pkt->buf) {
-               kfree(pkt);
-               return NULL;
-       }
+       virtio_vsock_skb_rx_put(skb);
 
-       pkt->buf_len = pkt->len;
-
-       nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
-       if (nbytes != pkt->len) {
-               vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
-                      pkt->len, nbytes);
-               virtio_transport_free_pkt(pkt);
+       nbytes = copy_from_iter(skb->data, payload_len, &iov_iter);
+       if (nbytes != payload_len) {
+               vq_err(vq, "Expected %zu byte payload, got %zu bytes\n",
+                      payload_len, nbytes);
+               kfree_skb(skb);
                return NULL;
        }
 
-       return pkt;
+       return skb;
 }
 
 /* Is there space left for replies to rx packets? */
@@ -496,9 +466,9 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
                                                  poll.work);
        struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
                                                 dev);
-       struct virtio_vsock_pkt *pkt;
        int head, pkts = 0, total_len = 0;
        unsigned int out, in;
+       struct sk_buff *skb;
        bool added = false;
 
        mutex_lock(&vq->mutex);
@@ -511,6 +481,8 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
 
        vhost_disable_notify(&vsock->dev, vq);
        do {
+               struct virtio_vsock_hdr *hdr;
+
                if (!vhost_vsock_more_replies(vsock)) {
                        /* Stop tx until the device processes already
                         * pending replies.  Leave tx virtqueue
@@ -532,24 +504,26 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
                        break;
                }
 
-               pkt = vhost_vsock_alloc_pkt(vq, out, in);
-               if (!pkt) {
+               skb = vhost_vsock_alloc_skb(vq, out, in);
+               if (!skb) {
                        vq_err(vq, "Faulted on pkt\n");
                        continue;
                }
 
-               total_len += sizeof(pkt->hdr) + pkt->len;
+               total_len += sizeof(*hdr) + skb->len;
 
                /* Deliver to monitoring devices all received packets */
-               virtio_transport_deliver_tap_pkt(pkt);
+               virtio_transport_deliver_tap_pkt(skb);
+
+               hdr = virtio_vsock_hdr(skb);
 
                /* Only accept correctly addressed packets */
-               if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid &&
-                   le64_to_cpu(pkt->hdr.dst_cid) ==
+               if (le64_to_cpu(hdr->src_cid) == vsock->guest_cid &&
+                   le64_to_cpu(hdr->dst_cid) ==
                    vhost_transport_get_local_cid())
-                       virtio_transport_recv_pkt(&vhost_transport, pkt);
+                       virtio_transport_recv_pkt(&vhost_transport, skb);
                else
-                       virtio_transport_free_pkt(pkt);
+                       kfree_skb(skb);
 
                vhost_add_used(vq, head, 0);
                added = true;
@@ -693,8 +667,7 @@ static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
                       VHOST_VSOCK_WEIGHT, true, NULL);
 
        file->private_data = vsock;
-       spin_lock_init(&vsock->send_pkt_list_lock);
-       INIT_LIST_HEAD(&vsock->send_pkt_list);
+       skb_queue_head_init(&vsock->send_pkt_queue);
        vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
        return 0;
 
@@ -760,16 +733,7 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
        vhost_vsock_flush(vsock);
        vhost_dev_stop(&vsock->dev);
 
-       spin_lock_bh(&vsock->send_pkt_list_lock);
-       while (!list_empty(&vsock->send_pkt_list)) {
-               struct virtio_vsock_pkt *pkt;
-
-               pkt = list_first_entry(&vsock->send_pkt_list,
-                               struct virtio_vsock_pkt, list);
-               list_del_init(&pkt->list);
-               virtio_transport_free_pkt(pkt);
-       }
-       spin_unlock_bh(&vsock->send_pkt_list_lock);
+       virtio_vsock_skb_queue_purge(&vsock->send_pkt_queue);
 
        vhost_dev_cleanup(&vsock->dev);
        kfree(vsock->dev.vqs);