Merge tag 'rproc-v6.4' of git://git.kernel.org/pub/scm/linux/kernel/git/remoteproc...
[linux-block.git] / drivers / vhost / vdpa.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2018-2020 Intel Corporation.
4  * Copyright (C) 2020 Red Hat, Inc.
5  *
6  * Author: Tiwei Bie <tiwei.bie@intel.com>
7  *         Jason Wang <jasowang@redhat.com>
8  *
9  * Thanks Michael S. Tsirkin for the valuable comments and
10  * suggestions.  And thanks to Cunming Liang and Zhihong Wang for all
11  * their supports.
12  */
13
14 #include <linux/kernel.h>
15 #include <linux/module.h>
16 #include <linux/cdev.h>
17 #include <linux/device.h>
18 #include <linux/mm.h>
19 #include <linux/slab.h>
20 #include <linux/iommu.h>
21 #include <linux/uuid.h>
22 #include <linux/vdpa.h>
23 #include <linux/nospec.h>
24 #include <linux/vhost.h>
25
26 #include "vhost.h"
27
28 enum {
29         VHOST_VDPA_BACKEND_FEATURES =
30         (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) |
31         (1ULL << VHOST_BACKEND_F_IOTLB_BATCH) |
32         (1ULL << VHOST_BACKEND_F_IOTLB_ASID),
33 };
34
35 #define VHOST_VDPA_DEV_MAX (1U << MINORBITS)
36
37 #define VHOST_VDPA_IOTLB_BUCKETS 16
38
39 struct vhost_vdpa_as {
40         struct hlist_node hash_link;
41         struct vhost_iotlb iotlb;
42         u32 id;
43 };
44
45 struct vhost_vdpa {
46         struct vhost_dev vdev;
47         struct iommu_domain *domain;
48         struct vhost_virtqueue *vqs;
49         struct completion completion;
50         struct vdpa_device *vdpa;
51         struct hlist_head as[VHOST_VDPA_IOTLB_BUCKETS];
52         struct device dev;
53         struct cdev cdev;
54         atomic_t opened;
55         u32 nvqs;
56         int virtio_id;
57         int minor;
58         struct eventfd_ctx *config_ctx;
59         int in_batch;
60         struct vdpa_iova_range range;
61         u32 batch_asid;
62 };
63
64 static DEFINE_IDA(vhost_vdpa_ida);
65
66 static dev_t vhost_vdpa_major;
67
68 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
69                                    struct vhost_iotlb *iotlb, u64 start,
70                                    u64 last, u32 asid);
71
72 static inline u32 iotlb_to_asid(struct vhost_iotlb *iotlb)
73 {
74         struct vhost_vdpa_as *as = container_of(iotlb, struct
75                                                 vhost_vdpa_as, iotlb);
76         return as->id;
77 }
78
79 static struct vhost_vdpa_as *asid_to_as(struct vhost_vdpa *v, u32 asid)
80 {
81         struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
82         struct vhost_vdpa_as *as;
83
84         hlist_for_each_entry(as, head, hash_link)
85                 if (as->id == asid)
86                         return as;
87
88         return NULL;
89 }
90
91 static struct vhost_iotlb *asid_to_iotlb(struct vhost_vdpa *v, u32 asid)
92 {
93         struct vhost_vdpa_as *as = asid_to_as(v, asid);
94
95         if (!as)
96                 return NULL;
97
98         return &as->iotlb;
99 }
100
101 static struct vhost_vdpa_as *vhost_vdpa_alloc_as(struct vhost_vdpa *v, u32 asid)
102 {
103         struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
104         struct vhost_vdpa_as *as;
105
106         if (asid_to_as(v, asid))
107                 return NULL;
108
109         if (asid >= v->vdpa->nas)
110                 return NULL;
111
112         as = kmalloc(sizeof(*as), GFP_KERNEL);
113         if (!as)
114                 return NULL;
115
116         vhost_iotlb_init(&as->iotlb, 0, 0);
117         as->id = asid;
118         hlist_add_head(&as->hash_link, head);
119
120         return as;
121 }
122
123 static struct vhost_vdpa_as *vhost_vdpa_find_alloc_as(struct vhost_vdpa *v,
124                                                       u32 asid)
125 {
126         struct vhost_vdpa_as *as = asid_to_as(v, asid);
127
128         if (as)
129                 return as;
130
131         return vhost_vdpa_alloc_as(v, asid);
132 }
133
134 static int vhost_vdpa_remove_as(struct vhost_vdpa *v, u32 asid)
135 {
136         struct vhost_vdpa_as *as = asid_to_as(v, asid);
137
138         if (!as)
139                 return -EINVAL;
140
141         hlist_del(&as->hash_link);
142         vhost_vdpa_iotlb_unmap(v, &as->iotlb, 0ULL, 0ULL - 1, asid);
143         kfree(as);
144
145         return 0;
146 }
147
148 static void handle_vq_kick(struct vhost_work *work)
149 {
150         struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
151                                                   poll.work);
152         struct vhost_vdpa *v = container_of(vq->dev, struct vhost_vdpa, vdev);
153         const struct vdpa_config_ops *ops = v->vdpa->config;
154
155         ops->kick_vq(v->vdpa, vq - v->vqs);
156 }
157
158 static irqreturn_t vhost_vdpa_virtqueue_cb(void *private)
159 {
160         struct vhost_virtqueue *vq = private;
161         struct eventfd_ctx *call_ctx = vq->call_ctx.ctx;
162
163         if (call_ctx)
164                 eventfd_signal(call_ctx, 1);
165
166         return IRQ_HANDLED;
167 }
168
169 static irqreturn_t vhost_vdpa_config_cb(void *private)
170 {
171         struct vhost_vdpa *v = private;
172         struct eventfd_ctx *config_ctx = v->config_ctx;
173
174         if (config_ctx)
175                 eventfd_signal(config_ctx, 1);
176
177         return IRQ_HANDLED;
178 }
179
180 static void vhost_vdpa_setup_vq_irq(struct vhost_vdpa *v, u16 qid)
181 {
182         struct vhost_virtqueue *vq = &v->vqs[qid];
183         const struct vdpa_config_ops *ops = v->vdpa->config;
184         struct vdpa_device *vdpa = v->vdpa;
185         int ret, irq;
186
187         if (!ops->get_vq_irq)
188                 return;
189
190         irq = ops->get_vq_irq(vdpa, qid);
191         if (irq < 0)
192                 return;
193
194         irq_bypass_unregister_producer(&vq->call_ctx.producer);
195         if (!vq->call_ctx.ctx)
196                 return;
197
198         vq->call_ctx.producer.token = vq->call_ctx.ctx;
199         vq->call_ctx.producer.irq = irq;
200         ret = irq_bypass_register_producer(&vq->call_ctx.producer);
201         if (unlikely(ret))
202                 dev_info(&v->dev, "vq %u, irq bypass producer (token %p) registration fails, ret =  %d\n",
203                          qid, vq->call_ctx.producer.token, ret);
204 }
205
206 static void vhost_vdpa_unsetup_vq_irq(struct vhost_vdpa *v, u16 qid)
207 {
208         struct vhost_virtqueue *vq = &v->vqs[qid];
209
210         irq_bypass_unregister_producer(&vq->call_ctx.producer);
211 }
212
213 static int vhost_vdpa_reset(struct vhost_vdpa *v)
214 {
215         struct vdpa_device *vdpa = v->vdpa;
216
217         v->in_batch = 0;
218
219         return vdpa_reset(vdpa);
220 }
221
222 static long vhost_vdpa_get_device_id(struct vhost_vdpa *v, u8 __user *argp)
223 {
224         struct vdpa_device *vdpa = v->vdpa;
225         const struct vdpa_config_ops *ops = vdpa->config;
226         u32 device_id;
227
228         device_id = ops->get_device_id(vdpa);
229
230         if (copy_to_user(argp, &device_id, sizeof(device_id)))
231                 return -EFAULT;
232
233         return 0;
234 }
235
236 static long vhost_vdpa_get_status(struct vhost_vdpa *v, u8 __user *statusp)
237 {
238         struct vdpa_device *vdpa = v->vdpa;
239         const struct vdpa_config_ops *ops = vdpa->config;
240         u8 status;
241
242         status = ops->get_status(vdpa);
243
244         if (copy_to_user(statusp, &status, sizeof(status)))
245                 return -EFAULT;
246
247         return 0;
248 }
249
250 static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp)
251 {
252         struct vdpa_device *vdpa = v->vdpa;
253         const struct vdpa_config_ops *ops = vdpa->config;
254         u8 status, status_old;
255         u32 nvqs = v->nvqs;
256         int ret;
257         u16 i;
258
259         if (copy_from_user(&status, statusp, sizeof(status)))
260                 return -EFAULT;
261
262         status_old = ops->get_status(vdpa);
263
264         /*
265          * Userspace shouldn't remove status bits unless reset the
266          * status to 0.
267          */
268         if (status != 0 && (status_old & ~status) != 0)
269                 return -EINVAL;
270
271         if ((status_old & VIRTIO_CONFIG_S_DRIVER_OK) && !(status & VIRTIO_CONFIG_S_DRIVER_OK))
272                 for (i = 0; i < nvqs; i++)
273                         vhost_vdpa_unsetup_vq_irq(v, i);
274
275         if (status == 0) {
276                 ret = vdpa_reset(vdpa);
277                 if (ret)
278                         return ret;
279         } else
280                 vdpa_set_status(vdpa, status);
281
282         if ((status & VIRTIO_CONFIG_S_DRIVER_OK) && !(status_old & VIRTIO_CONFIG_S_DRIVER_OK))
283                 for (i = 0; i < nvqs; i++)
284                         vhost_vdpa_setup_vq_irq(v, i);
285
286         return 0;
287 }
288
289 static int vhost_vdpa_config_validate(struct vhost_vdpa *v,
290                                       struct vhost_vdpa_config *c)
291 {
292         struct vdpa_device *vdpa = v->vdpa;
293         size_t size = vdpa->config->get_config_size(vdpa);
294
295         if (c->len == 0 || c->off > size)
296                 return -EINVAL;
297
298         if (c->len > size - c->off)
299                 return -E2BIG;
300
301         return 0;
302 }
303
304 static long vhost_vdpa_get_config(struct vhost_vdpa *v,
305                                   struct vhost_vdpa_config __user *c)
306 {
307         struct vdpa_device *vdpa = v->vdpa;
308         struct vhost_vdpa_config config;
309         unsigned long size = offsetof(struct vhost_vdpa_config, buf);
310         u8 *buf;
311
312         if (copy_from_user(&config, c, size))
313                 return -EFAULT;
314         if (vhost_vdpa_config_validate(v, &config))
315                 return -EINVAL;
316         buf = kvzalloc(config.len, GFP_KERNEL);
317         if (!buf)
318                 return -ENOMEM;
319
320         vdpa_get_config(vdpa, config.off, buf, config.len);
321
322         if (copy_to_user(c->buf, buf, config.len)) {
323                 kvfree(buf);
324                 return -EFAULT;
325         }
326
327         kvfree(buf);
328         return 0;
329 }
330
331 static long vhost_vdpa_set_config(struct vhost_vdpa *v,
332                                   struct vhost_vdpa_config __user *c)
333 {
334         struct vdpa_device *vdpa = v->vdpa;
335         struct vhost_vdpa_config config;
336         unsigned long size = offsetof(struct vhost_vdpa_config, buf);
337         u8 *buf;
338
339         if (copy_from_user(&config, c, size))
340                 return -EFAULT;
341         if (vhost_vdpa_config_validate(v, &config))
342                 return -EINVAL;
343
344         buf = vmemdup_user(c->buf, config.len);
345         if (IS_ERR(buf))
346                 return PTR_ERR(buf);
347
348         vdpa_set_config(vdpa, config.off, buf, config.len);
349
350         kvfree(buf);
351         return 0;
352 }
353
354 static bool vhost_vdpa_can_suspend(const struct vhost_vdpa *v)
355 {
356         struct vdpa_device *vdpa = v->vdpa;
357         const struct vdpa_config_ops *ops = vdpa->config;
358
359         return ops->suspend;
360 }
361
362 static bool vhost_vdpa_can_resume(const struct vhost_vdpa *v)
363 {
364         struct vdpa_device *vdpa = v->vdpa;
365         const struct vdpa_config_ops *ops = vdpa->config;
366
367         return ops->resume;
368 }
369
370 static long vhost_vdpa_get_features(struct vhost_vdpa *v, u64 __user *featurep)
371 {
372         struct vdpa_device *vdpa = v->vdpa;
373         const struct vdpa_config_ops *ops = vdpa->config;
374         u64 features;
375
376         features = ops->get_device_features(vdpa);
377
378         if (copy_to_user(featurep, &features, sizeof(features)))
379                 return -EFAULT;
380
381         return 0;
382 }
383
384 static long vhost_vdpa_set_features(struct vhost_vdpa *v, u64 __user *featurep)
385 {
386         struct vdpa_device *vdpa = v->vdpa;
387         const struct vdpa_config_ops *ops = vdpa->config;
388         u64 features;
389
390         /*
391          * It's not allowed to change the features after they have
392          * been negotiated.
393          */
394         if (ops->get_status(vdpa) & VIRTIO_CONFIG_S_FEATURES_OK)
395                 return -EBUSY;
396
397         if (copy_from_user(&features, featurep, sizeof(features)))
398                 return -EFAULT;
399
400         if (vdpa_set_features(vdpa, features))
401                 return -EINVAL;
402
403         return 0;
404 }
405
406 static long vhost_vdpa_get_vring_num(struct vhost_vdpa *v, u16 __user *argp)
407 {
408         struct vdpa_device *vdpa = v->vdpa;
409         const struct vdpa_config_ops *ops = vdpa->config;
410         u16 num;
411
412         num = ops->get_vq_num_max(vdpa);
413
414         if (copy_to_user(argp, &num, sizeof(num)))
415                 return -EFAULT;
416
417         return 0;
418 }
419
420 static void vhost_vdpa_config_put(struct vhost_vdpa *v)
421 {
422         if (v->config_ctx) {
423                 eventfd_ctx_put(v->config_ctx);
424                 v->config_ctx = NULL;
425         }
426 }
427
428 static long vhost_vdpa_set_config_call(struct vhost_vdpa *v, u32 __user *argp)
429 {
430         struct vdpa_callback cb;
431         int fd;
432         struct eventfd_ctx *ctx;
433
434         cb.callback = vhost_vdpa_config_cb;
435         cb.private = v;
436         if (copy_from_user(&fd, argp, sizeof(fd)))
437                 return  -EFAULT;
438
439         ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
440         swap(ctx, v->config_ctx);
441
442         if (!IS_ERR_OR_NULL(ctx))
443                 eventfd_ctx_put(ctx);
444
445         if (IS_ERR(v->config_ctx)) {
446                 long ret = PTR_ERR(v->config_ctx);
447
448                 v->config_ctx = NULL;
449                 return ret;
450         }
451
452         v->vdpa->config->set_config_cb(v->vdpa, &cb);
453
454         return 0;
455 }
456
457 static long vhost_vdpa_get_iova_range(struct vhost_vdpa *v, u32 __user *argp)
458 {
459         struct vhost_vdpa_iova_range range = {
460                 .first = v->range.first,
461                 .last = v->range.last,
462         };
463
464         if (copy_to_user(argp, &range, sizeof(range)))
465                 return -EFAULT;
466         return 0;
467 }
468
469 static long vhost_vdpa_get_config_size(struct vhost_vdpa *v, u32 __user *argp)
470 {
471         struct vdpa_device *vdpa = v->vdpa;
472         const struct vdpa_config_ops *ops = vdpa->config;
473         u32 size;
474
475         size = ops->get_config_size(vdpa);
476
477         if (copy_to_user(argp, &size, sizeof(size)))
478                 return -EFAULT;
479
480         return 0;
481 }
482
483 static long vhost_vdpa_get_vqs_count(struct vhost_vdpa *v, u32 __user *argp)
484 {
485         struct vdpa_device *vdpa = v->vdpa;
486
487         if (copy_to_user(argp, &vdpa->nvqs, sizeof(vdpa->nvqs)))
488                 return -EFAULT;
489
490         return 0;
491 }
492
493 /* After a successful return of ioctl the device must not process more
494  * virtqueue descriptors. The device can answer to read or writes of config
495  * fields as if it were not suspended. In particular, writing to "queue_enable"
496  * with a value of 1 will not make the device start processing buffers.
497  */
498 static long vhost_vdpa_suspend(struct vhost_vdpa *v)
499 {
500         struct vdpa_device *vdpa = v->vdpa;
501         const struct vdpa_config_ops *ops = vdpa->config;
502
503         if (!ops->suspend)
504                 return -EOPNOTSUPP;
505
506         return ops->suspend(vdpa);
507 }
508
509 /* After a successful return of this ioctl the device resumes processing
510  * virtqueue descriptors. The device becomes fully operational the same way it
511  * was before it was suspended.
512  */
513 static long vhost_vdpa_resume(struct vhost_vdpa *v)
514 {
515         struct vdpa_device *vdpa = v->vdpa;
516         const struct vdpa_config_ops *ops = vdpa->config;
517
518         if (!ops->resume)
519                 return -EOPNOTSUPP;
520
521         return ops->resume(vdpa);
522 }
523
524 static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd,
525                                    void __user *argp)
526 {
527         struct vdpa_device *vdpa = v->vdpa;
528         const struct vdpa_config_ops *ops = vdpa->config;
529         struct vdpa_vq_state vq_state;
530         struct vdpa_callback cb;
531         struct vhost_virtqueue *vq;
532         struct vhost_vring_state s;
533         u32 idx;
534         long r;
535
536         r = get_user(idx, (u32 __user *)argp);
537         if (r < 0)
538                 return r;
539
540         if (idx >= v->nvqs)
541                 return -ENOBUFS;
542
543         idx = array_index_nospec(idx, v->nvqs);
544         vq = &v->vqs[idx];
545
546         switch (cmd) {
547         case VHOST_VDPA_SET_VRING_ENABLE:
548                 if (copy_from_user(&s, argp, sizeof(s)))
549                         return -EFAULT;
550                 ops->set_vq_ready(vdpa, idx, s.num);
551                 return 0;
552         case VHOST_VDPA_GET_VRING_GROUP:
553                 if (!ops->get_vq_group)
554                         return -EOPNOTSUPP;
555                 s.index = idx;
556                 s.num = ops->get_vq_group(vdpa, idx);
557                 if (s.num >= vdpa->ngroups)
558                         return -EIO;
559                 else if (copy_to_user(argp, &s, sizeof(s)))
560                         return -EFAULT;
561                 return 0;
562         case VHOST_VDPA_SET_GROUP_ASID:
563                 if (copy_from_user(&s, argp, sizeof(s)))
564                         return -EFAULT;
565                 if (s.num >= vdpa->nas)
566                         return -EINVAL;
567                 if (!ops->set_group_asid)
568                         return -EOPNOTSUPP;
569                 return ops->set_group_asid(vdpa, idx, s.num);
570         case VHOST_GET_VRING_BASE:
571                 r = ops->get_vq_state(v->vdpa, idx, &vq_state);
572                 if (r)
573                         return r;
574
575                 vq->last_avail_idx = vq_state.split.avail_index;
576                 break;
577         }
578
579         r = vhost_vring_ioctl(&v->vdev, cmd, argp);
580         if (r)
581                 return r;
582
583         switch (cmd) {
584         case VHOST_SET_VRING_ADDR:
585                 if (ops->set_vq_address(vdpa, idx,
586                                         (u64)(uintptr_t)vq->desc,
587                                         (u64)(uintptr_t)vq->avail,
588                                         (u64)(uintptr_t)vq->used))
589                         r = -EINVAL;
590                 break;
591
592         case VHOST_SET_VRING_BASE:
593                 vq_state.split.avail_index = vq->last_avail_idx;
594                 if (ops->set_vq_state(vdpa, idx, &vq_state))
595                         r = -EINVAL;
596                 break;
597
598         case VHOST_SET_VRING_CALL:
599                 if (vq->call_ctx.ctx) {
600                         cb.callback = vhost_vdpa_virtqueue_cb;
601                         cb.private = vq;
602                 } else {
603                         cb.callback = NULL;
604                         cb.private = NULL;
605                 }
606                 ops->set_vq_cb(vdpa, idx, &cb);
607                 vhost_vdpa_setup_vq_irq(v, idx);
608                 break;
609
610         case VHOST_SET_VRING_NUM:
611                 ops->set_vq_num(vdpa, idx, vq->num);
612                 break;
613         }
614
615         return r;
616 }
617
618 static long vhost_vdpa_unlocked_ioctl(struct file *filep,
619                                       unsigned int cmd, unsigned long arg)
620 {
621         struct vhost_vdpa *v = filep->private_data;
622         struct vhost_dev *d = &v->vdev;
623         void __user *argp = (void __user *)arg;
624         u64 __user *featurep = argp;
625         u64 features;
626         long r = 0;
627
628         if (cmd == VHOST_SET_BACKEND_FEATURES) {
629                 if (copy_from_user(&features, featurep, sizeof(features)))
630                         return -EFAULT;
631                 if (features & ~(VHOST_VDPA_BACKEND_FEATURES |
632                                  BIT_ULL(VHOST_BACKEND_F_SUSPEND) |
633                                  BIT_ULL(VHOST_BACKEND_F_RESUME)))
634                         return -EOPNOTSUPP;
635                 if ((features & BIT_ULL(VHOST_BACKEND_F_SUSPEND)) &&
636                      !vhost_vdpa_can_suspend(v))
637                         return -EOPNOTSUPP;
638                 if ((features & BIT_ULL(VHOST_BACKEND_F_RESUME)) &&
639                      !vhost_vdpa_can_resume(v))
640                         return -EOPNOTSUPP;
641                 vhost_set_backend_features(&v->vdev, features);
642                 return 0;
643         }
644
645         mutex_lock(&d->mutex);
646
647         switch (cmd) {
648         case VHOST_VDPA_GET_DEVICE_ID:
649                 r = vhost_vdpa_get_device_id(v, argp);
650                 break;
651         case VHOST_VDPA_GET_STATUS:
652                 r = vhost_vdpa_get_status(v, argp);
653                 break;
654         case VHOST_VDPA_SET_STATUS:
655                 r = vhost_vdpa_set_status(v, argp);
656                 break;
657         case VHOST_VDPA_GET_CONFIG:
658                 r = vhost_vdpa_get_config(v, argp);
659                 break;
660         case VHOST_VDPA_SET_CONFIG:
661                 r = vhost_vdpa_set_config(v, argp);
662                 break;
663         case VHOST_GET_FEATURES:
664                 r = vhost_vdpa_get_features(v, argp);
665                 break;
666         case VHOST_SET_FEATURES:
667                 r = vhost_vdpa_set_features(v, argp);
668                 break;
669         case VHOST_VDPA_GET_VRING_NUM:
670                 r = vhost_vdpa_get_vring_num(v, argp);
671                 break;
672         case VHOST_VDPA_GET_GROUP_NUM:
673                 if (copy_to_user(argp, &v->vdpa->ngroups,
674                                  sizeof(v->vdpa->ngroups)))
675                         r = -EFAULT;
676                 break;
677         case VHOST_VDPA_GET_AS_NUM:
678                 if (copy_to_user(argp, &v->vdpa->nas, sizeof(v->vdpa->nas)))
679                         r = -EFAULT;
680                 break;
681         case VHOST_SET_LOG_BASE:
682         case VHOST_SET_LOG_FD:
683                 r = -ENOIOCTLCMD;
684                 break;
685         case VHOST_VDPA_SET_CONFIG_CALL:
686                 r = vhost_vdpa_set_config_call(v, argp);
687                 break;
688         case VHOST_GET_BACKEND_FEATURES:
689                 features = VHOST_VDPA_BACKEND_FEATURES;
690                 if (vhost_vdpa_can_suspend(v))
691                         features |= BIT_ULL(VHOST_BACKEND_F_SUSPEND);
692                 if (vhost_vdpa_can_resume(v))
693                         features |= BIT_ULL(VHOST_BACKEND_F_RESUME);
694                 if (copy_to_user(featurep, &features, sizeof(features)))
695                         r = -EFAULT;
696                 break;
697         case VHOST_VDPA_GET_IOVA_RANGE:
698                 r = vhost_vdpa_get_iova_range(v, argp);
699                 break;
700         case VHOST_VDPA_GET_CONFIG_SIZE:
701                 r = vhost_vdpa_get_config_size(v, argp);
702                 break;
703         case VHOST_VDPA_GET_VQS_COUNT:
704                 r = vhost_vdpa_get_vqs_count(v, argp);
705                 break;
706         case VHOST_VDPA_SUSPEND:
707                 r = vhost_vdpa_suspend(v);
708                 break;
709         case VHOST_VDPA_RESUME:
710                 r = vhost_vdpa_resume(v);
711                 break;
712         default:
713                 r = vhost_dev_ioctl(&v->vdev, cmd, argp);
714                 if (r == -ENOIOCTLCMD)
715                         r = vhost_vdpa_vring_ioctl(v, cmd, argp);
716                 break;
717         }
718
719         mutex_unlock(&d->mutex);
720         return r;
721 }
722 static void vhost_vdpa_general_unmap(struct vhost_vdpa *v,
723                                      struct vhost_iotlb_map *map, u32 asid)
724 {
725         struct vdpa_device *vdpa = v->vdpa;
726         const struct vdpa_config_ops *ops = vdpa->config;
727         if (ops->dma_map) {
728                 ops->dma_unmap(vdpa, asid, map->start, map->size);
729         } else if (ops->set_map == NULL) {
730                 iommu_unmap(v->domain, map->start, map->size);
731         }
732 }
733
734 static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
735                                 u64 start, u64 last, u32 asid)
736 {
737         struct vhost_dev *dev = &v->vdev;
738         struct vhost_iotlb_map *map;
739         struct page *page;
740         unsigned long pfn, pinned;
741
742         while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
743                 pinned = PFN_DOWN(map->size);
744                 for (pfn = PFN_DOWN(map->addr);
745                      pinned > 0; pfn++, pinned--) {
746                         page = pfn_to_page(pfn);
747                         if (map->perm & VHOST_ACCESS_WO)
748                                 set_page_dirty_lock(page);
749                         unpin_user_page(page);
750                 }
751                 atomic64_sub(PFN_DOWN(map->size), &dev->mm->pinned_vm);
752                 vhost_vdpa_general_unmap(v, map, asid);
753                 vhost_iotlb_map_free(iotlb, map);
754         }
755 }
756
757 static void vhost_vdpa_va_unmap(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
758                                 u64 start, u64 last, u32 asid)
759 {
760         struct vhost_iotlb_map *map;
761         struct vdpa_map_file *map_file;
762
763         while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
764                 map_file = (struct vdpa_map_file *)map->opaque;
765                 fput(map_file->file);
766                 kfree(map_file);
767                 vhost_vdpa_general_unmap(v, map, asid);
768                 vhost_iotlb_map_free(iotlb, map);
769         }
770 }
771
772 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
773                                    struct vhost_iotlb *iotlb, u64 start,
774                                    u64 last, u32 asid)
775 {
776         struct vdpa_device *vdpa = v->vdpa;
777
778         if (vdpa->use_va)
779                 return vhost_vdpa_va_unmap(v, iotlb, start, last, asid);
780
781         return vhost_vdpa_pa_unmap(v, iotlb, start, last, asid);
782 }
783
784 static int perm_to_iommu_flags(u32 perm)
785 {
786         int flags = 0;
787
788         switch (perm) {
789         case VHOST_ACCESS_WO:
790                 flags |= IOMMU_WRITE;
791                 break;
792         case VHOST_ACCESS_RO:
793                 flags |= IOMMU_READ;
794                 break;
795         case VHOST_ACCESS_RW:
796                 flags |= (IOMMU_WRITE | IOMMU_READ);
797                 break;
798         default:
799                 WARN(1, "invalidate vhost IOTLB permission\n");
800                 break;
801         }
802
803         return flags | IOMMU_CACHE;
804 }
805
806 static int vhost_vdpa_map(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
807                           u64 iova, u64 size, u64 pa, u32 perm, void *opaque)
808 {
809         struct vhost_dev *dev = &v->vdev;
810         struct vdpa_device *vdpa = v->vdpa;
811         const struct vdpa_config_ops *ops = vdpa->config;
812         u32 asid = iotlb_to_asid(iotlb);
813         int r = 0;
814
815         r = vhost_iotlb_add_range_ctx(iotlb, iova, iova + size - 1,
816                                       pa, perm, opaque);
817         if (r)
818                 return r;
819
820         if (ops->dma_map) {
821                 r = ops->dma_map(vdpa, asid, iova, size, pa, perm, opaque);
822         } else if (ops->set_map) {
823                 if (!v->in_batch)
824                         r = ops->set_map(vdpa, asid, iotlb);
825         } else {
826                 r = iommu_map(v->domain, iova, pa, size,
827                               perm_to_iommu_flags(perm), GFP_KERNEL);
828         }
829         if (r) {
830                 vhost_iotlb_del_range(iotlb, iova, iova + size - 1);
831                 return r;
832         }
833
834         if (!vdpa->use_va)
835                 atomic64_add(PFN_DOWN(size), &dev->mm->pinned_vm);
836
837         return 0;
838 }
839
840 static void vhost_vdpa_unmap(struct vhost_vdpa *v,
841                              struct vhost_iotlb *iotlb,
842                              u64 iova, u64 size)
843 {
844         struct vdpa_device *vdpa = v->vdpa;
845         const struct vdpa_config_ops *ops = vdpa->config;
846         u32 asid = iotlb_to_asid(iotlb);
847
848         vhost_vdpa_iotlb_unmap(v, iotlb, iova, iova + size - 1, asid);
849
850         if (ops->set_map) {
851                 if (!v->in_batch)
852                         ops->set_map(vdpa, asid, iotlb);
853         }
854         /* If we are in the middle of batch processing, delay the free
855          * of AS until BATCH_END.
856          */
857         if (!v->in_batch && !iotlb->nmaps)
858                 vhost_vdpa_remove_as(v, asid);
859 }
860
861 static int vhost_vdpa_va_map(struct vhost_vdpa *v,
862                              struct vhost_iotlb *iotlb,
863                              u64 iova, u64 size, u64 uaddr, u32 perm)
864 {
865         struct vhost_dev *dev = &v->vdev;
866         u64 offset, map_size, map_iova = iova;
867         struct vdpa_map_file *map_file;
868         struct vm_area_struct *vma;
869         int ret = 0;
870
871         mmap_read_lock(dev->mm);
872
873         while (size) {
874                 vma = find_vma(dev->mm, uaddr);
875                 if (!vma) {
876                         ret = -EINVAL;
877                         break;
878                 }
879                 map_size = min(size, vma->vm_end - uaddr);
880                 if (!(vma->vm_file && (vma->vm_flags & VM_SHARED) &&
881                         !(vma->vm_flags & (VM_IO | VM_PFNMAP))))
882                         goto next;
883
884                 map_file = kzalloc(sizeof(*map_file), GFP_KERNEL);
885                 if (!map_file) {
886                         ret = -ENOMEM;
887                         break;
888                 }
889                 offset = (vma->vm_pgoff << PAGE_SHIFT) + uaddr - vma->vm_start;
890                 map_file->offset = offset;
891                 map_file->file = get_file(vma->vm_file);
892                 ret = vhost_vdpa_map(v, iotlb, map_iova, map_size, uaddr,
893                                      perm, map_file);
894                 if (ret) {
895                         fput(map_file->file);
896                         kfree(map_file);
897                         break;
898                 }
899 next:
900                 size -= map_size;
901                 uaddr += map_size;
902                 map_iova += map_size;
903         }
904         if (ret)
905                 vhost_vdpa_unmap(v, iotlb, iova, map_iova - iova);
906
907         mmap_read_unlock(dev->mm);
908
909         return ret;
910 }
911
912 static int vhost_vdpa_pa_map(struct vhost_vdpa *v,
913                              struct vhost_iotlb *iotlb,
914                              u64 iova, u64 size, u64 uaddr, u32 perm)
915 {
916         struct vhost_dev *dev = &v->vdev;
917         struct page **page_list;
918         unsigned long list_size = PAGE_SIZE / sizeof(struct page *);
919         unsigned int gup_flags = FOLL_LONGTERM;
920         unsigned long npages, cur_base, map_pfn, last_pfn = 0;
921         unsigned long lock_limit, sz2pin, nchunks, i;
922         u64 start = iova;
923         long pinned;
924         int ret = 0;
925
926         /* Limit the use of memory for bookkeeping */
927         page_list = (struct page **) __get_free_page(GFP_KERNEL);
928         if (!page_list)
929                 return -ENOMEM;
930
931         if (perm & VHOST_ACCESS_WO)
932                 gup_flags |= FOLL_WRITE;
933
934         npages = PFN_UP(size + (iova & ~PAGE_MASK));
935         if (!npages) {
936                 ret = -EINVAL;
937                 goto free;
938         }
939
940         mmap_read_lock(dev->mm);
941
942         lock_limit = PFN_DOWN(rlimit(RLIMIT_MEMLOCK));
943         if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
944                 ret = -ENOMEM;
945                 goto unlock;
946         }
947
948         cur_base = uaddr & PAGE_MASK;
949         iova &= PAGE_MASK;
950         nchunks = 0;
951
952         while (npages) {
953                 sz2pin = min_t(unsigned long, npages, list_size);
954                 pinned = pin_user_pages(cur_base, sz2pin,
955                                         gup_flags, page_list, NULL);
956                 if (sz2pin != pinned) {
957                         if (pinned < 0) {
958                                 ret = pinned;
959                         } else {
960                                 unpin_user_pages(page_list, pinned);
961                                 ret = -ENOMEM;
962                         }
963                         goto out;
964                 }
965                 nchunks++;
966
967                 if (!last_pfn)
968                         map_pfn = page_to_pfn(page_list[0]);
969
970                 for (i = 0; i < pinned; i++) {
971                         unsigned long this_pfn = page_to_pfn(page_list[i]);
972                         u64 csize;
973
974                         if (last_pfn && (this_pfn != last_pfn + 1)) {
975                                 /* Pin a contiguous chunk of memory */
976                                 csize = PFN_PHYS(last_pfn - map_pfn + 1);
977                                 ret = vhost_vdpa_map(v, iotlb, iova, csize,
978                                                      PFN_PHYS(map_pfn),
979                                                      perm, NULL);
980                                 if (ret) {
981                                         /*
982                                          * Unpin the pages that are left unmapped
983                                          * from this point on in the current
984                                          * page_list. The remaining outstanding
985                                          * ones which may stride across several
986                                          * chunks will be covered in the common
987                                          * error path subsequently.
988                                          */
989                                         unpin_user_pages(&page_list[i],
990                                                          pinned - i);
991                                         goto out;
992                                 }
993
994                                 map_pfn = this_pfn;
995                                 iova += csize;
996                                 nchunks = 0;
997                         }
998
999                         last_pfn = this_pfn;
1000                 }
1001
1002                 cur_base += PFN_PHYS(pinned);
1003                 npages -= pinned;
1004         }
1005
1006         /* Pin the rest chunk */
1007         ret = vhost_vdpa_map(v, iotlb, iova, PFN_PHYS(last_pfn - map_pfn + 1),
1008                              PFN_PHYS(map_pfn), perm, NULL);
1009 out:
1010         if (ret) {
1011                 if (nchunks) {
1012                         unsigned long pfn;
1013
1014                         /*
1015                          * Unpin the outstanding pages which are yet to be
1016                          * mapped but haven't due to vdpa_map() or
1017                          * pin_user_pages() failure.
1018                          *
1019                          * Mapped pages are accounted in vdpa_map(), hence
1020                          * the corresponding unpinning will be handled by
1021                          * vdpa_unmap().
1022                          */
1023                         WARN_ON(!last_pfn);
1024                         for (pfn = map_pfn; pfn <= last_pfn; pfn++)
1025                                 unpin_user_page(pfn_to_page(pfn));
1026                 }
1027                 vhost_vdpa_unmap(v, iotlb, start, size);
1028         }
1029 unlock:
1030         mmap_read_unlock(dev->mm);
1031 free:
1032         free_page((unsigned long)page_list);
1033         return ret;
1034
1035 }
1036
1037 static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
1038                                            struct vhost_iotlb *iotlb,
1039                                            struct vhost_iotlb_msg *msg)
1040 {
1041         struct vdpa_device *vdpa = v->vdpa;
1042
1043         if (msg->iova < v->range.first || !msg->size ||
1044             msg->iova > U64_MAX - msg->size + 1 ||
1045             msg->iova + msg->size - 1 > v->range.last)
1046                 return -EINVAL;
1047
1048         if (vhost_iotlb_itree_first(iotlb, msg->iova,
1049                                     msg->iova + msg->size - 1))
1050                 return -EEXIST;
1051
1052         if (vdpa->use_va)
1053                 return vhost_vdpa_va_map(v, iotlb, msg->iova, msg->size,
1054                                          msg->uaddr, msg->perm);
1055
1056         return vhost_vdpa_pa_map(v, iotlb, msg->iova, msg->size, msg->uaddr,
1057                                  msg->perm);
1058 }
1059
1060 static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
1061                                         struct vhost_iotlb_msg *msg)
1062 {
1063         struct vhost_vdpa *v = container_of(dev, struct vhost_vdpa, vdev);
1064         struct vdpa_device *vdpa = v->vdpa;
1065         const struct vdpa_config_ops *ops = vdpa->config;
1066         struct vhost_iotlb *iotlb = NULL;
1067         struct vhost_vdpa_as *as = NULL;
1068         int r = 0;
1069
1070         mutex_lock(&dev->mutex);
1071
1072         r = vhost_dev_check_owner(dev);
1073         if (r)
1074                 goto unlock;
1075
1076         if (msg->type == VHOST_IOTLB_UPDATE ||
1077             msg->type == VHOST_IOTLB_BATCH_BEGIN) {
1078                 as = vhost_vdpa_find_alloc_as(v, asid);
1079                 if (!as) {
1080                         dev_err(&v->dev, "can't find and alloc asid %d\n",
1081                                 asid);
1082                         r = -EINVAL;
1083                         goto unlock;
1084                 }
1085                 iotlb = &as->iotlb;
1086         } else
1087                 iotlb = asid_to_iotlb(v, asid);
1088
1089         if ((v->in_batch && v->batch_asid != asid) || !iotlb) {
1090                 if (v->in_batch && v->batch_asid != asid) {
1091                         dev_info(&v->dev, "batch id %d asid %d\n",
1092                                  v->batch_asid, asid);
1093                 }
1094                 if (!iotlb)
1095                         dev_err(&v->dev, "no iotlb for asid %d\n", asid);
1096                 r = -EINVAL;
1097                 goto unlock;
1098         }
1099
1100         switch (msg->type) {
1101         case VHOST_IOTLB_UPDATE:
1102                 r = vhost_vdpa_process_iotlb_update(v, iotlb, msg);
1103                 break;
1104         case VHOST_IOTLB_INVALIDATE:
1105                 vhost_vdpa_unmap(v, iotlb, msg->iova, msg->size);
1106                 break;
1107         case VHOST_IOTLB_BATCH_BEGIN:
1108                 v->batch_asid = asid;
1109                 v->in_batch = true;
1110                 break;
1111         case VHOST_IOTLB_BATCH_END:
1112                 if (v->in_batch && ops->set_map)
1113                         ops->set_map(vdpa, asid, iotlb);
1114                 v->in_batch = false;
1115                 if (!iotlb->nmaps)
1116                         vhost_vdpa_remove_as(v, asid);
1117                 break;
1118         default:
1119                 r = -EINVAL;
1120                 break;
1121         }
1122 unlock:
1123         mutex_unlock(&dev->mutex);
1124
1125         return r;
1126 }
1127
1128 static ssize_t vhost_vdpa_chr_write_iter(struct kiocb *iocb,
1129                                          struct iov_iter *from)
1130 {
1131         struct file *file = iocb->ki_filp;
1132         struct vhost_vdpa *v = file->private_data;
1133         struct vhost_dev *dev = &v->vdev;
1134
1135         return vhost_chr_write_iter(dev, from);
1136 }
1137
1138 static int vhost_vdpa_alloc_domain(struct vhost_vdpa *v)
1139 {
1140         struct vdpa_device *vdpa = v->vdpa;
1141         const struct vdpa_config_ops *ops = vdpa->config;
1142         struct device *dma_dev = vdpa_get_dma_dev(vdpa);
1143         struct bus_type *bus;
1144         int ret;
1145
1146         /* Device want to do DMA by itself */
1147         if (ops->set_map || ops->dma_map)
1148                 return 0;
1149
1150         bus = dma_dev->bus;
1151         if (!bus)
1152                 return -EFAULT;
1153
1154         if (!device_iommu_capable(dma_dev, IOMMU_CAP_CACHE_COHERENCY)) {
1155                 dev_warn_once(&v->dev,
1156                               "Failed to allocate domain, device is not IOMMU cache coherent capable\n");
1157                 return -ENOTSUPP;
1158         }
1159
1160         v->domain = iommu_domain_alloc(bus);
1161         if (!v->domain)
1162                 return -EIO;
1163
1164         ret = iommu_attach_device(v->domain, dma_dev);
1165         if (ret)
1166                 goto err_attach;
1167
1168         return 0;
1169
1170 err_attach:
1171         iommu_domain_free(v->domain);
1172         v->domain = NULL;
1173         return ret;
1174 }
1175
1176 static void vhost_vdpa_free_domain(struct vhost_vdpa *v)
1177 {
1178         struct vdpa_device *vdpa = v->vdpa;
1179         struct device *dma_dev = vdpa_get_dma_dev(vdpa);
1180
1181         if (v->domain) {
1182                 iommu_detach_device(v->domain, dma_dev);
1183                 iommu_domain_free(v->domain);
1184         }
1185
1186         v->domain = NULL;
1187 }
1188
1189 static void vhost_vdpa_set_iova_range(struct vhost_vdpa *v)
1190 {
1191         struct vdpa_iova_range *range = &v->range;
1192         struct vdpa_device *vdpa = v->vdpa;
1193         const struct vdpa_config_ops *ops = vdpa->config;
1194
1195         if (ops->get_iova_range) {
1196                 *range = ops->get_iova_range(vdpa);
1197         } else if (v->domain && v->domain->geometry.force_aperture) {
1198                 range->first = v->domain->geometry.aperture_start;
1199                 range->last = v->domain->geometry.aperture_end;
1200         } else {
1201                 range->first = 0;
1202                 range->last = ULLONG_MAX;
1203         }
1204 }
1205
1206 static void vhost_vdpa_cleanup(struct vhost_vdpa *v)
1207 {
1208         struct vhost_vdpa_as *as;
1209         u32 asid;
1210
1211         for (asid = 0; asid < v->vdpa->nas; asid++) {
1212                 as = asid_to_as(v, asid);
1213                 if (as)
1214                         vhost_vdpa_remove_as(v, asid);
1215         }
1216
1217         vhost_vdpa_free_domain(v);
1218         vhost_dev_cleanup(&v->vdev);
1219         kfree(v->vdev.vqs);
1220 }
1221
1222 static int vhost_vdpa_open(struct inode *inode, struct file *filep)
1223 {
1224         struct vhost_vdpa *v;
1225         struct vhost_dev *dev;
1226         struct vhost_virtqueue **vqs;
1227         int r, opened;
1228         u32 i, nvqs;
1229
1230         v = container_of(inode->i_cdev, struct vhost_vdpa, cdev);
1231
1232         opened = atomic_cmpxchg(&v->opened, 0, 1);
1233         if (opened)
1234                 return -EBUSY;
1235
1236         nvqs = v->nvqs;
1237         r = vhost_vdpa_reset(v);
1238         if (r)
1239                 goto err;
1240
1241         vqs = kmalloc_array(nvqs, sizeof(*vqs), GFP_KERNEL);
1242         if (!vqs) {
1243                 r = -ENOMEM;
1244                 goto err;
1245         }
1246
1247         dev = &v->vdev;
1248         for (i = 0; i < nvqs; i++) {
1249                 vqs[i] = &v->vqs[i];
1250                 vqs[i]->handle_kick = handle_vq_kick;
1251         }
1252         vhost_dev_init(dev, vqs, nvqs, 0, 0, 0, false,
1253                        vhost_vdpa_process_iotlb_msg);
1254
1255         r = vhost_vdpa_alloc_domain(v);
1256         if (r)
1257                 goto err_alloc_domain;
1258
1259         vhost_vdpa_set_iova_range(v);
1260
1261         filep->private_data = v;
1262
1263         return 0;
1264
1265 err_alloc_domain:
1266         vhost_vdpa_cleanup(v);
1267 err:
1268         atomic_dec(&v->opened);
1269         return r;
1270 }
1271
1272 static void vhost_vdpa_clean_irq(struct vhost_vdpa *v)
1273 {
1274         u32 i;
1275
1276         for (i = 0; i < v->nvqs; i++)
1277                 vhost_vdpa_unsetup_vq_irq(v, i);
1278 }
1279
1280 static int vhost_vdpa_release(struct inode *inode, struct file *filep)
1281 {
1282         struct vhost_vdpa *v = filep->private_data;
1283         struct vhost_dev *d = &v->vdev;
1284
1285         mutex_lock(&d->mutex);
1286         filep->private_data = NULL;
1287         vhost_vdpa_clean_irq(v);
1288         vhost_vdpa_reset(v);
1289         vhost_dev_stop(&v->vdev);
1290         vhost_vdpa_config_put(v);
1291         vhost_vdpa_cleanup(v);
1292         mutex_unlock(&d->mutex);
1293
1294         atomic_dec(&v->opened);
1295         complete(&v->completion);
1296
1297         return 0;
1298 }
1299
1300 #ifdef CONFIG_MMU
1301 static vm_fault_t vhost_vdpa_fault(struct vm_fault *vmf)
1302 {
1303         struct vhost_vdpa *v = vmf->vma->vm_file->private_data;
1304         struct vdpa_device *vdpa = v->vdpa;
1305         const struct vdpa_config_ops *ops = vdpa->config;
1306         struct vdpa_notification_area notify;
1307         struct vm_area_struct *vma = vmf->vma;
1308         u16 index = vma->vm_pgoff;
1309
1310         notify = ops->get_vq_notification(vdpa, index);
1311
1312         vma->vm_page_prot = pgprot_noncached(vma->vm_page_prot);
1313         if (remap_pfn_range(vma, vmf->address & PAGE_MASK,
1314                             PFN_DOWN(notify.addr), PAGE_SIZE,
1315                             vma->vm_page_prot))
1316                 return VM_FAULT_SIGBUS;
1317
1318         return VM_FAULT_NOPAGE;
1319 }
1320
1321 static const struct vm_operations_struct vhost_vdpa_vm_ops = {
1322         .fault = vhost_vdpa_fault,
1323 };
1324
1325 static int vhost_vdpa_mmap(struct file *file, struct vm_area_struct *vma)
1326 {
1327         struct vhost_vdpa *v = vma->vm_file->private_data;
1328         struct vdpa_device *vdpa = v->vdpa;
1329         const struct vdpa_config_ops *ops = vdpa->config;
1330         struct vdpa_notification_area notify;
1331         unsigned long index = vma->vm_pgoff;
1332
1333         if (vma->vm_end - vma->vm_start != PAGE_SIZE)
1334                 return -EINVAL;
1335         if ((vma->vm_flags & VM_SHARED) == 0)
1336                 return -EINVAL;
1337         if (vma->vm_flags & VM_READ)
1338                 return -EINVAL;
1339         if (index > 65535)
1340                 return -EINVAL;
1341         if (!ops->get_vq_notification)
1342                 return -ENOTSUPP;
1343
1344         /* To be safe and easily modelled by userspace, We only
1345          * support the doorbell which sits on the page boundary and
1346          * does not share the page with other registers.
1347          */
1348         notify = ops->get_vq_notification(vdpa, index);
1349         if (notify.addr & (PAGE_SIZE - 1))
1350                 return -EINVAL;
1351         if (vma->vm_end - vma->vm_start != notify.size)
1352                 return -ENOTSUPP;
1353
1354         vm_flags_set(vma, VM_IO | VM_PFNMAP | VM_DONTEXPAND | VM_DONTDUMP);
1355         vma->vm_ops = &vhost_vdpa_vm_ops;
1356         return 0;
1357 }
1358 #endif /* CONFIG_MMU */
1359
1360 static const struct file_operations vhost_vdpa_fops = {
1361         .owner          = THIS_MODULE,
1362         .open           = vhost_vdpa_open,
1363         .release        = vhost_vdpa_release,
1364         .write_iter     = vhost_vdpa_chr_write_iter,
1365         .unlocked_ioctl = vhost_vdpa_unlocked_ioctl,
1366 #ifdef CONFIG_MMU
1367         .mmap           = vhost_vdpa_mmap,
1368 #endif /* CONFIG_MMU */
1369         .compat_ioctl   = compat_ptr_ioctl,
1370 };
1371
1372 static void vhost_vdpa_release_dev(struct device *device)
1373 {
1374         struct vhost_vdpa *v =
1375                container_of(device, struct vhost_vdpa, dev);
1376
1377         ida_simple_remove(&vhost_vdpa_ida, v->minor);
1378         kfree(v->vqs);
1379         kfree(v);
1380 }
1381
1382 static int vhost_vdpa_probe(struct vdpa_device *vdpa)
1383 {
1384         const struct vdpa_config_ops *ops = vdpa->config;
1385         struct vhost_vdpa *v;
1386         int minor;
1387         int i, r;
1388
1389         /* We can't support platform IOMMU device with more than 1
1390          * group or as
1391          */
1392         if (!ops->set_map && !ops->dma_map &&
1393             (vdpa->ngroups > 1 || vdpa->nas > 1))
1394                 return -EOPNOTSUPP;
1395
1396         v = kzalloc(sizeof(*v), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
1397         if (!v)
1398                 return -ENOMEM;
1399
1400         minor = ida_simple_get(&vhost_vdpa_ida, 0,
1401                                VHOST_VDPA_DEV_MAX, GFP_KERNEL);
1402         if (minor < 0) {
1403                 kfree(v);
1404                 return minor;
1405         }
1406
1407         atomic_set(&v->opened, 0);
1408         v->minor = minor;
1409         v->vdpa = vdpa;
1410         v->nvqs = vdpa->nvqs;
1411         v->virtio_id = ops->get_device_id(vdpa);
1412
1413         device_initialize(&v->dev);
1414         v->dev.release = vhost_vdpa_release_dev;
1415         v->dev.parent = &vdpa->dev;
1416         v->dev.devt = MKDEV(MAJOR(vhost_vdpa_major), minor);
1417         v->vqs = kmalloc_array(v->nvqs, sizeof(struct vhost_virtqueue),
1418                                GFP_KERNEL);
1419         if (!v->vqs) {
1420                 r = -ENOMEM;
1421                 goto err;
1422         }
1423
1424         r = dev_set_name(&v->dev, "vhost-vdpa-%u", minor);
1425         if (r)
1426                 goto err;
1427
1428         cdev_init(&v->cdev, &vhost_vdpa_fops);
1429         v->cdev.owner = THIS_MODULE;
1430
1431         r = cdev_device_add(&v->cdev, &v->dev);
1432         if (r)
1433                 goto err;
1434
1435         init_completion(&v->completion);
1436         vdpa_set_drvdata(vdpa, v);
1437
1438         for (i = 0; i < VHOST_VDPA_IOTLB_BUCKETS; i++)
1439                 INIT_HLIST_HEAD(&v->as[i]);
1440
1441         return 0;
1442
1443 err:
1444         put_device(&v->dev);
1445         ida_simple_remove(&vhost_vdpa_ida, v->minor);
1446         return r;
1447 }
1448
1449 static void vhost_vdpa_remove(struct vdpa_device *vdpa)
1450 {
1451         struct vhost_vdpa *v = vdpa_get_drvdata(vdpa);
1452         int opened;
1453
1454         cdev_device_del(&v->cdev, &v->dev);
1455
1456         do {
1457                 opened = atomic_cmpxchg(&v->opened, 0, 1);
1458                 if (!opened)
1459                         break;
1460                 wait_for_completion(&v->completion);
1461         } while (1);
1462
1463         put_device(&v->dev);
1464 }
1465
1466 static struct vdpa_driver vhost_vdpa_driver = {
1467         .driver = {
1468                 .name   = "vhost_vdpa",
1469         },
1470         .probe  = vhost_vdpa_probe,
1471         .remove = vhost_vdpa_remove,
1472 };
1473
1474 static int __init vhost_vdpa_init(void)
1475 {
1476         int r;
1477
1478         r = alloc_chrdev_region(&vhost_vdpa_major, 0, VHOST_VDPA_DEV_MAX,
1479                                 "vhost-vdpa");
1480         if (r)
1481                 goto err_alloc_chrdev;
1482
1483         r = vdpa_register_driver(&vhost_vdpa_driver);
1484         if (r)
1485                 goto err_vdpa_register_driver;
1486
1487         return 0;
1488
1489 err_vdpa_register_driver:
1490         unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1491 err_alloc_chrdev:
1492         return r;
1493 }
1494 module_init(vhost_vdpa_init);
1495
1496 static void __exit vhost_vdpa_exit(void)
1497 {
1498         vdpa_unregister_driver(&vhost_vdpa_driver);
1499         unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1500 }
1501 module_exit(vhost_vdpa_exit);
1502
1503 MODULE_VERSION("0.0.1");
1504 MODULE_LICENSE("GPL v2");
1505 MODULE_AUTHOR("Intel Corporation");
1506 MODULE_DESCRIPTION("vDPA-based vhost backend for virtio");