Merge branch '6.5/scsi-staging' into 6.5/scsi-fixes
[linux-2.6-block.git] / drivers / vhost / vhost.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (C) 2009 Red Hat, Inc.
3  * Copyright (C) 2006 Rusty Russell IBM Corporation
4  *
5  * Author: Michael S. Tsirkin <mst@redhat.com>
6  *
7  * Inspiration, some code, and most witty comments come from
8  * Documentation/virtual/lguest/lguest.c, by Rusty Russell
9  *
10  * Generic code for virtio server in host kernel.
11  */
12
13 #include <linux/eventfd.h>
14 #include <linux/vhost.h>
15 #include <linux/uio.h>
16 #include <linux/mm.h>
17 #include <linux/miscdevice.h>
18 #include <linux/mutex.h>
19 #include <linux/poll.h>
20 #include <linux/file.h>
21 #include <linux/highmem.h>
22 #include <linux/slab.h>
23 #include <linux/vmalloc.h>
24 #include <linux/kthread.h>
25 #include <linux/module.h>
26 #include <linux/sort.h>
27 #include <linux/sched/mm.h>
28 #include <linux/sched/signal.h>
29 #include <linux/sched/vhost_task.h>
30 #include <linux/interval_tree_generic.h>
31 #include <linux/nospec.h>
32 #include <linux/kcov.h>
33
34 #include "vhost.h"
35
36 static ushort max_mem_regions = 64;
37 module_param(max_mem_regions, ushort, 0444);
38 MODULE_PARM_DESC(max_mem_regions,
39         "Maximum number of memory regions in memory map. (default: 64)");
40 static int max_iotlb_entries = 2048;
41 module_param(max_iotlb_entries, int, 0444);
42 MODULE_PARM_DESC(max_iotlb_entries,
43         "Maximum number of iotlb entries. (default: 2048)");
44
45 enum {
46         VHOST_MEMORY_F_LOG = 0x1,
47 };
48
49 #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
50 #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
51
52 #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
53 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
54 {
55         vq->user_be = !virtio_legacy_is_little_endian();
56 }
57
58 static void vhost_enable_cross_endian_big(struct vhost_virtqueue *vq)
59 {
60         vq->user_be = true;
61 }
62
63 static void vhost_enable_cross_endian_little(struct vhost_virtqueue *vq)
64 {
65         vq->user_be = false;
66 }
67
68 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
69 {
70         struct vhost_vring_state s;
71
72         if (vq->private_data)
73                 return -EBUSY;
74
75         if (copy_from_user(&s, argp, sizeof(s)))
76                 return -EFAULT;
77
78         if (s.num != VHOST_VRING_LITTLE_ENDIAN &&
79             s.num != VHOST_VRING_BIG_ENDIAN)
80                 return -EINVAL;
81
82         if (s.num == VHOST_VRING_BIG_ENDIAN)
83                 vhost_enable_cross_endian_big(vq);
84         else
85                 vhost_enable_cross_endian_little(vq);
86
87         return 0;
88 }
89
90 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
91                                    int __user *argp)
92 {
93         struct vhost_vring_state s = {
94                 .index = idx,
95                 .num = vq->user_be
96         };
97
98         if (copy_to_user(argp, &s, sizeof(s)))
99                 return -EFAULT;
100
101         return 0;
102 }
103
104 static void vhost_init_is_le(struct vhost_virtqueue *vq)
105 {
106         /* Note for legacy virtio: user_be is initialized at reset time
107          * according to the host endianness. If userspace does not set an
108          * explicit endianness, the default behavior is native endian, as
109          * expected by legacy virtio.
110          */
111         vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) || !vq->user_be;
112 }
113 #else
114 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
115 {
116 }
117
118 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
119 {
120         return -ENOIOCTLCMD;
121 }
122
123 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
124                                    int __user *argp)
125 {
126         return -ENOIOCTLCMD;
127 }
128
129 static void vhost_init_is_le(struct vhost_virtqueue *vq)
130 {
131         vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1)
132                 || virtio_legacy_is_little_endian();
133 }
134 #endif /* CONFIG_VHOST_CROSS_ENDIAN_LEGACY */
135
136 static void vhost_reset_is_le(struct vhost_virtqueue *vq)
137 {
138         vhost_init_is_le(vq);
139 }
140
141 struct vhost_flush_struct {
142         struct vhost_work work;
143         struct completion wait_event;
144 };
145
146 static void vhost_flush_work(struct vhost_work *work)
147 {
148         struct vhost_flush_struct *s;
149
150         s = container_of(work, struct vhost_flush_struct, work);
151         complete(&s->wait_event);
152 }
153
154 static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
155                             poll_table *pt)
156 {
157         struct vhost_poll *poll;
158
159         poll = container_of(pt, struct vhost_poll, table);
160         poll->wqh = wqh;
161         add_wait_queue(wqh, &poll->wait);
162 }
163
164 static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync,
165                              void *key)
166 {
167         struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait);
168         struct vhost_work *work = &poll->work;
169
170         if (!(key_to_poll(key) & poll->mask))
171                 return 0;
172
173         if (!poll->dev->use_worker)
174                 work->fn(work);
175         else
176                 vhost_poll_queue(poll);
177
178         return 0;
179 }
180
181 void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
182 {
183         clear_bit(VHOST_WORK_QUEUED, &work->flags);
184         work->fn = fn;
185 }
186 EXPORT_SYMBOL_GPL(vhost_work_init);
187
188 /* Init poll structure */
189 void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
190                      __poll_t mask, struct vhost_dev *dev,
191                      struct vhost_virtqueue *vq)
192 {
193         init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
194         init_poll_funcptr(&poll->table, vhost_poll_func);
195         poll->mask = mask;
196         poll->dev = dev;
197         poll->wqh = NULL;
198         poll->vq = vq;
199
200         vhost_work_init(&poll->work, fn);
201 }
202 EXPORT_SYMBOL_GPL(vhost_poll_init);
203
204 /* Start polling a file. We add ourselves to file's wait queue. The caller must
205  * keep a reference to a file until after vhost_poll_stop is called. */
206 int vhost_poll_start(struct vhost_poll *poll, struct file *file)
207 {
208         __poll_t mask;
209
210         if (poll->wqh)
211                 return 0;
212
213         mask = vfs_poll(file, &poll->table);
214         if (mask)
215                 vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask));
216         if (mask & EPOLLERR) {
217                 vhost_poll_stop(poll);
218                 return -EINVAL;
219         }
220
221         return 0;
222 }
223 EXPORT_SYMBOL_GPL(vhost_poll_start);
224
225 /* Stop polling a file. After this function returns, it becomes safe to drop the
226  * file reference. You must also flush afterwards. */
227 void vhost_poll_stop(struct vhost_poll *poll)
228 {
229         if (poll->wqh) {
230                 remove_wait_queue(poll->wqh, &poll->wait);
231                 poll->wqh = NULL;
232         }
233 }
234 EXPORT_SYMBOL_GPL(vhost_poll_stop);
235
236 static void vhost_worker_queue(struct vhost_worker *worker,
237                                struct vhost_work *work)
238 {
239         if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
240                 /* We can only add the work to the list after we're
241                  * sure it was not in the list.
242                  * test_and_set_bit() implies a memory barrier.
243                  */
244                 llist_add(&work->node, &worker->work_list);
245                 vhost_task_wake(worker->vtsk);
246         }
247 }
248
249 bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work)
250 {
251         struct vhost_worker *worker;
252         bool queued = false;
253
254         rcu_read_lock();
255         worker = rcu_dereference(vq->worker);
256         if (worker) {
257                 queued = true;
258                 vhost_worker_queue(worker, work);
259         }
260         rcu_read_unlock();
261
262         return queued;
263 }
264 EXPORT_SYMBOL_GPL(vhost_vq_work_queue);
265
266 void vhost_vq_flush(struct vhost_virtqueue *vq)
267 {
268         struct vhost_flush_struct flush;
269
270         init_completion(&flush.wait_event);
271         vhost_work_init(&flush.work, vhost_flush_work);
272
273         if (vhost_vq_work_queue(vq, &flush.work))
274                 wait_for_completion(&flush.wait_event);
275 }
276 EXPORT_SYMBOL_GPL(vhost_vq_flush);
277
278 /**
279  * vhost_worker_flush - flush a worker
280  * @worker: worker to flush
281  *
282  * This does not use RCU to protect the worker, so the device or worker
283  * mutex must be held.
284  */
285 static void vhost_worker_flush(struct vhost_worker *worker)
286 {
287         struct vhost_flush_struct flush;
288
289         init_completion(&flush.wait_event);
290         vhost_work_init(&flush.work, vhost_flush_work);
291
292         vhost_worker_queue(worker, &flush.work);
293         wait_for_completion(&flush.wait_event);
294 }
295
296 void vhost_dev_flush(struct vhost_dev *dev)
297 {
298         struct vhost_worker *worker;
299         unsigned long i;
300
301         xa_for_each(&dev->worker_xa, i, worker) {
302                 mutex_lock(&worker->mutex);
303                 if (!worker->attachment_cnt) {
304                         mutex_unlock(&worker->mutex);
305                         continue;
306                 }
307                 vhost_worker_flush(worker);
308                 mutex_unlock(&worker->mutex);
309         }
310 }
311 EXPORT_SYMBOL_GPL(vhost_dev_flush);
312
313 /* A lockless hint for busy polling code to exit the loop */
314 bool vhost_vq_has_work(struct vhost_virtqueue *vq)
315 {
316         struct vhost_worker *worker;
317         bool has_work = false;
318
319         rcu_read_lock();
320         worker = rcu_dereference(vq->worker);
321         if (worker && !llist_empty(&worker->work_list))
322                 has_work = true;
323         rcu_read_unlock();
324
325         return has_work;
326 }
327 EXPORT_SYMBOL_GPL(vhost_vq_has_work);
328
329 void vhost_poll_queue(struct vhost_poll *poll)
330 {
331         vhost_vq_work_queue(poll->vq, &poll->work);
332 }
333 EXPORT_SYMBOL_GPL(vhost_poll_queue);
334
335 static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq)
336 {
337         int j;
338
339         for (j = 0; j < VHOST_NUM_ADDRS; j++)
340                 vq->meta_iotlb[j] = NULL;
341 }
342
343 static void vhost_vq_meta_reset(struct vhost_dev *d)
344 {
345         int i;
346
347         for (i = 0; i < d->nvqs; ++i)
348                 __vhost_vq_meta_reset(d->vqs[i]);
349 }
350
351 static void vhost_vring_call_reset(struct vhost_vring_call *call_ctx)
352 {
353         call_ctx->ctx = NULL;
354         memset(&call_ctx->producer, 0x0, sizeof(struct irq_bypass_producer));
355 }
356
357 bool vhost_vq_is_setup(struct vhost_virtqueue *vq)
358 {
359         return vq->avail && vq->desc && vq->used && vhost_vq_access_ok(vq);
360 }
361 EXPORT_SYMBOL_GPL(vhost_vq_is_setup);
362
363 static void vhost_vq_reset(struct vhost_dev *dev,
364                            struct vhost_virtqueue *vq)
365 {
366         vq->num = 1;
367         vq->desc = NULL;
368         vq->avail = NULL;
369         vq->used = NULL;
370         vq->last_avail_idx = 0;
371         vq->avail_idx = 0;
372         vq->last_used_idx = 0;
373         vq->signalled_used = 0;
374         vq->signalled_used_valid = false;
375         vq->used_flags = 0;
376         vq->log_used = false;
377         vq->log_addr = -1ull;
378         vq->private_data = NULL;
379         vq->acked_features = 0;
380         vq->acked_backend_features = 0;
381         vq->log_base = NULL;
382         vq->error_ctx = NULL;
383         vq->kick = NULL;
384         vq->log_ctx = NULL;
385         vhost_disable_cross_endian(vq);
386         vhost_reset_is_le(vq);
387         vq->busyloop_timeout = 0;
388         vq->umem = NULL;
389         vq->iotlb = NULL;
390         rcu_assign_pointer(vq->worker, NULL);
391         vhost_vring_call_reset(&vq->call_ctx);
392         __vhost_vq_meta_reset(vq);
393 }
394
395 static bool vhost_worker(void *data)
396 {
397         struct vhost_worker *worker = data;
398         struct vhost_work *work, *work_next;
399         struct llist_node *node;
400
401         node = llist_del_all(&worker->work_list);
402         if (node) {
403                 __set_current_state(TASK_RUNNING);
404
405                 node = llist_reverse_order(node);
406                 /* make sure flag is seen after deletion */
407                 smp_wmb();
408                 llist_for_each_entry_safe(work, work_next, node, node) {
409                         clear_bit(VHOST_WORK_QUEUED, &work->flags);
410                         kcov_remote_start_common(worker->kcov_handle);
411                         work->fn(work);
412                         kcov_remote_stop();
413                         cond_resched();
414                 }
415         }
416
417         return !!node;
418 }
419
420 static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
421 {
422         kfree(vq->indirect);
423         vq->indirect = NULL;
424         kfree(vq->log);
425         vq->log = NULL;
426         kfree(vq->heads);
427         vq->heads = NULL;
428 }
429
430 /* Helper to allocate iovec buffers for all vqs. */
431 static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
432 {
433         struct vhost_virtqueue *vq;
434         int i;
435
436         for (i = 0; i < dev->nvqs; ++i) {
437                 vq = dev->vqs[i];
438                 vq->indirect = kmalloc_array(UIO_MAXIOV,
439                                              sizeof(*vq->indirect),
440                                              GFP_KERNEL);
441                 vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log),
442                                         GFP_KERNEL);
443                 vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads),
444                                           GFP_KERNEL);
445                 if (!vq->indirect || !vq->log || !vq->heads)
446                         goto err_nomem;
447         }
448         return 0;
449
450 err_nomem:
451         for (; i >= 0; --i)
452                 vhost_vq_free_iovecs(dev->vqs[i]);
453         return -ENOMEM;
454 }
455
456 static void vhost_dev_free_iovecs(struct vhost_dev *dev)
457 {
458         int i;
459
460         for (i = 0; i < dev->nvqs; ++i)
461                 vhost_vq_free_iovecs(dev->vqs[i]);
462 }
463
464 bool vhost_exceeds_weight(struct vhost_virtqueue *vq,
465                           int pkts, int total_len)
466 {
467         struct vhost_dev *dev = vq->dev;
468
469         if ((dev->byte_weight && total_len >= dev->byte_weight) ||
470             pkts >= dev->weight) {
471                 vhost_poll_queue(&vq->poll);
472                 return true;
473         }
474
475         return false;
476 }
477 EXPORT_SYMBOL_GPL(vhost_exceeds_weight);
478
479 static size_t vhost_get_avail_size(struct vhost_virtqueue *vq,
480                                    unsigned int num)
481 {
482         size_t event __maybe_unused =
483                vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
484
485         return size_add(struct_size(vq->avail, ring, num), event);
486 }
487
488 static size_t vhost_get_used_size(struct vhost_virtqueue *vq,
489                                   unsigned int num)
490 {
491         size_t event __maybe_unused =
492                vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
493
494         return size_add(struct_size(vq->used, ring, num), event);
495 }
496
497 static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
498                                   unsigned int num)
499 {
500         return sizeof(*vq->desc) * num;
501 }
502
503 void vhost_dev_init(struct vhost_dev *dev,
504                     struct vhost_virtqueue **vqs, int nvqs,
505                     int iov_limit, int weight, int byte_weight,
506                     bool use_worker,
507                     int (*msg_handler)(struct vhost_dev *dev, u32 asid,
508                                        struct vhost_iotlb_msg *msg))
509 {
510         struct vhost_virtqueue *vq;
511         int i;
512
513         dev->vqs = vqs;
514         dev->nvqs = nvqs;
515         mutex_init(&dev->mutex);
516         dev->log_ctx = NULL;
517         dev->umem = NULL;
518         dev->iotlb = NULL;
519         dev->mm = NULL;
520         dev->iov_limit = iov_limit;
521         dev->weight = weight;
522         dev->byte_weight = byte_weight;
523         dev->use_worker = use_worker;
524         dev->msg_handler = msg_handler;
525         init_waitqueue_head(&dev->wait);
526         INIT_LIST_HEAD(&dev->read_list);
527         INIT_LIST_HEAD(&dev->pending_list);
528         spin_lock_init(&dev->iotlb_lock);
529         xa_init_flags(&dev->worker_xa, XA_FLAGS_ALLOC);
530
531         for (i = 0; i < dev->nvqs; ++i) {
532                 vq = dev->vqs[i];
533                 vq->log = NULL;
534                 vq->indirect = NULL;
535                 vq->heads = NULL;
536                 vq->dev = dev;
537                 mutex_init(&vq->mutex);
538                 vhost_vq_reset(dev, vq);
539                 if (vq->handle_kick)
540                         vhost_poll_init(&vq->poll, vq->handle_kick,
541                                         EPOLLIN, dev, vq);
542         }
543 }
544 EXPORT_SYMBOL_GPL(vhost_dev_init);
545
546 /* Caller should have device mutex */
547 long vhost_dev_check_owner(struct vhost_dev *dev)
548 {
549         /* Are you the owner? If not, I don't think you mean to do that */
550         return dev->mm == current->mm ? 0 : -EPERM;
551 }
552 EXPORT_SYMBOL_GPL(vhost_dev_check_owner);
553
554 /* Caller should have device mutex */
555 bool vhost_dev_has_owner(struct vhost_dev *dev)
556 {
557         return dev->mm;
558 }
559 EXPORT_SYMBOL_GPL(vhost_dev_has_owner);
560
561 static void vhost_attach_mm(struct vhost_dev *dev)
562 {
563         /* No owner, become one */
564         if (dev->use_worker) {
565                 dev->mm = get_task_mm(current);
566         } else {
567                 /* vDPA device does not use worker thead, so there's
568                  * no need to hold the address space for mm. This help
569                  * to avoid deadlock in the case of mmap() which may
570                  * held the refcnt of the file and depends on release
571                  * method to remove vma.
572                  */
573                 dev->mm = current->mm;
574                 mmgrab(dev->mm);
575         }
576 }
577
578 static void vhost_detach_mm(struct vhost_dev *dev)
579 {
580         if (!dev->mm)
581                 return;
582
583         if (dev->use_worker)
584                 mmput(dev->mm);
585         else
586                 mmdrop(dev->mm);
587
588         dev->mm = NULL;
589 }
590
591 static void vhost_worker_destroy(struct vhost_dev *dev,
592                                  struct vhost_worker *worker)
593 {
594         if (!worker)
595                 return;
596
597         WARN_ON(!llist_empty(&worker->work_list));
598         xa_erase(&dev->worker_xa, worker->id);
599         vhost_task_stop(worker->vtsk);
600         kfree(worker);
601 }
602
603 static void vhost_workers_free(struct vhost_dev *dev)
604 {
605         struct vhost_worker *worker;
606         unsigned long i;
607
608         if (!dev->use_worker)
609                 return;
610
611         for (i = 0; i < dev->nvqs; i++)
612                 rcu_assign_pointer(dev->vqs[i]->worker, NULL);
613         /*
614          * Free the default worker we created and cleanup workers userspace
615          * created but couldn't clean up (it forgot or crashed).
616          */
617         xa_for_each(&dev->worker_xa, i, worker)
618                 vhost_worker_destroy(dev, worker);
619         xa_destroy(&dev->worker_xa);
620 }
621
622 static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
623 {
624         struct vhost_worker *worker;
625         struct vhost_task *vtsk;
626         char name[TASK_COMM_LEN];
627         int ret;
628         u32 id;
629
630         worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
631         if (!worker)
632                 return NULL;
633
634         snprintf(name, sizeof(name), "vhost-%d", current->pid);
635
636         vtsk = vhost_task_create(vhost_worker, worker, name);
637         if (!vtsk)
638                 goto free_worker;
639
640         mutex_init(&worker->mutex);
641         init_llist_head(&worker->work_list);
642         worker->kcov_handle = kcov_common_handle();
643         worker->vtsk = vtsk;
644
645         vhost_task_start(vtsk);
646
647         ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
648         if (ret < 0)
649                 goto stop_worker;
650         worker->id = id;
651
652         return worker;
653
654 stop_worker:
655         vhost_task_stop(vtsk);
656 free_worker:
657         kfree(worker);
658         return NULL;
659 }
660
661 /* Caller must have device mutex */
662 static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
663                                      struct vhost_worker *worker)
664 {
665         struct vhost_worker *old_worker;
666
667         old_worker = rcu_dereference_check(vq->worker,
668                                            lockdep_is_held(&vq->dev->mutex));
669
670         mutex_lock(&worker->mutex);
671         worker->attachment_cnt++;
672         mutex_unlock(&worker->mutex);
673         rcu_assign_pointer(vq->worker, worker);
674
675         if (!old_worker)
676                 return;
677         /*
678          * Take the worker mutex to make sure we see the work queued from
679          * device wide flushes which doesn't use RCU for execution.
680          */
681         mutex_lock(&old_worker->mutex);
682         old_worker->attachment_cnt--;
683         /*
684          * We don't want to call synchronize_rcu for every vq during setup
685          * because it will slow down VM startup. If we haven't done
686          * VHOST_SET_VRING_KICK and not done the driver specific
687          * SET_ENDPOINT/RUNNUNG then we can skip the sync since there will
688          * not be any works queued for scsi and net.
689          */
690         mutex_lock(&vq->mutex);
691         if (!vhost_vq_get_backend(vq) && !vq->kick) {
692                 mutex_unlock(&vq->mutex);
693                 mutex_unlock(&old_worker->mutex);
694                 /*
695                  * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID.
696                  * Warn if it adds support for multiple workers but forgets to
697                  * handle the early queueing case.
698                  */
699                 WARN_ON(!old_worker->attachment_cnt &&
700                         !llist_empty(&old_worker->work_list));
701                 return;
702         }
703         mutex_unlock(&vq->mutex);
704
705         /* Make sure new vq queue/flush/poll calls see the new worker */
706         synchronize_rcu();
707         /* Make sure whatever was queued gets run */
708         vhost_worker_flush(old_worker);
709         mutex_unlock(&old_worker->mutex);
710 }
711
712  /* Caller must have device mutex */
713 static int vhost_vq_attach_worker(struct vhost_virtqueue *vq,
714                                   struct vhost_vring_worker *info)
715 {
716         unsigned long index = info->worker_id;
717         struct vhost_dev *dev = vq->dev;
718         struct vhost_worker *worker;
719
720         if (!dev->use_worker)
721                 return -EINVAL;
722
723         worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
724         if (!worker || worker->id != info->worker_id)
725                 return -ENODEV;
726
727         __vhost_vq_attach_worker(vq, worker);
728         return 0;
729 }
730
731 /* Caller must have device mutex */
732 static int vhost_new_worker(struct vhost_dev *dev,
733                             struct vhost_worker_state *info)
734 {
735         struct vhost_worker *worker;
736
737         worker = vhost_worker_create(dev);
738         if (!worker)
739                 return -ENOMEM;
740
741         info->worker_id = worker->id;
742         return 0;
743 }
744
745 /* Caller must have device mutex */
746 static int vhost_free_worker(struct vhost_dev *dev,
747                              struct vhost_worker_state *info)
748 {
749         unsigned long index = info->worker_id;
750         struct vhost_worker *worker;
751
752         worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
753         if (!worker || worker->id != info->worker_id)
754                 return -ENODEV;
755
756         mutex_lock(&worker->mutex);
757         if (worker->attachment_cnt) {
758                 mutex_unlock(&worker->mutex);
759                 return -EBUSY;
760         }
761         mutex_unlock(&worker->mutex);
762
763         vhost_worker_destroy(dev, worker);
764         return 0;
765 }
766
767 static int vhost_get_vq_from_user(struct vhost_dev *dev, void __user *argp,
768                                   struct vhost_virtqueue **vq, u32 *id)
769 {
770         u32 __user *idxp = argp;
771         u32 idx;
772         long r;
773
774         r = get_user(idx, idxp);
775         if (r < 0)
776                 return r;
777
778         if (idx >= dev->nvqs)
779                 return -ENOBUFS;
780
781         idx = array_index_nospec(idx, dev->nvqs);
782
783         *vq = dev->vqs[idx];
784         *id = idx;
785         return 0;
786 }
787
788 /* Caller must have device mutex */
789 long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl,
790                         void __user *argp)
791 {
792         struct vhost_vring_worker ring_worker;
793         struct vhost_worker_state state;
794         struct vhost_worker *worker;
795         struct vhost_virtqueue *vq;
796         long ret;
797         u32 idx;
798
799         if (!dev->use_worker)
800                 return -EINVAL;
801
802         if (!vhost_dev_has_owner(dev))
803                 return -EINVAL;
804
805         ret = vhost_dev_check_owner(dev);
806         if (ret)
807                 return ret;
808
809         switch (ioctl) {
810         /* dev worker ioctls */
811         case VHOST_NEW_WORKER:
812                 ret = vhost_new_worker(dev, &state);
813                 if (!ret && copy_to_user(argp, &state, sizeof(state)))
814                         ret = -EFAULT;
815                 return ret;
816         case VHOST_FREE_WORKER:
817                 if (copy_from_user(&state, argp, sizeof(state)))
818                         return -EFAULT;
819                 return vhost_free_worker(dev, &state);
820         /* vring worker ioctls */
821         case VHOST_ATTACH_VRING_WORKER:
822         case VHOST_GET_VRING_WORKER:
823                 break;
824         default:
825                 return -ENOIOCTLCMD;
826         }
827
828         ret = vhost_get_vq_from_user(dev, argp, &vq, &idx);
829         if (ret)
830                 return ret;
831
832         switch (ioctl) {
833         case VHOST_ATTACH_VRING_WORKER:
834                 if (copy_from_user(&ring_worker, argp, sizeof(ring_worker))) {
835                         ret = -EFAULT;
836                         break;
837                 }
838
839                 ret = vhost_vq_attach_worker(vq, &ring_worker);
840                 break;
841         case VHOST_GET_VRING_WORKER:
842                 worker = rcu_dereference_check(vq->worker,
843                                                lockdep_is_held(&dev->mutex));
844                 if (!worker) {
845                         ret = -EINVAL;
846                         break;
847                 }
848
849                 ring_worker.index = idx;
850                 ring_worker.worker_id = worker->id;
851
852                 if (copy_to_user(argp, &ring_worker, sizeof(ring_worker)))
853                         ret = -EFAULT;
854                 break;
855         default:
856                 ret = -ENOIOCTLCMD;
857                 break;
858         }
859
860         return ret;
861 }
862 EXPORT_SYMBOL_GPL(vhost_worker_ioctl);
863
864 /* Caller should have device mutex */
865 long vhost_dev_set_owner(struct vhost_dev *dev)
866 {
867         struct vhost_worker *worker;
868         int err, i;
869
870         /* Is there an owner already? */
871         if (vhost_dev_has_owner(dev)) {
872                 err = -EBUSY;
873                 goto err_mm;
874         }
875
876         vhost_attach_mm(dev);
877
878         err = vhost_dev_alloc_iovecs(dev);
879         if (err)
880                 goto err_iovecs;
881
882         if (dev->use_worker) {
883                 /*
884                  * This should be done last, because vsock can queue work
885                  * before VHOST_SET_OWNER so it simplifies the failure path
886                  * below since we don't have to worry about vsock queueing
887                  * while we free the worker.
888                  */
889                 worker = vhost_worker_create(dev);
890                 if (!worker) {
891                         err = -ENOMEM;
892                         goto err_worker;
893                 }
894
895                 for (i = 0; i < dev->nvqs; i++)
896                         __vhost_vq_attach_worker(dev->vqs[i], worker);
897         }
898
899         return 0;
900
901 err_worker:
902         vhost_dev_free_iovecs(dev);
903 err_iovecs:
904         vhost_detach_mm(dev);
905 err_mm:
906         return err;
907 }
908 EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
909
910 static struct vhost_iotlb *iotlb_alloc(void)
911 {
912         return vhost_iotlb_alloc(max_iotlb_entries,
913                                  VHOST_IOTLB_FLAG_RETIRE);
914 }
915
916 struct vhost_iotlb *vhost_dev_reset_owner_prepare(void)
917 {
918         return iotlb_alloc();
919 }
920 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
921
922 /* Caller should have device mutex */
923 void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem)
924 {
925         int i;
926
927         vhost_dev_cleanup(dev);
928
929         dev->umem = umem;
930         /* We don't need VQ locks below since vhost_dev_cleanup makes sure
931          * VQs aren't running.
932          */
933         for (i = 0; i < dev->nvqs; ++i)
934                 dev->vqs[i]->umem = umem;
935 }
936 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
937
938 void vhost_dev_stop(struct vhost_dev *dev)
939 {
940         int i;
941
942         for (i = 0; i < dev->nvqs; ++i) {
943                 if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick)
944                         vhost_poll_stop(&dev->vqs[i]->poll);
945         }
946
947         vhost_dev_flush(dev);
948 }
949 EXPORT_SYMBOL_GPL(vhost_dev_stop);
950
951 void vhost_clear_msg(struct vhost_dev *dev)
952 {
953         struct vhost_msg_node *node, *n;
954
955         spin_lock(&dev->iotlb_lock);
956
957         list_for_each_entry_safe(node, n, &dev->read_list, node) {
958                 list_del(&node->node);
959                 kfree(node);
960         }
961
962         list_for_each_entry_safe(node, n, &dev->pending_list, node) {
963                 list_del(&node->node);
964                 kfree(node);
965         }
966
967         spin_unlock(&dev->iotlb_lock);
968 }
969 EXPORT_SYMBOL_GPL(vhost_clear_msg);
970
971 void vhost_dev_cleanup(struct vhost_dev *dev)
972 {
973         int i;
974
975         for (i = 0; i < dev->nvqs; ++i) {
976                 if (dev->vqs[i]->error_ctx)
977                         eventfd_ctx_put(dev->vqs[i]->error_ctx);
978                 if (dev->vqs[i]->kick)
979                         fput(dev->vqs[i]->kick);
980                 if (dev->vqs[i]->call_ctx.ctx)
981                         eventfd_ctx_put(dev->vqs[i]->call_ctx.ctx);
982                 vhost_vq_reset(dev, dev->vqs[i]);
983         }
984         vhost_dev_free_iovecs(dev);
985         if (dev->log_ctx)
986                 eventfd_ctx_put(dev->log_ctx);
987         dev->log_ctx = NULL;
988         /* No one will access memory at this point */
989         vhost_iotlb_free(dev->umem);
990         dev->umem = NULL;
991         vhost_iotlb_free(dev->iotlb);
992         dev->iotlb = NULL;
993         vhost_clear_msg(dev);
994         wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
995         vhost_workers_free(dev);
996         vhost_detach_mm(dev);
997 }
998 EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
999
1000 static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
1001 {
1002         u64 a = addr / VHOST_PAGE_SIZE / 8;
1003
1004         /* Make sure 64 bit math will not overflow. */
1005         if (a > ULONG_MAX - (unsigned long)log_base ||
1006             a + (unsigned long)log_base > ULONG_MAX)
1007                 return false;
1008
1009         return access_ok(log_base + a,
1010                          (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
1011 }
1012
1013 /* Make sure 64 bit math will not overflow. */
1014 static bool vhost_overflow(u64 uaddr, u64 size)
1015 {
1016         if (uaddr > ULONG_MAX || size > ULONG_MAX)
1017                 return true;
1018
1019         if (!size)
1020                 return false;
1021
1022         return uaddr > ULONG_MAX - size + 1;
1023 }
1024
1025 /* Caller should have vq mutex and device mutex. */
1026 static bool vq_memory_access_ok(void __user *log_base, struct vhost_iotlb *umem,
1027                                 int log_all)
1028 {
1029         struct vhost_iotlb_map *map;
1030
1031         if (!umem)
1032                 return false;
1033
1034         list_for_each_entry(map, &umem->list, link) {
1035                 unsigned long a = map->addr;
1036
1037                 if (vhost_overflow(map->addr, map->size))
1038                         return false;
1039
1040
1041                 if (!access_ok((void __user *)a, map->size))
1042                         return false;
1043                 else if (log_all && !log_access_ok(log_base,
1044                                                    map->start,
1045                                                    map->size))
1046                         return false;
1047         }
1048         return true;
1049 }
1050
1051 static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq,
1052                                                u64 addr, unsigned int size,
1053                                                int type)
1054 {
1055         const struct vhost_iotlb_map *map = vq->meta_iotlb[type];
1056
1057         if (!map)
1058                 return NULL;
1059
1060         return (void __user *)(uintptr_t)(map->addr + addr - map->start);
1061 }
1062
1063 /* Can we switch to this memory table? */
1064 /* Caller should have device mutex but not vq mutex */
1065 static bool memory_access_ok(struct vhost_dev *d, struct vhost_iotlb *umem,
1066                              int log_all)
1067 {
1068         int i;
1069
1070         for (i = 0; i < d->nvqs; ++i) {
1071                 bool ok;
1072                 bool log;
1073
1074                 mutex_lock(&d->vqs[i]->mutex);
1075                 log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
1076                 /* If ring is inactive, will check when it's enabled. */
1077                 if (d->vqs[i]->private_data)
1078                         ok = vq_memory_access_ok(d->vqs[i]->log_base,
1079                                                  umem, log);
1080                 else
1081                         ok = true;
1082                 mutex_unlock(&d->vqs[i]->mutex);
1083                 if (!ok)
1084                         return false;
1085         }
1086         return true;
1087 }
1088
1089 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
1090                           struct iovec iov[], int iov_size, int access);
1091
1092 static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to,
1093                               const void *from, unsigned size)
1094 {
1095         int ret;
1096
1097         if (!vq->iotlb)
1098                 return __copy_to_user(to, from, size);
1099         else {
1100                 /* This function should be called after iotlb
1101                  * prefetch, which means we're sure that all vq
1102                  * could be access through iotlb. So -EAGAIN should
1103                  * not happen in this case.
1104                  */
1105                 struct iov_iter t;
1106                 void __user *uaddr = vhost_vq_meta_fetch(vq,
1107                                      (u64)(uintptr_t)to, size,
1108                                      VHOST_ADDR_USED);
1109
1110                 if (uaddr)
1111                         return __copy_to_user(uaddr, from, size);
1112
1113                 ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
1114                                      ARRAY_SIZE(vq->iotlb_iov),
1115                                      VHOST_ACCESS_WO);
1116                 if (ret < 0)
1117                         goto out;
1118                 iov_iter_init(&t, ITER_DEST, vq->iotlb_iov, ret, size);
1119                 ret = copy_to_iter(from, size, &t);
1120                 if (ret == size)
1121                         ret = 0;
1122         }
1123 out:
1124         return ret;
1125 }
1126
1127 static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
1128                                 void __user *from, unsigned size)
1129 {
1130         int ret;
1131
1132         if (!vq->iotlb)
1133                 return __copy_from_user(to, from, size);
1134         else {
1135                 /* This function should be called after iotlb
1136                  * prefetch, which means we're sure that vq
1137                  * could be access through iotlb. So -EAGAIN should
1138                  * not happen in this case.
1139                  */
1140                 void __user *uaddr = vhost_vq_meta_fetch(vq,
1141                                      (u64)(uintptr_t)from, size,
1142                                      VHOST_ADDR_DESC);
1143                 struct iov_iter f;
1144
1145                 if (uaddr)
1146                         return __copy_from_user(to, uaddr, size);
1147
1148                 ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
1149                                      ARRAY_SIZE(vq->iotlb_iov),
1150                                      VHOST_ACCESS_RO);
1151                 if (ret < 0) {
1152                         vq_err(vq, "IOTLB translation failure: uaddr "
1153                                "%p size 0x%llx\n", from,
1154                                (unsigned long long) size);
1155                         goto out;
1156                 }
1157                 iov_iter_init(&f, ITER_SOURCE, vq->iotlb_iov, ret, size);
1158                 ret = copy_from_iter(to, size, &f);
1159                 if (ret == size)
1160                         ret = 0;
1161         }
1162
1163 out:
1164         return ret;
1165 }
1166
1167 static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq,
1168                                           void __user *addr, unsigned int size,
1169                                           int type)
1170 {
1171         int ret;
1172
1173         ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
1174                              ARRAY_SIZE(vq->iotlb_iov),
1175                              VHOST_ACCESS_RO);
1176         if (ret < 0) {
1177                 vq_err(vq, "IOTLB translation failure: uaddr "
1178                         "%p size 0x%llx\n", addr,
1179                         (unsigned long long) size);
1180                 return NULL;
1181         }
1182
1183         if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
1184                 vq_err(vq, "Non atomic userspace memory access: uaddr "
1185                         "%p size 0x%llx\n", addr,
1186                         (unsigned long long) size);
1187                 return NULL;
1188         }
1189
1190         return vq->iotlb_iov[0].iov_base;
1191 }
1192
1193 /* This function should be called after iotlb
1194  * prefetch, which means we're sure that vq
1195  * could be access through iotlb. So -EAGAIN should
1196  * not happen in this case.
1197  */
1198 static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
1199                                             void __user *addr, unsigned int size,
1200                                             int type)
1201 {
1202         void __user *uaddr = vhost_vq_meta_fetch(vq,
1203                              (u64)(uintptr_t)addr, size, type);
1204         if (uaddr)
1205                 return uaddr;
1206
1207         return __vhost_get_user_slow(vq, addr, size, type);
1208 }
1209
1210 #define vhost_put_user(vq, x, ptr)              \
1211 ({ \
1212         int ret; \
1213         if (!vq->iotlb) { \
1214                 ret = __put_user(x, ptr); \
1215         } else { \
1216                 __typeof__(ptr) to = \
1217                         (__typeof__(ptr)) __vhost_get_user(vq, ptr,     \
1218                                           sizeof(*ptr), VHOST_ADDR_USED); \
1219                 if (to != NULL) \
1220                         ret = __put_user(x, to); \
1221                 else \
1222                         ret = -EFAULT;  \
1223         } \
1224         ret; \
1225 })
1226
1227 static inline int vhost_put_avail_event(struct vhost_virtqueue *vq)
1228 {
1229         return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
1230                               vhost_avail_event(vq));
1231 }
1232
1233 static inline int vhost_put_used(struct vhost_virtqueue *vq,
1234                                  struct vring_used_elem *head, int idx,
1235                                  int count)
1236 {
1237         return vhost_copy_to_user(vq, vq->used->ring + idx, head,
1238                                   count * sizeof(*head));
1239 }
1240
1241 static inline int vhost_put_used_flags(struct vhost_virtqueue *vq)
1242
1243 {
1244         return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
1245                               &vq->used->flags);
1246 }
1247
1248 static inline int vhost_put_used_idx(struct vhost_virtqueue *vq)
1249
1250 {
1251         return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
1252                               &vq->used->idx);
1253 }
1254
1255 #define vhost_get_user(vq, x, ptr, type)                \
1256 ({ \
1257         int ret; \
1258         if (!vq->iotlb) { \
1259                 ret = __get_user(x, ptr); \
1260         } else { \
1261                 __typeof__(ptr) from = \
1262                         (__typeof__(ptr)) __vhost_get_user(vq, ptr, \
1263                                                            sizeof(*ptr), \
1264                                                            type); \
1265                 if (from != NULL) \
1266                         ret = __get_user(x, from); \
1267                 else \
1268                         ret = -EFAULT; \
1269         } \
1270         ret; \
1271 })
1272
1273 #define vhost_get_avail(vq, x, ptr) \
1274         vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL)
1275
1276 #define vhost_get_used(vq, x, ptr) \
1277         vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
1278
1279 static void vhost_dev_lock_vqs(struct vhost_dev *d)
1280 {
1281         int i = 0;
1282         for (i = 0; i < d->nvqs; ++i)
1283                 mutex_lock_nested(&d->vqs[i]->mutex, i);
1284 }
1285
1286 static void vhost_dev_unlock_vqs(struct vhost_dev *d)
1287 {
1288         int i = 0;
1289         for (i = 0; i < d->nvqs; ++i)
1290                 mutex_unlock(&d->vqs[i]->mutex);
1291 }
1292
1293 static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq,
1294                                       __virtio16 *idx)
1295 {
1296         return vhost_get_avail(vq, *idx, &vq->avail->idx);
1297 }
1298
1299 static inline int vhost_get_avail_head(struct vhost_virtqueue *vq,
1300                                        __virtio16 *head, int idx)
1301 {
1302         return vhost_get_avail(vq, *head,
1303                                &vq->avail->ring[idx & (vq->num - 1)]);
1304 }
1305
1306 static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq,
1307                                         __virtio16 *flags)
1308 {
1309         return vhost_get_avail(vq, *flags, &vq->avail->flags);
1310 }
1311
1312 static inline int vhost_get_used_event(struct vhost_virtqueue *vq,
1313                                        __virtio16 *event)
1314 {
1315         return vhost_get_avail(vq, *event, vhost_used_event(vq));
1316 }
1317
1318 static inline int vhost_get_used_idx(struct vhost_virtqueue *vq,
1319                                      __virtio16 *idx)
1320 {
1321         return vhost_get_used(vq, *idx, &vq->used->idx);
1322 }
1323
1324 static inline int vhost_get_desc(struct vhost_virtqueue *vq,
1325                                  struct vring_desc *desc, int idx)
1326 {
1327         return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc));
1328 }
1329
1330 static void vhost_iotlb_notify_vq(struct vhost_dev *d,
1331                                   struct vhost_iotlb_msg *msg)
1332 {
1333         struct vhost_msg_node *node, *n;
1334
1335         spin_lock(&d->iotlb_lock);
1336
1337         list_for_each_entry_safe(node, n, &d->pending_list, node) {
1338                 struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
1339                 if (msg->iova <= vq_msg->iova &&
1340                     msg->iova + msg->size - 1 >= vq_msg->iova &&
1341                     vq_msg->type == VHOST_IOTLB_MISS) {
1342                         vhost_poll_queue(&node->vq->poll);
1343                         list_del(&node->node);
1344                         kfree(node);
1345                 }
1346         }
1347
1348         spin_unlock(&d->iotlb_lock);
1349 }
1350
1351 static bool umem_access_ok(u64 uaddr, u64 size, int access)
1352 {
1353         unsigned long a = uaddr;
1354
1355         /* Make sure 64 bit math will not overflow. */
1356         if (vhost_overflow(uaddr, size))
1357                 return false;
1358
1359         if ((access & VHOST_ACCESS_RO) &&
1360             !access_ok((void __user *)a, size))
1361                 return false;
1362         if ((access & VHOST_ACCESS_WO) &&
1363             !access_ok((void __user *)a, size))
1364                 return false;
1365         return true;
1366 }
1367
1368 static int vhost_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
1369                                    struct vhost_iotlb_msg *msg)
1370 {
1371         int ret = 0;
1372
1373         if (asid != 0)
1374                 return -EINVAL;
1375
1376         mutex_lock(&dev->mutex);
1377         vhost_dev_lock_vqs(dev);
1378         switch (msg->type) {
1379         case VHOST_IOTLB_UPDATE:
1380                 if (!dev->iotlb) {
1381                         ret = -EFAULT;
1382                         break;
1383                 }
1384                 if (!umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
1385                         ret = -EFAULT;
1386                         break;
1387                 }
1388                 vhost_vq_meta_reset(dev);
1389                 if (vhost_iotlb_add_range(dev->iotlb, msg->iova,
1390                                           msg->iova + msg->size - 1,
1391                                           msg->uaddr, msg->perm)) {
1392                         ret = -ENOMEM;
1393                         break;
1394                 }
1395                 vhost_iotlb_notify_vq(dev, msg);
1396                 break;
1397         case VHOST_IOTLB_INVALIDATE:
1398                 if (!dev->iotlb) {
1399                         ret = -EFAULT;
1400                         break;
1401                 }
1402                 vhost_vq_meta_reset(dev);
1403                 vhost_iotlb_del_range(dev->iotlb, msg->iova,
1404                                       msg->iova + msg->size - 1);
1405                 break;
1406         default:
1407                 ret = -EINVAL;
1408                 break;
1409         }
1410
1411         vhost_dev_unlock_vqs(dev);
1412         mutex_unlock(&dev->mutex);
1413
1414         return ret;
1415 }
1416 ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
1417                              struct iov_iter *from)
1418 {
1419         struct vhost_iotlb_msg msg;
1420         size_t offset;
1421         int type, ret;
1422         u32 asid = 0;
1423
1424         ret = copy_from_iter(&type, sizeof(type), from);
1425         if (ret != sizeof(type)) {
1426                 ret = -EINVAL;
1427                 goto done;
1428         }
1429
1430         switch (type) {
1431         case VHOST_IOTLB_MSG:
1432                 /* There maybe a hole after type for V1 message type,
1433                  * so skip it here.
1434                  */
1435                 offset = offsetof(struct vhost_msg, iotlb) - sizeof(int);
1436                 break;
1437         case VHOST_IOTLB_MSG_V2:
1438                 if (vhost_backend_has_feature(dev->vqs[0],
1439                                               VHOST_BACKEND_F_IOTLB_ASID)) {
1440                         ret = copy_from_iter(&asid, sizeof(asid), from);
1441                         if (ret != sizeof(asid)) {
1442                                 ret = -EINVAL;
1443                                 goto done;
1444                         }
1445                         offset = 0;
1446                 } else
1447                         offset = sizeof(__u32);
1448                 break;
1449         default:
1450                 ret = -EINVAL;
1451                 goto done;
1452         }
1453
1454         iov_iter_advance(from, offset);
1455         ret = copy_from_iter(&msg, sizeof(msg), from);
1456         if (ret != sizeof(msg)) {
1457                 ret = -EINVAL;
1458                 goto done;
1459         }
1460
1461         if ((msg.type == VHOST_IOTLB_UPDATE ||
1462              msg.type == VHOST_IOTLB_INVALIDATE) &&
1463              msg.size == 0) {
1464                 ret = -EINVAL;
1465                 goto done;
1466         }
1467
1468         if (dev->msg_handler)
1469                 ret = dev->msg_handler(dev, asid, &msg);
1470         else
1471                 ret = vhost_process_iotlb_msg(dev, asid, &msg);
1472         if (ret) {
1473                 ret = -EFAULT;
1474                 goto done;
1475         }
1476
1477         ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) :
1478               sizeof(struct vhost_msg_v2);
1479 done:
1480         return ret;
1481 }
1482 EXPORT_SYMBOL(vhost_chr_write_iter);
1483
1484 __poll_t vhost_chr_poll(struct file *file, struct vhost_dev *dev,
1485                             poll_table *wait)
1486 {
1487         __poll_t mask = 0;
1488
1489         poll_wait(file, &dev->wait, wait);
1490
1491         if (!list_empty(&dev->read_list))
1492                 mask |= EPOLLIN | EPOLLRDNORM;
1493
1494         return mask;
1495 }
1496 EXPORT_SYMBOL(vhost_chr_poll);
1497
1498 ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
1499                             int noblock)
1500 {
1501         DEFINE_WAIT(wait);
1502         struct vhost_msg_node *node;
1503         ssize_t ret = 0;
1504         unsigned size = sizeof(struct vhost_msg);
1505
1506         if (iov_iter_count(to) < size)
1507                 return 0;
1508
1509         while (1) {
1510                 if (!noblock)
1511                         prepare_to_wait(&dev->wait, &wait,
1512                                         TASK_INTERRUPTIBLE);
1513
1514                 node = vhost_dequeue_msg(dev, &dev->read_list);
1515                 if (node)
1516                         break;
1517                 if (noblock) {
1518                         ret = -EAGAIN;
1519                         break;
1520                 }
1521                 if (signal_pending(current)) {
1522                         ret = -ERESTARTSYS;
1523                         break;
1524                 }
1525                 if (!dev->iotlb) {
1526                         ret = -EBADFD;
1527                         break;
1528                 }
1529
1530                 schedule();
1531         }
1532
1533         if (!noblock)
1534                 finish_wait(&dev->wait, &wait);
1535
1536         if (node) {
1537                 struct vhost_iotlb_msg *msg;
1538                 void *start = &node->msg;
1539
1540                 switch (node->msg.type) {
1541                 case VHOST_IOTLB_MSG:
1542                         size = sizeof(node->msg);
1543                         msg = &node->msg.iotlb;
1544                         break;
1545                 case VHOST_IOTLB_MSG_V2:
1546                         size = sizeof(node->msg_v2);
1547                         msg = &node->msg_v2.iotlb;
1548                         break;
1549                 default:
1550                         BUG();
1551                         break;
1552                 }
1553
1554                 ret = copy_to_iter(start, size, to);
1555                 if (ret != size || msg->type != VHOST_IOTLB_MISS) {
1556                         kfree(node);
1557                         return ret;
1558                 }
1559                 vhost_enqueue_msg(dev, &dev->pending_list, node);
1560         }
1561
1562         return ret;
1563 }
1564 EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
1565
1566 static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
1567 {
1568         struct vhost_dev *dev = vq->dev;
1569         struct vhost_msg_node *node;
1570         struct vhost_iotlb_msg *msg;
1571         bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2);
1572
1573         node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG);
1574         if (!node)
1575                 return -ENOMEM;
1576
1577         if (v2) {
1578                 node->msg_v2.type = VHOST_IOTLB_MSG_V2;
1579                 msg = &node->msg_v2.iotlb;
1580         } else {
1581                 msg = &node->msg.iotlb;
1582         }
1583
1584         msg->type = VHOST_IOTLB_MISS;
1585         msg->iova = iova;
1586         msg->perm = access;
1587
1588         vhost_enqueue_msg(dev, &dev->read_list, node);
1589
1590         return 0;
1591 }
1592
1593 static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
1594                          vring_desc_t __user *desc,
1595                          vring_avail_t __user *avail,
1596                          vring_used_t __user *used)
1597
1598 {
1599         /* If an IOTLB device is present, the vring addresses are
1600          * GIOVAs. Access validation occurs at prefetch time. */
1601         if (vq->iotlb)
1602                 return true;
1603
1604         return access_ok(desc, vhost_get_desc_size(vq, num)) &&
1605                access_ok(avail, vhost_get_avail_size(vq, num)) &&
1606                access_ok(used, vhost_get_used_size(vq, num));
1607 }
1608
1609 static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
1610                                  const struct vhost_iotlb_map *map,
1611                                  int type)
1612 {
1613         int access = (type == VHOST_ADDR_USED) ?
1614                      VHOST_ACCESS_WO : VHOST_ACCESS_RO;
1615
1616         if (likely(map->perm & access))
1617                 vq->meta_iotlb[type] = map;
1618 }
1619
1620 static bool iotlb_access_ok(struct vhost_virtqueue *vq,
1621                             int access, u64 addr, u64 len, int type)
1622 {
1623         const struct vhost_iotlb_map *map;
1624         struct vhost_iotlb *umem = vq->iotlb;
1625         u64 s = 0, size, orig_addr = addr, last = addr + len - 1;
1626
1627         if (vhost_vq_meta_fetch(vq, addr, len, type))
1628                 return true;
1629
1630         while (len > s) {
1631                 map = vhost_iotlb_itree_first(umem, addr, last);
1632                 if (map == NULL || map->start > addr) {
1633                         vhost_iotlb_miss(vq, addr, access);
1634                         return false;
1635                 } else if (!(map->perm & access)) {
1636                         /* Report the possible access violation by
1637                          * request another translation from userspace.
1638                          */
1639                         return false;
1640                 }
1641
1642                 size = map->size - addr + map->start;
1643
1644                 if (orig_addr == addr && size >= len)
1645                         vhost_vq_meta_update(vq, map, type);
1646
1647                 s += size;
1648                 addr += size;
1649         }
1650
1651         return true;
1652 }
1653
1654 int vq_meta_prefetch(struct vhost_virtqueue *vq)
1655 {
1656         unsigned int num = vq->num;
1657
1658         if (!vq->iotlb)
1659                 return 1;
1660
1661         return iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->desc,
1662                                vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) &&
1663                iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->avail,
1664                                vhost_get_avail_size(vq, num),
1665                                VHOST_ADDR_AVAIL) &&
1666                iotlb_access_ok(vq, VHOST_MAP_WO, (u64)(uintptr_t)vq->used,
1667                                vhost_get_used_size(vq, num), VHOST_ADDR_USED);
1668 }
1669 EXPORT_SYMBOL_GPL(vq_meta_prefetch);
1670
1671 /* Can we log writes? */
1672 /* Caller should have device mutex but not vq mutex */
1673 bool vhost_log_access_ok(struct vhost_dev *dev)
1674 {
1675         return memory_access_ok(dev, dev->umem, 1);
1676 }
1677 EXPORT_SYMBOL_GPL(vhost_log_access_ok);
1678
1679 static bool vq_log_used_access_ok(struct vhost_virtqueue *vq,
1680                                   void __user *log_base,
1681                                   bool log_used,
1682                                   u64 log_addr)
1683 {
1684         /* If an IOTLB device is present, log_addr is a GIOVA that
1685          * will never be logged by log_used(). */
1686         if (vq->iotlb)
1687                 return true;
1688
1689         return !log_used || log_access_ok(log_base, log_addr,
1690                                           vhost_get_used_size(vq, vq->num));
1691 }
1692
1693 /* Verify access for write logging. */
1694 /* Caller should have vq mutex and device mutex */
1695 static bool vq_log_access_ok(struct vhost_virtqueue *vq,
1696                              void __user *log_base)
1697 {
1698         return vq_memory_access_ok(log_base, vq->umem,
1699                                    vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
1700                 vq_log_used_access_ok(vq, log_base, vq->log_used, vq->log_addr);
1701 }
1702
1703 /* Can we start vq? */
1704 /* Caller should have vq mutex and device mutex */
1705 bool vhost_vq_access_ok(struct vhost_virtqueue *vq)
1706 {
1707         if (!vq_log_access_ok(vq, vq->log_base))
1708                 return false;
1709
1710         return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used);
1711 }
1712 EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
1713
1714 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
1715 {
1716         struct vhost_memory mem, *newmem;
1717         struct vhost_memory_region *region;
1718         struct vhost_iotlb *newumem, *oldumem;
1719         unsigned long size = offsetof(struct vhost_memory, regions);
1720         int i;
1721
1722         if (copy_from_user(&mem, m, size))
1723                 return -EFAULT;
1724         if (mem.padding)
1725                 return -EOPNOTSUPP;
1726         if (mem.nregions > max_mem_regions)
1727                 return -E2BIG;
1728         newmem = kvzalloc(struct_size(newmem, regions, mem.nregions),
1729                         GFP_KERNEL);
1730         if (!newmem)
1731                 return -ENOMEM;
1732
1733         memcpy(newmem, &mem, size);
1734         if (copy_from_user(newmem->regions, m->regions,
1735                            flex_array_size(newmem, regions, mem.nregions))) {
1736                 kvfree(newmem);
1737                 return -EFAULT;
1738         }
1739
1740         newumem = iotlb_alloc();
1741         if (!newumem) {
1742                 kvfree(newmem);
1743                 return -ENOMEM;
1744         }
1745
1746         for (region = newmem->regions;
1747              region < newmem->regions + mem.nregions;
1748              region++) {
1749                 if (vhost_iotlb_add_range(newumem,
1750                                           region->guest_phys_addr,
1751                                           region->guest_phys_addr +
1752                                           region->memory_size - 1,
1753                                           region->userspace_addr,
1754                                           VHOST_MAP_RW))
1755                         goto err;
1756         }
1757
1758         if (!memory_access_ok(d, newumem, 0))
1759                 goto err;
1760
1761         oldumem = d->umem;
1762         d->umem = newumem;
1763
1764         /* All memory accesses are done under some VQ mutex. */
1765         for (i = 0; i < d->nvqs; ++i) {
1766                 mutex_lock(&d->vqs[i]->mutex);
1767                 d->vqs[i]->umem = newumem;
1768                 mutex_unlock(&d->vqs[i]->mutex);
1769         }
1770
1771         kvfree(newmem);
1772         vhost_iotlb_free(oldumem);
1773         return 0;
1774
1775 err:
1776         vhost_iotlb_free(newumem);
1777         kvfree(newmem);
1778         return -EFAULT;
1779 }
1780
1781 static long vhost_vring_set_num(struct vhost_dev *d,
1782                                 struct vhost_virtqueue *vq,
1783                                 void __user *argp)
1784 {
1785         struct vhost_vring_state s;
1786
1787         /* Resizing ring with an active backend?
1788          * You don't want to do that. */
1789         if (vq->private_data)
1790                 return -EBUSY;
1791
1792         if (copy_from_user(&s, argp, sizeof s))
1793                 return -EFAULT;
1794
1795         if (!s.num || s.num > 0xffff || (s.num & (s.num - 1)))
1796                 return -EINVAL;
1797         vq->num = s.num;
1798
1799         return 0;
1800 }
1801
1802 static long vhost_vring_set_addr(struct vhost_dev *d,
1803                                  struct vhost_virtqueue *vq,
1804                                  void __user *argp)
1805 {
1806         struct vhost_vring_addr a;
1807
1808         if (copy_from_user(&a, argp, sizeof a))
1809                 return -EFAULT;
1810         if (a.flags & ~(0x1 << VHOST_VRING_F_LOG))
1811                 return -EOPNOTSUPP;
1812
1813         /* For 32bit, verify that the top 32bits of the user
1814            data are set to zero. */
1815         if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
1816             (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
1817             (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr)
1818                 return -EFAULT;
1819
1820         /* Make sure it's safe to cast pointers to vring types. */
1821         BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
1822         BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
1823         if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
1824             (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
1825             (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1)))
1826                 return -EINVAL;
1827
1828         /* We only verify access here if backend is configured.
1829          * If it is not, we don't as size might not have been setup.
1830          * We will verify when backend is configured. */
1831         if (vq->private_data) {
1832                 if (!vq_access_ok(vq, vq->num,
1833                         (void __user *)(unsigned long)a.desc_user_addr,
1834                         (void __user *)(unsigned long)a.avail_user_addr,
1835                         (void __user *)(unsigned long)a.used_user_addr))
1836                         return -EINVAL;
1837
1838                 /* Also validate log access for used ring if enabled. */
1839                 if (!vq_log_used_access_ok(vq, vq->log_base,
1840                                 a.flags & (0x1 << VHOST_VRING_F_LOG),
1841                                 a.log_guest_addr))
1842                         return -EINVAL;
1843         }
1844
1845         vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
1846         vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
1847         vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
1848         vq->log_addr = a.log_guest_addr;
1849         vq->used = (void __user *)(unsigned long)a.used_user_addr;
1850
1851         return 0;
1852 }
1853
1854 static long vhost_vring_set_num_addr(struct vhost_dev *d,
1855                                      struct vhost_virtqueue *vq,
1856                                      unsigned int ioctl,
1857                                      void __user *argp)
1858 {
1859         long r;
1860
1861         mutex_lock(&vq->mutex);
1862
1863         switch (ioctl) {
1864         case VHOST_SET_VRING_NUM:
1865                 r = vhost_vring_set_num(d, vq, argp);
1866                 break;
1867         case VHOST_SET_VRING_ADDR:
1868                 r = vhost_vring_set_addr(d, vq, argp);
1869                 break;
1870         default:
1871                 BUG();
1872         }
1873
1874         mutex_unlock(&vq->mutex);
1875
1876         return r;
1877 }
1878 long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1879 {
1880         struct file *eventfp, *filep = NULL;
1881         bool pollstart = false, pollstop = false;
1882         struct eventfd_ctx *ctx = NULL;
1883         struct vhost_virtqueue *vq;
1884         struct vhost_vring_state s;
1885         struct vhost_vring_file f;
1886         u32 idx;
1887         long r;
1888
1889         r = vhost_get_vq_from_user(d, argp, &vq, &idx);
1890         if (r < 0)
1891                 return r;
1892
1893         if (ioctl == VHOST_SET_VRING_NUM ||
1894             ioctl == VHOST_SET_VRING_ADDR) {
1895                 return vhost_vring_set_num_addr(d, vq, ioctl, argp);
1896         }
1897
1898         mutex_lock(&vq->mutex);
1899
1900         switch (ioctl) {
1901         case VHOST_SET_VRING_BASE:
1902                 /* Moving base with an active backend?
1903                  * You don't want to do that. */
1904                 if (vq->private_data) {
1905                         r = -EBUSY;
1906                         break;
1907                 }
1908                 if (copy_from_user(&s, argp, sizeof s)) {
1909                         r = -EFAULT;
1910                         break;
1911                 }
1912                 if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
1913                         vq->last_avail_idx = s.num & 0xffff;
1914                         vq->last_used_idx = (s.num >> 16) & 0xffff;
1915                 } else {
1916                         if (s.num > 0xffff) {
1917                                 r = -EINVAL;
1918                                 break;
1919                         }
1920                         vq->last_avail_idx = s.num;
1921                 }
1922                 /* Forget the cached index value. */
1923                 vq->avail_idx = vq->last_avail_idx;
1924                 break;
1925         case VHOST_GET_VRING_BASE:
1926                 s.index = idx;
1927                 if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED))
1928                         s.num = (u32)vq->last_avail_idx | ((u32)vq->last_used_idx << 16);
1929                 else
1930                         s.num = vq->last_avail_idx;
1931                 if (copy_to_user(argp, &s, sizeof s))
1932                         r = -EFAULT;
1933                 break;
1934         case VHOST_SET_VRING_KICK:
1935                 if (copy_from_user(&f, argp, sizeof f)) {
1936                         r = -EFAULT;
1937                         break;
1938                 }
1939                 eventfp = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_fget(f.fd);
1940                 if (IS_ERR(eventfp)) {
1941                         r = PTR_ERR(eventfp);
1942                         break;
1943                 }
1944                 if (eventfp != vq->kick) {
1945                         pollstop = (filep = vq->kick) != NULL;
1946                         pollstart = (vq->kick = eventfp) != NULL;
1947                 } else
1948                         filep = eventfp;
1949                 break;
1950         case VHOST_SET_VRING_CALL:
1951                 if (copy_from_user(&f, argp, sizeof f)) {
1952                         r = -EFAULT;
1953                         break;
1954                 }
1955                 ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
1956                 if (IS_ERR(ctx)) {
1957                         r = PTR_ERR(ctx);
1958                         break;
1959                 }
1960
1961                 swap(ctx, vq->call_ctx.ctx);
1962                 break;
1963         case VHOST_SET_VRING_ERR:
1964                 if (copy_from_user(&f, argp, sizeof f)) {
1965                         r = -EFAULT;
1966                         break;
1967                 }
1968                 ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
1969                 if (IS_ERR(ctx)) {
1970                         r = PTR_ERR(ctx);
1971                         break;
1972                 }
1973                 swap(ctx, vq->error_ctx);
1974                 break;
1975         case VHOST_SET_VRING_ENDIAN:
1976                 r = vhost_set_vring_endian(vq, argp);
1977                 break;
1978         case VHOST_GET_VRING_ENDIAN:
1979                 r = vhost_get_vring_endian(vq, idx, argp);
1980                 break;
1981         case VHOST_SET_VRING_BUSYLOOP_TIMEOUT:
1982                 if (copy_from_user(&s, argp, sizeof(s))) {
1983                         r = -EFAULT;
1984                         break;
1985                 }
1986                 vq->busyloop_timeout = s.num;
1987                 break;
1988         case VHOST_GET_VRING_BUSYLOOP_TIMEOUT:
1989                 s.index = idx;
1990                 s.num = vq->busyloop_timeout;
1991                 if (copy_to_user(argp, &s, sizeof(s)))
1992                         r = -EFAULT;
1993                 break;
1994         default:
1995                 r = -ENOIOCTLCMD;
1996         }
1997
1998         if (pollstop && vq->handle_kick)
1999                 vhost_poll_stop(&vq->poll);
2000
2001         if (!IS_ERR_OR_NULL(ctx))
2002                 eventfd_ctx_put(ctx);
2003         if (filep)
2004                 fput(filep);
2005
2006         if (pollstart && vq->handle_kick)
2007                 r = vhost_poll_start(&vq->poll, vq->kick);
2008
2009         mutex_unlock(&vq->mutex);
2010
2011         if (pollstop && vq->handle_kick)
2012                 vhost_dev_flush(vq->poll.dev);
2013         return r;
2014 }
2015 EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
2016
2017 int vhost_init_device_iotlb(struct vhost_dev *d)
2018 {
2019         struct vhost_iotlb *niotlb, *oiotlb;
2020         int i;
2021
2022         niotlb = iotlb_alloc();
2023         if (!niotlb)
2024                 return -ENOMEM;
2025
2026         oiotlb = d->iotlb;
2027         d->iotlb = niotlb;
2028
2029         for (i = 0; i < d->nvqs; ++i) {
2030                 struct vhost_virtqueue *vq = d->vqs[i];
2031
2032                 mutex_lock(&vq->mutex);
2033                 vq->iotlb = niotlb;
2034                 __vhost_vq_meta_reset(vq);
2035                 mutex_unlock(&vq->mutex);
2036         }
2037
2038         vhost_iotlb_free(oiotlb);
2039
2040         return 0;
2041 }
2042 EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
2043
2044 /* Caller must have device mutex */
2045 long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
2046 {
2047         struct eventfd_ctx *ctx;
2048         u64 p;
2049         long r;
2050         int i, fd;
2051
2052         /* If you are not the owner, you can become one */
2053         if (ioctl == VHOST_SET_OWNER) {
2054                 r = vhost_dev_set_owner(d);
2055                 goto done;
2056         }
2057
2058         /* You must be the owner to do anything else */
2059         r = vhost_dev_check_owner(d);
2060         if (r)
2061                 goto done;
2062
2063         switch (ioctl) {
2064         case VHOST_SET_MEM_TABLE:
2065                 r = vhost_set_memory(d, argp);
2066                 break;
2067         case VHOST_SET_LOG_BASE:
2068                 if (copy_from_user(&p, argp, sizeof p)) {
2069                         r = -EFAULT;
2070                         break;
2071                 }
2072                 if ((u64)(unsigned long)p != p) {
2073                         r = -EFAULT;
2074                         break;
2075                 }
2076                 for (i = 0; i < d->nvqs; ++i) {
2077                         struct vhost_virtqueue *vq;
2078                         void __user *base = (void __user *)(unsigned long)p;
2079                         vq = d->vqs[i];
2080                         mutex_lock(&vq->mutex);
2081                         /* If ring is inactive, will check when it's enabled. */
2082                         if (vq->private_data && !vq_log_access_ok(vq, base))
2083                                 r = -EFAULT;
2084                         else
2085                                 vq->log_base = base;
2086                         mutex_unlock(&vq->mutex);
2087                 }
2088                 break;
2089         case VHOST_SET_LOG_FD:
2090                 r = get_user(fd, (int __user *)argp);
2091                 if (r < 0)
2092                         break;
2093                 ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
2094                 if (IS_ERR(ctx)) {
2095                         r = PTR_ERR(ctx);
2096                         break;
2097                 }
2098                 swap(ctx, d->log_ctx);
2099                 for (i = 0; i < d->nvqs; ++i) {
2100                         mutex_lock(&d->vqs[i]->mutex);
2101                         d->vqs[i]->log_ctx = d->log_ctx;
2102                         mutex_unlock(&d->vqs[i]->mutex);
2103                 }
2104                 if (ctx)
2105                         eventfd_ctx_put(ctx);
2106                 break;
2107         default:
2108                 r = -ENOIOCTLCMD;
2109                 break;
2110         }
2111 done:
2112         return r;
2113 }
2114 EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
2115
2116 /* TODO: This is really inefficient.  We need something like get_user()
2117  * (instruction directly accesses the data, with an exception table entry
2118  * returning -EFAULT). See Documentation/arch/x86/exception-tables.rst.
2119  */
2120 static int set_bit_to_user(int nr, void __user *addr)
2121 {
2122         unsigned long log = (unsigned long)addr;
2123         struct page *page;
2124         void *base;
2125         int bit = nr + (log % PAGE_SIZE) * 8;
2126         int r;
2127
2128         r = pin_user_pages_fast(log, 1, FOLL_WRITE, &page);
2129         if (r < 0)
2130                 return r;
2131         BUG_ON(r != 1);
2132         base = kmap_atomic(page);
2133         set_bit(bit, base);
2134         kunmap_atomic(base);
2135         unpin_user_pages_dirty_lock(&page, 1, true);
2136         return 0;
2137 }
2138
2139 static int log_write(void __user *log_base,
2140                      u64 write_address, u64 write_length)
2141 {
2142         u64 write_page = write_address / VHOST_PAGE_SIZE;
2143         int r;
2144
2145         if (!write_length)
2146                 return 0;
2147         write_length += write_address % VHOST_PAGE_SIZE;
2148         for (;;) {
2149                 u64 base = (u64)(unsigned long)log_base;
2150                 u64 log = base + write_page / 8;
2151                 int bit = write_page % 8;
2152                 if ((u64)(unsigned long)log != log)
2153                         return -EFAULT;
2154                 r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
2155                 if (r < 0)
2156                         return r;
2157                 if (write_length <= VHOST_PAGE_SIZE)
2158                         break;
2159                 write_length -= VHOST_PAGE_SIZE;
2160                 write_page += 1;
2161         }
2162         return r;
2163 }
2164
2165 static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
2166 {
2167         struct vhost_iotlb *umem = vq->umem;
2168         struct vhost_iotlb_map *u;
2169         u64 start, end, l, min;
2170         int r;
2171         bool hit = false;
2172
2173         while (len) {
2174                 min = len;
2175                 /* More than one GPAs can be mapped into a single HVA. So
2176                  * iterate all possible umems here to be safe.
2177                  */
2178                 list_for_each_entry(u, &umem->list, link) {
2179                         if (u->addr > hva - 1 + len ||
2180                             u->addr - 1 + u->size < hva)
2181                                 continue;
2182                         start = max(u->addr, hva);
2183                         end = min(u->addr - 1 + u->size, hva - 1 + len);
2184                         l = end - start + 1;
2185                         r = log_write(vq->log_base,
2186                                       u->start + start - u->addr,
2187                                       l);
2188                         if (r < 0)
2189                                 return r;
2190                         hit = true;
2191                         min = min(l, min);
2192                 }
2193
2194                 if (!hit)
2195                         return -EFAULT;
2196
2197                 len -= min;
2198                 hva += min;
2199         }
2200
2201         return 0;
2202 }
2203
2204 static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
2205 {
2206         struct iovec *iov = vq->log_iov;
2207         int i, ret;
2208
2209         if (!vq->iotlb)
2210                 return log_write(vq->log_base, vq->log_addr + used_offset, len);
2211
2212         ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
2213                              len, iov, 64, VHOST_ACCESS_WO);
2214         if (ret < 0)
2215                 return ret;
2216
2217         for (i = 0; i < ret; i++) {
2218                 ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
2219                                     iov[i].iov_len);
2220                 if (ret)
2221                         return ret;
2222         }
2223
2224         return 0;
2225 }
2226
2227 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
2228                     unsigned int log_num, u64 len, struct iovec *iov, int count)
2229 {
2230         int i, r;
2231
2232         /* Make sure data written is seen before log. */
2233         smp_wmb();
2234
2235         if (vq->iotlb) {
2236                 for (i = 0; i < count; i++) {
2237                         r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
2238                                           iov[i].iov_len);
2239                         if (r < 0)
2240                                 return r;
2241                 }
2242                 return 0;
2243         }
2244
2245         for (i = 0; i < log_num; ++i) {
2246                 u64 l = min(log[i].len, len);
2247                 r = log_write(vq->log_base, log[i].addr, l);
2248                 if (r < 0)
2249                         return r;
2250                 len -= l;
2251                 if (!len) {
2252                         if (vq->log_ctx)
2253                                 eventfd_signal(vq->log_ctx, 1);
2254                         return 0;
2255                 }
2256         }
2257         /* Length written exceeds what we have stored. This is a bug. */
2258         BUG();
2259         return 0;
2260 }
2261 EXPORT_SYMBOL_GPL(vhost_log_write);
2262
2263 static int vhost_update_used_flags(struct vhost_virtqueue *vq)
2264 {
2265         void __user *used;
2266         if (vhost_put_used_flags(vq))
2267                 return -EFAULT;
2268         if (unlikely(vq->log_used)) {
2269                 /* Make sure the flag is seen before log. */
2270                 smp_wmb();
2271                 /* Log used flag write. */
2272                 used = &vq->used->flags;
2273                 log_used(vq, (used - (void __user *)vq->used),
2274                          sizeof vq->used->flags);
2275                 if (vq->log_ctx)
2276                         eventfd_signal(vq->log_ctx, 1);
2277         }
2278         return 0;
2279 }
2280
2281 static int vhost_update_avail_event(struct vhost_virtqueue *vq)
2282 {
2283         if (vhost_put_avail_event(vq))
2284                 return -EFAULT;
2285         if (unlikely(vq->log_used)) {
2286                 void __user *used;
2287                 /* Make sure the event is seen before log. */
2288                 smp_wmb();
2289                 /* Log avail event write */
2290                 used = vhost_avail_event(vq);
2291                 log_used(vq, (used - (void __user *)vq->used),
2292                          sizeof *vhost_avail_event(vq));
2293                 if (vq->log_ctx)
2294                         eventfd_signal(vq->log_ctx, 1);
2295         }
2296         return 0;
2297 }
2298
2299 int vhost_vq_init_access(struct vhost_virtqueue *vq)
2300 {
2301         __virtio16 last_used_idx;
2302         int r;
2303         bool is_le = vq->is_le;
2304
2305         if (!vq->private_data)
2306                 return 0;
2307
2308         vhost_init_is_le(vq);
2309
2310         r = vhost_update_used_flags(vq);
2311         if (r)
2312                 goto err;
2313         vq->signalled_used_valid = false;
2314         if (!vq->iotlb &&
2315             !access_ok(&vq->used->idx, sizeof vq->used->idx)) {
2316                 r = -EFAULT;
2317                 goto err;
2318         }
2319         r = vhost_get_used_idx(vq, &last_used_idx);
2320         if (r) {
2321                 vq_err(vq, "Can't access used idx at %p\n",
2322                        &vq->used->idx);
2323                 goto err;
2324         }
2325         vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
2326         return 0;
2327
2328 err:
2329         vq->is_le = is_le;
2330         return r;
2331 }
2332 EXPORT_SYMBOL_GPL(vhost_vq_init_access);
2333
2334 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
2335                           struct iovec iov[], int iov_size, int access)
2336 {
2337         const struct vhost_iotlb_map *map;
2338         struct vhost_dev *dev = vq->dev;
2339         struct vhost_iotlb *umem = dev->iotlb ? dev->iotlb : dev->umem;
2340         struct iovec *_iov;
2341         u64 s = 0, last = addr + len - 1;
2342         int ret = 0;
2343
2344         while ((u64)len > s) {
2345                 u64 size;
2346                 if (unlikely(ret >= iov_size)) {
2347                         ret = -ENOBUFS;
2348                         break;
2349                 }
2350
2351                 map = vhost_iotlb_itree_first(umem, addr, last);
2352                 if (map == NULL || map->start > addr) {
2353                         if (umem != dev->iotlb) {
2354                                 ret = -EFAULT;
2355                                 break;
2356                         }
2357                         ret = -EAGAIN;
2358                         break;
2359                 } else if (!(map->perm & access)) {
2360                         ret = -EPERM;
2361                         break;
2362                 }
2363
2364                 _iov = iov + ret;
2365                 size = map->size - addr + map->start;
2366                 _iov->iov_len = min((u64)len - s, size);
2367                 _iov->iov_base = (void __user *)(unsigned long)
2368                                  (map->addr + addr - map->start);
2369                 s += size;
2370                 addr += size;
2371                 ++ret;
2372         }
2373
2374         if (ret == -EAGAIN)
2375                 vhost_iotlb_miss(vq, addr, access);
2376         return ret;
2377 }
2378
2379 /* Each buffer in the virtqueues is actually a chain of descriptors.  This
2380  * function returns the next descriptor in the chain,
2381  * or -1U if we're at the end. */
2382 static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc)
2383 {
2384         unsigned int next;
2385
2386         /* If this descriptor says it doesn't chain, we're done. */
2387         if (!(desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT)))
2388                 return -1U;
2389
2390         /* Check they're not leading us off end of descriptors. */
2391         next = vhost16_to_cpu(vq, READ_ONCE(desc->next));
2392         return next;
2393 }
2394
2395 static int get_indirect(struct vhost_virtqueue *vq,
2396                         struct iovec iov[], unsigned int iov_size,
2397                         unsigned int *out_num, unsigned int *in_num,
2398                         struct vhost_log *log, unsigned int *log_num,
2399                         struct vring_desc *indirect)
2400 {
2401         struct vring_desc desc;
2402         unsigned int i = 0, count, found = 0;
2403         u32 len = vhost32_to_cpu(vq, indirect->len);
2404         struct iov_iter from;
2405         int ret, access;
2406
2407         /* Sanity check */
2408         if (unlikely(len % sizeof desc)) {
2409                 vq_err(vq, "Invalid length in indirect descriptor: "
2410                        "len 0x%llx not multiple of 0x%zx\n",
2411                        (unsigned long long)len,
2412                        sizeof desc);
2413                 return -EINVAL;
2414         }
2415
2416         ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
2417                              UIO_MAXIOV, VHOST_ACCESS_RO);
2418         if (unlikely(ret < 0)) {
2419                 if (ret != -EAGAIN)
2420                         vq_err(vq, "Translation failure %d in indirect.\n", ret);
2421                 return ret;
2422         }
2423         iov_iter_init(&from, ITER_SOURCE, vq->indirect, ret, len);
2424         count = len / sizeof desc;
2425         /* Buffers are chained via a 16 bit next field, so
2426          * we can have at most 2^16 of these. */
2427         if (unlikely(count > USHRT_MAX + 1)) {
2428                 vq_err(vq, "Indirect buffer length too big: %d\n",
2429                        indirect->len);
2430                 return -E2BIG;
2431         }
2432
2433         do {
2434                 unsigned iov_count = *in_num + *out_num;
2435                 if (unlikely(++found > count)) {
2436                         vq_err(vq, "Loop detected: last one at %u "
2437                                "indirect size %u\n",
2438                                i, count);
2439                         return -EINVAL;
2440                 }
2441                 if (unlikely(!copy_from_iter_full(&desc, sizeof(desc), &from))) {
2442                         vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n",
2443                                i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2444                         return -EINVAL;
2445                 }
2446                 if (unlikely(desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) {
2447                         vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n",
2448                                i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2449                         return -EINVAL;
2450                 }
2451
2452                 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2453                         access = VHOST_ACCESS_WO;
2454                 else
2455                         access = VHOST_ACCESS_RO;
2456
2457                 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2458                                      vhost32_to_cpu(vq, desc.len), iov + iov_count,
2459                                      iov_size - iov_count, access);
2460                 if (unlikely(ret < 0)) {
2461                         if (ret != -EAGAIN)
2462                                 vq_err(vq, "Translation failure %d indirect idx %d\n",
2463                                         ret, i);
2464                         return ret;
2465                 }
2466                 /* If this is an input descriptor, increment that count. */
2467                 if (access == VHOST_ACCESS_WO) {
2468                         *in_num += ret;
2469                         if (unlikely(log && ret)) {
2470                                 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2471                                 log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2472                                 ++*log_num;
2473                         }
2474                 } else {
2475                         /* If it's an output descriptor, they're all supposed
2476                          * to come before any input descriptors. */
2477                         if (unlikely(*in_num)) {
2478                                 vq_err(vq, "Indirect descriptor "
2479                                        "has out after in: idx %d\n", i);
2480                                 return -EINVAL;
2481                         }
2482                         *out_num += ret;
2483                 }
2484         } while ((i = next_desc(vq, &desc)) != -1);
2485         return 0;
2486 }
2487
2488 /* This looks in the virtqueue and for the first available buffer, and converts
2489  * it to an iovec for convenient access.  Since descriptors consist of some
2490  * number of output then some number of input descriptors, it's actually two
2491  * iovecs, but we pack them into one and note how many of each there were.
2492  *
2493  * This function returns the descriptor number found, or vq->num (which is
2494  * never a valid descriptor number) if none was found.  A negative code is
2495  * returned on error. */
2496 int vhost_get_vq_desc(struct vhost_virtqueue *vq,
2497                       struct iovec iov[], unsigned int iov_size,
2498                       unsigned int *out_num, unsigned int *in_num,
2499                       struct vhost_log *log, unsigned int *log_num)
2500 {
2501         struct vring_desc desc;
2502         unsigned int i, head, found = 0;
2503         u16 last_avail_idx;
2504         __virtio16 avail_idx;
2505         __virtio16 ring_head;
2506         int ret, access;
2507
2508         /* Check it isn't doing very strange things with descriptor numbers. */
2509         last_avail_idx = vq->last_avail_idx;
2510
2511         if (vq->avail_idx == vq->last_avail_idx) {
2512                 if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) {
2513                         vq_err(vq, "Failed to access avail idx at %p\n",
2514                                 &vq->avail->idx);
2515                         return -EFAULT;
2516                 }
2517                 vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2518
2519                 if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) {
2520                         vq_err(vq, "Guest moved used index from %u to %u",
2521                                 last_avail_idx, vq->avail_idx);
2522                         return -EFAULT;
2523                 }
2524
2525                 /* If there's nothing new since last we looked, return
2526                  * invalid.
2527                  */
2528                 if (vq->avail_idx == last_avail_idx)
2529                         return vq->num;
2530
2531                 /* Only get avail ring entries after they have been
2532                  * exposed by guest.
2533                  */
2534                 smp_rmb();
2535         }
2536
2537         /* Grab the next descriptor number they're advertising, and increment
2538          * the index we've seen. */
2539         if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) {
2540                 vq_err(vq, "Failed to read head: idx %d address %p\n",
2541                        last_avail_idx,
2542                        &vq->avail->ring[last_avail_idx % vq->num]);
2543                 return -EFAULT;
2544         }
2545
2546         head = vhost16_to_cpu(vq, ring_head);
2547
2548         /* If their number is silly, that's an error. */
2549         if (unlikely(head >= vq->num)) {
2550                 vq_err(vq, "Guest says index %u > %u is available",
2551                        head, vq->num);
2552                 return -EINVAL;
2553         }
2554
2555         /* When we start there are none of either input nor output. */
2556         *out_num = *in_num = 0;
2557         if (unlikely(log))
2558                 *log_num = 0;
2559
2560         i = head;
2561         do {
2562                 unsigned iov_count = *in_num + *out_num;
2563                 if (unlikely(i >= vq->num)) {
2564                         vq_err(vq, "Desc index is %u > %u, head = %u",
2565                                i, vq->num, head);
2566                         return -EINVAL;
2567                 }
2568                 if (unlikely(++found > vq->num)) {
2569                         vq_err(vq, "Loop detected: last one at %u "
2570                                "vq size %u head %u\n",
2571                                i, vq->num, head);
2572                         return -EINVAL;
2573                 }
2574                 ret = vhost_get_desc(vq, &desc, i);
2575                 if (unlikely(ret)) {
2576                         vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
2577                                i, vq->desc + i);
2578                         return -EFAULT;
2579                 }
2580                 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) {
2581                         ret = get_indirect(vq, iov, iov_size,
2582                                            out_num, in_num,
2583                                            log, log_num, &desc);
2584                         if (unlikely(ret < 0)) {
2585                                 if (ret != -EAGAIN)
2586                                         vq_err(vq, "Failure detected "
2587                                                 "in indirect descriptor at idx %d\n", i);
2588                                 return ret;
2589                         }
2590                         continue;
2591                 }
2592
2593                 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2594                         access = VHOST_ACCESS_WO;
2595                 else
2596                         access = VHOST_ACCESS_RO;
2597                 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2598                                      vhost32_to_cpu(vq, desc.len), iov + iov_count,
2599                                      iov_size - iov_count, access);
2600                 if (unlikely(ret < 0)) {
2601                         if (ret != -EAGAIN)
2602                                 vq_err(vq, "Translation failure %d descriptor idx %d\n",
2603                                         ret, i);
2604                         return ret;
2605                 }
2606                 if (access == VHOST_ACCESS_WO) {
2607                         /* If this is an input descriptor,
2608                          * increment that count. */
2609                         *in_num += ret;
2610                         if (unlikely(log && ret)) {
2611                                 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2612                                 log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2613                                 ++*log_num;
2614                         }
2615                 } else {
2616                         /* If it's an output descriptor, they're all supposed
2617                          * to come before any input descriptors. */
2618                         if (unlikely(*in_num)) {
2619                                 vq_err(vq, "Descriptor has out after in: "
2620                                        "idx %d\n", i);
2621                                 return -EINVAL;
2622                         }
2623                         *out_num += ret;
2624                 }
2625         } while ((i = next_desc(vq, &desc)) != -1);
2626
2627         /* On success, increment avail index. */
2628         vq->last_avail_idx++;
2629
2630         /* Assume notifications from guest are disabled at this point,
2631          * if they aren't we would need to update avail_event index. */
2632         BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
2633         return head;
2634 }
2635 EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
2636
2637 /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
2638 void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
2639 {
2640         vq->last_avail_idx -= n;
2641 }
2642 EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);
2643
2644 /* After we've used one of their buffers, we tell them about it.  We'll then
2645  * want to notify the guest, using eventfd. */
2646 int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
2647 {
2648         struct vring_used_elem heads = {
2649                 cpu_to_vhost32(vq, head),
2650                 cpu_to_vhost32(vq, len)
2651         };
2652
2653         return vhost_add_used_n(vq, &heads, 1);
2654 }
2655 EXPORT_SYMBOL_GPL(vhost_add_used);
2656
2657 static int __vhost_add_used_n(struct vhost_virtqueue *vq,
2658                             struct vring_used_elem *heads,
2659                             unsigned count)
2660 {
2661         vring_used_elem_t __user *used;
2662         u16 old, new;
2663         int start;
2664
2665         start = vq->last_used_idx & (vq->num - 1);
2666         used = vq->used->ring + start;
2667         if (vhost_put_used(vq, heads, start, count)) {
2668                 vq_err(vq, "Failed to write used");
2669                 return -EFAULT;
2670         }
2671         if (unlikely(vq->log_used)) {
2672                 /* Make sure data is seen before log. */
2673                 smp_wmb();
2674                 /* Log used ring entry write. */
2675                 log_used(vq, ((void __user *)used - (void __user *)vq->used),
2676                          count * sizeof *used);
2677         }
2678         old = vq->last_used_idx;
2679         new = (vq->last_used_idx += count);
2680         /* If the driver never bothers to signal in a very long while,
2681          * used index might wrap around. If that happens, invalidate
2682          * signalled_used index we stored. TODO: make sure driver
2683          * signals at least once in 2^16 and remove this. */
2684         if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old)))
2685                 vq->signalled_used_valid = false;
2686         return 0;
2687 }
2688
2689 /* After we've used one of their buffers, we tell them about it.  We'll then
2690  * want to notify the guest, using eventfd. */
2691 int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
2692                      unsigned count)
2693 {
2694         int start, n, r;
2695
2696         start = vq->last_used_idx & (vq->num - 1);
2697         n = vq->num - start;
2698         if (n < count) {
2699                 r = __vhost_add_used_n(vq, heads, n);
2700                 if (r < 0)
2701                         return r;
2702                 heads += n;
2703                 count -= n;
2704         }
2705         r = __vhost_add_used_n(vq, heads, count);
2706
2707         /* Make sure buffer is written before we update index. */
2708         smp_wmb();
2709         if (vhost_put_used_idx(vq)) {
2710                 vq_err(vq, "Failed to increment used idx");
2711                 return -EFAULT;
2712         }
2713         if (unlikely(vq->log_used)) {
2714                 /* Make sure used idx is seen before log. */
2715                 smp_wmb();
2716                 /* Log used index update. */
2717                 log_used(vq, offsetof(struct vring_used, idx),
2718                          sizeof vq->used->idx);
2719                 if (vq->log_ctx)
2720                         eventfd_signal(vq->log_ctx, 1);
2721         }
2722         return r;
2723 }
2724 EXPORT_SYMBOL_GPL(vhost_add_used_n);
2725
2726 static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2727 {
2728         __u16 old, new;
2729         __virtio16 event;
2730         bool v;
2731         /* Flush out used index updates. This is paired
2732          * with the barrier that the Guest executes when enabling
2733          * interrupts. */
2734         smp_mb();
2735
2736         if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2737             unlikely(vq->avail_idx == vq->last_avail_idx))
2738                 return true;
2739
2740         if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2741                 __virtio16 flags;
2742                 if (vhost_get_avail_flags(vq, &flags)) {
2743                         vq_err(vq, "Failed to get flags");
2744                         return true;
2745                 }
2746                 return !(flags & cpu_to_vhost16(vq, VRING_AVAIL_F_NO_INTERRUPT));
2747         }
2748         old = vq->signalled_used;
2749         v = vq->signalled_used_valid;
2750         new = vq->signalled_used = vq->last_used_idx;
2751         vq->signalled_used_valid = true;
2752
2753         if (unlikely(!v))
2754                 return true;
2755
2756         if (vhost_get_used_event(vq, &event)) {
2757                 vq_err(vq, "Failed to get used event idx");
2758                 return true;
2759         }
2760         return vring_need_event(vhost16_to_cpu(vq, event), new, old);
2761 }
2762
2763 /* This actually signals the guest, using eventfd. */
2764 void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2765 {
2766         /* Signal the Guest tell them we used something up. */
2767         if (vq->call_ctx.ctx && vhost_notify(dev, vq))
2768                 eventfd_signal(vq->call_ctx.ctx, 1);
2769 }
2770 EXPORT_SYMBOL_GPL(vhost_signal);
2771
2772 /* And here's the combo meal deal.  Supersize me! */
2773 void vhost_add_used_and_signal(struct vhost_dev *dev,
2774                                struct vhost_virtqueue *vq,
2775                                unsigned int head, int len)
2776 {
2777         vhost_add_used(vq, head, len);
2778         vhost_signal(dev, vq);
2779 }
2780 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal);
2781
2782 /* multi-buffer version of vhost_add_used_and_signal */
2783 void vhost_add_used_and_signal_n(struct vhost_dev *dev,
2784                                  struct vhost_virtqueue *vq,
2785                                  struct vring_used_elem *heads, unsigned count)
2786 {
2787         vhost_add_used_n(vq, heads, count);
2788         vhost_signal(dev, vq);
2789 }
2790 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n);
2791
2792 /* return true if we're sure that avaiable ring is empty */
2793 bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2794 {
2795         __virtio16 avail_idx;
2796         int r;
2797
2798         if (vq->avail_idx != vq->last_avail_idx)
2799                 return false;
2800
2801         r = vhost_get_avail_idx(vq, &avail_idx);
2802         if (unlikely(r))
2803                 return false;
2804         vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2805
2806         return vq->avail_idx == vq->last_avail_idx;
2807 }
2808 EXPORT_SYMBOL_GPL(vhost_vq_avail_empty);
2809
2810 /* OK, now we need to know about added descriptors. */
2811 bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2812 {
2813         __virtio16 avail_idx;
2814         int r;
2815
2816         if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
2817                 return false;
2818         vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
2819         if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2820                 r = vhost_update_used_flags(vq);
2821                 if (r) {
2822                         vq_err(vq, "Failed to enable notification at %p: %d\n",
2823                                &vq->used->flags, r);
2824                         return false;
2825                 }
2826         } else {
2827                 r = vhost_update_avail_event(vq);
2828                 if (r) {
2829                         vq_err(vq, "Failed to update avail event index at %p: %d\n",
2830                                vhost_avail_event(vq), r);
2831                         return false;
2832                 }
2833         }
2834         /* They could have slipped one in as we were doing that: make
2835          * sure it's written, then check again. */
2836         smp_mb();
2837         r = vhost_get_avail_idx(vq, &avail_idx);
2838         if (r) {
2839                 vq_err(vq, "Failed to check avail idx at %p: %d\n",
2840                        &vq->avail->idx, r);
2841                 return false;
2842         }
2843         vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2844
2845         return vq->avail_idx != vq->last_avail_idx;
2846 }
2847 EXPORT_SYMBOL_GPL(vhost_enable_notify);
2848
2849 /* We don't need to be notified again. */
2850 void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2851 {
2852         int r;
2853
2854         if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
2855                 return;
2856         vq->used_flags |= VRING_USED_F_NO_NOTIFY;
2857         if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2858                 r = vhost_update_used_flags(vq);
2859                 if (r)
2860                         vq_err(vq, "Failed to disable notification at %p: %d\n",
2861                                &vq->used->flags, r);
2862         }
2863 }
2864 EXPORT_SYMBOL_GPL(vhost_disable_notify);
2865
2866 /* Create a new message. */
2867 struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
2868 {
2869         /* Make sure all padding within the structure is initialized. */
2870         struct vhost_msg_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
2871         if (!node)
2872                 return NULL;
2873
2874         node->vq = vq;
2875         node->msg.type = type;
2876         return node;
2877 }
2878 EXPORT_SYMBOL_GPL(vhost_new_msg);
2879
2880 void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
2881                        struct vhost_msg_node *node)
2882 {
2883         spin_lock(&dev->iotlb_lock);
2884         list_add_tail(&node->node, head);
2885         spin_unlock(&dev->iotlb_lock);
2886
2887         wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
2888 }
2889 EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
2890
2891 struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
2892                                          struct list_head *head)
2893 {
2894         struct vhost_msg_node *node = NULL;
2895
2896         spin_lock(&dev->iotlb_lock);
2897         if (!list_empty(head)) {
2898                 node = list_first_entry(head, struct vhost_msg_node,
2899                                         node);
2900                 list_del(&node->node);
2901         }
2902         spin_unlock(&dev->iotlb_lock);
2903
2904         return node;
2905 }
2906 EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
2907
2908 void vhost_set_backend_features(struct vhost_dev *dev, u64 features)
2909 {
2910         struct vhost_virtqueue *vq;
2911         int i;
2912
2913         mutex_lock(&dev->mutex);
2914         for (i = 0; i < dev->nvqs; ++i) {
2915                 vq = dev->vqs[i];
2916                 mutex_lock(&vq->mutex);
2917                 vq->acked_backend_features = features;
2918                 mutex_unlock(&vq->mutex);
2919         }
2920         mutex_unlock(&dev->mutex);
2921 }
2922 EXPORT_SYMBOL_GPL(vhost_set_backend_features);
2923
2924 static int __init vhost_init(void)
2925 {
2926         return 0;
2927 }
2928
2929 static void __exit vhost_exit(void)
2930 {
2931 }
2932
2933 module_init(vhost_init);
2934 module_exit(vhost_exit);
2935
2936 MODULE_VERSION("0.0.1");
2937 MODULE_LICENSE("GPL v2");
2938 MODULE_AUTHOR("Michael S. Tsirkin");
2939 MODULE_DESCRIPTION("Host kernel accelerator for virtio");