Merge tag 'armsoc-soc' of git://git.kernel.org/pub/scm/linux/kernel/git/arm/arm-soc
[linux-2.6-block.git] / drivers / vfio / vfio_iommu_type1.c
1 /*
2  * VFIO: IOMMU DMA mapping support for Type1 IOMMU
3  *
4  * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
5  *     Author: Alex Williamson <alex.williamson@redhat.com>
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License version 2 as
9  * published by the Free Software Foundation.
10  *
11  * Derived from original vfio:
12  * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
13  * Author: Tom Lyon, pugs@cisco.com
14  *
15  * We arbitrarily define a Type1 IOMMU as one matching the below code.
16  * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
17  * VT-d, but that makes it harder to re-use as theoretically anyone
18  * implementing a similar IOMMU could make use of this.  We expect the
19  * IOMMU to support the IOMMU API and have few to no restrictions around
20  * the IOVA range that can be mapped.  The Type1 IOMMU is currently
21  * optimized for relatively static mappings of a userspace process with
22  * userpsace pages pinned into memory.  We also assume devices and IOMMU
23  * domains are PCI based as the IOMMU API is still centered around a
24  * device/bus interface rather than a group interface.
25  */
26
27 #include <linux/compat.h>
28 #include <linux/device.h>
29 #include <linux/fs.h>
30 #include <linux/iommu.h>
31 #include <linux/module.h>
32 #include <linux/mm.h>
33 #include <linux/rbtree.h>
34 #include <linux/sched/signal.h>
35 #include <linux/sched/mm.h>
36 #include <linux/slab.h>
37 #include <linux/uaccess.h>
38 #include <linux/vfio.h>
39 #include <linux/workqueue.h>
40 #include <linux/mdev.h>
41 #include <linux/notifier.h>
42 #include <linux/dma-iommu.h>
43 #include <linux/irqdomain.h>
44
45 #define DRIVER_VERSION  "0.2"
46 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
47 #define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
48
49 static bool allow_unsafe_interrupts;
50 module_param_named(allow_unsafe_interrupts,
51                    allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
52 MODULE_PARM_DESC(allow_unsafe_interrupts,
53                  "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
54
55 static bool disable_hugepages;
56 module_param_named(disable_hugepages,
57                    disable_hugepages, bool, S_IRUGO | S_IWUSR);
58 MODULE_PARM_DESC(disable_hugepages,
59                  "Disable VFIO IOMMU support for IOMMU hugepages.");
60
61 struct vfio_iommu {
62         struct list_head        domain_list;
63         struct vfio_domain      *external_domain; /* domain for external user */
64         struct mutex            lock;
65         struct rb_root          dma_list;
66         struct blocking_notifier_head notifier;
67         bool                    v2;
68         bool                    nesting;
69 };
70
71 struct vfio_domain {
72         struct iommu_domain     *domain;
73         struct list_head        next;
74         struct list_head        group_list;
75         int                     prot;           /* IOMMU_CACHE */
76         bool                    fgsp;           /* Fine-grained super pages */
77 };
78
79 struct vfio_dma {
80         struct rb_node          node;
81         dma_addr_t              iova;           /* Device address */
82         unsigned long           vaddr;          /* Process virtual addr */
83         size_t                  size;           /* Map size (bytes) */
84         int                     prot;           /* IOMMU_READ/WRITE */
85         bool                    iommu_mapped;
86         struct task_struct      *task;
87         struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
88 };
89
90 struct vfio_group {
91         struct iommu_group      *iommu_group;
92         struct list_head        next;
93 };
94
95 /*
96  * Guest RAM pinning working set or DMA target
97  */
98 struct vfio_pfn {
99         struct rb_node          node;
100         dma_addr_t              iova;           /* Device address */
101         unsigned long           pfn;            /* Host pfn */
102         atomic_t                ref_count;
103 };
104
105 struct vfio_regions {
106         struct list_head list;
107         dma_addr_t iova;
108         phys_addr_t phys;
109         size_t len;
110 };
111
112 #define IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu) \
113                                         (!list_empty(&iommu->domain_list))
114
115 static int put_pfn(unsigned long pfn, int prot);
116
117 /*
118  * This code handles mapping and unmapping of user data buffers
119  * into DMA'ble space using the IOMMU
120  */
121
122 static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
123                                       dma_addr_t start, size_t size)
124 {
125         struct rb_node *node = iommu->dma_list.rb_node;
126
127         while (node) {
128                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
129
130                 if (start + size <= dma->iova)
131                         node = node->rb_left;
132                 else if (start >= dma->iova + dma->size)
133                         node = node->rb_right;
134                 else
135                         return dma;
136         }
137
138         return NULL;
139 }
140
141 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
142 {
143         struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
144         struct vfio_dma *dma;
145
146         while (*link) {
147                 parent = *link;
148                 dma = rb_entry(parent, struct vfio_dma, node);
149
150                 if (new->iova + new->size <= dma->iova)
151                         link = &(*link)->rb_left;
152                 else
153                         link = &(*link)->rb_right;
154         }
155
156         rb_link_node(&new->node, parent, link);
157         rb_insert_color(&new->node, &iommu->dma_list);
158 }
159
160 static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
161 {
162         rb_erase(&old->node, &iommu->dma_list);
163 }
164
165 /*
166  * Helper Functions for host iova-pfn list
167  */
168 static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
169 {
170         struct vfio_pfn *vpfn;
171         struct rb_node *node = dma->pfn_list.rb_node;
172
173         while (node) {
174                 vpfn = rb_entry(node, struct vfio_pfn, node);
175
176                 if (iova < vpfn->iova)
177                         node = node->rb_left;
178                 else if (iova > vpfn->iova)
179                         node = node->rb_right;
180                 else
181                         return vpfn;
182         }
183         return NULL;
184 }
185
186 static void vfio_link_pfn(struct vfio_dma *dma,
187                           struct vfio_pfn *new)
188 {
189         struct rb_node **link, *parent = NULL;
190         struct vfio_pfn *vpfn;
191
192         link = &dma->pfn_list.rb_node;
193         while (*link) {
194                 parent = *link;
195                 vpfn = rb_entry(parent, struct vfio_pfn, node);
196
197                 if (new->iova < vpfn->iova)
198                         link = &(*link)->rb_left;
199                 else
200                         link = &(*link)->rb_right;
201         }
202
203         rb_link_node(&new->node, parent, link);
204         rb_insert_color(&new->node, &dma->pfn_list);
205 }
206
207 static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
208 {
209         rb_erase(&old->node, &dma->pfn_list);
210 }
211
212 static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
213                                 unsigned long pfn)
214 {
215         struct vfio_pfn *vpfn;
216
217         vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
218         if (!vpfn)
219                 return -ENOMEM;
220
221         vpfn->iova = iova;
222         vpfn->pfn = pfn;
223         atomic_set(&vpfn->ref_count, 1);
224         vfio_link_pfn(dma, vpfn);
225         return 0;
226 }
227
228 static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
229                                       struct vfio_pfn *vpfn)
230 {
231         vfio_unlink_pfn(dma, vpfn);
232         kfree(vpfn);
233 }
234
235 static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
236                                                unsigned long iova)
237 {
238         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
239
240         if (vpfn)
241                 atomic_inc(&vpfn->ref_count);
242         return vpfn;
243 }
244
245 static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
246 {
247         int ret = 0;
248
249         if (atomic_dec_and_test(&vpfn->ref_count)) {
250                 ret = put_pfn(vpfn->pfn, dma->prot);
251                 vfio_remove_from_pfn_list(dma, vpfn);
252         }
253         return ret;
254 }
255
256 static int vfio_lock_acct(struct task_struct *task, long npage, bool *lock_cap)
257 {
258         struct mm_struct *mm;
259         bool is_current;
260         int ret;
261
262         if (!npage)
263                 return 0;
264
265         is_current = (task->mm == current->mm);
266
267         mm = is_current ? task->mm : get_task_mm(task);
268         if (!mm)
269                 return -ESRCH; /* process exited */
270
271         ret = down_write_killable(&mm->mmap_sem);
272         if (!ret) {
273                 if (npage > 0) {
274                         if (lock_cap ? !*lock_cap :
275                             !has_capability(task, CAP_IPC_LOCK)) {
276                                 unsigned long limit;
277
278                                 limit = task_rlimit(task,
279                                                 RLIMIT_MEMLOCK) >> PAGE_SHIFT;
280
281                                 if (mm->locked_vm + npage > limit)
282                                         ret = -ENOMEM;
283                         }
284                 }
285
286                 if (!ret)
287                         mm->locked_vm += npage;
288
289                 up_write(&mm->mmap_sem);
290         }
291
292         if (!is_current)
293                 mmput(mm);
294
295         return ret;
296 }
297
298 /*
299  * Some mappings aren't backed by a struct page, for example an mmap'd
300  * MMIO range for our own or another device.  These use a different
301  * pfn conversion and shouldn't be tracked as locked pages.
302  */
303 static bool is_invalid_reserved_pfn(unsigned long pfn)
304 {
305         if (pfn_valid(pfn)) {
306                 bool reserved;
307                 struct page *tail = pfn_to_page(pfn);
308                 struct page *head = compound_head(tail);
309                 reserved = !!(PageReserved(head));
310                 if (head != tail) {
311                         /*
312                          * "head" is not a dangling pointer
313                          * (compound_head takes care of that)
314                          * but the hugepage may have been split
315                          * from under us (and we may not hold a
316                          * reference count on the head page so it can
317                          * be reused before we run PageReferenced), so
318                          * we've to check PageTail before returning
319                          * what we just read.
320                          */
321                         smp_rmb();
322                         if (PageTail(tail))
323                                 return reserved;
324                 }
325                 return PageReserved(tail);
326         }
327
328         return true;
329 }
330
331 static int put_pfn(unsigned long pfn, int prot)
332 {
333         if (!is_invalid_reserved_pfn(pfn)) {
334                 struct page *page = pfn_to_page(pfn);
335                 if (prot & IOMMU_WRITE)
336                         SetPageDirty(page);
337                 put_page(page);
338                 return 1;
339         }
340         return 0;
341 }
342
343 static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
344                          int prot, unsigned long *pfn)
345 {
346         struct page *page[1];
347         struct vm_area_struct *vma;
348         struct vm_area_struct *vmas[1];
349         int ret;
350
351         if (mm == current->mm) {
352                 ret = get_user_pages_longterm(vaddr, 1, !!(prot & IOMMU_WRITE),
353                                               page, vmas);
354         } else {
355                 unsigned int flags = 0;
356
357                 if (prot & IOMMU_WRITE)
358                         flags |= FOLL_WRITE;
359
360                 down_read(&mm->mmap_sem);
361                 ret = get_user_pages_remote(NULL, mm, vaddr, 1, flags, page,
362                                             vmas, NULL);
363                 /*
364                  * The lifetime of a vaddr_get_pfn() page pin is
365                  * userspace-controlled. In the fs-dax case this could
366                  * lead to indefinite stalls in filesystem operations.
367                  * Disallow attempts to pin fs-dax pages via this
368                  * interface.
369                  */
370                 if (ret > 0 && vma_is_fsdax(vmas[0])) {
371                         ret = -EOPNOTSUPP;
372                         put_page(page[0]);
373                 }
374                 up_read(&mm->mmap_sem);
375         }
376
377         if (ret == 1) {
378                 *pfn = page_to_pfn(page[0]);
379                 return 0;
380         }
381
382         down_read(&mm->mmap_sem);
383
384         vma = find_vma_intersection(mm, vaddr, vaddr + 1);
385
386         if (vma && vma->vm_flags & VM_PFNMAP) {
387                 *pfn = ((vaddr - vma->vm_start) >> PAGE_SHIFT) + vma->vm_pgoff;
388                 if (is_invalid_reserved_pfn(*pfn))
389                         ret = 0;
390         }
391
392         up_read(&mm->mmap_sem);
393         return ret;
394 }
395
396 /*
397  * Attempt to pin pages.  We really don't want to track all the pfns and
398  * the iommu can only map chunks of consecutive pfns anyway, so get the
399  * first page and all consecutive pages with the same locking.
400  */
401 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
402                                   long npage, unsigned long *pfn_base,
403                                   bool lock_cap, unsigned long limit)
404 {
405         unsigned long pfn = 0;
406         long ret, pinned = 0, lock_acct = 0;
407         bool rsvd;
408         dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
409
410         /* This code path is only user initiated */
411         if (!current->mm)
412                 return -ENODEV;
413
414         ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
415         if (ret)
416                 return ret;
417
418         pinned++;
419         rsvd = is_invalid_reserved_pfn(*pfn_base);
420
421         /*
422          * Reserved pages aren't counted against the user, externally pinned
423          * pages are already counted against the user.
424          */
425         if (!rsvd && !vfio_find_vpfn(dma, iova)) {
426                 if (!lock_cap && current->mm->locked_vm + 1 > limit) {
427                         put_pfn(*pfn_base, dma->prot);
428                         pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
429                                         limit << PAGE_SHIFT);
430                         return -ENOMEM;
431                 }
432                 lock_acct++;
433         }
434
435         if (unlikely(disable_hugepages))
436                 goto out;
437
438         /* Lock all the consecutive pages from pfn_base */
439         for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
440              pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
441                 ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
442                 if (ret)
443                         break;
444
445                 if (pfn != *pfn_base + pinned ||
446                     rsvd != is_invalid_reserved_pfn(pfn)) {
447                         put_pfn(pfn, dma->prot);
448                         break;
449                 }
450
451                 if (!rsvd && !vfio_find_vpfn(dma, iova)) {
452                         if (!lock_cap &&
453                             current->mm->locked_vm + lock_acct + 1 > limit) {
454                                 put_pfn(pfn, dma->prot);
455                                 pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
456                                         __func__, limit << PAGE_SHIFT);
457                                 ret = -ENOMEM;
458                                 goto unpin_out;
459                         }
460                         lock_acct++;
461                 }
462         }
463
464 out:
465         ret = vfio_lock_acct(current, lock_acct, &lock_cap);
466
467 unpin_out:
468         if (ret) {
469                 if (!rsvd) {
470                         for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
471                                 put_pfn(pfn, dma->prot);
472                 }
473
474                 return ret;
475         }
476
477         return pinned;
478 }
479
480 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
481                                     unsigned long pfn, long npage,
482                                     bool do_accounting)
483 {
484         long unlocked = 0, locked = 0;
485         long i;
486
487         for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
488                 if (put_pfn(pfn++, dma->prot)) {
489                         unlocked++;
490                         if (vfio_find_vpfn(dma, iova))
491                                 locked++;
492                 }
493         }
494
495         if (do_accounting)
496                 vfio_lock_acct(dma->task, locked - unlocked, NULL);
497
498         return unlocked;
499 }
500
501 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
502                                   unsigned long *pfn_base, bool do_accounting)
503 {
504         struct mm_struct *mm;
505         int ret;
506
507         mm = get_task_mm(dma->task);
508         if (!mm)
509                 return -ENODEV;
510
511         ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
512         if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
513                 ret = vfio_lock_acct(dma->task, 1, NULL);
514                 if (ret) {
515                         put_pfn(*pfn_base, dma->prot);
516                         if (ret == -ENOMEM)
517                                 pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
518                                         "(%ld) exceeded\n", __func__,
519                                         dma->task->comm, task_pid_nr(dma->task),
520                                         task_rlimit(dma->task, RLIMIT_MEMLOCK));
521                 }
522         }
523
524         mmput(mm);
525         return ret;
526 }
527
528 static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
529                                     bool do_accounting)
530 {
531         int unlocked;
532         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
533
534         if (!vpfn)
535                 return 0;
536
537         unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
538
539         if (do_accounting)
540                 vfio_lock_acct(dma->task, -unlocked, NULL);
541
542         return unlocked;
543 }
544
545 static int vfio_iommu_type1_pin_pages(void *iommu_data,
546                                       unsigned long *user_pfn,
547                                       int npage, int prot,
548                                       unsigned long *phys_pfn)
549 {
550         struct vfio_iommu *iommu = iommu_data;
551         int i, j, ret;
552         unsigned long remote_vaddr;
553         struct vfio_dma *dma;
554         bool do_accounting;
555
556         if (!iommu || !user_pfn || !phys_pfn)
557                 return -EINVAL;
558
559         /* Supported for v2 version only */
560         if (!iommu->v2)
561                 return -EACCES;
562
563         mutex_lock(&iommu->lock);
564
565         /* Fail if notifier list is empty */
566         if ((!iommu->external_domain) || (!iommu->notifier.head)) {
567                 ret = -EINVAL;
568                 goto pin_done;
569         }
570
571         /*
572          * If iommu capable domain exist in the container then all pages are
573          * already pinned and accounted. Accouting should be done if there is no
574          * iommu capable domain in the container.
575          */
576         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
577
578         for (i = 0; i < npage; i++) {
579                 dma_addr_t iova;
580                 struct vfio_pfn *vpfn;
581
582                 iova = user_pfn[i] << PAGE_SHIFT;
583                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
584                 if (!dma) {
585                         ret = -EINVAL;
586                         goto pin_unwind;
587                 }
588
589                 if ((dma->prot & prot) != prot) {
590                         ret = -EPERM;
591                         goto pin_unwind;
592                 }
593
594                 vpfn = vfio_iova_get_vfio_pfn(dma, iova);
595                 if (vpfn) {
596                         phys_pfn[i] = vpfn->pfn;
597                         continue;
598                 }
599
600                 remote_vaddr = dma->vaddr + iova - dma->iova;
601                 ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
602                                              do_accounting);
603                 if (ret)
604                         goto pin_unwind;
605
606                 ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
607                 if (ret) {
608                         vfio_unpin_page_external(dma, iova, do_accounting);
609                         goto pin_unwind;
610                 }
611         }
612
613         ret = i;
614         goto pin_done;
615
616 pin_unwind:
617         phys_pfn[i] = 0;
618         for (j = 0; j < i; j++) {
619                 dma_addr_t iova;
620
621                 iova = user_pfn[j] << PAGE_SHIFT;
622                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
623                 vfio_unpin_page_external(dma, iova, do_accounting);
624                 phys_pfn[j] = 0;
625         }
626 pin_done:
627         mutex_unlock(&iommu->lock);
628         return ret;
629 }
630
631 static int vfio_iommu_type1_unpin_pages(void *iommu_data,
632                                         unsigned long *user_pfn,
633                                         int npage)
634 {
635         struct vfio_iommu *iommu = iommu_data;
636         bool do_accounting;
637         int i;
638
639         if (!iommu || !user_pfn)
640                 return -EINVAL;
641
642         /* Supported for v2 version only */
643         if (!iommu->v2)
644                 return -EACCES;
645
646         mutex_lock(&iommu->lock);
647
648         if (!iommu->external_domain) {
649                 mutex_unlock(&iommu->lock);
650                 return -EINVAL;
651         }
652
653         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
654         for (i = 0; i < npage; i++) {
655                 struct vfio_dma *dma;
656                 dma_addr_t iova;
657
658                 iova = user_pfn[i] << PAGE_SHIFT;
659                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
660                 if (!dma)
661                         goto unpin_exit;
662                 vfio_unpin_page_external(dma, iova, do_accounting);
663         }
664
665 unpin_exit:
666         mutex_unlock(&iommu->lock);
667         return i > npage ? npage : (i > 0 ? i : -EINVAL);
668 }
669
670 static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
671                                 struct list_head *regions)
672 {
673         long unlocked = 0;
674         struct vfio_regions *entry, *next;
675
676         iommu_tlb_sync(domain->domain);
677
678         list_for_each_entry_safe(entry, next, regions, list) {
679                 unlocked += vfio_unpin_pages_remote(dma,
680                                                     entry->iova,
681                                                     entry->phys >> PAGE_SHIFT,
682                                                     entry->len >> PAGE_SHIFT,
683                                                     false);
684                 list_del(&entry->list);
685                 kfree(entry);
686         }
687
688         cond_resched();
689
690         return unlocked;
691 }
692
693 /*
694  * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
695  * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
696  * of these regions (currently using a list).
697  *
698  * This value specifies maximum number of regions for each IOTLB flush sync.
699  */
700 #define VFIO_IOMMU_TLB_SYNC_MAX         512
701
702 static size_t unmap_unpin_fast(struct vfio_domain *domain,
703                                struct vfio_dma *dma, dma_addr_t *iova,
704                                size_t len, phys_addr_t phys, long *unlocked,
705                                struct list_head *unmapped_list,
706                                int *unmapped_cnt)
707 {
708         size_t unmapped = 0;
709         struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
710
711         if (entry) {
712                 unmapped = iommu_unmap_fast(domain->domain, *iova, len);
713
714                 if (!unmapped) {
715                         kfree(entry);
716                 } else {
717                         iommu_tlb_range_add(domain->domain, *iova, unmapped);
718                         entry->iova = *iova;
719                         entry->phys = phys;
720                         entry->len  = unmapped;
721                         list_add_tail(&entry->list, unmapped_list);
722
723                         *iova += unmapped;
724                         (*unmapped_cnt)++;
725                 }
726         }
727
728         /*
729          * Sync if the number of fast-unmap regions hits the limit
730          * or in case of errors.
731          */
732         if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
733                 *unlocked += vfio_sync_unpin(dma, domain,
734                                              unmapped_list);
735                 *unmapped_cnt = 0;
736         }
737
738         return unmapped;
739 }
740
741 static size_t unmap_unpin_slow(struct vfio_domain *domain,
742                                struct vfio_dma *dma, dma_addr_t *iova,
743                                size_t len, phys_addr_t phys,
744                                long *unlocked)
745 {
746         size_t unmapped = iommu_unmap(domain->domain, *iova, len);
747
748         if (unmapped) {
749                 *unlocked += vfio_unpin_pages_remote(dma, *iova,
750                                                      phys >> PAGE_SHIFT,
751                                                      unmapped >> PAGE_SHIFT,
752                                                      false);
753                 *iova += unmapped;
754                 cond_resched();
755         }
756         return unmapped;
757 }
758
759 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
760                              bool do_accounting)
761 {
762         dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
763         struct vfio_domain *domain, *d;
764         LIST_HEAD(unmapped_region_list);
765         int unmapped_region_cnt = 0;
766         long unlocked = 0;
767
768         if (!dma->size)
769                 return 0;
770
771         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
772                 return 0;
773
774         /*
775          * We use the IOMMU to track the physical addresses, otherwise we'd
776          * need a much more complicated tracking system.  Unfortunately that
777          * means we need to use one of the iommu domains to figure out the
778          * pfns to unpin.  The rest need to be unmapped in advance so we have
779          * no iommu translations remaining when the pages are unpinned.
780          */
781         domain = d = list_first_entry(&iommu->domain_list,
782                                       struct vfio_domain, next);
783
784         list_for_each_entry_continue(d, &iommu->domain_list, next) {
785                 iommu_unmap(d->domain, dma->iova, dma->size);
786                 cond_resched();
787         }
788
789         while (iova < end) {
790                 size_t unmapped, len;
791                 phys_addr_t phys, next;
792
793                 phys = iommu_iova_to_phys(domain->domain, iova);
794                 if (WARN_ON(!phys)) {
795                         iova += PAGE_SIZE;
796                         continue;
797                 }
798
799                 /*
800                  * To optimize for fewer iommu_unmap() calls, each of which
801                  * may require hardware cache flushing, try to find the
802                  * largest contiguous physical memory chunk to unmap.
803                  */
804                 for (len = PAGE_SIZE;
805                      !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
806                         next = iommu_iova_to_phys(domain->domain, iova + len);
807                         if (next != phys + len)
808                                 break;
809                 }
810
811                 /*
812                  * First, try to use fast unmap/unpin. In case of failure,
813                  * switch to slow unmap/unpin path.
814                  */
815                 unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
816                                             &unlocked, &unmapped_region_list,
817                                             &unmapped_region_cnt);
818                 if (!unmapped) {
819                         unmapped = unmap_unpin_slow(domain, dma, &iova, len,
820                                                     phys, &unlocked);
821                         if (WARN_ON(!unmapped))
822                                 break;
823                 }
824         }
825
826         dma->iommu_mapped = false;
827
828         if (unmapped_region_cnt)
829                 unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list);
830
831         if (do_accounting) {
832                 vfio_lock_acct(dma->task, -unlocked, NULL);
833                 return 0;
834         }
835         return unlocked;
836 }
837
838 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
839 {
840         vfio_unmap_unpin(iommu, dma, true);
841         vfio_unlink_dma(iommu, dma);
842         put_task_struct(dma->task);
843         kfree(dma);
844 }
845
846 static unsigned long vfio_pgsize_bitmap(struct vfio_iommu *iommu)
847 {
848         struct vfio_domain *domain;
849         unsigned long bitmap = ULONG_MAX;
850
851         mutex_lock(&iommu->lock);
852         list_for_each_entry(domain, &iommu->domain_list, next)
853                 bitmap &= domain->domain->pgsize_bitmap;
854         mutex_unlock(&iommu->lock);
855
856         /*
857          * In case the IOMMU supports page sizes smaller than PAGE_SIZE
858          * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
859          * That way the user will be able to map/unmap buffers whose size/
860          * start address is aligned with PAGE_SIZE. Pinning code uses that
861          * granularity while iommu driver can use the sub-PAGE_SIZE size
862          * to map the buffer.
863          */
864         if (bitmap & ~PAGE_MASK) {
865                 bitmap &= PAGE_MASK;
866                 bitmap |= PAGE_SIZE;
867         }
868
869         return bitmap;
870 }
871
872 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
873                              struct vfio_iommu_type1_dma_unmap *unmap)
874 {
875         uint64_t mask;
876         struct vfio_dma *dma, *dma_last = NULL;
877         size_t unmapped = 0;
878         int ret = 0, retries = 0;
879
880         mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1;
881
882         if (unmap->iova & mask)
883                 return -EINVAL;
884         if (!unmap->size || unmap->size & mask)
885                 return -EINVAL;
886         if (unmap->iova + unmap->size < unmap->iova ||
887             unmap->size > SIZE_MAX)
888                 return -EINVAL;
889
890         WARN_ON(mask & PAGE_MASK);
891 again:
892         mutex_lock(&iommu->lock);
893
894         /*
895          * vfio-iommu-type1 (v1) - User mappings were coalesced together to
896          * avoid tracking individual mappings.  This means that the granularity
897          * of the original mapping was lost and the user was allowed to attempt
898          * to unmap any range.  Depending on the contiguousness of physical
899          * memory and page sizes supported by the IOMMU, arbitrary unmaps may
900          * or may not have worked.  We only guaranteed unmap granularity
901          * matching the original mapping; even though it was untracked here,
902          * the original mappings are reflected in IOMMU mappings.  This
903          * resulted in a couple unusual behaviors.  First, if a range is not
904          * able to be unmapped, ex. a set of 4k pages that was mapped as a
905          * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
906          * a zero sized unmap.  Also, if an unmap request overlaps the first
907          * address of a hugepage, the IOMMU will unmap the entire hugepage.
908          * This also returns success and the returned unmap size reflects the
909          * actual size unmapped.
910          *
911          * We attempt to maintain compatibility with this "v1" interface, but
912          * we take control out of the hands of the IOMMU.  Therefore, an unmap
913          * request offset from the beginning of the original mapping will
914          * return success with zero sized unmap.  And an unmap request covering
915          * the first iova of mapping will unmap the entire range.
916          *
917          * The v2 version of this interface intends to be more deterministic.
918          * Unmap requests must fully cover previous mappings.  Multiple
919          * mappings may still be unmaped by specifying large ranges, but there
920          * must not be any previous mappings bisected by the range.  An error
921          * will be returned if these conditions are not met.  The v2 interface
922          * will only return success and a size of zero if there were no
923          * mappings within the range.
924          */
925         if (iommu->v2) {
926                 dma = vfio_find_dma(iommu, unmap->iova, 1);
927                 if (dma && dma->iova != unmap->iova) {
928                         ret = -EINVAL;
929                         goto unlock;
930                 }
931                 dma = vfio_find_dma(iommu, unmap->iova + unmap->size - 1, 0);
932                 if (dma && dma->iova + dma->size != unmap->iova + unmap->size) {
933                         ret = -EINVAL;
934                         goto unlock;
935                 }
936         }
937
938         while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
939                 if (!iommu->v2 && unmap->iova > dma->iova)
940                         break;
941                 /*
942                  * Task with same address space who mapped this iova range is
943                  * allowed to unmap the iova range.
944                  */
945                 if (dma->task->mm != current->mm)
946                         break;
947
948                 if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
949                         struct vfio_iommu_type1_dma_unmap nb_unmap;
950
951                         if (dma_last == dma) {
952                                 BUG_ON(++retries > 10);
953                         } else {
954                                 dma_last = dma;
955                                 retries = 0;
956                         }
957
958                         nb_unmap.iova = dma->iova;
959                         nb_unmap.size = dma->size;
960
961                         /*
962                          * Notify anyone (mdev vendor drivers) to invalidate and
963                          * unmap iovas within the range we're about to unmap.
964                          * Vendor drivers MUST unpin pages in response to an
965                          * invalidation.
966                          */
967                         mutex_unlock(&iommu->lock);
968                         blocking_notifier_call_chain(&iommu->notifier,
969                                                     VFIO_IOMMU_NOTIFY_DMA_UNMAP,
970                                                     &nb_unmap);
971                         goto again;
972                 }
973                 unmapped += dma->size;
974                 vfio_remove_dma(iommu, dma);
975         }
976
977 unlock:
978         mutex_unlock(&iommu->lock);
979
980         /* Report how much was unmapped */
981         unmap->size = unmapped;
982
983         return ret;
984 }
985
986 /*
987  * Turns out AMD IOMMU has a page table bug where it won't map large pages
988  * to a region that previously mapped smaller pages.  This should be fixed
989  * soon, so this is just a temporary workaround to break mappings down into
990  * PAGE_SIZE.  Better to map smaller pages than nothing.
991  */
992 static int map_try_harder(struct vfio_domain *domain, dma_addr_t iova,
993                           unsigned long pfn, long npage, int prot)
994 {
995         long i;
996         int ret = 0;
997
998         for (i = 0; i < npage; i++, pfn++, iova += PAGE_SIZE) {
999                 ret = iommu_map(domain->domain, iova,
1000                                 (phys_addr_t)pfn << PAGE_SHIFT,
1001                                 PAGE_SIZE, prot | domain->prot);
1002                 if (ret)
1003                         break;
1004         }
1005
1006         for (; i < npage && i > 0; i--, iova -= PAGE_SIZE)
1007                 iommu_unmap(domain->domain, iova, PAGE_SIZE);
1008
1009         return ret;
1010 }
1011
1012 static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
1013                           unsigned long pfn, long npage, int prot)
1014 {
1015         struct vfio_domain *d;
1016         int ret;
1017
1018         list_for_each_entry(d, &iommu->domain_list, next) {
1019                 ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
1020                                 npage << PAGE_SHIFT, prot | d->prot);
1021                 if (ret) {
1022                         if (ret != -EBUSY ||
1023                             map_try_harder(d, iova, pfn, npage, prot))
1024                                 goto unwind;
1025                 }
1026
1027                 cond_resched();
1028         }
1029
1030         return 0;
1031
1032 unwind:
1033         list_for_each_entry_continue_reverse(d, &iommu->domain_list, next)
1034                 iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
1035
1036         return ret;
1037 }
1038
1039 static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1040                             size_t map_size)
1041 {
1042         dma_addr_t iova = dma->iova;
1043         unsigned long vaddr = dma->vaddr;
1044         size_t size = map_size;
1045         long npage;
1046         unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1047         bool lock_cap = capable(CAP_IPC_LOCK);
1048         int ret = 0;
1049
1050         while (size) {
1051                 /* Pin a contiguous chunk of memory */
1052                 npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1053                                               size >> PAGE_SHIFT, &pfn,
1054                                               lock_cap, limit);
1055                 if (npage <= 0) {
1056                         WARN_ON(!npage);
1057                         ret = (int)npage;
1058                         break;
1059                 }
1060
1061                 /* Map it! */
1062                 ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1063                                      dma->prot);
1064                 if (ret) {
1065                         vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1066                                                 npage, true);
1067                         break;
1068                 }
1069
1070                 size -= npage << PAGE_SHIFT;
1071                 dma->size += npage << PAGE_SHIFT;
1072         }
1073
1074         dma->iommu_mapped = true;
1075
1076         if (ret)
1077                 vfio_remove_dma(iommu, dma);
1078
1079         return ret;
1080 }
1081
1082 static int vfio_dma_do_map(struct vfio_iommu *iommu,
1083                            struct vfio_iommu_type1_dma_map *map)
1084 {
1085         dma_addr_t iova = map->iova;
1086         unsigned long vaddr = map->vaddr;
1087         size_t size = map->size;
1088         int ret = 0, prot = 0;
1089         uint64_t mask;
1090         struct vfio_dma *dma;
1091
1092         /* Verify that none of our __u64 fields overflow */
1093         if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1094                 return -EINVAL;
1095
1096         mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1;
1097
1098         WARN_ON(mask & PAGE_MASK);
1099
1100         /* READ/WRITE from device perspective */
1101         if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1102                 prot |= IOMMU_WRITE;
1103         if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1104                 prot |= IOMMU_READ;
1105
1106         if (!prot || !size || (size | iova | vaddr) & mask)
1107                 return -EINVAL;
1108
1109         /* Don't allow IOVA or virtual address wrap */
1110         if (iova + size - 1 < iova || vaddr + size - 1 < vaddr)
1111                 return -EINVAL;
1112
1113         mutex_lock(&iommu->lock);
1114
1115         if (vfio_find_dma(iommu, iova, size)) {
1116                 ret = -EEXIST;
1117                 goto out_unlock;
1118         }
1119
1120         dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1121         if (!dma) {
1122                 ret = -ENOMEM;
1123                 goto out_unlock;
1124         }
1125
1126         dma->iova = iova;
1127         dma->vaddr = vaddr;
1128         dma->prot = prot;
1129         get_task_struct(current);
1130         dma->task = current;
1131         dma->pfn_list = RB_ROOT;
1132
1133         /* Insert zero-sized and grow as we map chunks of it */
1134         vfio_link_dma(iommu, dma);
1135
1136         /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1137         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1138                 dma->size = size;
1139         else
1140                 ret = vfio_pin_map_dma(iommu, dma, size);
1141
1142 out_unlock:
1143         mutex_unlock(&iommu->lock);
1144         return ret;
1145 }
1146
1147 static int vfio_bus_type(struct device *dev, void *data)
1148 {
1149         struct bus_type **bus = data;
1150
1151         if (*bus && *bus != dev->bus)
1152                 return -EINVAL;
1153
1154         *bus = dev->bus;
1155
1156         return 0;
1157 }
1158
1159 static int vfio_iommu_replay(struct vfio_iommu *iommu,
1160                              struct vfio_domain *domain)
1161 {
1162         struct vfio_domain *d;
1163         struct rb_node *n;
1164         unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1165         bool lock_cap = capable(CAP_IPC_LOCK);
1166         int ret;
1167
1168         /* Arbitrarily pick the first domain in the list for lookups */
1169         d = list_first_entry(&iommu->domain_list, struct vfio_domain, next);
1170         n = rb_first(&iommu->dma_list);
1171
1172         for (; n; n = rb_next(n)) {
1173                 struct vfio_dma *dma;
1174                 dma_addr_t iova;
1175
1176                 dma = rb_entry(n, struct vfio_dma, node);
1177                 iova = dma->iova;
1178
1179                 while (iova < dma->iova + dma->size) {
1180                         phys_addr_t phys;
1181                         size_t size;
1182
1183                         if (dma->iommu_mapped) {
1184                                 phys_addr_t p;
1185                                 dma_addr_t i;
1186
1187                                 phys = iommu_iova_to_phys(d->domain, iova);
1188
1189                                 if (WARN_ON(!phys)) {
1190                                         iova += PAGE_SIZE;
1191                                         continue;
1192                                 }
1193
1194                                 size = PAGE_SIZE;
1195                                 p = phys + size;
1196                                 i = iova + size;
1197                                 while (i < dma->iova + dma->size &&
1198                                        p == iommu_iova_to_phys(d->domain, i)) {
1199                                         size += PAGE_SIZE;
1200                                         p += PAGE_SIZE;
1201                                         i += PAGE_SIZE;
1202                                 }
1203                         } else {
1204                                 unsigned long pfn;
1205                                 unsigned long vaddr = dma->vaddr +
1206                                                      (iova - dma->iova);
1207                                 size_t n = dma->iova + dma->size - iova;
1208                                 long npage;
1209
1210                                 npage = vfio_pin_pages_remote(dma, vaddr,
1211                                                               n >> PAGE_SHIFT,
1212                                                               &pfn, lock_cap,
1213                                                               limit);
1214                                 if (npage <= 0) {
1215                                         WARN_ON(!npage);
1216                                         ret = (int)npage;
1217                                         return ret;
1218                                 }
1219
1220                                 phys = pfn << PAGE_SHIFT;
1221                                 size = npage << PAGE_SHIFT;
1222                         }
1223
1224                         ret = iommu_map(domain->domain, iova, phys,
1225                                         size, dma->prot | domain->prot);
1226                         if (ret)
1227                                 return ret;
1228
1229                         iova += size;
1230                 }
1231                 dma->iommu_mapped = true;
1232         }
1233         return 0;
1234 }
1235
1236 /*
1237  * We change our unmap behavior slightly depending on whether the IOMMU
1238  * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1239  * for practically any contiguous power-of-two mapping we give it.  This means
1240  * we don't need to look for contiguous chunks ourselves to make unmapping
1241  * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1242  * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1243  * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1244  * hugetlbfs is in use.
1245  */
1246 static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1247 {
1248         struct page *pages;
1249         int ret, order = get_order(PAGE_SIZE * 2);
1250
1251         pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1252         if (!pages)
1253                 return;
1254
1255         ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1256                         IOMMU_READ | IOMMU_WRITE | domain->prot);
1257         if (!ret) {
1258                 size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1259
1260                 if (unmapped == PAGE_SIZE)
1261                         iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1262                 else
1263                         domain->fgsp = true;
1264         }
1265
1266         __free_pages(pages, order);
1267 }
1268
1269 static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
1270                                            struct iommu_group *iommu_group)
1271 {
1272         struct vfio_group *g;
1273
1274         list_for_each_entry(g, &domain->group_list, next) {
1275                 if (g->iommu_group == iommu_group)
1276                         return g;
1277         }
1278
1279         return NULL;
1280 }
1281
1282 static bool vfio_iommu_has_sw_msi(struct iommu_group *group, phys_addr_t *base)
1283 {
1284         struct list_head group_resv_regions;
1285         struct iommu_resv_region *region, *next;
1286         bool ret = false;
1287
1288         INIT_LIST_HEAD(&group_resv_regions);
1289         iommu_get_group_resv_regions(group, &group_resv_regions);
1290         list_for_each_entry(region, &group_resv_regions, list) {
1291                 /*
1292                  * The presence of any 'real' MSI regions should take
1293                  * precedence over the software-managed one if the
1294                  * IOMMU driver happens to advertise both types.
1295                  */
1296                 if (region->type == IOMMU_RESV_MSI) {
1297                         ret = false;
1298                         break;
1299                 }
1300
1301                 if (region->type == IOMMU_RESV_SW_MSI) {
1302                         *base = region->start;
1303                         ret = true;
1304                 }
1305         }
1306         list_for_each_entry_safe(region, next, &group_resv_regions, list)
1307                 kfree(region);
1308         return ret;
1309 }
1310
1311 static int vfio_iommu_type1_attach_group(void *iommu_data,
1312                                          struct iommu_group *iommu_group)
1313 {
1314         struct vfio_iommu *iommu = iommu_data;
1315         struct vfio_group *group;
1316         struct vfio_domain *domain, *d;
1317         struct bus_type *bus = NULL, *mdev_bus;
1318         int ret;
1319         bool resv_msi, msi_remap;
1320         phys_addr_t resv_msi_base;
1321
1322         mutex_lock(&iommu->lock);
1323
1324         list_for_each_entry(d, &iommu->domain_list, next) {
1325                 if (find_iommu_group(d, iommu_group)) {
1326                         mutex_unlock(&iommu->lock);
1327                         return -EINVAL;
1328                 }
1329         }
1330
1331         if (iommu->external_domain) {
1332                 if (find_iommu_group(iommu->external_domain, iommu_group)) {
1333                         mutex_unlock(&iommu->lock);
1334                         return -EINVAL;
1335                 }
1336         }
1337
1338         group = kzalloc(sizeof(*group), GFP_KERNEL);
1339         domain = kzalloc(sizeof(*domain), GFP_KERNEL);
1340         if (!group || !domain) {
1341                 ret = -ENOMEM;
1342                 goto out_free;
1343         }
1344
1345         group->iommu_group = iommu_group;
1346
1347         /* Determine bus_type in order to allocate a domain */
1348         ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
1349         if (ret)
1350                 goto out_free;
1351
1352         mdev_bus = symbol_get(mdev_bus_type);
1353
1354         if (mdev_bus) {
1355                 if ((bus == mdev_bus) && !iommu_present(bus)) {
1356                         symbol_put(mdev_bus_type);
1357                         if (!iommu->external_domain) {
1358                                 INIT_LIST_HEAD(&domain->group_list);
1359                                 iommu->external_domain = domain;
1360                         } else
1361                                 kfree(domain);
1362
1363                         list_add(&group->next,
1364                                  &iommu->external_domain->group_list);
1365                         mutex_unlock(&iommu->lock);
1366                         return 0;
1367                 }
1368                 symbol_put(mdev_bus_type);
1369         }
1370
1371         domain->domain = iommu_domain_alloc(bus);
1372         if (!domain->domain) {
1373                 ret = -EIO;
1374                 goto out_free;
1375         }
1376
1377         if (iommu->nesting) {
1378                 int attr = 1;
1379
1380                 ret = iommu_domain_set_attr(domain->domain, DOMAIN_ATTR_NESTING,
1381                                             &attr);
1382                 if (ret)
1383                         goto out_domain;
1384         }
1385
1386         ret = iommu_attach_group(domain->domain, iommu_group);
1387         if (ret)
1388                 goto out_domain;
1389
1390         resv_msi = vfio_iommu_has_sw_msi(iommu_group, &resv_msi_base);
1391
1392         INIT_LIST_HEAD(&domain->group_list);
1393         list_add(&group->next, &domain->group_list);
1394
1395         msi_remap = irq_domain_check_msi_remap() ||
1396                     iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
1397
1398         if (!allow_unsafe_interrupts && !msi_remap) {
1399                 pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
1400                        __func__);
1401                 ret = -EPERM;
1402                 goto out_detach;
1403         }
1404
1405         if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
1406                 domain->prot |= IOMMU_CACHE;
1407
1408         /*
1409          * Try to match an existing compatible domain.  We don't want to
1410          * preclude an IOMMU driver supporting multiple bus_types and being
1411          * able to include different bus_types in the same IOMMU domain, so
1412          * we test whether the domains use the same iommu_ops rather than
1413          * testing if they're on the same bus_type.
1414          */
1415         list_for_each_entry(d, &iommu->domain_list, next) {
1416                 if (d->domain->ops == domain->domain->ops &&
1417                     d->prot == domain->prot) {
1418                         iommu_detach_group(domain->domain, iommu_group);
1419                         if (!iommu_attach_group(d->domain, iommu_group)) {
1420                                 list_add(&group->next, &d->group_list);
1421                                 iommu_domain_free(domain->domain);
1422                                 kfree(domain);
1423                                 mutex_unlock(&iommu->lock);
1424                                 return 0;
1425                         }
1426
1427                         ret = iommu_attach_group(domain->domain, iommu_group);
1428                         if (ret)
1429                                 goto out_domain;
1430                 }
1431         }
1432
1433         vfio_test_domain_fgsp(domain);
1434
1435         /* replay mappings on new domains */
1436         ret = vfio_iommu_replay(iommu, domain);
1437         if (ret)
1438                 goto out_detach;
1439
1440         if (resv_msi) {
1441                 ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
1442                 if (ret)
1443                         goto out_detach;
1444         }
1445
1446         list_add(&domain->next, &iommu->domain_list);
1447
1448         mutex_unlock(&iommu->lock);
1449
1450         return 0;
1451
1452 out_detach:
1453         iommu_detach_group(domain->domain, iommu_group);
1454 out_domain:
1455         iommu_domain_free(domain->domain);
1456 out_free:
1457         kfree(domain);
1458         kfree(group);
1459         mutex_unlock(&iommu->lock);
1460         return ret;
1461 }
1462
1463 static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
1464 {
1465         struct rb_node *node;
1466
1467         while ((node = rb_first(&iommu->dma_list)))
1468                 vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
1469 }
1470
1471 static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
1472 {
1473         struct rb_node *n, *p;
1474
1475         n = rb_first(&iommu->dma_list);
1476         for (; n; n = rb_next(n)) {
1477                 struct vfio_dma *dma;
1478                 long locked = 0, unlocked = 0;
1479
1480                 dma = rb_entry(n, struct vfio_dma, node);
1481                 unlocked += vfio_unmap_unpin(iommu, dma, false);
1482                 p = rb_first(&dma->pfn_list);
1483                 for (; p; p = rb_next(p)) {
1484                         struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
1485                                                          node);
1486
1487                         if (!is_invalid_reserved_pfn(vpfn->pfn))
1488                                 locked++;
1489                 }
1490                 vfio_lock_acct(dma->task, locked - unlocked, NULL);
1491         }
1492 }
1493
1494 static void vfio_sanity_check_pfn_list(struct vfio_iommu *iommu)
1495 {
1496         struct rb_node *n;
1497
1498         n = rb_first(&iommu->dma_list);
1499         for (; n; n = rb_next(n)) {
1500                 struct vfio_dma *dma;
1501
1502                 dma = rb_entry(n, struct vfio_dma, node);
1503
1504                 if (WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list)))
1505                         break;
1506         }
1507         /* mdev vendor driver must unregister notifier */
1508         WARN_ON(iommu->notifier.head);
1509 }
1510
1511 static void vfio_iommu_type1_detach_group(void *iommu_data,
1512                                           struct iommu_group *iommu_group)
1513 {
1514         struct vfio_iommu *iommu = iommu_data;
1515         struct vfio_domain *domain;
1516         struct vfio_group *group;
1517
1518         mutex_lock(&iommu->lock);
1519
1520         if (iommu->external_domain) {
1521                 group = find_iommu_group(iommu->external_domain, iommu_group);
1522                 if (group) {
1523                         list_del(&group->next);
1524                         kfree(group);
1525
1526                         if (list_empty(&iommu->external_domain->group_list)) {
1527                                 vfio_sanity_check_pfn_list(iommu);
1528
1529                                 if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1530                                         vfio_iommu_unmap_unpin_all(iommu);
1531
1532                                 kfree(iommu->external_domain);
1533                                 iommu->external_domain = NULL;
1534                         }
1535                         goto detach_group_done;
1536                 }
1537         }
1538
1539         list_for_each_entry(domain, &iommu->domain_list, next) {
1540                 group = find_iommu_group(domain, iommu_group);
1541                 if (!group)
1542                         continue;
1543
1544                 iommu_detach_group(domain->domain, iommu_group);
1545                 list_del(&group->next);
1546                 kfree(group);
1547                 /*
1548                  * Group ownership provides privilege, if the group list is
1549                  * empty, the domain goes away. If it's the last domain with
1550                  * iommu and external domain doesn't exist, then all the
1551                  * mappings go away too. If it's the last domain with iommu and
1552                  * external domain exist, update accounting
1553                  */
1554                 if (list_empty(&domain->group_list)) {
1555                         if (list_is_singular(&iommu->domain_list)) {
1556                                 if (!iommu->external_domain)
1557                                         vfio_iommu_unmap_unpin_all(iommu);
1558                                 else
1559                                         vfio_iommu_unmap_unpin_reaccount(iommu);
1560                         }
1561                         iommu_domain_free(domain->domain);
1562                         list_del(&domain->next);
1563                         kfree(domain);
1564                 }
1565                 break;
1566         }
1567
1568 detach_group_done:
1569         mutex_unlock(&iommu->lock);
1570 }
1571
1572 static void *vfio_iommu_type1_open(unsigned long arg)
1573 {
1574         struct vfio_iommu *iommu;
1575
1576         iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
1577         if (!iommu)
1578                 return ERR_PTR(-ENOMEM);
1579
1580         switch (arg) {
1581         case VFIO_TYPE1_IOMMU:
1582                 break;
1583         case VFIO_TYPE1_NESTING_IOMMU:
1584                 iommu->nesting = true;
1585         case VFIO_TYPE1v2_IOMMU:
1586                 iommu->v2 = true;
1587                 break;
1588         default:
1589                 kfree(iommu);
1590                 return ERR_PTR(-EINVAL);
1591         }
1592
1593         INIT_LIST_HEAD(&iommu->domain_list);
1594         iommu->dma_list = RB_ROOT;
1595         mutex_init(&iommu->lock);
1596         BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
1597
1598         return iommu;
1599 }
1600
1601 static void vfio_release_domain(struct vfio_domain *domain, bool external)
1602 {
1603         struct vfio_group *group, *group_tmp;
1604
1605         list_for_each_entry_safe(group, group_tmp,
1606                                  &domain->group_list, next) {
1607                 if (!external)
1608                         iommu_detach_group(domain->domain, group->iommu_group);
1609                 list_del(&group->next);
1610                 kfree(group);
1611         }
1612
1613         if (!external)
1614                 iommu_domain_free(domain->domain);
1615 }
1616
1617 static void vfio_iommu_type1_release(void *iommu_data)
1618 {
1619         struct vfio_iommu *iommu = iommu_data;
1620         struct vfio_domain *domain, *domain_tmp;
1621
1622         if (iommu->external_domain) {
1623                 vfio_release_domain(iommu->external_domain, true);
1624                 vfio_sanity_check_pfn_list(iommu);
1625                 kfree(iommu->external_domain);
1626         }
1627
1628         vfio_iommu_unmap_unpin_all(iommu);
1629
1630         list_for_each_entry_safe(domain, domain_tmp,
1631                                  &iommu->domain_list, next) {
1632                 vfio_release_domain(domain, false);
1633                 list_del(&domain->next);
1634                 kfree(domain);
1635         }
1636         kfree(iommu);
1637 }
1638
1639 static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
1640 {
1641         struct vfio_domain *domain;
1642         int ret = 1;
1643
1644         mutex_lock(&iommu->lock);
1645         list_for_each_entry(domain, &iommu->domain_list, next) {
1646                 if (!(domain->prot & IOMMU_CACHE)) {
1647                         ret = 0;
1648                         break;
1649                 }
1650         }
1651         mutex_unlock(&iommu->lock);
1652
1653         return ret;
1654 }
1655
1656 static long vfio_iommu_type1_ioctl(void *iommu_data,
1657                                    unsigned int cmd, unsigned long arg)
1658 {
1659         struct vfio_iommu *iommu = iommu_data;
1660         unsigned long minsz;
1661
1662         if (cmd == VFIO_CHECK_EXTENSION) {
1663                 switch (arg) {
1664                 case VFIO_TYPE1_IOMMU:
1665                 case VFIO_TYPE1v2_IOMMU:
1666                 case VFIO_TYPE1_NESTING_IOMMU:
1667                         return 1;
1668                 case VFIO_DMA_CC_IOMMU:
1669                         if (!iommu)
1670                                 return 0;
1671                         return vfio_domains_have_iommu_cache(iommu);
1672                 default:
1673                         return 0;
1674                 }
1675         } else if (cmd == VFIO_IOMMU_GET_INFO) {
1676                 struct vfio_iommu_type1_info info;
1677
1678                 minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
1679
1680                 if (copy_from_user(&info, (void __user *)arg, minsz))
1681                         return -EFAULT;
1682
1683                 if (info.argsz < minsz)
1684                         return -EINVAL;
1685
1686                 info.flags = VFIO_IOMMU_INFO_PGSIZES;
1687
1688                 info.iova_pgsizes = vfio_pgsize_bitmap(iommu);
1689
1690                 return copy_to_user((void __user *)arg, &info, minsz) ?
1691                         -EFAULT : 0;
1692
1693         } else if (cmd == VFIO_IOMMU_MAP_DMA) {
1694                 struct vfio_iommu_type1_dma_map map;
1695                 uint32_t mask = VFIO_DMA_MAP_FLAG_READ |
1696                                 VFIO_DMA_MAP_FLAG_WRITE;
1697
1698                 minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
1699
1700                 if (copy_from_user(&map, (void __user *)arg, minsz))
1701                         return -EFAULT;
1702
1703                 if (map.argsz < minsz || map.flags & ~mask)
1704                         return -EINVAL;
1705
1706                 return vfio_dma_do_map(iommu, &map);
1707
1708         } else if (cmd == VFIO_IOMMU_UNMAP_DMA) {
1709                 struct vfio_iommu_type1_dma_unmap unmap;
1710                 long ret;
1711
1712                 minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
1713
1714                 if (copy_from_user(&unmap, (void __user *)arg, minsz))
1715                         return -EFAULT;
1716
1717                 if (unmap.argsz < minsz || unmap.flags)
1718                         return -EINVAL;
1719
1720                 ret = vfio_dma_do_unmap(iommu, &unmap);
1721                 if (ret)
1722                         return ret;
1723
1724                 return copy_to_user((void __user *)arg, &unmap, minsz) ?
1725                         -EFAULT : 0;
1726         }
1727
1728         return -ENOTTY;
1729 }
1730
1731 static int vfio_iommu_type1_register_notifier(void *iommu_data,
1732                                               unsigned long *events,
1733                                               struct notifier_block *nb)
1734 {
1735         struct vfio_iommu *iommu = iommu_data;
1736
1737         /* clear known events */
1738         *events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
1739
1740         /* refuse to register if still events remaining */
1741         if (*events)
1742                 return -EINVAL;
1743
1744         return blocking_notifier_chain_register(&iommu->notifier, nb);
1745 }
1746
1747 static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
1748                                                 struct notifier_block *nb)
1749 {
1750         struct vfio_iommu *iommu = iommu_data;
1751
1752         return blocking_notifier_chain_unregister(&iommu->notifier, nb);
1753 }
1754
1755 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
1756         .name                   = "vfio-iommu-type1",
1757         .owner                  = THIS_MODULE,
1758         .open                   = vfio_iommu_type1_open,
1759         .release                = vfio_iommu_type1_release,
1760         .ioctl                  = vfio_iommu_type1_ioctl,
1761         .attach_group           = vfio_iommu_type1_attach_group,
1762         .detach_group           = vfio_iommu_type1_detach_group,
1763         .pin_pages              = vfio_iommu_type1_pin_pages,
1764         .unpin_pages            = vfio_iommu_type1_unpin_pages,
1765         .register_notifier      = vfio_iommu_type1_register_notifier,
1766         .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
1767 };
1768
1769 static int __init vfio_iommu_type1_init(void)
1770 {
1771         return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
1772 }
1773
1774 static void __exit vfio_iommu_type1_cleanup(void)
1775 {
1776         vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
1777 }
1778
1779 module_init(vfio_iommu_type1_init);
1780 module_exit(vfio_iommu_type1_cleanup);
1781
1782 MODULE_VERSION(DRIVER_VERSION);
1783 MODULE_LICENSE("GPL v2");
1784 MODULE_AUTHOR(DRIVER_AUTHOR);
1785 MODULE_DESCRIPTION(DRIVER_DESC);