mmap locking API: use coccinelle to convert mmap_sem rwsem call sites
[linux-block.git] / drivers / vfio / vfio_iommu_type1.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * VFIO: IOMMU DMA mapping support for Type1 IOMMU
4  *
5  * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
6  *     Author: Alex Williamson <alex.williamson@redhat.com>
7  *
8  * Derived from original vfio:
9  * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
10  * Author: Tom Lyon, pugs@cisco.com
11  *
12  * We arbitrarily define a Type1 IOMMU as one matching the below code.
13  * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
14  * VT-d, but that makes it harder to re-use as theoretically anyone
15  * implementing a similar IOMMU could make use of this.  We expect the
16  * IOMMU to support the IOMMU API and have few to no restrictions around
17  * the IOVA range that can be mapped.  The Type1 IOMMU is currently
18  * optimized for relatively static mappings of a userspace process with
19  * userpsace pages pinned into memory.  We also assume devices and IOMMU
20  * domains are PCI based as the IOMMU API is still centered around a
21  * device/bus interface rather than a group interface.
22  */
23
24 #include <linux/compat.h>
25 #include <linux/device.h>
26 #include <linux/fs.h>
27 #include <linux/iommu.h>
28 #include <linux/module.h>
29 #include <linux/mm.h>
30 #include <linux/mmu_context.h>
31 #include <linux/rbtree.h>
32 #include <linux/sched/signal.h>
33 #include <linux/sched/mm.h>
34 #include <linux/slab.h>
35 #include <linux/uaccess.h>
36 #include <linux/vfio.h>
37 #include <linux/workqueue.h>
38 #include <linux/mdev.h>
39 #include <linux/notifier.h>
40 #include <linux/dma-iommu.h>
41 #include <linux/irqdomain.h>
42
43 #define DRIVER_VERSION  "0.2"
44 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
45 #define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
46
47 static bool allow_unsafe_interrupts;
48 module_param_named(allow_unsafe_interrupts,
49                    allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
50 MODULE_PARM_DESC(allow_unsafe_interrupts,
51                  "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
52
53 static bool disable_hugepages;
54 module_param_named(disable_hugepages,
55                    disable_hugepages, bool, S_IRUGO | S_IWUSR);
56 MODULE_PARM_DESC(disable_hugepages,
57                  "Disable VFIO IOMMU support for IOMMU hugepages.");
58
59 static unsigned int dma_entry_limit __read_mostly = U16_MAX;
60 module_param_named(dma_entry_limit, dma_entry_limit, uint, 0644);
61 MODULE_PARM_DESC(dma_entry_limit,
62                  "Maximum number of user DMA mappings per container (65535).");
63
64 struct vfio_iommu {
65         struct list_head        domain_list;
66         struct list_head        iova_list;
67         struct vfio_domain      *external_domain; /* domain for external user */
68         struct mutex            lock;
69         struct rb_root          dma_list;
70         struct blocking_notifier_head notifier;
71         unsigned int            dma_avail;
72         uint64_t                pgsize_bitmap;
73         bool                    v2;
74         bool                    nesting;
75         bool                    dirty_page_tracking;
76         bool                    pinned_page_dirty_scope;
77 };
78
79 struct vfio_domain {
80         struct iommu_domain     *domain;
81         struct list_head        next;
82         struct list_head        group_list;
83         int                     prot;           /* IOMMU_CACHE */
84         bool                    fgsp;           /* Fine-grained super pages */
85 };
86
87 struct vfio_dma {
88         struct rb_node          node;
89         dma_addr_t              iova;           /* Device address */
90         unsigned long           vaddr;          /* Process virtual addr */
91         size_t                  size;           /* Map size (bytes) */
92         int                     prot;           /* IOMMU_READ/WRITE */
93         bool                    iommu_mapped;
94         bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
95         struct task_struct      *task;
96         struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
97         unsigned long           *bitmap;
98 };
99
100 struct vfio_group {
101         struct iommu_group      *iommu_group;
102         struct list_head        next;
103         bool                    mdev_group;     /* An mdev group */
104         bool                    pinned_page_dirty_scope;
105 };
106
107 struct vfio_iova {
108         struct list_head        list;
109         dma_addr_t              start;
110         dma_addr_t              end;
111 };
112
113 /*
114  * Guest RAM pinning working set or DMA target
115  */
116 struct vfio_pfn {
117         struct rb_node          node;
118         dma_addr_t              iova;           /* Device address */
119         unsigned long           pfn;            /* Host pfn */
120         unsigned int            ref_count;
121 };
122
123 struct vfio_regions {
124         struct list_head list;
125         dma_addr_t iova;
126         phys_addr_t phys;
127         size_t len;
128 };
129
130 #define IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu) \
131                                         (!list_empty(&iommu->domain_list))
132
133 #define DIRTY_BITMAP_BYTES(n)   (ALIGN(n, BITS_PER_TYPE(u64)) / BITS_PER_BYTE)
134
135 /*
136  * Input argument of number of bits to bitmap_set() is unsigned integer, which
137  * further casts to signed integer for unaligned multi-bit operation,
138  * __bitmap_set().
139  * Then maximum bitmap size supported is 2^31 bits divided by 2^3 bits/byte,
140  * that is 2^28 (256 MB) which maps to 2^31 * 2^12 = 2^43 (8TB) on 4K page
141  * system.
142  */
143 #define DIRTY_BITMAP_PAGES_MAX   ((u64)INT_MAX)
144 #define DIRTY_BITMAP_SIZE_MAX    DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
145
146 static int put_pfn(unsigned long pfn, int prot);
147
148 static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
149                                                struct iommu_group *iommu_group);
150
151 static void update_pinned_page_dirty_scope(struct vfio_iommu *iommu);
152 /*
153  * This code handles mapping and unmapping of user data buffers
154  * into DMA'ble space using the IOMMU
155  */
156
157 static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
158                                       dma_addr_t start, size_t size)
159 {
160         struct rb_node *node = iommu->dma_list.rb_node;
161
162         while (node) {
163                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
164
165                 if (start + size <= dma->iova)
166                         node = node->rb_left;
167                 else if (start >= dma->iova + dma->size)
168                         node = node->rb_right;
169                 else
170                         return dma;
171         }
172
173         return NULL;
174 }
175
176 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
177 {
178         struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
179         struct vfio_dma *dma;
180
181         while (*link) {
182                 parent = *link;
183                 dma = rb_entry(parent, struct vfio_dma, node);
184
185                 if (new->iova + new->size <= dma->iova)
186                         link = &(*link)->rb_left;
187                 else
188                         link = &(*link)->rb_right;
189         }
190
191         rb_link_node(&new->node, parent, link);
192         rb_insert_color(&new->node, &iommu->dma_list);
193 }
194
195 static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
196 {
197         rb_erase(&old->node, &iommu->dma_list);
198 }
199
200
201 static int vfio_dma_bitmap_alloc(struct vfio_dma *dma, size_t pgsize)
202 {
203         uint64_t npages = dma->size / pgsize;
204
205         if (npages > DIRTY_BITMAP_PAGES_MAX)
206                 return -EINVAL;
207
208         /*
209          * Allocate extra 64 bits that are used to calculate shift required for
210          * bitmap_shift_left() to manipulate and club unaligned number of pages
211          * in adjacent vfio_dma ranges.
212          */
213         dma->bitmap = kvzalloc(DIRTY_BITMAP_BYTES(npages) + sizeof(u64),
214                                GFP_KERNEL);
215         if (!dma->bitmap)
216                 return -ENOMEM;
217
218         return 0;
219 }
220
221 static void vfio_dma_bitmap_free(struct vfio_dma *dma)
222 {
223         kfree(dma->bitmap);
224         dma->bitmap = NULL;
225 }
226
227 static void vfio_dma_populate_bitmap(struct vfio_dma *dma, size_t pgsize)
228 {
229         struct rb_node *p;
230         unsigned long pgshift = __ffs(pgsize);
231
232         for (p = rb_first(&dma->pfn_list); p; p = rb_next(p)) {
233                 struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn, node);
234
235                 bitmap_set(dma->bitmap, (vpfn->iova - dma->iova) >> pgshift, 1);
236         }
237 }
238
239 static int vfio_dma_bitmap_alloc_all(struct vfio_iommu *iommu, size_t pgsize)
240 {
241         struct rb_node *n;
242
243         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
244                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
245                 int ret;
246
247                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
248                 if (ret) {
249                         struct rb_node *p;
250
251                         for (p = rb_prev(n); p; p = rb_prev(p)) {
252                                 struct vfio_dma *dma = rb_entry(n,
253                                                         struct vfio_dma, node);
254
255                                 vfio_dma_bitmap_free(dma);
256                         }
257                         return ret;
258                 }
259                 vfio_dma_populate_bitmap(dma, pgsize);
260         }
261         return 0;
262 }
263
264 static void vfio_dma_bitmap_free_all(struct vfio_iommu *iommu)
265 {
266         struct rb_node *n;
267
268         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
269                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
270
271                 vfio_dma_bitmap_free(dma);
272         }
273 }
274
275 /*
276  * Helper Functions for host iova-pfn list
277  */
278 static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
279 {
280         struct vfio_pfn *vpfn;
281         struct rb_node *node = dma->pfn_list.rb_node;
282
283         while (node) {
284                 vpfn = rb_entry(node, struct vfio_pfn, node);
285
286                 if (iova < vpfn->iova)
287                         node = node->rb_left;
288                 else if (iova > vpfn->iova)
289                         node = node->rb_right;
290                 else
291                         return vpfn;
292         }
293         return NULL;
294 }
295
296 static void vfio_link_pfn(struct vfio_dma *dma,
297                           struct vfio_pfn *new)
298 {
299         struct rb_node **link, *parent = NULL;
300         struct vfio_pfn *vpfn;
301
302         link = &dma->pfn_list.rb_node;
303         while (*link) {
304                 parent = *link;
305                 vpfn = rb_entry(parent, struct vfio_pfn, node);
306
307                 if (new->iova < vpfn->iova)
308                         link = &(*link)->rb_left;
309                 else
310                         link = &(*link)->rb_right;
311         }
312
313         rb_link_node(&new->node, parent, link);
314         rb_insert_color(&new->node, &dma->pfn_list);
315 }
316
317 static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
318 {
319         rb_erase(&old->node, &dma->pfn_list);
320 }
321
322 static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
323                                 unsigned long pfn)
324 {
325         struct vfio_pfn *vpfn;
326
327         vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
328         if (!vpfn)
329                 return -ENOMEM;
330
331         vpfn->iova = iova;
332         vpfn->pfn = pfn;
333         vpfn->ref_count = 1;
334         vfio_link_pfn(dma, vpfn);
335         return 0;
336 }
337
338 static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
339                                       struct vfio_pfn *vpfn)
340 {
341         vfio_unlink_pfn(dma, vpfn);
342         kfree(vpfn);
343 }
344
345 static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
346                                                unsigned long iova)
347 {
348         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
349
350         if (vpfn)
351                 vpfn->ref_count++;
352         return vpfn;
353 }
354
355 static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
356 {
357         int ret = 0;
358
359         vpfn->ref_count--;
360         if (!vpfn->ref_count) {
361                 ret = put_pfn(vpfn->pfn, dma->prot);
362                 vfio_remove_from_pfn_list(dma, vpfn);
363         }
364         return ret;
365 }
366
367 static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
368 {
369         struct mm_struct *mm;
370         int ret;
371
372         if (!npage)
373                 return 0;
374
375         mm = async ? get_task_mm(dma->task) : dma->task->mm;
376         if (!mm)
377                 return -ESRCH; /* process exited */
378
379         ret = mmap_write_lock_killable(mm);
380         if (!ret) {
381                 ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task,
382                                           dma->lock_cap);
383                 mmap_write_unlock(mm);
384         }
385
386         if (async)
387                 mmput(mm);
388
389         return ret;
390 }
391
392 /*
393  * Some mappings aren't backed by a struct page, for example an mmap'd
394  * MMIO range for our own or another device.  These use a different
395  * pfn conversion and shouldn't be tracked as locked pages.
396  * For compound pages, any driver that sets the reserved bit in head
397  * page needs to set the reserved bit in all subpages to be safe.
398  */
399 static bool is_invalid_reserved_pfn(unsigned long pfn)
400 {
401         if (pfn_valid(pfn))
402                 return PageReserved(pfn_to_page(pfn));
403
404         return true;
405 }
406
407 static int put_pfn(unsigned long pfn, int prot)
408 {
409         if (!is_invalid_reserved_pfn(pfn)) {
410                 struct page *page = pfn_to_page(pfn);
411
412                 unpin_user_pages_dirty_lock(&page, 1, prot & IOMMU_WRITE);
413                 return 1;
414         }
415         return 0;
416 }
417
418 static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
419                             unsigned long vaddr, unsigned long *pfn,
420                             bool write_fault)
421 {
422         int ret;
423
424         ret = follow_pfn(vma, vaddr, pfn);
425         if (ret) {
426                 bool unlocked = false;
427
428                 ret = fixup_user_fault(NULL, mm, vaddr,
429                                        FAULT_FLAG_REMOTE |
430                                        (write_fault ?  FAULT_FLAG_WRITE : 0),
431                                        &unlocked);
432                 if (unlocked)
433                         return -EAGAIN;
434
435                 if (ret)
436                         return ret;
437
438                 ret = follow_pfn(vma, vaddr, pfn);
439         }
440
441         return ret;
442 }
443
444 static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
445                          int prot, unsigned long *pfn)
446 {
447         struct page *page[1];
448         struct vm_area_struct *vma;
449         unsigned int flags = 0;
450         int ret;
451
452         if (prot & IOMMU_WRITE)
453                 flags |= FOLL_WRITE;
454
455         mmap_read_lock(mm);
456         ret = pin_user_pages_remote(NULL, mm, vaddr, 1, flags | FOLL_LONGTERM,
457                                     page, NULL, NULL);
458         if (ret == 1) {
459                 *pfn = page_to_pfn(page[0]);
460                 ret = 0;
461                 goto done;
462         }
463
464         vaddr = untagged_addr(vaddr);
465
466 retry:
467         vma = find_vma_intersection(mm, vaddr, vaddr + 1);
468
469         if (vma && vma->vm_flags & VM_PFNMAP) {
470                 ret = follow_fault_pfn(vma, mm, vaddr, pfn, prot & IOMMU_WRITE);
471                 if (ret == -EAGAIN)
472                         goto retry;
473
474                 if (!ret && !is_invalid_reserved_pfn(*pfn))
475                         ret = -EFAULT;
476         }
477 done:
478         mmap_read_unlock(mm);
479         return ret;
480 }
481
482 /*
483  * Attempt to pin pages.  We really don't want to track all the pfns and
484  * the iommu can only map chunks of consecutive pfns anyway, so get the
485  * first page and all consecutive pages with the same locking.
486  */
487 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
488                                   long npage, unsigned long *pfn_base,
489                                   unsigned long limit)
490 {
491         unsigned long pfn = 0;
492         long ret, pinned = 0, lock_acct = 0;
493         bool rsvd;
494         dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
495
496         /* This code path is only user initiated */
497         if (!current->mm)
498                 return -ENODEV;
499
500         ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
501         if (ret)
502                 return ret;
503
504         pinned++;
505         rsvd = is_invalid_reserved_pfn(*pfn_base);
506
507         /*
508          * Reserved pages aren't counted against the user, externally pinned
509          * pages are already counted against the user.
510          */
511         if (!rsvd && !vfio_find_vpfn(dma, iova)) {
512                 if (!dma->lock_cap && current->mm->locked_vm + 1 > limit) {
513                         put_pfn(*pfn_base, dma->prot);
514                         pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
515                                         limit << PAGE_SHIFT);
516                         return -ENOMEM;
517                 }
518                 lock_acct++;
519         }
520
521         if (unlikely(disable_hugepages))
522                 goto out;
523
524         /* Lock all the consecutive pages from pfn_base */
525         for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
526              pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
527                 ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
528                 if (ret)
529                         break;
530
531                 if (pfn != *pfn_base + pinned ||
532                     rsvd != is_invalid_reserved_pfn(pfn)) {
533                         put_pfn(pfn, dma->prot);
534                         break;
535                 }
536
537                 if (!rsvd && !vfio_find_vpfn(dma, iova)) {
538                         if (!dma->lock_cap &&
539                             current->mm->locked_vm + lock_acct + 1 > limit) {
540                                 put_pfn(pfn, dma->prot);
541                                 pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
542                                         __func__, limit << PAGE_SHIFT);
543                                 ret = -ENOMEM;
544                                 goto unpin_out;
545                         }
546                         lock_acct++;
547                 }
548         }
549
550 out:
551         ret = vfio_lock_acct(dma, lock_acct, false);
552
553 unpin_out:
554         if (ret) {
555                 if (!rsvd) {
556                         for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
557                                 put_pfn(pfn, dma->prot);
558                 }
559
560                 return ret;
561         }
562
563         return pinned;
564 }
565
566 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
567                                     unsigned long pfn, long npage,
568                                     bool do_accounting)
569 {
570         long unlocked = 0, locked = 0;
571         long i;
572
573         for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
574                 if (put_pfn(pfn++, dma->prot)) {
575                         unlocked++;
576                         if (vfio_find_vpfn(dma, iova))
577                                 locked++;
578                 }
579         }
580
581         if (do_accounting)
582                 vfio_lock_acct(dma, locked - unlocked, true);
583
584         return unlocked;
585 }
586
587 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
588                                   unsigned long *pfn_base, bool do_accounting)
589 {
590         struct mm_struct *mm;
591         int ret;
592
593         mm = get_task_mm(dma->task);
594         if (!mm)
595                 return -ENODEV;
596
597         ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
598         if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
599                 ret = vfio_lock_acct(dma, 1, true);
600                 if (ret) {
601                         put_pfn(*pfn_base, dma->prot);
602                         if (ret == -ENOMEM)
603                                 pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
604                                         "(%ld) exceeded\n", __func__,
605                                         dma->task->comm, task_pid_nr(dma->task),
606                                         task_rlimit(dma->task, RLIMIT_MEMLOCK));
607                 }
608         }
609
610         mmput(mm);
611         return ret;
612 }
613
614 static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
615                                     bool do_accounting)
616 {
617         int unlocked;
618         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
619
620         if (!vpfn)
621                 return 0;
622
623         unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
624
625         if (do_accounting)
626                 vfio_lock_acct(dma, -unlocked, true);
627
628         return unlocked;
629 }
630
631 static int vfio_iommu_type1_pin_pages(void *iommu_data,
632                                       struct iommu_group *iommu_group,
633                                       unsigned long *user_pfn,
634                                       int npage, int prot,
635                                       unsigned long *phys_pfn)
636 {
637         struct vfio_iommu *iommu = iommu_data;
638         struct vfio_group *group;
639         int i, j, ret;
640         unsigned long remote_vaddr;
641         struct vfio_dma *dma;
642         bool do_accounting;
643
644         if (!iommu || !user_pfn || !phys_pfn)
645                 return -EINVAL;
646
647         /* Supported for v2 version only */
648         if (!iommu->v2)
649                 return -EACCES;
650
651         mutex_lock(&iommu->lock);
652
653         /* Fail if notifier list is empty */
654         if (!iommu->notifier.head) {
655                 ret = -EINVAL;
656                 goto pin_done;
657         }
658
659         /*
660          * If iommu capable domain exist in the container then all pages are
661          * already pinned and accounted. Accouting should be done if there is no
662          * iommu capable domain in the container.
663          */
664         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
665
666         for (i = 0; i < npage; i++) {
667                 dma_addr_t iova;
668                 struct vfio_pfn *vpfn;
669
670                 iova = user_pfn[i] << PAGE_SHIFT;
671                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
672                 if (!dma) {
673                         ret = -EINVAL;
674                         goto pin_unwind;
675                 }
676
677                 if ((dma->prot & prot) != prot) {
678                         ret = -EPERM;
679                         goto pin_unwind;
680                 }
681
682                 vpfn = vfio_iova_get_vfio_pfn(dma, iova);
683                 if (vpfn) {
684                         phys_pfn[i] = vpfn->pfn;
685                         continue;
686                 }
687
688                 remote_vaddr = dma->vaddr + (iova - dma->iova);
689                 ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
690                                              do_accounting);
691                 if (ret)
692                         goto pin_unwind;
693
694                 ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
695                 if (ret) {
696                         vfio_unpin_page_external(dma, iova, do_accounting);
697                         goto pin_unwind;
698                 }
699
700                 if (iommu->dirty_page_tracking) {
701                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
702
703                         /*
704                          * Bitmap populated with the smallest supported page
705                          * size
706                          */
707                         bitmap_set(dma->bitmap,
708                                    (iova - dma->iova) >> pgshift, 1);
709                 }
710         }
711         ret = i;
712
713         group = vfio_iommu_find_iommu_group(iommu, iommu_group);
714         if (!group->pinned_page_dirty_scope) {
715                 group->pinned_page_dirty_scope = true;
716                 update_pinned_page_dirty_scope(iommu);
717         }
718
719         goto pin_done;
720
721 pin_unwind:
722         phys_pfn[i] = 0;
723         for (j = 0; j < i; j++) {
724                 dma_addr_t iova;
725
726                 iova = user_pfn[j] << PAGE_SHIFT;
727                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
728                 vfio_unpin_page_external(dma, iova, do_accounting);
729                 phys_pfn[j] = 0;
730         }
731 pin_done:
732         mutex_unlock(&iommu->lock);
733         return ret;
734 }
735
736 static int vfio_iommu_type1_unpin_pages(void *iommu_data,
737                                         unsigned long *user_pfn,
738                                         int npage)
739 {
740         struct vfio_iommu *iommu = iommu_data;
741         bool do_accounting;
742         int i;
743
744         if (!iommu || !user_pfn)
745                 return -EINVAL;
746
747         /* Supported for v2 version only */
748         if (!iommu->v2)
749                 return -EACCES;
750
751         mutex_lock(&iommu->lock);
752
753         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
754         for (i = 0; i < npage; i++) {
755                 struct vfio_dma *dma;
756                 dma_addr_t iova;
757
758                 iova = user_pfn[i] << PAGE_SHIFT;
759                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
760                 if (!dma)
761                         goto unpin_exit;
762                 vfio_unpin_page_external(dma, iova, do_accounting);
763         }
764
765 unpin_exit:
766         mutex_unlock(&iommu->lock);
767         return i > npage ? npage : (i > 0 ? i : -EINVAL);
768 }
769
770 static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
771                             struct list_head *regions,
772                             struct iommu_iotlb_gather *iotlb_gather)
773 {
774         long unlocked = 0;
775         struct vfio_regions *entry, *next;
776
777         iommu_tlb_sync(domain->domain, iotlb_gather);
778
779         list_for_each_entry_safe(entry, next, regions, list) {
780                 unlocked += vfio_unpin_pages_remote(dma,
781                                                     entry->iova,
782                                                     entry->phys >> PAGE_SHIFT,
783                                                     entry->len >> PAGE_SHIFT,
784                                                     false);
785                 list_del(&entry->list);
786                 kfree(entry);
787         }
788
789         cond_resched();
790
791         return unlocked;
792 }
793
794 /*
795  * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
796  * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
797  * of these regions (currently using a list).
798  *
799  * This value specifies maximum number of regions for each IOTLB flush sync.
800  */
801 #define VFIO_IOMMU_TLB_SYNC_MAX         512
802
803 static size_t unmap_unpin_fast(struct vfio_domain *domain,
804                                struct vfio_dma *dma, dma_addr_t *iova,
805                                size_t len, phys_addr_t phys, long *unlocked,
806                                struct list_head *unmapped_list,
807                                int *unmapped_cnt,
808                                struct iommu_iotlb_gather *iotlb_gather)
809 {
810         size_t unmapped = 0;
811         struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
812
813         if (entry) {
814                 unmapped = iommu_unmap_fast(domain->domain, *iova, len,
815                                             iotlb_gather);
816
817                 if (!unmapped) {
818                         kfree(entry);
819                 } else {
820                         entry->iova = *iova;
821                         entry->phys = phys;
822                         entry->len  = unmapped;
823                         list_add_tail(&entry->list, unmapped_list);
824
825                         *iova += unmapped;
826                         (*unmapped_cnt)++;
827                 }
828         }
829
830         /*
831          * Sync if the number of fast-unmap regions hits the limit
832          * or in case of errors.
833          */
834         if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
835                 *unlocked += vfio_sync_unpin(dma, domain, unmapped_list,
836                                              iotlb_gather);
837                 *unmapped_cnt = 0;
838         }
839
840         return unmapped;
841 }
842
843 static size_t unmap_unpin_slow(struct vfio_domain *domain,
844                                struct vfio_dma *dma, dma_addr_t *iova,
845                                size_t len, phys_addr_t phys,
846                                long *unlocked)
847 {
848         size_t unmapped = iommu_unmap(domain->domain, *iova, len);
849
850         if (unmapped) {
851                 *unlocked += vfio_unpin_pages_remote(dma, *iova,
852                                                      phys >> PAGE_SHIFT,
853                                                      unmapped >> PAGE_SHIFT,
854                                                      false);
855                 *iova += unmapped;
856                 cond_resched();
857         }
858         return unmapped;
859 }
860
861 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
862                              bool do_accounting)
863 {
864         dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
865         struct vfio_domain *domain, *d;
866         LIST_HEAD(unmapped_region_list);
867         struct iommu_iotlb_gather iotlb_gather;
868         int unmapped_region_cnt = 0;
869         long unlocked = 0;
870
871         if (!dma->size)
872                 return 0;
873
874         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
875                 return 0;
876
877         /*
878          * We use the IOMMU to track the physical addresses, otherwise we'd
879          * need a much more complicated tracking system.  Unfortunately that
880          * means we need to use one of the iommu domains to figure out the
881          * pfns to unpin.  The rest need to be unmapped in advance so we have
882          * no iommu translations remaining when the pages are unpinned.
883          */
884         domain = d = list_first_entry(&iommu->domain_list,
885                                       struct vfio_domain, next);
886
887         list_for_each_entry_continue(d, &iommu->domain_list, next) {
888                 iommu_unmap(d->domain, dma->iova, dma->size);
889                 cond_resched();
890         }
891
892         iommu_iotlb_gather_init(&iotlb_gather);
893         while (iova < end) {
894                 size_t unmapped, len;
895                 phys_addr_t phys, next;
896
897                 phys = iommu_iova_to_phys(domain->domain, iova);
898                 if (WARN_ON(!phys)) {
899                         iova += PAGE_SIZE;
900                         continue;
901                 }
902
903                 /*
904                  * To optimize for fewer iommu_unmap() calls, each of which
905                  * may require hardware cache flushing, try to find the
906                  * largest contiguous physical memory chunk to unmap.
907                  */
908                 for (len = PAGE_SIZE;
909                      !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
910                         next = iommu_iova_to_phys(domain->domain, iova + len);
911                         if (next != phys + len)
912                                 break;
913                 }
914
915                 /*
916                  * First, try to use fast unmap/unpin. In case of failure,
917                  * switch to slow unmap/unpin path.
918                  */
919                 unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
920                                             &unlocked, &unmapped_region_list,
921                                             &unmapped_region_cnt,
922                                             &iotlb_gather);
923                 if (!unmapped) {
924                         unmapped = unmap_unpin_slow(domain, dma, &iova, len,
925                                                     phys, &unlocked);
926                         if (WARN_ON(!unmapped))
927                                 break;
928                 }
929         }
930
931         dma->iommu_mapped = false;
932
933         if (unmapped_region_cnt) {
934                 unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list,
935                                             &iotlb_gather);
936         }
937
938         if (do_accounting) {
939                 vfio_lock_acct(dma, -unlocked, true);
940                 return 0;
941         }
942         return unlocked;
943 }
944
945 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
946 {
947         vfio_unmap_unpin(iommu, dma, true);
948         vfio_unlink_dma(iommu, dma);
949         put_task_struct(dma->task);
950         vfio_dma_bitmap_free(dma);
951         kfree(dma);
952         iommu->dma_avail++;
953 }
954
955 static void vfio_update_pgsize_bitmap(struct vfio_iommu *iommu)
956 {
957         struct vfio_domain *domain;
958
959         iommu->pgsize_bitmap = ULONG_MAX;
960
961         list_for_each_entry(domain, &iommu->domain_list, next)
962                 iommu->pgsize_bitmap &= domain->domain->pgsize_bitmap;
963
964         /*
965          * In case the IOMMU supports page sizes smaller than PAGE_SIZE
966          * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
967          * That way the user will be able to map/unmap buffers whose size/
968          * start address is aligned with PAGE_SIZE. Pinning code uses that
969          * granularity while iommu driver can use the sub-PAGE_SIZE size
970          * to map the buffer.
971          */
972         if (iommu->pgsize_bitmap & ~PAGE_MASK) {
973                 iommu->pgsize_bitmap &= PAGE_MASK;
974                 iommu->pgsize_bitmap |= PAGE_SIZE;
975         }
976 }
977
978 static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
979                               struct vfio_dma *dma, dma_addr_t base_iova,
980                               size_t pgsize)
981 {
982         unsigned long pgshift = __ffs(pgsize);
983         unsigned long nbits = dma->size >> pgshift;
984         unsigned long bit_offset = (dma->iova - base_iova) >> pgshift;
985         unsigned long copy_offset = bit_offset / BITS_PER_LONG;
986         unsigned long shift = bit_offset % BITS_PER_LONG;
987         unsigned long leftover;
988
989         /*
990          * mark all pages dirty if any IOMMU capable device is not able
991          * to report dirty pages and all pages are pinned and mapped.
992          */
993         if (!iommu->pinned_page_dirty_scope && dma->iommu_mapped)
994                 bitmap_set(dma->bitmap, 0, nbits);
995
996         if (shift) {
997                 bitmap_shift_left(dma->bitmap, dma->bitmap, shift,
998                                   nbits + shift);
999
1000                 if (copy_from_user(&leftover,
1001                                    (void __user *)(bitmap + copy_offset),
1002                                    sizeof(leftover)))
1003                         return -EFAULT;
1004
1005                 bitmap_or(dma->bitmap, dma->bitmap, &leftover, shift);
1006         }
1007
1008         if (copy_to_user((void __user *)(bitmap + copy_offset), dma->bitmap,
1009                          DIRTY_BITMAP_BYTES(nbits + shift)))
1010                 return -EFAULT;
1011
1012         return 0;
1013 }
1014
1015 static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1016                                   dma_addr_t iova, size_t size, size_t pgsize)
1017 {
1018         struct vfio_dma *dma;
1019         struct rb_node *n;
1020         unsigned long pgshift = __ffs(pgsize);
1021         int ret;
1022
1023         /*
1024          * GET_BITMAP request must fully cover vfio_dma mappings.  Multiple
1025          * vfio_dma mappings may be clubbed by specifying large ranges, but
1026          * there must not be any previous mappings bisected by the range.
1027          * An error will be returned if these conditions are not met.
1028          */
1029         dma = vfio_find_dma(iommu, iova, 1);
1030         if (dma && dma->iova != iova)
1031                 return -EINVAL;
1032
1033         dma = vfio_find_dma(iommu, iova + size - 1, 0);
1034         if (dma && dma->iova + dma->size != iova + size)
1035                 return -EINVAL;
1036
1037         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1038                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1039
1040                 if (dma->iova < iova)
1041                         continue;
1042
1043                 if (dma->iova > iova + size - 1)
1044                         break;
1045
1046                 ret = update_user_bitmap(bitmap, iommu, dma, iova, pgsize);
1047                 if (ret)
1048                         return ret;
1049
1050                 /*
1051                  * Re-populate bitmap to include all pinned pages which are
1052                  * considered as dirty but exclude pages which are unpinned and
1053                  * pages which are marked dirty by vfio_dma_rw()
1054                  */
1055                 bitmap_clear(dma->bitmap, 0, dma->size >> pgshift);
1056                 vfio_dma_populate_bitmap(dma, pgsize);
1057         }
1058         return 0;
1059 }
1060
1061 static int verify_bitmap_size(uint64_t npages, uint64_t bitmap_size)
1062 {
1063         if (!npages || !bitmap_size || (bitmap_size > DIRTY_BITMAP_SIZE_MAX) ||
1064             (bitmap_size < DIRTY_BITMAP_BYTES(npages)))
1065                 return -EINVAL;
1066
1067         return 0;
1068 }
1069
1070 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
1071                              struct vfio_iommu_type1_dma_unmap *unmap,
1072                              struct vfio_bitmap *bitmap)
1073 {
1074         struct vfio_dma *dma, *dma_last = NULL;
1075         size_t unmapped = 0, pgsize;
1076         int ret = 0, retries = 0;
1077         unsigned long pgshift;
1078
1079         mutex_lock(&iommu->lock);
1080
1081         pgshift = __ffs(iommu->pgsize_bitmap);
1082         pgsize = (size_t)1 << pgshift;
1083
1084         if (unmap->iova & (pgsize - 1)) {
1085                 ret = -EINVAL;
1086                 goto unlock;
1087         }
1088
1089         if (!unmap->size || unmap->size & (pgsize - 1)) {
1090                 ret = -EINVAL;
1091                 goto unlock;
1092         }
1093
1094         if (unmap->iova + unmap->size - 1 < unmap->iova ||
1095             unmap->size > SIZE_MAX) {
1096                 ret = -EINVAL;
1097                 goto unlock;
1098         }
1099
1100         /* When dirty tracking is enabled, allow only min supported pgsize */
1101         if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
1102             (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {
1103                 ret = -EINVAL;
1104                 goto unlock;
1105         }
1106
1107         WARN_ON((pgsize - 1) & PAGE_MASK);
1108 again:
1109         /*
1110          * vfio-iommu-type1 (v1) - User mappings were coalesced together to
1111          * avoid tracking individual mappings.  This means that the granularity
1112          * of the original mapping was lost and the user was allowed to attempt
1113          * to unmap any range.  Depending on the contiguousness of physical
1114          * memory and page sizes supported by the IOMMU, arbitrary unmaps may
1115          * or may not have worked.  We only guaranteed unmap granularity
1116          * matching the original mapping; even though it was untracked here,
1117          * the original mappings are reflected in IOMMU mappings.  This
1118          * resulted in a couple unusual behaviors.  First, if a range is not
1119          * able to be unmapped, ex. a set of 4k pages that was mapped as a
1120          * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
1121          * a zero sized unmap.  Also, if an unmap request overlaps the first
1122          * address of a hugepage, the IOMMU will unmap the entire hugepage.
1123          * This also returns success and the returned unmap size reflects the
1124          * actual size unmapped.
1125          *
1126          * We attempt to maintain compatibility with this "v1" interface, but
1127          * we take control out of the hands of the IOMMU.  Therefore, an unmap
1128          * request offset from the beginning of the original mapping will
1129          * return success with zero sized unmap.  And an unmap request covering
1130          * the first iova of mapping will unmap the entire range.
1131          *
1132          * The v2 version of this interface intends to be more deterministic.
1133          * Unmap requests must fully cover previous mappings.  Multiple
1134          * mappings may still be unmaped by specifying large ranges, but there
1135          * must not be any previous mappings bisected by the range.  An error
1136          * will be returned if these conditions are not met.  The v2 interface
1137          * will only return success and a size of zero if there were no
1138          * mappings within the range.
1139          */
1140         if (iommu->v2) {
1141                 dma = vfio_find_dma(iommu, unmap->iova, 1);
1142                 if (dma && dma->iova != unmap->iova) {
1143                         ret = -EINVAL;
1144                         goto unlock;
1145                 }
1146                 dma = vfio_find_dma(iommu, unmap->iova + unmap->size - 1, 0);
1147                 if (dma && dma->iova + dma->size != unmap->iova + unmap->size) {
1148                         ret = -EINVAL;
1149                         goto unlock;
1150                 }
1151         }
1152
1153         while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
1154                 if (!iommu->v2 && unmap->iova > dma->iova)
1155                         break;
1156                 /*
1157                  * Task with same address space who mapped this iova range is
1158                  * allowed to unmap the iova range.
1159                  */
1160                 if (dma->task->mm != current->mm)
1161                         break;
1162
1163                 if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
1164                         struct vfio_iommu_type1_dma_unmap nb_unmap;
1165
1166                         if (dma_last == dma) {
1167                                 BUG_ON(++retries > 10);
1168                         } else {
1169                                 dma_last = dma;
1170                                 retries = 0;
1171                         }
1172
1173                         nb_unmap.iova = dma->iova;
1174                         nb_unmap.size = dma->size;
1175
1176                         /*
1177                          * Notify anyone (mdev vendor drivers) to invalidate and
1178                          * unmap iovas within the range we're about to unmap.
1179                          * Vendor drivers MUST unpin pages in response to an
1180                          * invalidation.
1181                          */
1182                         mutex_unlock(&iommu->lock);
1183                         blocking_notifier_call_chain(&iommu->notifier,
1184                                                     VFIO_IOMMU_NOTIFY_DMA_UNMAP,
1185                                                     &nb_unmap);
1186                         mutex_lock(&iommu->lock);
1187                         goto again;
1188                 }
1189
1190                 if (unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
1191                         ret = update_user_bitmap(bitmap->data, iommu, dma,
1192                                                  unmap->iova, pgsize);
1193                         if (ret)
1194                                 break;
1195                 }
1196
1197                 unmapped += dma->size;
1198                 vfio_remove_dma(iommu, dma);
1199         }
1200
1201 unlock:
1202         mutex_unlock(&iommu->lock);
1203
1204         /* Report how much was unmapped */
1205         unmap->size = unmapped;
1206
1207         return ret;
1208 }
1209
1210 static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
1211                           unsigned long pfn, long npage, int prot)
1212 {
1213         struct vfio_domain *d;
1214         int ret;
1215
1216         list_for_each_entry(d, &iommu->domain_list, next) {
1217                 ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
1218                                 npage << PAGE_SHIFT, prot | d->prot);
1219                 if (ret)
1220                         goto unwind;
1221
1222                 cond_resched();
1223         }
1224
1225         return 0;
1226
1227 unwind:
1228         list_for_each_entry_continue_reverse(d, &iommu->domain_list, next)
1229                 iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
1230
1231         return ret;
1232 }
1233
1234 static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1235                             size_t map_size)
1236 {
1237         dma_addr_t iova = dma->iova;
1238         unsigned long vaddr = dma->vaddr;
1239         size_t size = map_size;
1240         long npage;
1241         unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1242         int ret = 0;
1243
1244         while (size) {
1245                 /* Pin a contiguous chunk of memory */
1246                 npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1247                                               size >> PAGE_SHIFT, &pfn, limit);
1248                 if (npage <= 0) {
1249                         WARN_ON(!npage);
1250                         ret = (int)npage;
1251                         break;
1252                 }
1253
1254                 /* Map it! */
1255                 ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1256                                      dma->prot);
1257                 if (ret) {
1258                         vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1259                                                 npage, true);
1260                         break;
1261                 }
1262
1263                 size -= npage << PAGE_SHIFT;
1264                 dma->size += npage << PAGE_SHIFT;
1265         }
1266
1267         dma->iommu_mapped = true;
1268
1269         if (ret)
1270                 vfio_remove_dma(iommu, dma);
1271
1272         return ret;
1273 }
1274
1275 /*
1276  * Check dma map request is within a valid iova range
1277  */
1278 static bool vfio_iommu_iova_dma_valid(struct vfio_iommu *iommu,
1279                                       dma_addr_t start, dma_addr_t end)
1280 {
1281         struct list_head *iova = &iommu->iova_list;
1282         struct vfio_iova *node;
1283
1284         list_for_each_entry(node, iova, list) {
1285                 if (start >= node->start && end <= node->end)
1286                         return true;
1287         }
1288
1289         /*
1290          * Check for list_empty() as well since a container with
1291          * a single mdev device will have an empty list.
1292          */
1293         return list_empty(iova);
1294 }
1295
1296 static int vfio_dma_do_map(struct vfio_iommu *iommu,
1297                            struct vfio_iommu_type1_dma_map *map)
1298 {
1299         dma_addr_t iova = map->iova;
1300         unsigned long vaddr = map->vaddr;
1301         size_t size = map->size;
1302         int ret = 0, prot = 0;
1303         size_t pgsize;
1304         struct vfio_dma *dma;
1305
1306         /* Verify that none of our __u64 fields overflow */
1307         if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1308                 return -EINVAL;
1309
1310         /* READ/WRITE from device perspective */
1311         if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1312                 prot |= IOMMU_WRITE;
1313         if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1314                 prot |= IOMMU_READ;
1315
1316         mutex_lock(&iommu->lock);
1317
1318         pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
1319
1320         WARN_ON((pgsize - 1) & PAGE_MASK);
1321
1322         if (!prot || !size || (size | iova | vaddr) & (pgsize - 1)) {
1323                 ret = -EINVAL;
1324                 goto out_unlock;
1325         }
1326
1327         /* Don't allow IOVA or virtual address wrap */
1328         if (iova + size - 1 < iova || vaddr + size - 1 < vaddr) {
1329                 ret = -EINVAL;
1330                 goto out_unlock;
1331         }
1332
1333         if (vfio_find_dma(iommu, iova, size)) {
1334                 ret = -EEXIST;
1335                 goto out_unlock;
1336         }
1337
1338         if (!iommu->dma_avail) {
1339                 ret = -ENOSPC;
1340                 goto out_unlock;
1341         }
1342
1343         if (!vfio_iommu_iova_dma_valid(iommu, iova, iova + size - 1)) {
1344                 ret = -EINVAL;
1345                 goto out_unlock;
1346         }
1347
1348         dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1349         if (!dma) {
1350                 ret = -ENOMEM;
1351                 goto out_unlock;
1352         }
1353
1354         iommu->dma_avail--;
1355         dma->iova = iova;
1356         dma->vaddr = vaddr;
1357         dma->prot = prot;
1358
1359         /*
1360          * We need to be able to both add to a task's locked memory and test
1361          * against the locked memory limit and we need to be able to do both
1362          * outside of this call path as pinning can be asynchronous via the
1363          * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1364          * task_struct and VM locked pages requires an mm_struct, however
1365          * holding an indefinite mm reference is not recommended, therefore we
1366          * only hold a reference to a task.  We could hold a reference to
1367          * current, however QEMU uses this call path through vCPU threads,
1368          * which can be killed resulting in a NULL mm and failure in the unmap
1369          * path when called via a different thread.  Avoid this problem by
1370          * using the group_leader as threads within the same group require
1371          * both CLONE_THREAD and CLONE_VM and will therefore use the same
1372          * mm_struct.
1373          *
1374          * Previously we also used the task for testing CAP_IPC_LOCK at the
1375          * time of pinning and accounting, however has_capability() makes use
1376          * of real_cred, a copy-on-write field, so we can't guarantee that it
1377          * matches group_leader, or in fact that it might not change by the
1378          * time it's evaluated.  If a process were to call MAP_DMA with
1379          * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
1380          * possibly see different results for an iommu_mapped vfio_dma vs
1381          * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
1382          * time of calling MAP_DMA.
1383          */
1384         get_task_struct(current->group_leader);
1385         dma->task = current->group_leader;
1386         dma->lock_cap = capable(CAP_IPC_LOCK);
1387
1388         dma->pfn_list = RB_ROOT;
1389
1390         /* Insert zero-sized and grow as we map chunks of it */
1391         vfio_link_dma(iommu, dma);
1392
1393         /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1394         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1395                 dma->size = size;
1396         else
1397                 ret = vfio_pin_map_dma(iommu, dma, size);
1398
1399         if (!ret && iommu->dirty_page_tracking) {
1400                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
1401                 if (ret)
1402                         vfio_remove_dma(iommu, dma);
1403         }
1404
1405 out_unlock:
1406         mutex_unlock(&iommu->lock);
1407         return ret;
1408 }
1409
1410 static int vfio_bus_type(struct device *dev, void *data)
1411 {
1412         struct bus_type **bus = data;
1413
1414         if (*bus && *bus != dev->bus)
1415                 return -EINVAL;
1416
1417         *bus = dev->bus;
1418
1419         return 0;
1420 }
1421
1422 static int vfio_iommu_replay(struct vfio_iommu *iommu,
1423                              struct vfio_domain *domain)
1424 {
1425         struct vfio_domain *d;
1426         struct rb_node *n;
1427         unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1428         int ret;
1429
1430         /* Arbitrarily pick the first domain in the list for lookups */
1431         d = list_first_entry(&iommu->domain_list, struct vfio_domain, next);
1432         n = rb_first(&iommu->dma_list);
1433
1434         for (; n; n = rb_next(n)) {
1435                 struct vfio_dma *dma;
1436                 dma_addr_t iova;
1437
1438                 dma = rb_entry(n, struct vfio_dma, node);
1439                 iova = dma->iova;
1440
1441                 while (iova < dma->iova + dma->size) {
1442                         phys_addr_t phys;
1443                         size_t size;
1444
1445                         if (dma->iommu_mapped) {
1446                                 phys_addr_t p;
1447                                 dma_addr_t i;
1448
1449                                 phys = iommu_iova_to_phys(d->domain, iova);
1450
1451                                 if (WARN_ON(!phys)) {
1452                                         iova += PAGE_SIZE;
1453                                         continue;
1454                                 }
1455
1456                                 size = PAGE_SIZE;
1457                                 p = phys + size;
1458                                 i = iova + size;
1459                                 while (i < dma->iova + dma->size &&
1460                                        p == iommu_iova_to_phys(d->domain, i)) {
1461                                         size += PAGE_SIZE;
1462                                         p += PAGE_SIZE;
1463                                         i += PAGE_SIZE;
1464                                 }
1465                         } else {
1466                                 unsigned long pfn;
1467                                 unsigned long vaddr = dma->vaddr +
1468                                                      (iova - dma->iova);
1469                                 size_t n = dma->iova + dma->size - iova;
1470                                 long npage;
1471
1472                                 npage = vfio_pin_pages_remote(dma, vaddr,
1473                                                               n >> PAGE_SHIFT,
1474                                                               &pfn, limit);
1475                                 if (npage <= 0) {
1476                                         WARN_ON(!npage);
1477                                         ret = (int)npage;
1478                                         return ret;
1479                                 }
1480
1481                                 phys = pfn << PAGE_SHIFT;
1482                                 size = npage << PAGE_SHIFT;
1483                         }
1484
1485                         ret = iommu_map(domain->domain, iova, phys,
1486                                         size, dma->prot | domain->prot);
1487                         if (ret)
1488                                 return ret;
1489
1490                         iova += size;
1491                 }
1492                 dma->iommu_mapped = true;
1493         }
1494         return 0;
1495 }
1496
1497 /*
1498  * We change our unmap behavior slightly depending on whether the IOMMU
1499  * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1500  * for practically any contiguous power-of-two mapping we give it.  This means
1501  * we don't need to look for contiguous chunks ourselves to make unmapping
1502  * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1503  * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1504  * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1505  * hugetlbfs is in use.
1506  */
1507 static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1508 {
1509         struct page *pages;
1510         int ret, order = get_order(PAGE_SIZE * 2);
1511
1512         pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1513         if (!pages)
1514                 return;
1515
1516         ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1517                         IOMMU_READ | IOMMU_WRITE | domain->prot);
1518         if (!ret) {
1519                 size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1520
1521                 if (unmapped == PAGE_SIZE)
1522                         iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1523                 else
1524                         domain->fgsp = true;
1525         }
1526
1527         __free_pages(pages, order);
1528 }
1529
1530 static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
1531                                            struct iommu_group *iommu_group)
1532 {
1533         struct vfio_group *g;
1534
1535         list_for_each_entry(g, &domain->group_list, next) {
1536                 if (g->iommu_group == iommu_group)
1537                         return g;
1538         }
1539
1540         return NULL;
1541 }
1542
1543 static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
1544                                                struct iommu_group *iommu_group)
1545 {
1546         struct vfio_domain *domain;
1547         struct vfio_group *group = NULL;
1548
1549         list_for_each_entry(domain, &iommu->domain_list, next) {
1550                 group = find_iommu_group(domain, iommu_group);
1551                 if (group)
1552                         return group;
1553         }
1554
1555         if (iommu->external_domain)
1556                 group = find_iommu_group(iommu->external_domain, iommu_group);
1557
1558         return group;
1559 }
1560
1561 static void update_pinned_page_dirty_scope(struct vfio_iommu *iommu)
1562 {
1563         struct vfio_domain *domain;
1564         struct vfio_group *group;
1565
1566         list_for_each_entry(domain, &iommu->domain_list, next) {
1567                 list_for_each_entry(group, &domain->group_list, next) {
1568                         if (!group->pinned_page_dirty_scope) {
1569                                 iommu->pinned_page_dirty_scope = false;
1570                                 return;
1571                         }
1572                 }
1573         }
1574
1575         if (iommu->external_domain) {
1576                 domain = iommu->external_domain;
1577                 list_for_each_entry(group, &domain->group_list, next) {
1578                         if (!group->pinned_page_dirty_scope) {
1579                                 iommu->pinned_page_dirty_scope = false;
1580                                 return;
1581                         }
1582                 }
1583         }
1584
1585         iommu->pinned_page_dirty_scope = true;
1586 }
1587
1588 static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
1589                                   phys_addr_t *base)
1590 {
1591         struct iommu_resv_region *region;
1592         bool ret = false;
1593
1594         list_for_each_entry(region, group_resv_regions, list) {
1595                 /*
1596                  * The presence of any 'real' MSI regions should take
1597                  * precedence over the software-managed one if the
1598                  * IOMMU driver happens to advertise both types.
1599                  */
1600                 if (region->type == IOMMU_RESV_MSI) {
1601                         ret = false;
1602                         break;
1603                 }
1604
1605                 if (region->type == IOMMU_RESV_SW_MSI) {
1606                         *base = region->start;
1607                         ret = true;
1608                 }
1609         }
1610
1611         return ret;
1612 }
1613
1614 static struct device *vfio_mdev_get_iommu_device(struct device *dev)
1615 {
1616         struct device *(*fn)(struct device *dev);
1617         struct device *iommu_device;
1618
1619         fn = symbol_get(mdev_get_iommu_device);
1620         if (fn) {
1621                 iommu_device = fn(dev);
1622                 symbol_put(mdev_get_iommu_device);
1623
1624                 return iommu_device;
1625         }
1626
1627         return NULL;
1628 }
1629
1630 static int vfio_mdev_attach_domain(struct device *dev, void *data)
1631 {
1632         struct iommu_domain *domain = data;
1633         struct device *iommu_device;
1634
1635         iommu_device = vfio_mdev_get_iommu_device(dev);
1636         if (iommu_device) {
1637                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1638                         return iommu_aux_attach_device(domain, iommu_device);
1639                 else
1640                         return iommu_attach_device(domain, iommu_device);
1641         }
1642
1643         return -EINVAL;
1644 }
1645
1646 static int vfio_mdev_detach_domain(struct device *dev, void *data)
1647 {
1648         struct iommu_domain *domain = data;
1649         struct device *iommu_device;
1650
1651         iommu_device = vfio_mdev_get_iommu_device(dev);
1652         if (iommu_device) {
1653                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1654                         iommu_aux_detach_device(domain, iommu_device);
1655                 else
1656                         iommu_detach_device(domain, iommu_device);
1657         }
1658
1659         return 0;
1660 }
1661
1662 static int vfio_iommu_attach_group(struct vfio_domain *domain,
1663                                    struct vfio_group *group)
1664 {
1665         if (group->mdev_group)
1666                 return iommu_group_for_each_dev(group->iommu_group,
1667                                                 domain->domain,
1668                                                 vfio_mdev_attach_domain);
1669         else
1670                 return iommu_attach_group(domain->domain, group->iommu_group);
1671 }
1672
1673 static void vfio_iommu_detach_group(struct vfio_domain *domain,
1674                                     struct vfio_group *group)
1675 {
1676         if (group->mdev_group)
1677                 iommu_group_for_each_dev(group->iommu_group, domain->domain,
1678                                          vfio_mdev_detach_domain);
1679         else
1680                 iommu_detach_group(domain->domain, group->iommu_group);
1681 }
1682
1683 static bool vfio_bus_is_mdev(struct bus_type *bus)
1684 {
1685         struct bus_type *mdev_bus;
1686         bool ret = false;
1687
1688         mdev_bus = symbol_get(mdev_bus_type);
1689         if (mdev_bus) {
1690                 ret = (bus == mdev_bus);
1691                 symbol_put(mdev_bus_type);
1692         }
1693
1694         return ret;
1695 }
1696
1697 static int vfio_mdev_iommu_device(struct device *dev, void *data)
1698 {
1699         struct device **old = data, *new;
1700
1701         new = vfio_mdev_get_iommu_device(dev);
1702         if (!new || (*old && *old != new))
1703                 return -EINVAL;
1704
1705         *old = new;
1706
1707         return 0;
1708 }
1709
1710 /*
1711  * This is a helper function to insert an address range to iova list.
1712  * The list is initially created with a single entry corresponding to
1713  * the IOMMU domain geometry to which the device group is attached.
1714  * The list aperture gets modified when a new domain is added to the
1715  * container if the new aperture doesn't conflict with the current one
1716  * or with any existing dma mappings. The list is also modified to
1717  * exclude any reserved regions associated with the device group.
1718  */
1719 static int vfio_iommu_iova_insert(struct list_head *head,
1720                                   dma_addr_t start, dma_addr_t end)
1721 {
1722         struct vfio_iova *region;
1723
1724         region = kmalloc(sizeof(*region), GFP_KERNEL);
1725         if (!region)
1726                 return -ENOMEM;
1727
1728         INIT_LIST_HEAD(&region->list);
1729         region->start = start;
1730         region->end = end;
1731
1732         list_add_tail(&region->list, head);
1733         return 0;
1734 }
1735
1736 /*
1737  * Check the new iommu aperture conflicts with existing aper or with any
1738  * existing dma mappings.
1739  */
1740 static bool vfio_iommu_aper_conflict(struct vfio_iommu *iommu,
1741                                      dma_addr_t start, dma_addr_t end)
1742 {
1743         struct vfio_iova *first, *last;
1744         struct list_head *iova = &iommu->iova_list;
1745
1746         if (list_empty(iova))
1747                 return false;
1748
1749         /* Disjoint sets, return conflict */
1750         first = list_first_entry(iova, struct vfio_iova, list);
1751         last = list_last_entry(iova, struct vfio_iova, list);
1752         if (start > last->end || end < first->start)
1753                 return true;
1754
1755         /* Check for any existing dma mappings below the new start */
1756         if (start > first->start) {
1757                 if (vfio_find_dma(iommu, first->start, start - first->start))
1758                         return true;
1759         }
1760
1761         /* Check for any existing dma mappings beyond the new end */
1762         if (end < last->end) {
1763                 if (vfio_find_dma(iommu, end + 1, last->end - end))
1764                         return true;
1765         }
1766
1767         return false;
1768 }
1769
1770 /*
1771  * Resize iommu iova aperture window. This is called only if the new
1772  * aperture has no conflict with existing aperture and dma mappings.
1773  */
1774 static int vfio_iommu_aper_resize(struct list_head *iova,
1775                                   dma_addr_t start, dma_addr_t end)
1776 {
1777         struct vfio_iova *node, *next;
1778
1779         if (list_empty(iova))
1780                 return vfio_iommu_iova_insert(iova, start, end);
1781
1782         /* Adjust iova list start */
1783         list_for_each_entry_safe(node, next, iova, list) {
1784                 if (start < node->start)
1785                         break;
1786                 if (start >= node->start && start < node->end) {
1787                         node->start = start;
1788                         break;
1789                 }
1790                 /* Delete nodes before new start */
1791                 list_del(&node->list);
1792                 kfree(node);
1793         }
1794
1795         /* Adjust iova list end */
1796         list_for_each_entry_safe(node, next, iova, list) {
1797                 if (end > node->end)
1798                         continue;
1799                 if (end > node->start && end <= node->end) {
1800                         node->end = end;
1801                         continue;
1802                 }
1803                 /* Delete nodes after new end */
1804                 list_del(&node->list);
1805                 kfree(node);
1806         }
1807
1808         return 0;
1809 }
1810
1811 /*
1812  * Check reserved region conflicts with existing dma mappings
1813  */
1814 static bool vfio_iommu_resv_conflict(struct vfio_iommu *iommu,
1815                                      struct list_head *resv_regions)
1816 {
1817         struct iommu_resv_region *region;
1818
1819         /* Check for conflict with existing dma mappings */
1820         list_for_each_entry(region, resv_regions, list) {
1821                 if (region->type == IOMMU_RESV_DIRECT_RELAXABLE)
1822                         continue;
1823
1824                 if (vfio_find_dma(iommu, region->start, region->length))
1825                         return true;
1826         }
1827
1828         return false;
1829 }
1830
1831 /*
1832  * Check iova region overlap with  reserved regions and
1833  * exclude them from the iommu iova range
1834  */
1835 static int vfio_iommu_resv_exclude(struct list_head *iova,
1836                                    struct list_head *resv_regions)
1837 {
1838         struct iommu_resv_region *resv;
1839         struct vfio_iova *n, *next;
1840
1841         list_for_each_entry(resv, resv_regions, list) {
1842                 phys_addr_t start, end;
1843
1844                 if (resv->type == IOMMU_RESV_DIRECT_RELAXABLE)
1845                         continue;
1846
1847                 start = resv->start;
1848                 end = resv->start + resv->length - 1;
1849
1850                 list_for_each_entry_safe(n, next, iova, list) {
1851                         int ret = 0;
1852
1853                         /* No overlap */
1854                         if (start > n->end || end < n->start)
1855                                 continue;
1856                         /*
1857                          * Insert a new node if current node overlaps with the
1858                          * reserve region to exlude that from valid iova range.
1859                          * Note that, new node is inserted before the current
1860                          * node and finally the current node is deleted keeping
1861                          * the list updated and sorted.
1862                          */
1863                         if (start > n->start)
1864                                 ret = vfio_iommu_iova_insert(&n->list, n->start,
1865                                                              start - 1);
1866                         if (!ret && end < n->end)
1867                                 ret = vfio_iommu_iova_insert(&n->list, end + 1,
1868                                                              n->end);
1869                         if (ret)
1870                                 return ret;
1871
1872                         list_del(&n->list);
1873                         kfree(n);
1874                 }
1875         }
1876
1877         if (list_empty(iova))
1878                 return -EINVAL;
1879
1880         return 0;
1881 }
1882
1883 static void vfio_iommu_resv_free(struct list_head *resv_regions)
1884 {
1885         struct iommu_resv_region *n, *next;
1886
1887         list_for_each_entry_safe(n, next, resv_regions, list) {
1888                 list_del(&n->list);
1889                 kfree(n);
1890         }
1891 }
1892
1893 static void vfio_iommu_iova_free(struct list_head *iova)
1894 {
1895         struct vfio_iova *n, *next;
1896
1897         list_for_each_entry_safe(n, next, iova, list) {
1898                 list_del(&n->list);
1899                 kfree(n);
1900         }
1901 }
1902
1903 static int vfio_iommu_iova_get_copy(struct vfio_iommu *iommu,
1904                                     struct list_head *iova_copy)
1905 {
1906         struct list_head *iova = &iommu->iova_list;
1907         struct vfio_iova *n;
1908         int ret;
1909
1910         list_for_each_entry(n, iova, list) {
1911                 ret = vfio_iommu_iova_insert(iova_copy, n->start, n->end);
1912                 if (ret)
1913                         goto out_free;
1914         }
1915
1916         return 0;
1917
1918 out_free:
1919         vfio_iommu_iova_free(iova_copy);
1920         return ret;
1921 }
1922
1923 static void vfio_iommu_iova_insert_copy(struct vfio_iommu *iommu,
1924                                         struct list_head *iova_copy)
1925 {
1926         struct list_head *iova = &iommu->iova_list;
1927
1928         vfio_iommu_iova_free(iova);
1929
1930         list_splice_tail(iova_copy, iova);
1931 }
1932 static int vfio_iommu_type1_attach_group(void *iommu_data,
1933                                          struct iommu_group *iommu_group)
1934 {
1935         struct vfio_iommu *iommu = iommu_data;
1936         struct vfio_group *group;
1937         struct vfio_domain *domain, *d;
1938         struct bus_type *bus = NULL;
1939         int ret;
1940         bool resv_msi, msi_remap;
1941         phys_addr_t resv_msi_base = 0;
1942         struct iommu_domain_geometry geo;
1943         LIST_HEAD(iova_copy);
1944         LIST_HEAD(group_resv_regions);
1945
1946         mutex_lock(&iommu->lock);
1947
1948         list_for_each_entry(d, &iommu->domain_list, next) {
1949                 if (find_iommu_group(d, iommu_group)) {
1950                         mutex_unlock(&iommu->lock);
1951                         return -EINVAL;
1952                 }
1953         }
1954
1955         if (iommu->external_domain) {
1956                 if (find_iommu_group(iommu->external_domain, iommu_group)) {
1957                         mutex_unlock(&iommu->lock);
1958                         return -EINVAL;
1959                 }
1960         }
1961
1962         group = kzalloc(sizeof(*group), GFP_KERNEL);
1963         domain = kzalloc(sizeof(*domain), GFP_KERNEL);
1964         if (!group || !domain) {
1965                 ret = -ENOMEM;
1966                 goto out_free;
1967         }
1968
1969         group->iommu_group = iommu_group;
1970
1971         /* Determine bus_type in order to allocate a domain */
1972         ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
1973         if (ret)
1974                 goto out_free;
1975
1976         if (vfio_bus_is_mdev(bus)) {
1977                 struct device *iommu_device = NULL;
1978
1979                 group->mdev_group = true;
1980
1981                 /* Determine the isolation type */
1982                 ret = iommu_group_for_each_dev(iommu_group, &iommu_device,
1983                                                vfio_mdev_iommu_device);
1984                 if (ret || !iommu_device) {
1985                         if (!iommu->external_domain) {
1986                                 INIT_LIST_HEAD(&domain->group_list);
1987                                 iommu->external_domain = domain;
1988                                 vfio_update_pgsize_bitmap(iommu);
1989                         } else {
1990                                 kfree(domain);
1991                         }
1992
1993                         list_add(&group->next,
1994                                  &iommu->external_domain->group_list);
1995                         /*
1996                          * Non-iommu backed group cannot dirty memory directly,
1997                          * it can only use interfaces that provide dirty
1998                          * tracking.
1999                          * The iommu scope can only be promoted with the
2000                          * addition of a dirty tracking group.
2001                          */
2002                         group->pinned_page_dirty_scope = true;
2003                         if (!iommu->pinned_page_dirty_scope)
2004                                 update_pinned_page_dirty_scope(iommu);
2005                         mutex_unlock(&iommu->lock);
2006
2007                         return 0;
2008                 }
2009
2010                 bus = iommu_device->bus;
2011         }
2012
2013         domain->domain = iommu_domain_alloc(bus);
2014         if (!domain->domain) {
2015                 ret = -EIO;
2016                 goto out_free;
2017         }
2018
2019         if (iommu->nesting) {
2020                 int attr = 1;
2021
2022                 ret = iommu_domain_set_attr(domain->domain, DOMAIN_ATTR_NESTING,
2023                                             &attr);
2024                 if (ret)
2025                         goto out_domain;
2026         }
2027
2028         ret = vfio_iommu_attach_group(domain, group);
2029         if (ret)
2030                 goto out_domain;
2031
2032         /* Get aperture info */
2033         iommu_domain_get_attr(domain->domain, DOMAIN_ATTR_GEOMETRY, &geo);
2034
2035         if (vfio_iommu_aper_conflict(iommu, geo.aperture_start,
2036                                      geo.aperture_end)) {
2037                 ret = -EINVAL;
2038                 goto out_detach;
2039         }
2040
2041         ret = iommu_get_group_resv_regions(iommu_group, &group_resv_regions);
2042         if (ret)
2043                 goto out_detach;
2044
2045         if (vfio_iommu_resv_conflict(iommu, &group_resv_regions)) {
2046                 ret = -EINVAL;
2047                 goto out_detach;
2048         }
2049
2050         /*
2051          * We don't want to work on the original iova list as the list
2052          * gets modified and in case of failure we have to retain the
2053          * original list. Get a copy here.
2054          */
2055         ret = vfio_iommu_iova_get_copy(iommu, &iova_copy);
2056         if (ret)
2057                 goto out_detach;
2058
2059         ret = vfio_iommu_aper_resize(&iova_copy, geo.aperture_start,
2060                                      geo.aperture_end);
2061         if (ret)
2062                 goto out_detach;
2063
2064         ret = vfio_iommu_resv_exclude(&iova_copy, &group_resv_regions);
2065         if (ret)
2066                 goto out_detach;
2067
2068         resv_msi = vfio_iommu_has_sw_msi(&group_resv_regions, &resv_msi_base);
2069
2070         INIT_LIST_HEAD(&domain->group_list);
2071         list_add(&group->next, &domain->group_list);
2072
2073         msi_remap = irq_domain_check_msi_remap() ||
2074                     iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
2075
2076         if (!allow_unsafe_interrupts && !msi_remap) {
2077                 pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
2078                        __func__);
2079                 ret = -EPERM;
2080                 goto out_detach;
2081         }
2082
2083         if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
2084                 domain->prot |= IOMMU_CACHE;
2085
2086         /*
2087          * Try to match an existing compatible domain.  We don't want to
2088          * preclude an IOMMU driver supporting multiple bus_types and being
2089          * able to include different bus_types in the same IOMMU domain, so
2090          * we test whether the domains use the same iommu_ops rather than
2091          * testing if they're on the same bus_type.
2092          */
2093         list_for_each_entry(d, &iommu->domain_list, next) {
2094                 if (d->domain->ops == domain->domain->ops &&
2095                     d->prot == domain->prot) {
2096                         vfio_iommu_detach_group(domain, group);
2097                         if (!vfio_iommu_attach_group(d, group)) {
2098                                 list_add(&group->next, &d->group_list);
2099                                 iommu_domain_free(domain->domain);
2100                                 kfree(domain);
2101                                 goto done;
2102                         }
2103
2104                         ret = vfio_iommu_attach_group(domain, group);
2105                         if (ret)
2106                                 goto out_domain;
2107                 }
2108         }
2109
2110         vfio_test_domain_fgsp(domain);
2111
2112         /* replay mappings on new domains */
2113         ret = vfio_iommu_replay(iommu, domain);
2114         if (ret)
2115                 goto out_detach;
2116
2117         if (resv_msi) {
2118                 ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
2119                 if (ret && ret != -ENODEV)
2120                         goto out_detach;
2121         }
2122
2123         list_add(&domain->next, &iommu->domain_list);
2124         vfio_update_pgsize_bitmap(iommu);
2125 done:
2126         /* Delete the old one and insert new iova list */
2127         vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2128
2129         /*
2130          * An iommu backed group can dirty memory directly and therefore
2131          * demotes the iommu scope until it declares itself dirty tracking
2132          * capable via the page pinning interface.
2133          */
2134         iommu->pinned_page_dirty_scope = false;
2135         mutex_unlock(&iommu->lock);
2136         vfio_iommu_resv_free(&group_resv_regions);
2137
2138         return 0;
2139
2140 out_detach:
2141         vfio_iommu_detach_group(domain, group);
2142 out_domain:
2143         iommu_domain_free(domain->domain);
2144         vfio_iommu_iova_free(&iova_copy);
2145         vfio_iommu_resv_free(&group_resv_regions);
2146 out_free:
2147         kfree(domain);
2148         kfree(group);
2149         mutex_unlock(&iommu->lock);
2150         return ret;
2151 }
2152
2153 static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
2154 {
2155         struct rb_node *node;
2156
2157         while ((node = rb_first(&iommu->dma_list)))
2158                 vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
2159 }
2160
2161 static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
2162 {
2163         struct rb_node *n, *p;
2164
2165         n = rb_first(&iommu->dma_list);
2166         for (; n; n = rb_next(n)) {
2167                 struct vfio_dma *dma;
2168                 long locked = 0, unlocked = 0;
2169
2170                 dma = rb_entry(n, struct vfio_dma, node);
2171                 unlocked += vfio_unmap_unpin(iommu, dma, false);
2172                 p = rb_first(&dma->pfn_list);
2173                 for (; p; p = rb_next(p)) {
2174                         struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
2175                                                          node);
2176
2177                         if (!is_invalid_reserved_pfn(vpfn->pfn))
2178                                 locked++;
2179                 }
2180                 vfio_lock_acct(dma, locked - unlocked, true);
2181         }
2182 }
2183
2184 static void vfio_sanity_check_pfn_list(struct vfio_iommu *iommu)
2185 {
2186         struct rb_node *n;
2187
2188         n = rb_first(&iommu->dma_list);
2189         for (; n; n = rb_next(n)) {
2190                 struct vfio_dma *dma;
2191
2192                 dma = rb_entry(n, struct vfio_dma, node);
2193
2194                 if (WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list)))
2195                         break;
2196         }
2197         /* mdev vendor driver must unregister notifier */
2198         WARN_ON(iommu->notifier.head);
2199 }
2200
2201 /*
2202  * Called when a domain is removed in detach. It is possible that
2203  * the removed domain decided the iova aperture window. Modify the
2204  * iova aperture with the smallest window among existing domains.
2205  */
2206 static void vfio_iommu_aper_expand(struct vfio_iommu *iommu,
2207                                    struct list_head *iova_copy)
2208 {
2209         struct vfio_domain *domain;
2210         struct iommu_domain_geometry geo;
2211         struct vfio_iova *node;
2212         dma_addr_t start = 0;
2213         dma_addr_t end = (dma_addr_t)~0;
2214
2215         if (list_empty(iova_copy))
2216                 return;
2217
2218         list_for_each_entry(domain, &iommu->domain_list, next) {
2219                 iommu_domain_get_attr(domain->domain, DOMAIN_ATTR_GEOMETRY,
2220                                       &geo);
2221                 if (geo.aperture_start > start)
2222                         start = geo.aperture_start;
2223                 if (geo.aperture_end < end)
2224                         end = geo.aperture_end;
2225         }
2226
2227         /* Modify aperture limits. The new aper is either same or bigger */
2228         node = list_first_entry(iova_copy, struct vfio_iova, list);
2229         node->start = start;
2230         node = list_last_entry(iova_copy, struct vfio_iova, list);
2231         node->end = end;
2232 }
2233
2234 /*
2235  * Called when a group is detached. The reserved regions for that
2236  * group can be part of valid iova now. But since reserved regions
2237  * may be duplicated among groups, populate the iova valid regions
2238  * list again.
2239  */
2240 static int vfio_iommu_resv_refresh(struct vfio_iommu *iommu,
2241                                    struct list_head *iova_copy)
2242 {
2243         struct vfio_domain *d;
2244         struct vfio_group *g;
2245         struct vfio_iova *node;
2246         dma_addr_t start, end;
2247         LIST_HEAD(resv_regions);
2248         int ret;
2249
2250         if (list_empty(iova_copy))
2251                 return -EINVAL;
2252
2253         list_for_each_entry(d, &iommu->domain_list, next) {
2254                 list_for_each_entry(g, &d->group_list, next) {
2255                         ret = iommu_get_group_resv_regions(g->iommu_group,
2256                                                            &resv_regions);
2257                         if (ret)
2258                                 goto done;
2259                 }
2260         }
2261
2262         node = list_first_entry(iova_copy, struct vfio_iova, list);
2263         start = node->start;
2264         node = list_last_entry(iova_copy, struct vfio_iova, list);
2265         end = node->end;
2266
2267         /* purge the iova list and create new one */
2268         vfio_iommu_iova_free(iova_copy);
2269
2270         ret = vfio_iommu_aper_resize(iova_copy, start, end);
2271         if (ret)
2272                 goto done;
2273
2274         /* Exclude current reserved regions from iova ranges */
2275         ret = vfio_iommu_resv_exclude(iova_copy, &resv_regions);
2276 done:
2277         vfio_iommu_resv_free(&resv_regions);
2278         return ret;
2279 }
2280
2281 static void vfio_iommu_type1_detach_group(void *iommu_data,
2282                                           struct iommu_group *iommu_group)
2283 {
2284         struct vfio_iommu *iommu = iommu_data;
2285         struct vfio_domain *domain;
2286         struct vfio_group *group;
2287         bool update_dirty_scope = false;
2288         LIST_HEAD(iova_copy);
2289
2290         mutex_lock(&iommu->lock);
2291
2292         if (iommu->external_domain) {
2293                 group = find_iommu_group(iommu->external_domain, iommu_group);
2294                 if (group) {
2295                         update_dirty_scope = !group->pinned_page_dirty_scope;
2296                         list_del(&group->next);
2297                         kfree(group);
2298
2299                         if (list_empty(&iommu->external_domain->group_list)) {
2300                                 vfio_sanity_check_pfn_list(iommu);
2301
2302                                 if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
2303                                         vfio_iommu_unmap_unpin_all(iommu);
2304
2305                                 kfree(iommu->external_domain);
2306                                 iommu->external_domain = NULL;
2307                         }
2308                         goto detach_group_done;
2309                 }
2310         }
2311
2312         /*
2313          * Get a copy of iova list. This will be used to update
2314          * and to replace the current one later. Please note that
2315          * we will leave the original list as it is if update fails.
2316          */
2317         vfio_iommu_iova_get_copy(iommu, &iova_copy);
2318
2319         list_for_each_entry(domain, &iommu->domain_list, next) {
2320                 group = find_iommu_group(domain, iommu_group);
2321                 if (!group)
2322                         continue;
2323
2324                 vfio_iommu_detach_group(domain, group);
2325                 update_dirty_scope = !group->pinned_page_dirty_scope;
2326                 list_del(&group->next);
2327                 kfree(group);
2328                 /*
2329                  * Group ownership provides privilege, if the group list is
2330                  * empty, the domain goes away. If it's the last domain with
2331                  * iommu and external domain doesn't exist, then all the
2332                  * mappings go away too. If it's the last domain with iommu and
2333                  * external domain exist, update accounting
2334                  */
2335                 if (list_empty(&domain->group_list)) {
2336                         if (list_is_singular(&iommu->domain_list)) {
2337                                 if (!iommu->external_domain)
2338                                         vfio_iommu_unmap_unpin_all(iommu);
2339                                 else
2340                                         vfio_iommu_unmap_unpin_reaccount(iommu);
2341                         }
2342                         iommu_domain_free(domain->domain);
2343                         list_del(&domain->next);
2344                         kfree(domain);
2345                         vfio_iommu_aper_expand(iommu, &iova_copy);
2346                         vfio_update_pgsize_bitmap(iommu);
2347                 }
2348                 break;
2349         }
2350
2351         if (!vfio_iommu_resv_refresh(iommu, &iova_copy))
2352                 vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2353         else
2354                 vfio_iommu_iova_free(&iova_copy);
2355
2356 detach_group_done:
2357         /*
2358          * Removal of a group without dirty tracking may allow the iommu scope
2359          * to be promoted.
2360          */
2361         if (update_dirty_scope)
2362                 update_pinned_page_dirty_scope(iommu);
2363         mutex_unlock(&iommu->lock);
2364 }
2365
2366 static void *vfio_iommu_type1_open(unsigned long arg)
2367 {
2368         struct vfio_iommu *iommu;
2369
2370         iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
2371         if (!iommu)
2372                 return ERR_PTR(-ENOMEM);
2373
2374         switch (arg) {
2375         case VFIO_TYPE1_IOMMU:
2376                 break;
2377         case VFIO_TYPE1_NESTING_IOMMU:
2378                 iommu->nesting = true;
2379                 /* fall through */
2380         case VFIO_TYPE1v2_IOMMU:
2381                 iommu->v2 = true;
2382                 break;
2383         default:
2384                 kfree(iommu);
2385                 return ERR_PTR(-EINVAL);
2386         }
2387
2388         INIT_LIST_HEAD(&iommu->domain_list);
2389         INIT_LIST_HEAD(&iommu->iova_list);
2390         iommu->dma_list = RB_ROOT;
2391         iommu->dma_avail = dma_entry_limit;
2392         mutex_init(&iommu->lock);
2393         BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
2394
2395         return iommu;
2396 }
2397
2398 static void vfio_release_domain(struct vfio_domain *domain, bool external)
2399 {
2400         struct vfio_group *group, *group_tmp;
2401
2402         list_for_each_entry_safe(group, group_tmp,
2403                                  &domain->group_list, next) {
2404                 if (!external)
2405                         vfio_iommu_detach_group(domain, group);
2406                 list_del(&group->next);
2407                 kfree(group);
2408         }
2409
2410         if (!external)
2411                 iommu_domain_free(domain->domain);
2412 }
2413
2414 static void vfio_iommu_type1_release(void *iommu_data)
2415 {
2416         struct vfio_iommu *iommu = iommu_data;
2417         struct vfio_domain *domain, *domain_tmp;
2418
2419         if (iommu->external_domain) {
2420                 vfio_release_domain(iommu->external_domain, true);
2421                 vfio_sanity_check_pfn_list(iommu);
2422                 kfree(iommu->external_domain);
2423         }
2424
2425         vfio_iommu_unmap_unpin_all(iommu);
2426
2427         list_for_each_entry_safe(domain, domain_tmp,
2428                                  &iommu->domain_list, next) {
2429                 vfio_release_domain(domain, false);
2430                 list_del(&domain->next);
2431                 kfree(domain);
2432         }
2433
2434         vfio_iommu_iova_free(&iommu->iova_list);
2435
2436         kfree(iommu);
2437 }
2438
2439 static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
2440 {
2441         struct vfio_domain *domain;
2442         int ret = 1;
2443
2444         mutex_lock(&iommu->lock);
2445         list_for_each_entry(domain, &iommu->domain_list, next) {
2446                 if (!(domain->prot & IOMMU_CACHE)) {
2447                         ret = 0;
2448                         break;
2449                 }
2450         }
2451         mutex_unlock(&iommu->lock);
2452
2453         return ret;
2454 }
2455
2456 static int vfio_iommu_iova_add_cap(struct vfio_info_cap *caps,
2457                  struct vfio_iommu_type1_info_cap_iova_range *cap_iovas,
2458                  size_t size)
2459 {
2460         struct vfio_info_cap_header *header;
2461         struct vfio_iommu_type1_info_cap_iova_range *iova_cap;
2462
2463         header = vfio_info_cap_add(caps, size,
2464                                    VFIO_IOMMU_TYPE1_INFO_CAP_IOVA_RANGE, 1);
2465         if (IS_ERR(header))
2466                 return PTR_ERR(header);
2467
2468         iova_cap = container_of(header,
2469                                 struct vfio_iommu_type1_info_cap_iova_range,
2470                                 header);
2471         iova_cap->nr_iovas = cap_iovas->nr_iovas;
2472         memcpy(iova_cap->iova_ranges, cap_iovas->iova_ranges,
2473                cap_iovas->nr_iovas * sizeof(*cap_iovas->iova_ranges));
2474         return 0;
2475 }
2476
2477 static int vfio_iommu_iova_build_caps(struct vfio_iommu *iommu,
2478                                       struct vfio_info_cap *caps)
2479 {
2480         struct vfio_iommu_type1_info_cap_iova_range *cap_iovas;
2481         struct vfio_iova *iova;
2482         size_t size;
2483         int iovas = 0, i = 0, ret;
2484
2485         list_for_each_entry(iova, &iommu->iova_list, list)
2486                 iovas++;
2487
2488         if (!iovas) {
2489                 /*
2490                  * Return 0 as a container with a single mdev device
2491                  * will have an empty list
2492                  */
2493                 return 0;
2494         }
2495
2496         size = sizeof(*cap_iovas) + (iovas * sizeof(*cap_iovas->iova_ranges));
2497
2498         cap_iovas = kzalloc(size, GFP_KERNEL);
2499         if (!cap_iovas)
2500                 return -ENOMEM;
2501
2502         cap_iovas->nr_iovas = iovas;
2503
2504         list_for_each_entry(iova, &iommu->iova_list, list) {
2505                 cap_iovas->iova_ranges[i].start = iova->start;
2506                 cap_iovas->iova_ranges[i].end = iova->end;
2507                 i++;
2508         }
2509
2510         ret = vfio_iommu_iova_add_cap(caps, cap_iovas, size);
2511
2512         kfree(cap_iovas);
2513         return ret;
2514 }
2515
2516 static int vfio_iommu_migration_build_caps(struct vfio_iommu *iommu,
2517                                            struct vfio_info_cap *caps)
2518 {
2519         struct vfio_iommu_type1_info_cap_migration cap_mig;
2520
2521         cap_mig.header.id = VFIO_IOMMU_TYPE1_INFO_CAP_MIGRATION;
2522         cap_mig.header.version = 1;
2523
2524         cap_mig.flags = 0;
2525         /* support minimum pgsize */
2526         cap_mig.pgsize_bitmap = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2527         cap_mig.max_dirty_bitmap_size = DIRTY_BITMAP_SIZE_MAX;
2528
2529         return vfio_info_add_capability(caps, &cap_mig.header, sizeof(cap_mig));
2530 }
2531
2532 static long vfio_iommu_type1_ioctl(void *iommu_data,
2533                                    unsigned int cmd, unsigned long arg)
2534 {
2535         struct vfio_iommu *iommu = iommu_data;
2536         unsigned long minsz;
2537
2538         if (cmd == VFIO_CHECK_EXTENSION) {
2539                 switch (arg) {
2540                 case VFIO_TYPE1_IOMMU:
2541                 case VFIO_TYPE1v2_IOMMU:
2542                 case VFIO_TYPE1_NESTING_IOMMU:
2543                         return 1;
2544                 case VFIO_DMA_CC_IOMMU:
2545                         if (!iommu)
2546                                 return 0;
2547                         return vfio_domains_have_iommu_cache(iommu);
2548                 default:
2549                         return 0;
2550                 }
2551         } else if (cmd == VFIO_IOMMU_GET_INFO) {
2552                 struct vfio_iommu_type1_info info;
2553                 struct vfio_info_cap caps = { .buf = NULL, .size = 0 };
2554                 unsigned long capsz;
2555                 int ret;
2556
2557                 minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
2558
2559                 /* For backward compatibility, cannot require this */
2560                 capsz = offsetofend(struct vfio_iommu_type1_info, cap_offset);
2561
2562                 if (copy_from_user(&info, (void __user *)arg, minsz))
2563                         return -EFAULT;
2564
2565                 if (info.argsz < minsz)
2566                         return -EINVAL;
2567
2568                 if (info.argsz >= capsz) {
2569                         minsz = capsz;
2570                         info.cap_offset = 0; /* output, no-recopy necessary */
2571                 }
2572
2573                 mutex_lock(&iommu->lock);
2574                 info.flags = VFIO_IOMMU_INFO_PGSIZES;
2575
2576                 info.iova_pgsizes = iommu->pgsize_bitmap;
2577
2578                 ret = vfio_iommu_migration_build_caps(iommu, &caps);
2579
2580                 if (!ret)
2581                         ret = vfio_iommu_iova_build_caps(iommu, &caps);
2582
2583                 mutex_unlock(&iommu->lock);
2584
2585                 if (ret)
2586                         return ret;
2587
2588                 if (caps.size) {
2589                         info.flags |= VFIO_IOMMU_INFO_CAPS;
2590
2591                         if (info.argsz < sizeof(info) + caps.size) {
2592                                 info.argsz = sizeof(info) + caps.size;
2593                         } else {
2594                                 vfio_info_cap_shift(&caps, sizeof(info));
2595                                 if (copy_to_user((void __user *)arg +
2596                                                 sizeof(info), caps.buf,
2597                                                 caps.size)) {
2598                                         kfree(caps.buf);
2599                                         return -EFAULT;
2600                                 }
2601                                 info.cap_offset = sizeof(info);
2602                         }
2603
2604                         kfree(caps.buf);
2605                 }
2606
2607                 return copy_to_user((void __user *)arg, &info, minsz) ?
2608                         -EFAULT : 0;
2609
2610         } else if (cmd == VFIO_IOMMU_MAP_DMA) {
2611                 struct vfio_iommu_type1_dma_map map;
2612                 uint32_t mask = VFIO_DMA_MAP_FLAG_READ |
2613                                 VFIO_DMA_MAP_FLAG_WRITE;
2614
2615                 minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
2616
2617                 if (copy_from_user(&map, (void __user *)arg, minsz))
2618                         return -EFAULT;
2619
2620                 if (map.argsz < minsz || map.flags & ~mask)
2621                         return -EINVAL;
2622
2623                 return vfio_dma_do_map(iommu, &map);
2624
2625         } else if (cmd == VFIO_IOMMU_UNMAP_DMA) {
2626                 struct vfio_iommu_type1_dma_unmap unmap;
2627                 struct vfio_bitmap bitmap = { 0 };
2628                 int ret;
2629
2630                 minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
2631
2632                 if (copy_from_user(&unmap, (void __user *)arg, minsz))
2633                         return -EFAULT;
2634
2635                 if (unmap.argsz < minsz ||
2636                     unmap.flags & ~VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP)
2637                         return -EINVAL;
2638
2639                 if (unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
2640                         unsigned long pgshift;
2641
2642                         if (unmap.argsz < (minsz + sizeof(bitmap)))
2643                                 return -EINVAL;
2644
2645                         if (copy_from_user(&bitmap,
2646                                            (void __user *)(arg + minsz),
2647                                            sizeof(bitmap)))
2648                                 return -EFAULT;
2649
2650                         if (!access_ok((void __user *)bitmap.data, bitmap.size))
2651                                 return -EINVAL;
2652
2653                         pgshift = __ffs(bitmap.pgsize);
2654                         ret = verify_bitmap_size(unmap.size >> pgshift,
2655                                                  bitmap.size);
2656                         if (ret)
2657                                 return ret;
2658                 }
2659
2660                 ret = vfio_dma_do_unmap(iommu, &unmap, &bitmap);
2661                 if (ret)
2662                         return ret;
2663
2664                 return copy_to_user((void __user *)arg, &unmap, minsz) ?
2665                         -EFAULT : 0;
2666         } else if (cmd == VFIO_IOMMU_DIRTY_PAGES) {
2667                 struct vfio_iommu_type1_dirty_bitmap dirty;
2668                 uint32_t mask = VFIO_IOMMU_DIRTY_PAGES_FLAG_START |
2669                                 VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP |
2670                                 VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP;
2671                 int ret = 0;
2672
2673                 if (!iommu->v2)
2674                         return -EACCES;
2675
2676                 minsz = offsetofend(struct vfio_iommu_type1_dirty_bitmap,
2677                                     flags);
2678
2679                 if (copy_from_user(&dirty, (void __user *)arg, minsz))
2680                         return -EFAULT;
2681
2682                 if (dirty.argsz < minsz || dirty.flags & ~mask)
2683                         return -EINVAL;
2684
2685                 /* only one flag should be set at a time */
2686                 if (__ffs(dirty.flags) != __fls(dirty.flags))
2687                         return -EINVAL;
2688
2689                 if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_START) {
2690                         size_t pgsize;
2691
2692                         mutex_lock(&iommu->lock);
2693                         pgsize = 1 << __ffs(iommu->pgsize_bitmap);
2694                         if (!iommu->dirty_page_tracking) {
2695                                 ret = vfio_dma_bitmap_alloc_all(iommu, pgsize);
2696                                 if (!ret)
2697                                         iommu->dirty_page_tracking = true;
2698                         }
2699                         mutex_unlock(&iommu->lock);
2700                         return ret;
2701                 } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP) {
2702                         mutex_lock(&iommu->lock);
2703                         if (iommu->dirty_page_tracking) {
2704                                 iommu->dirty_page_tracking = false;
2705                                 vfio_dma_bitmap_free_all(iommu);
2706                         }
2707                         mutex_unlock(&iommu->lock);
2708                         return 0;
2709                 } else if (dirty.flags &
2710                                  VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP) {
2711                         struct vfio_iommu_type1_dirty_bitmap_get range;
2712                         unsigned long pgshift;
2713                         size_t data_size = dirty.argsz - minsz;
2714                         size_t iommu_pgsize;
2715
2716                         if (!data_size || data_size < sizeof(range))
2717                                 return -EINVAL;
2718
2719                         if (copy_from_user(&range, (void __user *)(arg + minsz),
2720                                            sizeof(range)))
2721                                 return -EFAULT;
2722
2723                         if (range.iova + range.size < range.iova)
2724                                 return -EINVAL;
2725                         if (!access_ok((void __user *)range.bitmap.data,
2726                                        range.bitmap.size))
2727                                 return -EINVAL;
2728
2729                         pgshift = __ffs(range.bitmap.pgsize);
2730                         ret = verify_bitmap_size(range.size >> pgshift,
2731                                                  range.bitmap.size);
2732                         if (ret)
2733                                 return ret;
2734
2735                         mutex_lock(&iommu->lock);
2736
2737                         iommu_pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2738
2739                         /* allow only smallest supported pgsize */
2740                         if (range.bitmap.pgsize != iommu_pgsize) {
2741                                 ret = -EINVAL;
2742                                 goto out_unlock;
2743                         }
2744                         if (range.iova & (iommu_pgsize - 1)) {
2745                                 ret = -EINVAL;
2746                                 goto out_unlock;
2747                         }
2748                         if (!range.size || range.size & (iommu_pgsize - 1)) {
2749                                 ret = -EINVAL;
2750                                 goto out_unlock;
2751                         }
2752
2753                         if (iommu->dirty_page_tracking)
2754                                 ret = vfio_iova_dirty_bitmap(range.bitmap.data,
2755                                                 iommu, range.iova, range.size,
2756                                                 range.bitmap.pgsize);
2757                         else
2758                                 ret = -EINVAL;
2759 out_unlock:
2760                         mutex_unlock(&iommu->lock);
2761
2762                         return ret;
2763                 }
2764         }
2765
2766         return -ENOTTY;
2767 }
2768
2769 static int vfio_iommu_type1_register_notifier(void *iommu_data,
2770                                               unsigned long *events,
2771                                               struct notifier_block *nb)
2772 {
2773         struct vfio_iommu *iommu = iommu_data;
2774
2775         /* clear known events */
2776         *events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
2777
2778         /* refuse to register if still events remaining */
2779         if (*events)
2780                 return -EINVAL;
2781
2782         return blocking_notifier_chain_register(&iommu->notifier, nb);
2783 }
2784
2785 static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
2786                                                 struct notifier_block *nb)
2787 {
2788         struct vfio_iommu *iommu = iommu_data;
2789
2790         return blocking_notifier_chain_unregister(&iommu->notifier, nb);
2791 }
2792
2793 static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
2794                                          dma_addr_t user_iova, void *data,
2795                                          size_t count, bool write,
2796                                          size_t *copied)
2797 {
2798         struct mm_struct *mm;
2799         unsigned long vaddr;
2800         struct vfio_dma *dma;
2801         bool kthread = current->mm == NULL;
2802         size_t offset;
2803
2804         *copied = 0;
2805
2806         dma = vfio_find_dma(iommu, user_iova, 1);
2807         if (!dma)
2808                 return -EINVAL;
2809
2810         if ((write && !(dma->prot & IOMMU_WRITE)) ||
2811                         !(dma->prot & IOMMU_READ))
2812                 return -EPERM;
2813
2814         mm = get_task_mm(dma->task);
2815
2816         if (!mm)
2817                 return -EPERM;
2818
2819         if (kthread)
2820                 use_mm(mm);
2821         else if (current->mm != mm)
2822                 goto out;
2823
2824         offset = user_iova - dma->iova;
2825
2826         if (count > dma->size - offset)
2827                 count = dma->size - offset;
2828
2829         vaddr = dma->vaddr + offset;
2830
2831         if (write) {
2832                 *copied = copy_to_user((void __user *)vaddr, data,
2833                                          count) ? 0 : count;
2834                 if (*copied && iommu->dirty_page_tracking) {
2835                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
2836                         /*
2837                          * Bitmap populated with the smallest supported page
2838                          * size
2839                          */
2840                         bitmap_set(dma->bitmap, offset >> pgshift,
2841                                    *copied >> pgshift);
2842                 }
2843         } else
2844                 *copied = copy_from_user(data, (void __user *)vaddr,
2845                                            count) ? 0 : count;
2846         if (kthread)
2847                 unuse_mm(mm);
2848 out:
2849         mmput(mm);
2850         return *copied ? 0 : -EFAULT;
2851 }
2852
2853 static int vfio_iommu_type1_dma_rw(void *iommu_data, dma_addr_t user_iova,
2854                                    void *data, size_t count, bool write)
2855 {
2856         struct vfio_iommu *iommu = iommu_data;
2857         int ret = 0;
2858         size_t done;
2859
2860         mutex_lock(&iommu->lock);
2861         while (count > 0) {
2862                 ret = vfio_iommu_type1_dma_rw_chunk(iommu, user_iova, data,
2863                                                     count, write, &done);
2864                 if (ret)
2865                         break;
2866
2867                 count -= done;
2868                 data += done;
2869                 user_iova += done;
2870         }
2871
2872         mutex_unlock(&iommu->lock);
2873         return ret;
2874 }
2875
2876 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
2877         .name                   = "vfio-iommu-type1",
2878         .owner                  = THIS_MODULE,
2879         .open                   = vfio_iommu_type1_open,
2880         .release                = vfio_iommu_type1_release,
2881         .ioctl                  = vfio_iommu_type1_ioctl,
2882         .attach_group           = vfio_iommu_type1_attach_group,
2883         .detach_group           = vfio_iommu_type1_detach_group,
2884         .pin_pages              = vfio_iommu_type1_pin_pages,
2885         .unpin_pages            = vfio_iommu_type1_unpin_pages,
2886         .register_notifier      = vfio_iommu_type1_register_notifier,
2887         .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
2888         .dma_rw                 = vfio_iommu_type1_dma_rw,
2889 };
2890
2891 static int __init vfio_iommu_type1_init(void)
2892 {
2893         return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
2894 }
2895
2896 static void __exit vfio_iommu_type1_cleanup(void)
2897 {
2898         vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
2899 }
2900
2901 module_init(vfio_iommu_type1_init);
2902 module_exit(vfio_iommu_type1_cleanup);
2903
2904 MODULE_VERSION(DRIVER_VERSION);
2905 MODULE_LICENSE("GPL v2");
2906 MODULE_AUTHOR(DRIVER_AUTHOR);
2907 MODULE_DESCRIPTION(DRIVER_DESC);