ee6079847091027eed02fde1f64503a03592a6af
[linux-2.6-block.git] / drivers / iommu / iommufd / selftest.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * Kernel side components to support tools/testing/selftests/iommu
5  */
6 #include <linux/slab.h>
7 #include <linux/iommu.h>
8 #include <linux/xarray.h>
9 #include <linux/file.h>
10 #include <linux/anon_inodes.h>
11 #include <linux/fault-inject.h>
12 #include <linux/platform_device.h>
13 #include <uapi/linux/iommufd.h>
14
15 #include "../iommu-priv.h"
16 #include "io_pagetable.h"
17 #include "iommufd_private.h"
18 #include "iommufd_test.h"
19
20 static DECLARE_FAULT_ATTR(fail_iommufd);
21 static struct dentry *dbgfs_root;
22 static struct platform_device *selftest_iommu_dev;
23
24 size_t iommufd_test_memory_limit = 65536;
25
26 enum {
27         MOCK_IO_PAGE_SIZE = PAGE_SIZE / 2,
28
29         /*
30          * Like a real page table alignment requires the low bits of the address
31          * to be zero. xarray also requires the high bit to be zero, so we store
32          * the pfns shifted. The upper bits are used for metadata.
33          */
34         MOCK_PFN_MASK = ULONG_MAX / MOCK_IO_PAGE_SIZE,
35
36         _MOCK_PFN_START = MOCK_PFN_MASK + 1,
37         MOCK_PFN_START_IOVA = _MOCK_PFN_START,
38         MOCK_PFN_LAST_IOVA = _MOCK_PFN_START,
39 };
40
41 /*
42  * Syzkaller has trouble randomizing the correct iova to use since it is linked
43  * to the map ioctl's output, and it has no ide about that. So, simplify things.
44  * In syzkaller mode the 64 bit IOVA is converted into an nth area and offset
45  * value. This has a much smaller randomization space and syzkaller can hit it.
46  */
47 static unsigned long iommufd_test_syz_conv_iova(struct io_pagetable *iopt,
48                                                 u64 *iova)
49 {
50         struct syz_layout {
51                 __u32 nth_area;
52                 __u32 offset;
53         };
54         struct syz_layout *syz = (void *)iova;
55         unsigned int nth = syz->nth_area;
56         struct iopt_area *area;
57
58         down_read(&iopt->iova_rwsem);
59         for (area = iopt_area_iter_first(iopt, 0, ULONG_MAX); area;
60              area = iopt_area_iter_next(area, 0, ULONG_MAX)) {
61                 if (nth == 0) {
62                         up_read(&iopt->iova_rwsem);
63                         return iopt_area_iova(area) + syz->offset;
64                 }
65                 nth--;
66         }
67         up_read(&iopt->iova_rwsem);
68
69         return 0;
70 }
71
72 void iommufd_test_syz_conv_iova_id(struct iommufd_ucmd *ucmd,
73                                    unsigned int ioas_id, u64 *iova, u32 *flags)
74 {
75         struct iommufd_ioas *ioas;
76
77         if (!(*flags & MOCK_FLAGS_ACCESS_SYZ))
78                 return;
79         *flags &= ~(u32)MOCK_FLAGS_ACCESS_SYZ;
80
81         ioas = iommufd_get_ioas(ucmd->ictx, ioas_id);
82         if (IS_ERR(ioas))
83                 return;
84         *iova = iommufd_test_syz_conv_iova(&ioas->iopt, iova);
85         iommufd_put_object(&ioas->obj);
86 }
87
88 struct mock_iommu_domain {
89         struct iommu_domain domain;
90         struct xarray pfns;
91 };
92
93 enum selftest_obj_type {
94         TYPE_IDEV,
95 };
96
97 struct mock_dev {
98         struct device dev;
99 };
100
101 struct selftest_obj {
102         struct iommufd_object obj;
103         enum selftest_obj_type type;
104
105         union {
106                 struct {
107                         struct iommufd_device *idev;
108                         struct iommufd_ctx *ictx;
109                         struct mock_dev *mock_dev;
110                 } idev;
111         };
112 };
113
114 static int mock_domain_nop_attach(struct iommu_domain *domain,
115                                   struct device *dev)
116 {
117         return 0;
118 }
119
120 static const struct iommu_domain_ops mock_blocking_ops = {
121         .attach_dev = mock_domain_nop_attach,
122 };
123
124 static struct iommu_domain mock_blocking_domain = {
125         .type = IOMMU_DOMAIN_BLOCKED,
126         .ops = &mock_blocking_ops,
127 };
128
129 static void *mock_domain_hw_info(struct device *dev, u32 *length, u32 *type)
130 {
131         struct iommu_test_hw_info *info;
132
133         info = kzalloc(sizeof(*info), GFP_KERNEL);
134         if (!info)
135                 return ERR_PTR(-ENOMEM);
136
137         info->test_reg = IOMMU_HW_INFO_SELFTEST_REGVAL;
138         *length = sizeof(*info);
139         *type = IOMMU_HW_INFO_TYPE_SELFTEST;
140
141         return info;
142 }
143
144 static struct iommu_domain *mock_domain_alloc_paging(struct device *dev)
145 {
146         struct mock_iommu_domain *mock;
147
148         mock = kzalloc(sizeof(*mock), GFP_KERNEL);
149         if (!mock)
150                 return NULL;
151         mock->domain.geometry.aperture_start = MOCK_APERTURE_START;
152         mock->domain.geometry.aperture_end = MOCK_APERTURE_LAST;
153         mock->domain.pgsize_bitmap = MOCK_IO_PAGE_SIZE;
154         xa_init(&mock->pfns);
155         return &mock->domain;
156 }
157
158 static void mock_domain_free(struct iommu_domain *domain)
159 {
160         struct mock_iommu_domain *mock =
161                 container_of(domain, struct mock_iommu_domain, domain);
162
163         WARN_ON(!xa_empty(&mock->pfns));
164         kfree(mock);
165 }
166
167 static int mock_domain_map_pages(struct iommu_domain *domain,
168                                  unsigned long iova, phys_addr_t paddr,
169                                  size_t pgsize, size_t pgcount, int prot,
170                                  gfp_t gfp, size_t *mapped)
171 {
172         struct mock_iommu_domain *mock =
173                 container_of(domain, struct mock_iommu_domain, domain);
174         unsigned long flags = MOCK_PFN_START_IOVA;
175         unsigned long start_iova = iova;
176
177         /*
178          * xarray does not reliably work with fault injection because it does a
179          * retry allocation, so put our own failure point.
180          */
181         if (iommufd_should_fail())
182                 return -ENOENT;
183
184         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
185         WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
186         for (; pgcount; pgcount--) {
187                 size_t cur;
188
189                 for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
190                         void *old;
191
192                         if (pgcount == 1 && cur + MOCK_IO_PAGE_SIZE == pgsize)
193                                 flags = MOCK_PFN_LAST_IOVA;
194                         old = xa_store(&mock->pfns, iova / MOCK_IO_PAGE_SIZE,
195                                        xa_mk_value((paddr / MOCK_IO_PAGE_SIZE) |
196                                                    flags),
197                                        gfp);
198                         if (xa_is_err(old)) {
199                                 for (; start_iova != iova;
200                                      start_iova += MOCK_IO_PAGE_SIZE)
201                                         xa_erase(&mock->pfns,
202                                                  start_iova /
203                                                          MOCK_IO_PAGE_SIZE);
204                                 return xa_err(old);
205                         }
206                         WARN_ON(old);
207                         iova += MOCK_IO_PAGE_SIZE;
208                         paddr += MOCK_IO_PAGE_SIZE;
209                         *mapped += MOCK_IO_PAGE_SIZE;
210                         flags = 0;
211                 }
212         }
213         return 0;
214 }
215
216 static size_t mock_domain_unmap_pages(struct iommu_domain *domain,
217                                       unsigned long iova, size_t pgsize,
218                                       size_t pgcount,
219                                       struct iommu_iotlb_gather *iotlb_gather)
220 {
221         struct mock_iommu_domain *mock =
222                 container_of(domain, struct mock_iommu_domain, domain);
223         bool first = true;
224         size_t ret = 0;
225         void *ent;
226
227         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
228         WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
229
230         for (; pgcount; pgcount--) {
231                 size_t cur;
232
233                 for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
234                         ent = xa_erase(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
235                         WARN_ON(!ent);
236                         /*
237                          * iommufd generates unmaps that must be a strict
238                          * superset of the map's performend So every starting
239                          * IOVA should have been an iova passed to map, and the
240                          *
241                          * First IOVA must be present and have been a first IOVA
242                          * passed to map_pages
243                          */
244                         if (first) {
245                                 WARN_ON(!(xa_to_value(ent) &
246                                           MOCK_PFN_START_IOVA));
247                                 first = false;
248                         }
249                         if (pgcount == 1 && cur + MOCK_IO_PAGE_SIZE == pgsize)
250                                 WARN_ON(!(xa_to_value(ent) &
251                                           MOCK_PFN_LAST_IOVA));
252
253                         iova += MOCK_IO_PAGE_SIZE;
254                         ret += MOCK_IO_PAGE_SIZE;
255                 }
256         }
257         return ret;
258 }
259
260 static phys_addr_t mock_domain_iova_to_phys(struct iommu_domain *domain,
261                                             dma_addr_t iova)
262 {
263         struct mock_iommu_domain *mock =
264                 container_of(domain, struct mock_iommu_domain, domain);
265         void *ent;
266
267         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
268         ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
269         WARN_ON(!ent);
270         return (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE;
271 }
272
273 static bool mock_domain_capable(struct device *dev, enum iommu_cap cap)
274 {
275         return cap == IOMMU_CAP_CACHE_COHERENCY;
276 }
277
278 static struct iommu_device mock_iommu_device = {
279 };
280
281 static struct iommu_device *mock_probe_device(struct device *dev)
282 {
283         return &mock_iommu_device;
284 }
285
286 static const struct iommu_ops mock_ops = {
287         /*
288          * IOMMU_DOMAIN_BLOCKED cannot be returned from def_domain_type()
289          * because it is zero.
290          */
291         .default_domain = &mock_blocking_domain,
292         .blocked_domain = &mock_blocking_domain,
293         .owner = THIS_MODULE,
294         .pgsize_bitmap = MOCK_IO_PAGE_SIZE,
295         .hw_info = mock_domain_hw_info,
296         .domain_alloc_paging = mock_domain_alloc_paging,
297         .capable = mock_domain_capable,
298         .device_group = generic_device_group,
299         .probe_device = mock_probe_device,
300         .default_domain_ops =
301                 &(struct iommu_domain_ops){
302                         .free = mock_domain_free,
303                         .attach_dev = mock_domain_nop_attach,
304                         .map_pages = mock_domain_map_pages,
305                         .unmap_pages = mock_domain_unmap_pages,
306                         .iova_to_phys = mock_domain_iova_to_phys,
307                 },
308 };
309
310 static inline struct iommufd_hw_pagetable *
311 get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id,
312                  struct mock_iommu_domain **mock)
313 {
314         struct iommufd_hw_pagetable *hwpt;
315         struct iommufd_object *obj;
316
317         obj = iommufd_get_object(ucmd->ictx, mockpt_id,
318                                  IOMMUFD_OBJ_HW_PAGETABLE);
319         if (IS_ERR(obj))
320                 return ERR_CAST(obj);
321         hwpt = container_of(obj, struct iommufd_hw_pagetable, obj);
322         if (hwpt->domain->ops != mock_ops.default_domain_ops) {
323                 iommufd_put_object(&hwpt->obj);
324                 return ERR_PTR(-EINVAL);
325         }
326         *mock = container_of(hwpt->domain, struct mock_iommu_domain, domain);
327         return hwpt;
328 }
329
330 struct mock_bus_type {
331         struct bus_type bus;
332         struct notifier_block nb;
333 };
334
335 static struct mock_bus_type iommufd_mock_bus_type = {
336         .bus = {
337                 .name = "iommufd_mock",
338         },
339 };
340
341 static atomic_t mock_dev_num;
342
343 static void mock_dev_release(struct device *dev)
344 {
345         struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
346
347         atomic_dec(&mock_dev_num);
348         kfree(mdev);
349 }
350
351 static struct mock_dev *mock_dev_create(void)
352 {
353         struct mock_dev *mdev;
354         int rc;
355
356         mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
357         if (!mdev)
358                 return ERR_PTR(-ENOMEM);
359
360         device_initialize(&mdev->dev);
361         mdev->dev.release = mock_dev_release;
362         mdev->dev.bus = &iommufd_mock_bus_type.bus;
363
364         rc = dev_set_name(&mdev->dev, "iommufd_mock%u",
365                           atomic_inc_return(&mock_dev_num));
366         if (rc)
367                 goto err_put;
368
369         rc = device_add(&mdev->dev);
370         if (rc)
371                 goto err_put;
372         return mdev;
373
374 err_put:
375         put_device(&mdev->dev);
376         return ERR_PTR(rc);
377 }
378
379 static void mock_dev_destroy(struct mock_dev *mdev)
380 {
381         device_unregister(&mdev->dev);
382 }
383
384 bool iommufd_selftest_is_mock_dev(struct device *dev)
385 {
386         return dev->release == mock_dev_release;
387 }
388
389 /* Create an hw_pagetable with the mock domain so we can test the domain ops */
390 static int iommufd_test_mock_domain(struct iommufd_ucmd *ucmd,
391                                     struct iommu_test_cmd *cmd)
392 {
393         struct iommufd_device *idev;
394         struct selftest_obj *sobj;
395         u32 pt_id = cmd->id;
396         u32 idev_id;
397         int rc;
398
399         sobj = iommufd_object_alloc(ucmd->ictx, sobj, IOMMUFD_OBJ_SELFTEST);
400         if (IS_ERR(sobj))
401                 return PTR_ERR(sobj);
402
403         sobj->idev.ictx = ucmd->ictx;
404         sobj->type = TYPE_IDEV;
405
406         sobj->idev.mock_dev = mock_dev_create();
407         if (IS_ERR(sobj->idev.mock_dev)) {
408                 rc = PTR_ERR(sobj->idev.mock_dev);
409                 goto out_sobj;
410         }
411
412         idev = iommufd_device_bind(ucmd->ictx, &sobj->idev.mock_dev->dev,
413                                    &idev_id);
414         if (IS_ERR(idev)) {
415                 rc = PTR_ERR(idev);
416                 goto out_mdev;
417         }
418         sobj->idev.idev = idev;
419
420         rc = iommufd_device_attach(idev, &pt_id);
421         if (rc)
422                 goto out_unbind;
423
424         /* Userspace must destroy the device_id to destroy the object */
425         cmd->mock_domain.out_hwpt_id = pt_id;
426         cmd->mock_domain.out_stdev_id = sobj->obj.id;
427         cmd->mock_domain.out_idev_id = idev_id;
428         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
429         if (rc)
430                 goto out_detach;
431         iommufd_object_finalize(ucmd->ictx, &sobj->obj);
432         return 0;
433
434 out_detach:
435         iommufd_device_detach(idev);
436 out_unbind:
437         iommufd_device_unbind(idev);
438 out_mdev:
439         mock_dev_destroy(sobj->idev.mock_dev);
440 out_sobj:
441         iommufd_object_abort(ucmd->ictx, &sobj->obj);
442         return rc;
443 }
444
445 /* Replace the mock domain with a manually allocated hw_pagetable */
446 static int iommufd_test_mock_domain_replace(struct iommufd_ucmd *ucmd,
447                                             unsigned int device_id, u32 pt_id,
448                                             struct iommu_test_cmd *cmd)
449 {
450         struct iommufd_object *dev_obj;
451         struct selftest_obj *sobj;
452         int rc;
453
454         /*
455          * Prefer to use the OBJ_SELFTEST because the destroy_rwsem will ensure
456          * it doesn't race with detach, which is not allowed.
457          */
458         dev_obj =
459                 iommufd_get_object(ucmd->ictx, device_id, IOMMUFD_OBJ_SELFTEST);
460         if (IS_ERR(dev_obj))
461                 return PTR_ERR(dev_obj);
462
463         sobj = container_of(dev_obj, struct selftest_obj, obj);
464         if (sobj->type != TYPE_IDEV) {
465                 rc = -EINVAL;
466                 goto out_dev_obj;
467         }
468
469         rc = iommufd_device_replace(sobj->idev.idev, &pt_id);
470         if (rc)
471                 goto out_dev_obj;
472
473         cmd->mock_domain_replace.pt_id = pt_id;
474         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
475
476 out_dev_obj:
477         iommufd_put_object(dev_obj);
478         return rc;
479 }
480
481 /* Add an additional reserved IOVA to the IOAS */
482 static int iommufd_test_add_reserved(struct iommufd_ucmd *ucmd,
483                                      unsigned int mockpt_id,
484                                      unsigned long start, size_t length)
485 {
486         struct iommufd_ioas *ioas;
487         int rc;
488
489         ioas = iommufd_get_ioas(ucmd->ictx, mockpt_id);
490         if (IS_ERR(ioas))
491                 return PTR_ERR(ioas);
492         down_write(&ioas->iopt.iova_rwsem);
493         rc = iopt_reserve_iova(&ioas->iopt, start, start + length - 1, NULL);
494         up_write(&ioas->iopt.iova_rwsem);
495         iommufd_put_object(&ioas->obj);
496         return rc;
497 }
498
499 /* Check that every pfn under each iova matches the pfn under a user VA */
500 static int iommufd_test_md_check_pa(struct iommufd_ucmd *ucmd,
501                                     unsigned int mockpt_id, unsigned long iova,
502                                     size_t length, void __user *uptr)
503 {
504         struct iommufd_hw_pagetable *hwpt;
505         struct mock_iommu_domain *mock;
506         uintptr_t end;
507         int rc;
508
509         if (iova % MOCK_IO_PAGE_SIZE || length % MOCK_IO_PAGE_SIZE ||
510             (uintptr_t)uptr % MOCK_IO_PAGE_SIZE ||
511             check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
512                 return -EINVAL;
513
514         hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
515         if (IS_ERR(hwpt))
516                 return PTR_ERR(hwpt);
517
518         for (; length; length -= MOCK_IO_PAGE_SIZE) {
519                 struct page *pages[1];
520                 unsigned long pfn;
521                 long npages;
522                 void *ent;
523
524                 npages = get_user_pages_fast((uintptr_t)uptr & PAGE_MASK, 1, 0,
525                                              pages);
526                 if (npages < 0) {
527                         rc = npages;
528                         goto out_put;
529                 }
530                 if (WARN_ON(npages != 1)) {
531                         rc = -EFAULT;
532                         goto out_put;
533                 }
534                 pfn = page_to_pfn(pages[0]);
535                 put_page(pages[0]);
536
537                 ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
538                 if (!ent ||
539                     (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE !=
540                             pfn * PAGE_SIZE + ((uintptr_t)uptr % PAGE_SIZE)) {
541                         rc = -EINVAL;
542                         goto out_put;
543                 }
544                 iova += MOCK_IO_PAGE_SIZE;
545                 uptr += MOCK_IO_PAGE_SIZE;
546         }
547         rc = 0;
548
549 out_put:
550         iommufd_put_object(&hwpt->obj);
551         return rc;
552 }
553
554 /* Check that the page ref count matches, to look for missing pin/unpins */
555 static int iommufd_test_md_check_refs(struct iommufd_ucmd *ucmd,
556                                       void __user *uptr, size_t length,
557                                       unsigned int refs)
558 {
559         uintptr_t end;
560
561         if (length % PAGE_SIZE || (uintptr_t)uptr % PAGE_SIZE ||
562             check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
563                 return -EINVAL;
564
565         for (; length; length -= PAGE_SIZE) {
566                 struct page *pages[1];
567                 long npages;
568
569                 npages = get_user_pages_fast((uintptr_t)uptr, 1, 0, pages);
570                 if (npages < 0)
571                         return npages;
572                 if (WARN_ON(npages != 1))
573                         return -EFAULT;
574                 if (!PageCompound(pages[0])) {
575                         unsigned int count;
576
577                         count = page_ref_count(pages[0]);
578                         if (count / GUP_PIN_COUNTING_BIAS != refs) {
579                                 put_page(pages[0]);
580                                 return -EIO;
581                         }
582                 }
583                 put_page(pages[0]);
584                 uptr += PAGE_SIZE;
585         }
586         return 0;
587 }
588
589 struct selftest_access {
590         struct iommufd_access *access;
591         struct file *file;
592         struct mutex lock;
593         struct list_head items;
594         unsigned int next_id;
595         bool destroying;
596 };
597
598 struct selftest_access_item {
599         struct list_head items_elm;
600         unsigned long iova;
601         size_t length;
602         unsigned int id;
603 };
604
605 static const struct file_operations iommfd_test_staccess_fops;
606
607 static struct selftest_access *iommufd_access_get(int fd)
608 {
609         struct file *file;
610
611         file = fget(fd);
612         if (!file)
613                 return ERR_PTR(-EBADFD);
614
615         if (file->f_op != &iommfd_test_staccess_fops) {
616                 fput(file);
617                 return ERR_PTR(-EBADFD);
618         }
619         return file->private_data;
620 }
621
622 static void iommufd_test_access_unmap(void *data, unsigned long iova,
623                                       unsigned long length)
624 {
625         unsigned long iova_last = iova + length - 1;
626         struct selftest_access *staccess = data;
627         struct selftest_access_item *item;
628         struct selftest_access_item *tmp;
629
630         mutex_lock(&staccess->lock);
631         list_for_each_entry_safe(item, tmp, &staccess->items, items_elm) {
632                 if (iova > item->iova + item->length - 1 ||
633                     iova_last < item->iova)
634                         continue;
635                 list_del(&item->items_elm);
636                 iommufd_access_unpin_pages(staccess->access, item->iova,
637                                            item->length);
638                 kfree(item);
639         }
640         mutex_unlock(&staccess->lock);
641 }
642
643 static int iommufd_test_access_item_destroy(struct iommufd_ucmd *ucmd,
644                                             unsigned int access_id,
645                                             unsigned int item_id)
646 {
647         struct selftest_access_item *item;
648         struct selftest_access *staccess;
649
650         staccess = iommufd_access_get(access_id);
651         if (IS_ERR(staccess))
652                 return PTR_ERR(staccess);
653
654         mutex_lock(&staccess->lock);
655         list_for_each_entry(item, &staccess->items, items_elm) {
656                 if (item->id == item_id) {
657                         list_del(&item->items_elm);
658                         iommufd_access_unpin_pages(staccess->access, item->iova,
659                                                    item->length);
660                         mutex_unlock(&staccess->lock);
661                         kfree(item);
662                         fput(staccess->file);
663                         return 0;
664                 }
665         }
666         mutex_unlock(&staccess->lock);
667         fput(staccess->file);
668         return -ENOENT;
669 }
670
671 static int iommufd_test_staccess_release(struct inode *inode,
672                                          struct file *filep)
673 {
674         struct selftest_access *staccess = filep->private_data;
675
676         if (staccess->access) {
677                 iommufd_test_access_unmap(staccess, 0, ULONG_MAX);
678                 iommufd_access_destroy(staccess->access);
679         }
680         mutex_destroy(&staccess->lock);
681         kfree(staccess);
682         return 0;
683 }
684
685 static const struct iommufd_access_ops selftest_access_ops_pin = {
686         .needs_pin_pages = 1,
687         .unmap = iommufd_test_access_unmap,
688 };
689
690 static const struct iommufd_access_ops selftest_access_ops = {
691         .unmap = iommufd_test_access_unmap,
692 };
693
694 static const struct file_operations iommfd_test_staccess_fops = {
695         .release = iommufd_test_staccess_release,
696 };
697
698 static struct selftest_access *iommufd_test_alloc_access(void)
699 {
700         struct selftest_access *staccess;
701         struct file *filep;
702
703         staccess = kzalloc(sizeof(*staccess), GFP_KERNEL_ACCOUNT);
704         if (!staccess)
705                 return ERR_PTR(-ENOMEM);
706         INIT_LIST_HEAD(&staccess->items);
707         mutex_init(&staccess->lock);
708
709         filep = anon_inode_getfile("[iommufd_test_staccess]",
710                                    &iommfd_test_staccess_fops, staccess,
711                                    O_RDWR);
712         if (IS_ERR(filep)) {
713                 kfree(staccess);
714                 return ERR_CAST(filep);
715         }
716         staccess->file = filep;
717         return staccess;
718 }
719
720 static int iommufd_test_create_access(struct iommufd_ucmd *ucmd,
721                                       unsigned int ioas_id, unsigned int flags)
722 {
723         struct iommu_test_cmd *cmd = ucmd->cmd;
724         struct selftest_access *staccess;
725         struct iommufd_access *access;
726         u32 id;
727         int fdno;
728         int rc;
729
730         if (flags & ~MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES)
731                 return -EOPNOTSUPP;
732
733         staccess = iommufd_test_alloc_access();
734         if (IS_ERR(staccess))
735                 return PTR_ERR(staccess);
736
737         fdno = get_unused_fd_flags(O_CLOEXEC);
738         if (fdno < 0) {
739                 rc = -ENOMEM;
740                 goto out_free_staccess;
741         }
742
743         access = iommufd_access_create(
744                 ucmd->ictx,
745                 (flags & MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES) ?
746                         &selftest_access_ops_pin :
747                         &selftest_access_ops,
748                 staccess, &id);
749         if (IS_ERR(access)) {
750                 rc = PTR_ERR(access);
751                 goto out_put_fdno;
752         }
753         rc = iommufd_access_attach(access, ioas_id);
754         if (rc)
755                 goto out_destroy;
756         cmd->create_access.out_access_fd = fdno;
757         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
758         if (rc)
759                 goto out_destroy;
760
761         staccess->access = access;
762         fd_install(fdno, staccess->file);
763         return 0;
764
765 out_destroy:
766         iommufd_access_destroy(access);
767 out_put_fdno:
768         put_unused_fd(fdno);
769 out_free_staccess:
770         fput(staccess->file);
771         return rc;
772 }
773
774 static int iommufd_test_access_replace_ioas(struct iommufd_ucmd *ucmd,
775                                             unsigned int access_id,
776                                             unsigned int ioas_id)
777 {
778         struct selftest_access *staccess;
779         int rc;
780
781         staccess = iommufd_access_get(access_id);
782         if (IS_ERR(staccess))
783                 return PTR_ERR(staccess);
784
785         rc = iommufd_access_replace(staccess->access, ioas_id);
786         fput(staccess->file);
787         return rc;
788 }
789
790 /* Check that the pages in a page array match the pages in the user VA */
791 static int iommufd_test_check_pages(void __user *uptr, struct page **pages,
792                                     size_t npages)
793 {
794         for (; npages; npages--) {
795                 struct page *tmp_pages[1];
796                 long rc;
797
798                 rc = get_user_pages_fast((uintptr_t)uptr, 1, 0, tmp_pages);
799                 if (rc < 0)
800                         return rc;
801                 if (WARN_ON(rc != 1))
802                         return -EFAULT;
803                 put_page(tmp_pages[0]);
804                 if (tmp_pages[0] != *pages)
805                         return -EBADE;
806                 pages++;
807                 uptr += PAGE_SIZE;
808         }
809         return 0;
810 }
811
812 static int iommufd_test_access_pages(struct iommufd_ucmd *ucmd,
813                                      unsigned int access_id, unsigned long iova,
814                                      size_t length, void __user *uptr,
815                                      u32 flags)
816 {
817         struct iommu_test_cmd *cmd = ucmd->cmd;
818         struct selftest_access_item *item;
819         struct selftest_access *staccess;
820         struct page **pages;
821         size_t npages;
822         int rc;
823
824         /* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
825         if (length > 16*1024*1024)
826                 return -ENOMEM;
827
828         if (flags & ~(MOCK_FLAGS_ACCESS_WRITE | MOCK_FLAGS_ACCESS_SYZ))
829                 return -EOPNOTSUPP;
830
831         staccess = iommufd_access_get(access_id);
832         if (IS_ERR(staccess))
833                 return PTR_ERR(staccess);
834
835         if (staccess->access->ops != &selftest_access_ops_pin) {
836                 rc = -EOPNOTSUPP;
837                 goto out_put;
838         }
839
840         if (flags & MOCK_FLAGS_ACCESS_SYZ)
841                 iova = iommufd_test_syz_conv_iova(&staccess->access->ioas->iopt,
842                                         &cmd->access_pages.iova);
843
844         npages = (ALIGN(iova + length, PAGE_SIZE) -
845                   ALIGN_DOWN(iova, PAGE_SIZE)) /
846                  PAGE_SIZE;
847         pages = kvcalloc(npages, sizeof(*pages), GFP_KERNEL_ACCOUNT);
848         if (!pages) {
849                 rc = -ENOMEM;
850                 goto out_put;
851         }
852
853         /*
854          * Drivers will need to think very carefully about this locking. The
855          * core code can do multiple unmaps instantaneously after
856          * iommufd_access_pin_pages() and *all* the unmaps must not return until
857          * the range is unpinned. This simple implementation puts a global lock
858          * around the pin, which may not suit drivers that want this to be a
859          * performance path. drivers that get this wrong will trigger WARN_ON
860          * races and cause EDEADLOCK failures to userspace.
861          */
862         mutex_lock(&staccess->lock);
863         rc = iommufd_access_pin_pages(staccess->access, iova, length, pages,
864                                       flags & MOCK_FLAGS_ACCESS_WRITE);
865         if (rc)
866                 goto out_unlock;
867
868         /* For syzkaller allow uptr to be NULL to skip this check */
869         if (uptr) {
870                 rc = iommufd_test_check_pages(
871                         uptr - (iova - ALIGN_DOWN(iova, PAGE_SIZE)), pages,
872                         npages);
873                 if (rc)
874                         goto out_unaccess;
875         }
876
877         item = kzalloc(sizeof(*item), GFP_KERNEL_ACCOUNT);
878         if (!item) {
879                 rc = -ENOMEM;
880                 goto out_unaccess;
881         }
882
883         item->iova = iova;
884         item->length = length;
885         item->id = staccess->next_id++;
886         list_add_tail(&item->items_elm, &staccess->items);
887
888         cmd->access_pages.out_access_pages_id = item->id;
889         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
890         if (rc)
891                 goto out_free_item;
892         goto out_unlock;
893
894 out_free_item:
895         list_del(&item->items_elm);
896         kfree(item);
897 out_unaccess:
898         iommufd_access_unpin_pages(staccess->access, iova, length);
899 out_unlock:
900         mutex_unlock(&staccess->lock);
901         kvfree(pages);
902 out_put:
903         fput(staccess->file);
904         return rc;
905 }
906
907 static int iommufd_test_access_rw(struct iommufd_ucmd *ucmd,
908                                   unsigned int access_id, unsigned long iova,
909                                   size_t length, void __user *ubuf,
910                                   unsigned int flags)
911 {
912         struct iommu_test_cmd *cmd = ucmd->cmd;
913         struct selftest_access *staccess;
914         void *tmp;
915         int rc;
916
917         /* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
918         if (length > 16*1024*1024)
919                 return -ENOMEM;
920
921         if (flags & ~(MOCK_ACCESS_RW_WRITE | MOCK_ACCESS_RW_SLOW_PATH |
922                       MOCK_FLAGS_ACCESS_SYZ))
923                 return -EOPNOTSUPP;
924
925         staccess = iommufd_access_get(access_id);
926         if (IS_ERR(staccess))
927                 return PTR_ERR(staccess);
928
929         tmp = kvzalloc(length, GFP_KERNEL_ACCOUNT);
930         if (!tmp) {
931                 rc = -ENOMEM;
932                 goto out_put;
933         }
934
935         if (flags & MOCK_ACCESS_RW_WRITE) {
936                 if (copy_from_user(tmp, ubuf, length)) {
937                         rc = -EFAULT;
938                         goto out_free;
939                 }
940         }
941
942         if (flags & MOCK_FLAGS_ACCESS_SYZ)
943                 iova = iommufd_test_syz_conv_iova(&staccess->access->ioas->iopt,
944                                         &cmd->access_rw.iova);
945
946         rc = iommufd_access_rw(staccess->access, iova, tmp, length, flags);
947         if (rc)
948                 goto out_free;
949         if (!(flags & MOCK_ACCESS_RW_WRITE)) {
950                 if (copy_to_user(ubuf, tmp, length)) {
951                         rc = -EFAULT;
952                         goto out_free;
953                 }
954         }
955
956 out_free:
957         kvfree(tmp);
958 out_put:
959         fput(staccess->file);
960         return rc;
961 }
962 static_assert((unsigned int)MOCK_ACCESS_RW_WRITE == IOMMUFD_ACCESS_RW_WRITE);
963 static_assert((unsigned int)MOCK_ACCESS_RW_SLOW_PATH ==
964               __IOMMUFD_ACCESS_RW_SLOW_PATH);
965
966 void iommufd_selftest_destroy(struct iommufd_object *obj)
967 {
968         struct selftest_obj *sobj = container_of(obj, struct selftest_obj, obj);
969
970         switch (sobj->type) {
971         case TYPE_IDEV:
972                 iommufd_device_detach(sobj->idev.idev);
973                 iommufd_device_unbind(sobj->idev.idev);
974                 mock_dev_destroy(sobj->idev.mock_dev);
975                 break;
976         }
977 }
978
979 int iommufd_test(struct iommufd_ucmd *ucmd)
980 {
981         struct iommu_test_cmd *cmd = ucmd->cmd;
982
983         switch (cmd->op) {
984         case IOMMU_TEST_OP_ADD_RESERVED:
985                 return iommufd_test_add_reserved(ucmd, cmd->id,
986                                                  cmd->add_reserved.start,
987                                                  cmd->add_reserved.length);
988         case IOMMU_TEST_OP_MOCK_DOMAIN:
989                 return iommufd_test_mock_domain(ucmd, cmd);
990         case IOMMU_TEST_OP_MOCK_DOMAIN_REPLACE:
991                 return iommufd_test_mock_domain_replace(
992                         ucmd, cmd->id, cmd->mock_domain_replace.pt_id, cmd);
993         case IOMMU_TEST_OP_MD_CHECK_MAP:
994                 return iommufd_test_md_check_pa(
995                         ucmd, cmd->id, cmd->check_map.iova,
996                         cmd->check_map.length,
997                         u64_to_user_ptr(cmd->check_map.uptr));
998         case IOMMU_TEST_OP_MD_CHECK_REFS:
999                 return iommufd_test_md_check_refs(
1000                         ucmd, u64_to_user_ptr(cmd->check_refs.uptr),
1001                         cmd->check_refs.length, cmd->check_refs.refs);
1002         case IOMMU_TEST_OP_CREATE_ACCESS:
1003                 return iommufd_test_create_access(ucmd, cmd->id,
1004                                                   cmd->create_access.flags);
1005         case IOMMU_TEST_OP_ACCESS_REPLACE_IOAS:
1006                 return iommufd_test_access_replace_ioas(
1007                         ucmd, cmd->id, cmd->access_replace_ioas.ioas_id);
1008         case IOMMU_TEST_OP_ACCESS_PAGES:
1009                 return iommufd_test_access_pages(
1010                         ucmd, cmd->id, cmd->access_pages.iova,
1011                         cmd->access_pages.length,
1012                         u64_to_user_ptr(cmd->access_pages.uptr),
1013                         cmd->access_pages.flags);
1014         case IOMMU_TEST_OP_ACCESS_RW:
1015                 return iommufd_test_access_rw(
1016                         ucmd, cmd->id, cmd->access_rw.iova,
1017                         cmd->access_rw.length,
1018                         u64_to_user_ptr(cmd->access_rw.uptr),
1019                         cmd->access_rw.flags);
1020         case IOMMU_TEST_OP_DESTROY_ACCESS_PAGES:
1021                 return iommufd_test_access_item_destroy(
1022                         ucmd, cmd->id, cmd->destroy_access_pages.access_pages_id);
1023         case IOMMU_TEST_OP_SET_TEMP_MEMORY_LIMIT:
1024                 /* Protect _batch_init(), can not be less than elmsz */
1025                 if (cmd->memory_limit.limit <
1026                     sizeof(unsigned long) + sizeof(u32))
1027                         return -EINVAL;
1028                 iommufd_test_memory_limit = cmd->memory_limit.limit;
1029                 return 0;
1030         default:
1031                 return -EOPNOTSUPP;
1032         }
1033 }
1034
1035 bool iommufd_should_fail(void)
1036 {
1037         return should_fail(&fail_iommufd, 1);
1038 }
1039
1040 int __init iommufd_test_init(void)
1041 {
1042         struct platform_device_info pdevinfo = {
1043                 .name = "iommufd_selftest_iommu",
1044         };
1045         int rc;
1046
1047         dbgfs_root =
1048                 fault_create_debugfs_attr("fail_iommufd", NULL, &fail_iommufd);
1049
1050         selftest_iommu_dev = platform_device_register_full(&pdevinfo);
1051         if (IS_ERR(selftest_iommu_dev)) {
1052                 rc = PTR_ERR(selftest_iommu_dev);
1053                 goto err_dbgfs;
1054         }
1055
1056         rc = bus_register(&iommufd_mock_bus_type.bus);
1057         if (rc)
1058                 goto err_platform;
1059
1060         rc = iommu_device_sysfs_add(&mock_iommu_device,
1061                                     &selftest_iommu_dev->dev, NULL, "%s",
1062                                     dev_name(&selftest_iommu_dev->dev));
1063         if (rc)
1064                 goto err_bus;
1065
1066         rc = iommu_device_register_bus(&mock_iommu_device, &mock_ops,
1067                                   &iommufd_mock_bus_type.bus,
1068                                   &iommufd_mock_bus_type.nb);
1069         if (rc)
1070                 goto err_sysfs;
1071         return 0;
1072
1073 err_sysfs:
1074         iommu_device_sysfs_remove(&mock_iommu_device);
1075 err_bus:
1076         bus_unregister(&iommufd_mock_bus_type.bus);
1077 err_platform:
1078         platform_device_unregister(selftest_iommu_dev);
1079 err_dbgfs:
1080         debugfs_remove_recursive(dbgfs_root);
1081         return rc;
1082 }
1083
1084 void iommufd_test_exit(void)
1085 {
1086         iommu_device_sysfs_remove(&mock_iommu_device);
1087         iommu_device_unregister_bus(&mock_iommu_device,
1088                                     &iommufd_mock_bus_type.bus,
1089                                     &iommufd_mock_bus_type.nb);
1090         bus_unregister(&iommufd_mock_bus_type.bus);
1091         platform_device_unregister(selftest_iommu_dev);
1092         debugfs_remove_recursive(dbgfs_root);
1093 }