d02c9b41a768e6a9a968e1b685a8ed2f390212e6
[linux-2.6-block.git] / net / vmw_vsock / virtio_transport_common.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * common code for virtio vsock
4  *
5  * Copyright (C) 2013-2015 Red Hat, Inc.
6  * Author: Asias He <asias@redhat.com>
7  *         Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 #include <linux/spinlock.h>
10 #include <linux/module.h>
11 #include <linux/sched/signal.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio.h>
15 #include <linux/virtio_ids.h>
16 #include <linux/virtio_config.h>
17 #include <linux/virtio_vsock.h>
18 #include <uapi/linux/vsockmon.h>
19
20 #include <net/sock.h>
21 #include <net/af_vsock.h>
22
23 #define CREATE_TRACE_POINTS
24 #include <trace/events/vsock_virtio_transport_common.h>
25
26 /* How long to wait for graceful shutdown of a connection */
27 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
28
29 /* Threshold for detecting small packets to copy */
30 #define GOOD_COPY_LEN  128
31
32 static const struct virtio_transport *virtio_transport_get_ops(void)
33 {
34         const struct vsock_transport *t = vsock_core_get_transport();
35
36         return container_of(t, struct virtio_transport, transport);
37 }
38
39 static struct virtio_vsock_pkt *
40 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
41                            size_t len,
42                            u32 src_cid,
43                            u32 src_port,
44                            u32 dst_cid,
45                            u32 dst_port)
46 {
47         struct virtio_vsock_pkt *pkt;
48         int err;
49
50         pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
51         if (!pkt)
52                 return NULL;
53
54         pkt->hdr.type           = cpu_to_le16(info->type);
55         pkt->hdr.op             = cpu_to_le16(info->op);
56         pkt->hdr.src_cid        = cpu_to_le64(src_cid);
57         pkt->hdr.dst_cid        = cpu_to_le64(dst_cid);
58         pkt->hdr.src_port       = cpu_to_le32(src_port);
59         pkt->hdr.dst_port       = cpu_to_le32(dst_port);
60         pkt->hdr.flags          = cpu_to_le32(info->flags);
61         pkt->len                = len;
62         pkt->hdr.len            = cpu_to_le32(len);
63         pkt->reply              = info->reply;
64         pkt->vsk                = info->vsk;
65
66         if (info->msg && len > 0) {
67                 pkt->buf = kmalloc(len, GFP_KERNEL);
68                 if (!pkt->buf)
69                         goto out_pkt;
70
71                 pkt->buf_len = len;
72
73                 err = memcpy_from_msg(pkt->buf, info->msg, len);
74                 if (err)
75                         goto out;
76         }
77
78         trace_virtio_transport_alloc_pkt(src_cid, src_port,
79                                          dst_cid, dst_port,
80                                          len,
81                                          info->type,
82                                          info->op,
83                                          info->flags);
84
85         return pkt;
86
87 out:
88         kfree(pkt->buf);
89 out_pkt:
90         kfree(pkt);
91         return NULL;
92 }
93
94 /* Packet capture */
95 static struct sk_buff *virtio_transport_build_skb(void *opaque)
96 {
97         struct virtio_vsock_pkt *pkt = opaque;
98         struct af_vsockmon_hdr *hdr;
99         struct sk_buff *skb;
100         size_t payload_len;
101         void *payload_buf;
102
103         /* A packet could be split to fit the RX buffer, so we can retrieve
104          * the payload length from the header and the buffer pointer taking
105          * care of the offset in the original packet.
106          */
107         payload_len = le32_to_cpu(pkt->hdr.len);
108         payload_buf = pkt->buf + pkt->off;
109
110         skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + payload_len,
111                         GFP_ATOMIC);
112         if (!skb)
113                 return NULL;
114
115         hdr = skb_put(skb, sizeof(*hdr));
116
117         /* pkt->hdr is little-endian so no need to byteswap here */
118         hdr->src_cid = pkt->hdr.src_cid;
119         hdr->src_port = pkt->hdr.src_port;
120         hdr->dst_cid = pkt->hdr.dst_cid;
121         hdr->dst_port = pkt->hdr.dst_port;
122
123         hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
124         hdr->len = cpu_to_le16(sizeof(pkt->hdr));
125         memset(hdr->reserved, 0, sizeof(hdr->reserved));
126
127         switch (le16_to_cpu(pkt->hdr.op)) {
128         case VIRTIO_VSOCK_OP_REQUEST:
129         case VIRTIO_VSOCK_OP_RESPONSE:
130                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
131                 break;
132         case VIRTIO_VSOCK_OP_RST:
133         case VIRTIO_VSOCK_OP_SHUTDOWN:
134                 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
135                 break;
136         case VIRTIO_VSOCK_OP_RW:
137                 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
138                 break;
139         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
140         case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
141                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
142                 break;
143         default:
144                 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
145                 break;
146         }
147
148         skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
149
150         if (payload_len) {
151                 skb_put_data(skb, payload_buf, payload_len);
152         }
153
154         return skb;
155 }
156
157 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
158 {
159         vsock_deliver_tap(virtio_transport_build_skb, pkt);
160 }
161 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
162
163 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
164                                           struct virtio_vsock_pkt_info *info)
165 {
166         u32 src_cid, src_port, dst_cid, dst_port;
167         struct virtio_vsock_sock *vvs;
168         struct virtio_vsock_pkt *pkt;
169         u32 pkt_len = info->pkt_len;
170
171         src_cid = vm_sockets_get_local_cid();
172         src_port = vsk->local_addr.svm_port;
173         if (!info->remote_cid) {
174                 dst_cid = vsk->remote_addr.svm_cid;
175                 dst_port = vsk->remote_addr.svm_port;
176         } else {
177                 dst_cid = info->remote_cid;
178                 dst_port = info->remote_port;
179         }
180
181         vvs = vsk->trans;
182
183         /* we can send less than pkt_len bytes */
184         if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
185                 pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
186
187         /* virtio_transport_get_credit might return less than pkt_len credit */
188         pkt_len = virtio_transport_get_credit(vvs, pkt_len);
189
190         /* Do not send zero length OP_RW pkt */
191         if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
192                 return pkt_len;
193
194         pkt = virtio_transport_alloc_pkt(info, pkt_len,
195                                          src_cid, src_port,
196                                          dst_cid, dst_port);
197         if (!pkt) {
198                 virtio_transport_put_credit(vvs, pkt_len);
199                 return -ENOMEM;
200         }
201
202         virtio_transport_inc_tx_pkt(vvs, pkt);
203
204         return virtio_transport_get_ops()->send_pkt(pkt);
205 }
206
207 static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
208                                         struct virtio_vsock_pkt *pkt)
209 {
210         if (vvs->rx_bytes + pkt->len > vvs->buf_alloc)
211                 return false;
212
213         vvs->rx_bytes += pkt->len;
214         return true;
215 }
216
217 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
218                                         struct virtio_vsock_pkt *pkt)
219 {
220         vvs->rx_bytes -= pkt->len;
221         vvs->fwd_cnt += pkt->len;
222 }
223
224 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
225 {
226         spin_lock_bh(&vvs->rx_lock);
227         vvs->last_fwd_cnt = vvs->fwd_cnt;
228         pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
229         pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
230         spin_unlock_bh(&vvs->rx_lock);
231 }
232 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
233
234 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
235 {
236         u32 ret;
237
238         spin_lock_bh(&vvs->tx_lock);
239         ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
240         if (ret > credit)
241                 ret = credit;
242         vvs->tx_cnt += ret;
243         spin_unlock_bh(&vvs->tx_lock);
244
245         return ret;
246 }
247 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
248
249 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
250 {
251         spin_lock_bh(&vvs->tx_lock);
252         vvs->tx_cnt -= credit;
253         spin_unlock_bh(&vvs->tx_lock);
254 }
255 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
256
257 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
258                                                int type,
259                                                struct virtio_vsock_hdr *hdr)
260 {
261         struct virtio_vsock_pkt_info info = {
262                 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
263                 .type = type,
264                 .vsk = vsk,
265         };
266
267         return virtio_transport_send_pkt_info(vsk, &info);
268 }
269
270 static ssize_t
271 virtio_transport_stream_do_peek(struct vsock_sock *vsk,
272                                 struct msghdr *msg,
273                                 size_t len)
274 {
275         struct virtio_vsock_sock *vvs = vsk->trans;
276         struct virtio_vsock_pkt *pkt;
277         size_t bytes, total = 0, off;
278         int err = -EFAULT;
279
280         spin_lock_bh(&vvs->rx_lock);
281
282         list_for_each_entry(pkt, &vvs->rx_queue, list) {
283                 off = pkt->off;
284
285                 if (total == len)
286                         break;
287
288                 while (total < len && off < pkt->len) {
289                         bytes = len - total;
290                         if (bytes > pkt->len - off)
291                                 bytes = pkt->len - off;
292
293                         /* sk_lock is held by caller so no one else can dequeue.
294                          * Unlock rx_lock since memcpy_to_msg() may sleep.
295                          */
296                         spin_unlock_bh(&vvs->rx_lock);
297
298                         err = memcpy_to_msg(msg, pkt->buf + off, bytes);
299                         if (err)
300                                 goto out;
301
302                         spin_lock_bh(&vvs->rx_lock);
303
304                         total += bytes;
305                         off += bytes;
306                 }
307         }
308
309         spin_unlock_bh(&vvs->rx_lock);
310
311         return total;
312
313 out:
314         if (total)
315                 err = total;
316         return err;
317 }
318
319 static ssize_t
320 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
321                                    struct msghdr *msg,
322                                    size_t len)
323 {
324         struct virtio_vsock_sock *vvs = vsk->trans;
325         struct virtio_vsock_pkt *pkt;
326         size_t bytes, total = 0;
327         u32 free_space;
328         int err = -EFAULT;
329
330         spin_lock_bh(&vvs->rx_lock);
331         while (total < len && !list_empty(&vvs->rx_queue)) {
332                 pkt = list_first_entry(&vvs->rx_queue,
333                                        struct virtio_vsock_pkt, list);
334
335                 bytes = len - total;
336                 if (bytes > pkt->len - pkt->off)
337                         bytes = pkt->len - pkt->off;
338
339                 /* sk_lock is held by caller so no one else can dequeue.
340                  * Unlock rx_lock since memcpy_to_msg() may sleep.
341                  */
342                 spin_unlock_bh(&vvs->rx_lock);
343
344                 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
345                 if (err)
346                         goto out;
347
348                 spin_lock_bh(&vvs->rx_lock);
349
350                 total += bytes;
351                 pkt->off += bytes;
352                 if (pkt->off == pkt->len) {
353                         virtio_transport_dec_rx_pkt(vvs, pkt);
354                         list_del(&pkt->list);
355                         virtio_transport_free_pkt(pkt);
356                 }
357         }
358
359         free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt);
360
361         spin_unlock_bh(&vvs->rx_lock);
362
363         /* To reduce the number of credit update messages,
364          * don't update credits as long as lots of space is available.
365          * Note: the limit chosen here is arbitrary. Setting the limit
366          * too high causes extra messages. Too low causes transmitter
367          * stalls. As stalls are in theory more expensive than extra
368          * messages, we set the limit to a high value. TODO: experiment
369          * with different values.
370          */
371         if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
372                 virtio_transport_send_credit_update(vsk,
373                                                     VIRTIO_VSOCK_TYPE_STREAM,
374                                                     NULL);
375         }
376
377         return total;
378
379 out:
380         if (total)
381                 err = total;
382         return err;
383 }
384
385 ssize_t
386 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
387                                 struct msghdr *msg,
388                                 size_t len, int flags)
389 {
390         if (flags & MSG_PEEK)
391                 return virtio_transport_stream_do_peek(vsk, msg, len);
392         else
393                 return virtio_transport_stream_do_dequeue(vsk, msg, len);
394 }
395 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
396
397 int
398 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
399                                struct msghdr *msg,
400                                size_t len, int flags)
401 {
402         return -EOPNOTSUPP;
403 }
404 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
405
406 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
407 {
408         struct virtio_vsock_sock *vvs = vsk->trans;
409         s64 bytes;
410
411         spin_lock_bh(&vvs->rx_lock);
412         bytes = vvs->rx_bytes;
413         spin_unlock_bh(&vvs->rx_lock);
414
415         return bytes;
416 }
417 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
418
419 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
420 {
421         struct virtio_vsock_sock *vvs = vsk->trans;
422         s64 bytes;
423
424         bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
425         if (bytes < 0)
426                 bytes = 0;
427
428         return bytes;
429 }
430
431 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
432 {
433         struct virtio_vsock_sock *vvs = vsk->trans;
434         s64 bytes;
435
436         spin_lock_bh(&vvs->tx_lock);
437         bytes = virtio_transport_has_space(vsk);
438         spin_unlock_bh(&vvs->tx_lock);
439
440         return bytes;
441 }
442 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
443
444 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
445                                     struct vsock_sock *psk)
446 {
447         struct virtio_vsock_sock *vvs;
448
449         vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
450         if (!vvs)
451                 return -ENOMEM;
452
453         vsk->trans = vvs;
454         vvs->vsk = vsk;
455         if (psk) {
456                 struct virtio_vsock_sock *ptrans = psk->trans;
457
458                 vvs->buf_size   = ptrans->buf_size;
459                 vvs->buf_size_min = ptrans->buf_size_min;
460                 vvs->buf_size_max = ptrans->buf_size_max;
461                 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
462         } else {
463                 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
464                 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
465                 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
466         }
467
468         vvs->buf_alloc = vvs->buf_size;
469
470         spin_lock_init(&vvs->rx_lock);
471         spin_lock_init(&vvs->tx_lock);
472         INIT_LIST_HEAD(&vvs->rx_queue);
473
474         return 0;
475 }
476 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
477
478 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
479 {
480         struct virtio_vsock_sock *vvs = vsk->trans;
481
482         return vvs->buf_size;
483 }
484 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
485
486 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
487 {
488         struct virtio_vsock_sock *vvs = vsk->trans;
489
490         return vvs->buf_size_min;
491 }
492 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
493
494 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
495 {
496         struct virtio_vsock_sock *vvs = vsk->trans;
497
498         return vvs->buf_size_max;
499 }
500 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
501
502 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
503 {
504         struct virtio_vsock_sock *vvs = vsk->trans;
505
506         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
507                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
508         if (val < vvs->buf_size_min)
509                 vvs->buf_size_min = val;
510         if (val > vvs->buf_size_max)
511                 vvs->buf_size_max = val;
512         vvs->buf_size = val;
513         vvs->buf_alloc = val;
514
515         virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
516                                             NULL);
517 }
518 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
519
520 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
521 {
522         struct virtio_vsock_sock *vvs = vsk->trans;
523
524         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
525                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
526         if (val > vvs->buf_size)
527                 vvs->buf_size = val;
528         vvs->buf_size_min = val;
529 }
530 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
531
532 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
533 {
534         struct virtio_vsock_sock *vvs = vsk->trans;
535
536         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
537                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
538         if (val < vvs->buf_size)
539                 vvs->buf_size = val;
540         vvs->buf_size_max = val;
541 }
542 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
543
544 int
545 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
546                                 size_t target,
547                                 bool *data_ready_now)
548 {
549         if (vsock_stream_has_data(vsk))
550                 *data_ready_now = true;
551         else
552                 *data_ready_now = false;
553
554         return 0;
555 }
556 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
557
558 int
559 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
560                                  size_t target,
561                                  bool *space_avail_now)
562 {
563         s64 free_space;
564
565         free_space = vsock_stream_has_space(vsk);
566         if (free_space > 0)
567                 *space_avail_now = true;
568         else if (free_space == 0)
569                 *space_avail_now = false;
570
571         return 0;
572 }
573 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
574
575 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
576         size_t target, struct vsock_transport_recv_notify_data *data)
577 {
578         return 0;
579 }
580 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
581
582 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
583         size_t target, struct vsock_transport_recv_notify_data *data)
584 {
585         return 0;
586 }
587 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
588
589 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
590         size_t target, struct vsock_transport_recv_notify_data *data)
591 {
592         return 0;
593 }
594 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
595
596 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
597         size_t target, ssize_t copied, bool data_read,
598         struct vsock_transport_recv_notify_data *data)
599 {
600         return 0;
601 }
602 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
603
604 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
605         struct vsock_transport_send_notify_data *data)
606 {
607         return 0;
608 }
609 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
610
611 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
612         struct vsock_transport_send_notify_data *data)
613 {
614         return 0;
615 }
616 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
617
618 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
619         struct vsock_transport_send_notify_data *data)
620 {
621         return 0;
622 }
623 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
624
625 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
626         ssize_t written, struct vsock_transport_send_notify_data *data)
627 {
628         return 0;
629 }
630 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
631
632 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
633 {
634         struct virtio_vsock_sock *vvs = vsk->trans;
635
636         return vvs->buf_size;
637 }
638 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
639
640 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
641 {
642         return true;
643 }
644 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
645
646 bool virtio_transport_stream_allow(u32 cid, u32 port)
647 {
648         return true;
649 }
650 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
651
652 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
653                                 struct sockaddr_vm *addr)
654 {
655         return -EOPNOTSUPP;
656 }
657 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
658
659 bool virtio_transport_dgram_allow(u32 cid, u32 port)
660 {
661         return false;
662 }
663 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
664
665 int virtio_transport_connect(struct vsock_sock *vsk)
666 {
667         struct virtio_vsock_pkt_info info = {
668                 .op = VIRTIO_VSOCK_OP_REQUEST,
669                 .type = VIRTIO_VSOCK_TYPE_STREAM,
670                 .vsk = vsk,
671         };
672
673         return virtio_transport_send_pkt_info(vsk, &info);
674 }
675 EXPORT_SYMBOL_GPL(virtio_transport_connect);
676
677 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
678 {
679         struct virtio_vsock_pkt_info info = {
680                 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
681                 .type = VIRTIO_VSOCK_TYPE_STREAM,
682                 .flags = (mode & RCV_SHUTDOWN ?
683                           VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
684                          (mode & SEND_SHUTDOWN ?
685                           VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
686                 .vsk = vsk,
687         };
688
689         return virtio_transport_send_pkt_info(vsk, &info);
690 }
691 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
692
693 int
694 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
695                                struct sockaddr_vm *remote_addr,
696                                struct msghdr *msg,
697                                size_t dgram_len)
698 {
699         return -EOPNOTSUPP;
700 }
701 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
702
703 ssize_t
704 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
705                                 struct msghdr *msg,
706                                 size_t len)
707 {
708         struct virtio_vsock_pkt_info info = {
709                 .op = VIRTIO_VSOCK_OP_RW,
710                 .type = VIRTIO_VSOCK_TYPE_STREAM,
711                 .msg = msg,
712                 .pkt_len = len,
713                 .vsk = vsk,
714         };
715
716         return virtio_transport_send_pkt_info(vsk, &info);
717 }
718 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
719
720 void virtio_transport_destruct(struct vsock_sock *vsk)
721 {
722         struct virtio_vsock_sock *vvs = vsk->trans;
723
724         kfree(vvs);
725 }
726 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
727
728 static int virtio_transport_reset(struct vsock_sock *vsk,
729                                   struct virtio_vsock_pkt *pkt)
730 {
731         struct virtio_vsock_pkt_info info = {
732                 .op = VIRTIO_VSOCK_OP_RST,
733                 .type = VIRTIO_VSOCK_TYPE_STREAM,
734                 .reply = !!pkt,
735                 .vsk = vsk,
736         };
737
738         /* Send RST only if the original pkt is not a RST pkt */
739         if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
740                 return 0;
741
742         return virtio_transport_send_pkt_info(vsk, &info);
743 }
744
745 /* Normally packets are associated with a socket.  There may be no socket if an
746  * attempt was made to connect to a socket that does not exist.
747  */
748 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
749 {
750         const struct virtio_transport *t;
751         struct virtio_vsock_pkt *reply;
752         struct virtio_vsock_pkt_info info = {
753                 .op = VIRTIO_VSOCK_OP_RST,
754                 .type = le16_to_cpu(pkt->hdr.type),
755                 .reply = true,
756         };
757
758         /* Send RST only if the original pkt is not a RST pkt */
759         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
760                 return 0;
761
762         reply = virtio_transport_alloc_pkt(&info, 0,
763                                            le64_to_cpu(pkt->hdr.dst_cid),
764                                            le32_to_cpu(pkt->hdr.dst_port),
765                                            le64_to_cpu(pkt->hdr.src_cid),
766                                            le32_to_cpu(pkt->hdr.src_port));
767         if (!reply)
768                 return -ENOMEM;
769
770         t = virtio_transport_get_ops();
771         if (!t) {
772                 virtio_transport_free_pkt(reply);
773                 return -ENOTCONN;
774         }
775
776         return t->send_pkt(reply);
777 }
778
779 static void virtio_transport_wait_close(struct sock *sk, long timeout)
780 {
781         if (timeout) {
782                 DEFINE_WAIT_FUNC(wait, woken_wake_function);
783
784                 add_wait_queue(sk_sleep(sk), &wait);
785
786                 do {
787                         if (sk_wait_event(sk, &timeout,
788                                           sock_flag(sk, SOCK_DONE), &wait))
789                                 break;
790                 } while (!signal_pending(current) && timeout);
791
792                 remove_wait_queue(sk_sleep(sk), &wait);
793         }
794 }
795
796 static void virtio_transport_do_close(struct vsock_sock *vsk,
797                                       bool cancel_timeout)
798 {
799         struct sock *sk = sk_vsock(vsk);
800
801         sock_set_flag(sk, SOCK_DONE);
802         vsk->peer_shutdown = SHUTDOWN_MASK;
803         if (vsock_stream_has_data(vsk) <= 0)
804                 sk->sk_state = TCP_CLOSING;
805         sk->sk_state_change(sk);
806
807         if (vsk->close_work_scheduled &&
808             (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
809                 vsk->close_work_scheduled = false;
810
811                 vsock_remove_sock(vsk);
812
813                 /* Release refcnt obtained when we scheduled the timeout */
814                 sock_put(sk);
815         }
816 }
817
818 static void virtio_transport_close_timeout(struct work_struct *work)
819 {
820         struct vsock_sock *vsk =
821                 container_of(work, struct vsock_sock, close_work.work);
822         struct sock *sk = sk_vsock(vsk);
823
824         sock_hold(sk);
825         lock_sock(sk);
826
827         if (!sock_flag(sk, SOCK_DONE)) {
828                 (void)virtio_transport_reset(vsk, NULL);
829
830                 virtio_transport_do_close(vsk, false);
831         }
832
833         vsk->close_work_scheduled = false;
834
835         release_sock(sk);
836         sock_put(sk);
837 }
838
839 /* User context, vsk->sk is locked */
840 static bool virtio_transport_close(struct vsock_sock *vsk)
841 {
842         struct sock *sk = &vsk->sk;
843
844         if (!(sk->sk_state == TCP_ESTABLISHED ||
845               sk->sk_state == TCP_CLOSING))
846                 return true;
847
848         /* Already received SHUTDOWN from peer, reply with RST */
849         if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
850                 (void)virtio_transport_reset(vsk, NULL);
851                 return true;
852         }
853
854         if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
855                 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
856
857         if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
858                 virtio_transport_wait_close(sk, sk->sk_lingertime);
859
860         if (sock_flag(sk, SOCK_DONE)) {
861                 return true;
862         }
863
864         sock_hold(sk);
865         INIT_DELAYED_WORK(&vsk->close_work,
866                           virtio_transport_close_timeout);
867         vsk->close_work_scheduled = true;
868         schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
869         return false;
870 }
871
872 void virtio_transport_release(struct vsock_sock *vsk)
873 {
874         struct virtio_vsock_sock *vvs = vsk->trans;
875         struct virtio_vsock_pkt *pkt, *tmp;
876         struct sock *sk = &vsk->sk;
877         bool remove_sock = true;
878
879         lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
880         if (sk->sk_type == SOCK_STREAM)
881                 remove_sock = virtio_transport_close(vsk);
882
883         list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
884                 list_del(&pkt->list);
885                 virtio_transport_free_pkt(pkt);
886         }
887         release_sock(sk);
888
889         if (remove_sock)
890                 vsock_remove_sock(vsk);
891 }
892 EXPORT_SYMBOL_GPL(virtio_transport_release);
893
894 static int
895 virtio_transport_recv_connecting(struct sock *sk,
896                                  struct virtio_vsock_pkt *pkt)
897 {
898         struct vsock_sock *vsk = vsock_sk(sk);
899         int err;
900         int skerr;
901
902         switch (le16_to_cpu(pkt->hdr.op)) {
903         case VIRTIO_VSOCK_OP_RESPONSE:
904                 sk->sk_state = TCP_ESTABLISHED;
905                 sk->sk_socket->state = SS_CONNECTED;
906                 vsock_insert_connected(vsk);
907                 sk->sk_state_change(sk);
908                 break;
909         case VIRTIO_VSOCK_OP_INVALID:
910                 break;
911         case VIRTIO_VSOCK_OP_RST:
912                 skerr = ECONNRESET;
913                 err = 0;
914                 goto destroy;
915         default:
916                 skerr = EPROTO;
917                 err = -EINVAL;
918                 goto destroy;
919         }
920         return 0;
921
922 destroy:
923         virtio_transport_reset(vsk, pkt);
924         sk->sk_state = TCP_CLOSE;
925         sk->sk_err = skerr;
926         sk->sk_error_report(sk);
927         return err;
928 }
929
930 static void
931 virtio_transport_recv_enqueue(struct vsock_sock *vsk,
932                               struct virtio_vsock_pkt *pkt)
933 {
934         struct virtio_vsock_sock *vvs = vsk->trans;
935         bool can_enqueue, free_pkt = false;
936
937         pkt->len = le32_to_cpu(pkt->hdr.len);
938         pkt->off = 0;
939
940         spin_lock_bh(&vvs->rx_lock);
941
942         can_enqueue = virtio_transport_inc_rx_pkt(vvs, pkt);
943         if (!can_enqueue) {
944                 free_pkt = true;
945                 goto out;
946         }
947
948         /* Try to copy small packets into the buffer of last packet queued,
949          * to avoid wasting memory queueing the entire buffer with a small
950          * payload.
951          */
952         if (pkt->len <= GOOD_COPY_LEN && !list_empty(&vvs->rx_queue)) {
953                 struct virtio_vsock_pkt *last_pkt;
954
955                 last_pkt = list_last_entry(&vvs->rx_queue,
956                                            struct virtio_vsock_pkt, list);
957
958                 /* If there is space in the last packet queued, we copy the
959                  * new packet in its buffer.
960                  */
961                 if (pkt->len <= last_pkt->buf_len - last_pkt->len) {
962                         memcpy(last_pkt->buf + last_pkt->len, pkt->buf,
963                                pkt->len);
964                         last_pkt->len += pkt->len;
965                         free_pkt = true;
966                         goto out;
967                 }
968         }
969
970         list_add_tail(&pkt->list, &vvs->rx_queue);
971
972 out:
973         spin_unlock_bh(&vvs->rx_lock);
974         if (free_pkt)
975                 virtio_transport_free_pkt(pkt);
976 }
977
978 static int
979 virtio_transport_recv_connected(struct sock *sk,
980                                 struct virtio_vsock_pkt *pkt)
981 {
982         struct vsock_sock *vsk = vsock_sk(sk);
983         int err = 0;
984
985         switch (le16_to_cpu(pkt->hdr.op)) {
986         case VIRTIO_VSOCK_OP_RW:
987                 virtio_transport_recv_enqueue(vsk, pkt);
988                 sk->sk_data_ready(sk);
989                 return err;
990         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
991                 sk->sk_write_space(sk);
992                 break;
993         case VIRTIO_VSOCK_OP_SHUTDOWN:
994                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
995                         vsk->peer_shutdown |= RCV_SHUTDOWN;
996                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
997                         vsk->peer_shutdown |= SEND_SHUTDOWN;
998                 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
999                     vsock_stream_has_data(vsk) <= 0) {
1000                         sock_set_flag(sk, SOCK_DONE);
1001                         sk->sk_state = TCP_CLOSING;
1002                 }
1003                 if (le32_to_cpu(pkt->hdr.flags))
1004                         sk->sk_state_change(sk);
1005                 break;
1006         case VIRTIO_VSOCK_OP_RST:
1007                 virtio_transport_do_close(vsk, true);
1008                 break;
1009         default:
1010                 err = -EINVAL;
1011                 break;
1012         }
1013
1014         virtio_transport_free_pkt(pkt);
1015         return err;
1016 }
1017
1018 static void
1019 virtio_transport_recv_disconnecting(struct sock *sk,
1020                                     struct virtio_vsock_pkt *pkt)
1021 {
1022         struct vsock_sock *vsk = vsock_sk(sk);
1023
1024         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
1025                 virtio_transport_do_close(vsk, true);
1026 }
1027
1028 static int
1029 virtio_transport_send_response(struct vsock_sock *vsk,
1030                                struct virtio_vsock_pkt *pkt)
1031 {
1032         struct virtio_vsock_pkt_info info = {
1033                 .op = VIRTIO_VSOCK_OP_RESPONSE,
1034                 .type = VIRTIO_VSOCK_TYPE_STREAM,
1035                 .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
1036                 .remote_port = le32_to_cpu(pkt->hdr.src_port),
1037                 .reply = true,
1038                 .vsk = vsk,
1039         };
1040
1041         return virtio_transport_send_pkt_info(vsk, &info);
1042 }
1043
1044 /* Handle server socket */
1045 static int
1046 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
1047 {
1048         struct vsock_sock *vsk = vsock_sk(sk);
1049         struct vsock_sock *vchild;
1050         struct sock *child;
1051
1052         if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
1053                 virtio_transport_reset(vsk, pkt);
1054                 return -EINVAL;
1055         }
1056
1057         if (sk_acceptq_is_full(sk)) {
1058                 virtio_transport_reset(vsk, pkt);
1059                 return -ENOMEM;
1060         }
1061
1062         child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
1063                                sk->sk_type, 0);
1064         if (!child) {
1065                 virtio_transport_reset(vsk, pkt);
1066                 return -ENOMEM;
1067         }
1068
1069         sk->sk_ack_backlog++;
1070
1071         lock_sock_nested(child, SINGLE_DEPTH_NESTING);
1072
1073         child->sk_state = TCP_ESTABLISHED;
1074
1075         vchild = vsock_sk(child);
1076         vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
1077                         le32_to_cpu(pkt->hdr.dst_port));
1078         vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
1079                         le32_to_cpu(pkt->hdr.src_port));
1080
1081         vsock_insert_connected(vchild);
1082         vsock_enqueue_accept(sk, child);
1083         virtio_transport_send_response(vchild, pkt);
1084
1085         release_sock(child);
1086
1087         sk->sk_data_ready(sk);
1088         return 0;
1089 }
1090
1091 static bool virtio_transport_space_update(struct sock *sk,
1092                                           struct virtio_vsock_pkt *pkt)
1093 {
1094         struct vsock_sock *vsk = vsock_sk(sk);
1095         struct virtio_vsock_sock *vvs = vsk->trans;
1096         bool space_available;
1097
1098         /* buf_alloc and fwd_cnt is always included in the hdr */
1099         spin_lock_bh(&vvs->tx_lock);
1100         vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
1101         vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
1102         space_available = virtio_transport_has_space(vsk);
1103         spin_unlock_bh(&vvs->tx_lock);
1104         return space_available;
1105 }
1106
1107 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
1108  * lock.
1109  */
1110 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
1111 {
1112         struct sockaddr_vm src, dst;
1113         struct vsock_sock *vsk;
1114         struct sock *sk;
1115         bool space_available;
1116
1117         vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
1118                         le32_to_cpu(pkt->hdr.src_port));
1119         vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
1120                         le32_to_cpu(pkt->hdr.dst_port));
1121
1122         trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
1123                                         dst.svm_cid, dst.svm_port,
1124                                         le32_to_cpu(pkt->hdr.len),
1125                                         le16_to_cpu(pkt->hdr.type),
1126                                         le16_to_cpu(pkt->hdr.op),
1127                                         le32_to_cpu(pkt->hdr.flags),
1128                                         le32_to_cpu(pkt->hdr.buf_alloc),
1129                                         le32_to_cpu(pkt->hdr.fwd_cnt));
1130
1131         if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
1132                 (void)virtio_transport_reset_no_sock(pkt);
1133                 goto free_pkt;
1134         }
1135
1136         /* The socket must be in connected or bound table
1137          * otherwise send reset back
1138          */
1139         sk = vsock_find_connected_socket(&src, &dst);
1140         if (!sk) {
1141                 sk = vsock_find_bound_socket(&dst);
1142                 if (!sk) {
1143                         (void)virtio_transport_reset_no_sock(pkt);
1144                         goto free_pkt;
1145                 }
1146         }
1147
1148         vsk = vsock_sk(sk);
1149
1150         space_available = virtio_transport_space_update(sk, pkt);
1151
1152         lock_sock(sk);
1153
1154         /* Update CID in case it has changed after a transport reset event */
1155         vsk->local_addr.svm_cid = dst.svm_cid;
1156
1157         if (space_available)
1158                 sk->sk_write_space(sk);
1159
1160         switch (sk->sk_state) {
1161         case TCP_LISTEN:
1162                 virtio_transport_recv_listen(sk, pkt);
1163                 virtio_transport_free_pkt(pkt);
1164                 break;
1165         case TCP_SYN_SENT:
1166                 virtio_transport_recv_connecting(sk, pkt);
1167                 virtio_transport_free_pkt(pkt);
1168                 break;
1169         case TCP_ESTABLISHED:
1170                 virtio_transport_recv_connected(sk, pkt);
1171                 break;
1172         case TCP_CLOSING:
1173                 virtio_transport_recv_disconnecting(sk, pkt);
1174                 virtio_transport_free_pkt(pkt);
1175                 break;
1176         default:
1177                 virtio_transport_free_pkt(pkt);
1178                 break;
1179         }
1180         release_sock(sk);
1181
1182         /* Release refcnt obtained when we fetched this socket out of the
1183          * bound or connected list.
1184          */
1185         sock_put(sk);
1186         return;
1187
1188 free_pkt:
1189         virtio_transport_free_pkt(pkt);
1190 }
1191 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1192
1193 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1194 {
1195         kfree(pkt->buf);
1196         kfree(pkt);
1197 }
1198 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1199
1200 MODULE_LICENSE("GPL v2");
1201 MODULE_AUTHOR("Asias He");
1202 MODULE_DESCRIPTION("common code for virtio vsock");