mm/hmm: clean up some coding style and comments
[linux-2.6-block.git] / mm / hmm.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Copyright 2013 Red Hat Inc.
4  *
5  * Authors: Jérôme Glisse <jglisse@redhat.com>
6  */
7 /*
8  * Refer to include/linux/hmm.h for information about heterogeneous memory
9  * management or HMM for short.
10  */
11 #include <linux/mm.h>
12 #include <linux/hmm.h>
13 #include <linux/init.h>
14 #include <linux/rmap.h>
15 #include <linux/swap.h>
16 #include <linux/slab.h>
17 #include <linux/sched.h>
18 #include <linux/mmzone.h>
19 #include <linux/pagemap.h>
20 #include <linux/swapops.h>
21 #include <linux/hugetlb.h>
22 #include <linux/memremap.h>
23 #include <linux/jump_label.h>
24 #include <linux/dma-mapping.h>
25 #include <linux/mmu_notifier.h>
26 #include <linux/memory_hotplug.h>
27
28 #define PA_SECTION_SIZE (1UL << PA_SECTION_SHIFT)
29
30 #if IS_ENABLED(CONFIG_HMM_MIRROR)
31 static const struct mmu_notifier_ops hmm_mmu_notifier_ops;
32
33 static inline struct hmm *mm_get_hmm(struct mm_struct *mm)
34 {
35         struct hmm *hmm = READ_ONCE(mm->hmm);
36
37         if (hmm && kref_get_unless_zero(&hmm->kref))
38                 return hmm;
39
40         return NULL;
41 }
42
43 /**
44  * hmm_get_or_create - register HMM against an mm (HMM internal)
45  *
46  * @mm: mm struct to attach to
47  * Returns: returns an HMM object, either by referencing the existing
48  *          (per-process) object, or by creating a new one.
49  *
50  * This is not intended to be used directly by device drivers. If mm already
51  * has an HMM struct then it get a reference on it and returns it. Otherwise
52  * it allocates an HMM struct, initializes it, associate it with the mm and
53  * returns it.
54  */
55 static struct hmm *hmm_get_or_create(struct mm_struct *mm)
56 {
57         struct hmm *hmm = mm_get_hmm(mm);
58         bool cleanup = false;
59
60         if (hmm)
61                 return hmm;
62
63         hmm = kmalloc(sizeof(*hmm), GFP_KERNEL);
64         if (!hmm)
65                 return NULL;
66         init_waitqueue_head(&hmm->wq);
67         INIT_LIST_HEAD(&hmm->mirrors);
68         init_rwsem(&hmm->mirrors_sem);
69         hmm->mmu_notifier.ops = NULL;
70         INIT_LIST_HEAD(&hmm->ranges);
71         mutex_init(&hmm->lock);
72         kref_init(&hmm->kref);
73         hmm->notifiers = 0;
74         hmm->dead = false;
75         hmm->mm = mm;
76
77         spin_lock(&mm->page_table_lock);
78         if (!mm->hmm)
79                 mm->hmm = hmm;
80         else
81                 cleanup = true;
82         spin_unlock(&mm->page_table_lock);
83
84         if (cleanup)
85                 goto error;
86
87         /*
88          * We should only get here if hold the mmap_sem in write mode ie on
89          * registration of first mirror through hmm_mirror_register()
90          */
91         hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops;
92         if (__mmu_notifier_register(&hmm->mmu_notifier, mm))
93                 goto error_mm;
94
95         return hmm;
96
97 error_mm:
98         spin_lock(&mm->page_table_lock);
99         if (mm->hmm == hmm)
100                 mm->hmm = NULL;
101         spin_unlock(&mm->page_table_lock);
102 error:
103         kfree(hmm);
104         return NULL;
105 }
106
107 static void hmm_free(struct kref *kref)
108 {
109         struct hmm *hmm = container_of(kref, struct hmm, kref);
110         struct mm_struct *mm = hmm->mm;
111
112         mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm);
113
114         spin_lock(&mm->page_table_lock);
115         if (mm->hmm == hmm)
116                 mm->hmm = NULL;
117         spin_unlock(&mm->page_table_lock);
118
119         kfree(hmm);
120 }
121
122 static inline void hmm_put(struct hmm *hmm)
123 {
124         kref_put(&hmm->kref, hmm_free);
125 }
126
127 void hmm_mm_destroy(struct mm_struct *mm)
128 {
129         struct hmm *hmm;
130
131         spin_lock(&mm->page_table_lock);
132         hmm = mm_get_hmm(mm);
133         mm->hmm = NULL;
134         if (hmm) {
135                 hmm->mm = NULL;
136                 hmm->dead = true;
137                 spin_unlock(&mm->page_table_lock);
138                 hmm_put(hmm);
139                 return;
140         }
141
142         spin_unlock(&mm->page_table_lock);
143 }
144
145 static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
146 {
147         struct hmm *hmm = mm_get_hmm(mm);
148         struct hmm_mirror *mirror;
149         struct hmm_range *range;
150
151         /* Report this HMM as dying. */
152         hmm->dead = true;
153
154         /* Wake-up everyone waiting on any range. */
155         mutex_lock(&hmm->lock);
156         list_for_each_entry(range, &hmm->ranges, list)
157                 range->valid = false;
158         wake_up_all(&hmm->wq);
159         mutex_unlock(&hmm->lock);
160
161         down_write(&hmm->mirrors_sem);
162         mirror = list_first_entry_or_null(&hmm->mirrors, struct hmm_mirror,
163                                           list);
164         while (mirror) {
165                 list_del_init(&mirror->list);
166                 if (mirror->ops->release) {
167                         /*
168                          * Drop mirrors_sem so the release callback can wait
169                          * on any pending work that might itself trigger a
170                          * mmu_notifier callback and thus would deadlock with
171                          * us.
172                          */
173                         up_write(&hmm->mirrors_sem);
174                         mirror->ops->release(mirror);
175                         down_write(&hmm->mirrors_sem);
176                 }
177                 mirror = list_first_entry_or_null(&hmm->mirrors,
178                                                   struct hmm_mirror, list);
179         }
180         up_write(&hmm->mirrors_sem);
181
182         hmm_put(hmm);
183 }
184
185 static int hmm_invalidate_range_start(struct mmu_notifier *mn,
186                         const struct mmu_notifier_range *nrange)
187 {
188         struct hmm *hmm = mm_get_hmm(nrange->mm);
189         struct hmm_mirror *mirror;
190         struct hmm_update update;
191         struct hmm_range *range;
192         int ret = 0;
193
194         VM_BUG_ON(!hmm);
195
196         update.start = nrange->start;
197         update.end = nrange->end;
198         update.event = HMM_UPDATE_INVALIDATE;
199         update.blockable = mmu_notifier_range_blockable(nrange);
200
201         if (mmu_notifier_range_blockable(nrange))
202                 mutex_lock(&hmm->lock);
203         else if (!mutex_trylock(&hmm->lock)) {
204                 ret = -EAGAIN;
205                 goto out;
206         }
207         hmm->notifiers++;
208         list_for_each_entry(range, &hmm->ranges, list) {
209                 if (update.end < range->start || update.start >= range->end)
210                         continue;
211
212                 range->valid = false;
213         }
214         mutex_unlock(&hmm->lock);
215
216         if (mmu_notifier_range_blockable(nrange))
217                 down_read(&hmm->mirrors_sem);
218         else if (!down_read_trylock(&hmm->mirrors_sem)) {
219                 ret = -EAGAIN;
220                 goto out;
221         }
222         list_for_each_entry(mirror, &hmm->mirrors, list) {
223                 int ret;
224
225                 ret = mirror->ops->sync_cpu_device_pagetables(mirror, &update);
226                 if (!update.blockable && ret == -EAGAIN)
227                         break;
228         }
229         up_read(&hmm->mirrors_sem);
230
231 out:
232         hmm_put(hmm);
233         return ret;
234 }
235
236 static void hmm_invalidate_range_end(struct mmu_notifier *mn,
237                         const struct mmu_notifier_range *nrange)
238 {
239         struct hmm *hmm = mm_get_hmm(nrange->mm);
240
241         VM_BUG_ON(!hmm);
242
243         mutex_lock(&hmm->lock);
244         hmm->notifiers--;
245         if (!hmm->notifiers) {
246                 struct hmm_range *range;
247
248                 list_for_each_entry(range, &hmm->ranges, list) {
249                         if (range->valid)
250                                 continue;
251                         range->valid = true;
252                 }
253                 wake_up_all(&hmm->wq);
254         }
255         mutex_unlock(&hmm->lock);
256
257         hmm_put(hmm);
258 }
259
260 static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
261         .release                = hmm_release,
262         .invalidate_range_start = hmm_invalidate_range_start,
263         .invalidate_range_end   = hmm_invalidate_range_end,
264 };
265
266 /*
267  * hmm_mirror_register() - register a mirror against an mm
268  *
269  * @mirror: new mirror struct to register
270  * @mm: mm to register against
271  * Return: 0 on success, -ENOMEM if no memory, -EINVAL if invalid arguments
272  *
273  * To start mirroring a process address space, the device driver must register
274  * an HMM mirror struct.
275  *
276  * THE mm->mmap_sem MUST BE HELD IN WRITE MODE !
277  */
278 int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm)
279 {
280         /* Sanity check */
281         if (!mm || !mirror || !mirror->ops)
282                 return -EINVAL;
283
284         mirror->hmm = hmm_get_or_create(mm);
285         if (!mirror->hmm)
286                 return -ENOMEM;
287
288         down_write(&mirror->hmm->mirrors_sem);
289         list_add(&mirror->list, &mirror->hmm->mirrors);
290         up_write(&mirror->hmm->mirrors_sem);
291
292         return 0;
293 }
294 EXPORT_SYMBOL(hmm_mirror_register);
295
296 /*
297  * hmm_mirror_unregister() - unregister a mirror
298  *
299  * @mirror: mirror struct to unregister
300  *
301  * Stop mirroring a process address space, and cleanup.
302  */
303 void hmm_mirror_unregister(struct hmm_mirror *mirror)
304 {
305         struct hmm *hmm = READ_ONCE(mirror->hmm);
306
307         if (hmm == NULL)
308                 return;
309
310         down_write(&hmm->mirrors_sem);
311         list_del_init(&mirror->list);
312         /* To protect us against double unregister ... */
313         mirror->hmm = NULL;
314         up_write(&hmm->mirrors_sem);
315
316         hmm_put(hmm);
317 }
318 EXPORT_SYMBOL(hmm_mirror_unregister);
319
320 struct hmm_vma_walk {
321         struct hmm_range        *range;
322         struct dev_pagemap      *pgmap;
323         unsigned long           last;
324         bool                    fault;
325         bool                    block;
326 };
327
328 static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
329                             bool write_fault, uint64_t *pfn)
330 {
331         unsigned int flags = FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_REMOTE;
332         struct hmm_vma_walk *hmm_vma_walk = walk->private;
333         struct hmm_range *range = hmm_vma_walk->range;
334         struct vm_area_struct *vma = walk->vma;
335         vm_fault_t ret;
336
337         flags |= hmm_vma_walk->block ? 0 : FAULT_FLAG_ALLOW_RETRY;
338         flags |= write_fault ? FAULT_FLAG_WRITE : 0;
339         ret = handle_mm_fault(vma, addr, flags);
340         if (ret & VM_FAULT_RETRY)
341                 return -EAGAIN;
342         if (ret & VM_FAULT_ERROR) {
343                 *pfn = range->values[HMM_PFN_ERROR];
344                 return -EFAULT;
345         }
346
347         return -EBUSY;
348 }
349
350 static int hmm_pfns_bad(unsigned long addr,
351                         unsigned long end,
352                         struct mm_walk *walk)
353 {
354         struct hmm_vma_walk *hmm_vma_walk = walk->private;
355         struct hmm_range *range = hmm_vma_walk->range;
356         uint64_t *pfns = range->pfns;
357         unsigned long i;
358
359         i = (addr - range->start) >> PAGE_SHIFT;
360         for (; addr < end; addr += PAGE_SIZE, i++)
361                 pfns[i] = range->values[HMM_PFN_ERROR];
362
363         return 0;
364 }
365
366 /*
367  * hmm_vma_walk_hole() - handle a range lacking valid pmd or pte(s)
368  * @start: range virtual start address (inclusive)
369  * @end: range virtual end address (exclusive)
370  * @fault: should we fault or not ?
371  * @write_fault: write fault ?
372  * @walk: mm_walk structure
373  * Return: 0 on success, -EBUSY after page fault, or page fault error
374  *
375  * This function will be called whenever pmd_none() or pte_none() returns true,
376  * or whenever there is no page directory covering the virtual address range.
377  */
378 static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end,
379                               bool fault, bool write_fault,
380                               struct mm_walk *walk)
381 {
382         struct hmm_vma_walk *hmm_vma_walk = walk->private;
383         struct hmm_range *range = hmm_vma_walk->range;
384         uint64_t *pfns = range->pfns;
385         unsigned long i, page_size;
386
387         hmm_vma_walk->last = addr;
388         page_size = hmm_range_page_size(range);
389         i = (addr - range->start) >> range->page_shift;
390
391         for (; addr < end; addr += page_size, i++) {
392                 pfns[i] = range->values[HMM_PFN_NONE];
393                 if (fault || write_fault) {
394                         int ret;
395
396                         ret = hmm_vma_do_fault(walk, addr, write_fault,
397                                                &pfns[i]);
398                         if (ret != -EBUSY)
399                                 return ret;
400                 }
401         }
402
403         return (fault || write_fault) ? -EBUSY : 0;
404 }
405
406 static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
407                                       uint64_t pfns, uint64_t cpu_flags,
408                                       bool *fault, bool *write_fault)
409 {
410         struct hmm_range *range = hmm_vma_walk->range;
411
412         if (!hmm_vma_walk->fault)
413                 return;
414
415         /*
416          * So we not only consider the individual per page request we also
417          * consider the default flags requested for the range. The API can
418          * be use in 2 fashions. The first one where the HMM user coalesce
419          * multiple page fault into one request and set flags per pfns for
420          * of those faults. The second one where the HMM user want to pre-
421          * fault a range with specific flags. For the latter one it is a
422          * waste to have the user pre-fill the pfn arrays with a default
423          * flags value.
424          */
425         pfns = (pfns & range->pfn_flags_mask) | range->default_flags;
426
427         /* We aren't ask to do anything ... */
428         if (!(pfns & range->flags[HMM_PFN_VALID]))
429                 return;
430         /* If this is device memory than only fault if explicitly requested */
431         if ((cpu_flags & range->flags[HMM_PFN_DEVICE_PRIVATE])) {
432                 /* Do we fault on device memory ? */
433                 if (pfns & range->flags[HMM_PFN_DEVICE_PRIVATE]) {
434                         *write_fault = pfns & range->flags[HMM_PFN_WRITE];
435                         *fault = true;
436                 }
437                 return;
438         }
439
440         /* If CPU page table is not valid then we need to fault */
441         *fault = !(cpu_flags & range->flags[HMM_PFN_VALID]);
442         /* Need to write fault ? */
443         if ((pfns & range->flags[HMM_PFN_WRITE]) &&
444             !(cpu_flags & range->flags[HMM_PFN_WRITE])) {
445                 *write_fault = true;
446                 *fault = true;
447         }
448 }
449
450 static void hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
451                                  const uint64_t *pfns, unsigned long npages,
452                                  uint64_t cpu_flags, bool *fault,
453                                  bool *write_fault)
454 {
455         unsigned long i;
456
457         if (!hmm_vma_walk->fault) {
458                 *fault = *write_fault = false;
459                 return;
460         }
461
462         *fault = *write_fault = false;
463         for (i = 0; i < npages; ++i) {
464                 hmm_pte_need_fault(hmm_vma_walk, pfns[i], cpu_flags,
465                                    fault, write_fault);
466                 if ((*write_fault))
467                         return;
468         }
469 }
470
471 static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
472                              struct mm_walk *walk)
473 {
474         struct hmm_vma_walk *hmm_vma_walk = walk->private;
475         struct hmm_range *range = hmm_vma_walk->range;
476         bool fault, write_fault;
477         unsigned long i, npages;
478         uint64_t *pfns;
479
480         i = (addr - range->start) >> PAGE_SHIFT;
481         npages = (end - addr) >> PAGE_SHIFT;
482         pfns = &range->pfns[i];
483         hmm_range_need_fault(hmm_vma_walk, pfns, npages,
484                              0, &fault, &write_fault);
485         return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
486 }
487
488 static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd)
489 {
490         if (pmd_protnone(pmd))
491                 return 0;
492         return pmd_write(pmd) ? range->flags[HMM_PFN_VALID] |
493                                 range->flags[HMM_PFN_WRITE] :
494                                 range->flags[HMM_PFN_VALID];
495 }
496
497 static inline uint64_t pud_to_hmm_pfn_flags(struct hmm_range *range, pud_t pud)
498 {
499         if (!pud_present(pud))
500                 return 0;
501         return pud_write(pud) ? range->flags[HMM_PFN_VALID] |
502                                 range->flags[HMM_PFN_WRITE] :
503                                 range->flags[HMM_PFN_VALID];
504 }
505
506 static int hmm_vma_handle_pmd(struct mm_walk *walk,
507                               unsigned long addr,
508                               unsigned long end,
509                               uint64_t *pfns,
510                               pmd_t pmd)
511 {
512 #ifdef CONFIG_TRANSPARENT_HUGEPAGE
513         struct hmm_vma_walk *hmm_vma_walk = walk->private;
514         struct hmm_range *range = hmm_vma_walk->range;
515         unsigned long pfn, npages, i;
516         bool fault, write_fault;
517         uint64_t cpu_flags;
518
519         npages = (end - addr) >> PAGE_SHIFT;
520         cpu_flags = pmd_to_hmm_pfn_flags(range, pmd);
521         hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags,
522                              &fault, &write_fault);
523
524         if (pmd_protnone(pmd) || fault || write_fault)
525                 return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
526
527         pfn = pmd_pfn(pmd) + pte_index(addr);
528         for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) {
529                 if (pmd_devmap(pmd)) {
530                         hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
531                                               hmm_vma_walk->pgmap);
532                         if (unlikely(!hmm_vma_walk->pgmap))
533                                 return -EBUSY;
534                 }
535                 pfns[i] = hmm_device_entry_from_pfn(range, pfn) | cpu_flags;
536         }
537         if (hmm_vma_walk->pgmap) {
538                 put_dev_pagemap(hmm_vma_walk->pgmap);
539                 hmm_vma_walk->pgmap = NULL;
540         }
541         hmm_vma_walk->last = end;
542         return 0;
543 #else
544         /* If THP is not enabled then we should never reach that code ! */
545         return -EINVAL;
546 #endif
547 }
548
549 static inline uint64_t pte_to_hmm_pfn_flags(struct hmm_range *range, pte_t pte)
550 {
551         if (pte_none(pte) || !pte_present(pte))
552                 return 0;
553         return pte_write(pte) ? range->flags[HMM_PFN_VALID] |
554                                 range->flags[HMM_PFN_WRITE] :
555                                 range->flags[HMM_PFN_VALID];
556 }
557
558 static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
559                               unsigned long end, pmd_t *pmdp, pte_t *ptep,
560                               uint64_t *pfn)
561 {
562         struct hmm_vma_walk *hmm_vma_walk = walk->private;
563         struct hmm_range *range = hmm_vma_walk->range;
564         struct vm_area_struct *vma = walk->vma;
565         bool fault, write_fault;
566         uint64_t cpu_flags;
567         pte_t pte = *ptep;
568         uint64_t orig_pfn = *pfn;
569
570         *pfn = range->values[HMM_PFN_NONE];
571         fault = write_fault = false;
572
573         if (pte_none(pte)) {
574                 hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0,
575                                    &fault, &write_fault);
576                 if (fault || write_fault)
577                         goto fault;
578                 return 0;
579         }
580
581         if (!pte_present(pte)) {
582                 swp_entry_t entry = pte_to_swp_entry(pte);
583
584                 if (!non_swap_entry(entry)) {
585                         if (fault || write_fault)
586                                 goto fault;
587                         return 0;
588                 }
589
590                 /*
591                  * This is a special swap entry, ignore migration, use
592                  * device and report anything else as error.
593                  */
594                 if (is_device_private_entry(entry)) {
595                         cpu_flags = range->flags[HMM_PFN_VALID] |
596                                 range->flags[HMM_PFN_DEVICE_PRIVATE];
597                         cpu_flags |= is_write_device_private_entry(entry) ?
598                                 range->flags[HMM_PFN_WRITE] : 0;
599                         hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
600                                            &fault, &write_fault);
601                         if (fault || write_fault)
602                                 goto fault;
603                         *pfn = hmm_device_entry_from_pfn(range,
604                                             swp_offset(entry));
605                         *pfn |= cpu_flags;
606                         return 0;
607                 }
608
609                 if (is_migration_entry(entry)) {
610                         if (fault || write_fault) {
611                                 pte_unmap(ptep);
612                                 hmm_vma_walk->last = addr;
613                                 migration_entry_wait(vma->vm_mm,
614                                                      pmdp, addr);
615                                 return -EBUSY;
616                         }
617                         return 0;
618                 }
619
620                 /* Report error for everything else */
621                 *pfn = range->values[HMM_PFN_ERROR];
622                 return -EFAULT;
623         } else {
624                 cpu_flags = pte_to_hmm_pfn_flags(range, pte);
625                 hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
626                                    &fault, &write_fault);
627         }
628
629         if (fault || write_fault)
630                 goto fault;
631
632         if (pte_devmap(pte)) {
633                 hmm_vma_walk->pgmap = get_dev_pagemap(pte_pfn(pte),
634                                               hmm_vma_walk->pgmap);
635                 if (unlikely(!hmm_vma_walk->pgmap))
636                         return -EBUSY;
637         } else if (IS_ENABLED(CONFIG_ARCH_HAS_PTE_SPECIAL) && pte_special(pte)) {
638                 *pfn = range->values[HMM_PFN_SPECIAL];
639                 return -EFAULT;
640         }
641
642         *pfn = hmm_device_entry_from_pfn(range, pte_pfn(pte)) | cpu_flags;
643         return 0;
644
645 fault:
646         if (hmm_vma_walk->pgmap) {
647                 put_dev_pagemap(hmm_vma_walk->pgmap);
648                 hmm_vma_walk->pgmap = NULL;
649         }
650         pte_unmap(ptep);
651         /* Fault any virtual address we were asked to fault */
652         return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
653 }
654
655 static int hmm_vma_walk_pmd(pmd_t *pmdp,
656                             unsigned long start,
657                             unsigned long end,
658                             struct mm_walk *walk)
659 {
660         struct hmm_vma_walk *hmm_vma_walk = walk->private;
661         struct hmm_range *range = hmm_vma_walk->range;
662         struct vm_area_struct *vma = walk->vma;
663         uint64_t *pfns = range->pfns;
664         unsigned long addr = start, i;
665         pte_t *ptep;
666         pmd_t pmd;
667
668
669 again:
670         pmd = READ_ONCE(*pmdp);
671         if (pmd_none(pmd))
672                 return hmm_vma_walk_hole(start, end, walk);
673
674         if (pmd_huge(pmd) && (range->vma->vm_flags & VM_HUGETLB))
675                 return hmm_pfns_bad(start, end, walk);
676
677         if (thp_migration_supported() && is_pmd_migration_entry(pmd)) {
678                 bool fault, write_fault;
679                 unsigned long npages;
680                 uint64_t *pfns;
681
682                 i = (addr - range->start) >> PAGE_SHIFT;
683                 npages = (end - addr) >> PAGE_SHIFT;
684                 pfns = &range->pfns[i];
685
686                 hmm_range_need_fault(hmm_vma_walk, pfns, npages,
687                                      0, &fault, &write_fault);
688                 if (fault || write_fault) {
689                         hmm_vma_walk->last = addr;
690                         pmd_migration_entry_wait(vma->vm_mm, pmdp);
691                         return -EBUSY;
692                 }
693                 return 0;
694         } else if (!pmd_present(pmd))
695                 return hmm_pfns_bad(start, end, walk);
696
697         if (pmd_devmap(pmd) || pmd_trans_huge(pmd)) {
698                 /*
699                  * No need to take pmd_lock here, even if some other threads
700                  * is splitting the huge pmd we will get that event through
701                  * mmu_notifier callback.
702                  *
703                  * So just read pmd value and check again its a transparent
704                  * huge or device mapping one and compute corresponding pfn
705                  * values.
706                  */
707                 pmd = pmd_read_atomic(pmdp);
708                 barrier();
709                 if (!pmd_devmap(pmd) && !pmd_trans_huge(pmd))
710                         goto again;
711
712                 i = (addr - range->start) >> PAGE_SHIFT;
713                 return hmm_vma_handle_pmd(walk, addr, end, &pfns[i], pmd);
714         }
715
716         /*
717          * We have handled all the valid case above ie either none, migration,
718          * huge or transparent huge. At this point either it is a valid pmd
719          * entry pointing to pte directory or it is a bad pmd that will not
720          * recover.
721          */
722         if (pmd_bad(pmd))
723                 return hmm_pfns_bad(start, end, walk);
724
725         ptep = pte_offset_map(pmdp, addr);
726         i = (addr - range->start) >> PAGE_SHIFT;
727         for (; addr < end; addr += PAGE_SIZE, ptep++, i++) {
728                 int r;
729
730                 r = hmm_vma_handle_pte(walk, addr, end, pmdp, ptep, &pfns[i]);
731                 if (r) {
732                         /* hmm_vma_handle_pte() did unmap pte directory */
733                         hmm_vma_walk->last = addr;
734                         return r;
735                 }
736         }
737         if (hmm_vma_walk->pgmap) {
738                 /*
739                  * We do put_dev_pagemap() here and not in hmm_vma_handle_pte()
740                  * so that we can leverage get_dev_pagemap() optimization which
741                  * will not re-take a reference on a pgmap if we already have
742                  * one.
743                  */
744                 put_dev_pagemap(hmm_vma_walk->pgmap);
745                 hmm_vma_walk->pgmap = NULL;
746         }
747         pte_unmap(ptep - 1);
748
749         hmm_vma_walk->last = addr;
750         return 0;
751 }
752
753 static int hmm_vma_walk_pud(pud_t *pudp,
754                             unsigned long start,
755                             unsigned long end,
756                             struct mm_walk *walk)
757 {
758         struct hmm_vma_walk *hmm_vma_walk = walk->private;
759         struct hmm_range *range = hmm_vma_walk->range;
760         unsigned long addr = start, next;
761         pmd_t *pmdp;
762         pud_t pud;
763         int ret;
764
765 again:
766         pud = READ_ONCE(*pudp);
767         if (pud_none(pud))
768                 return hmm_vma_walk_hole(start, end, walk);
769
770         if (pud_huge(pud) && pud_devmap(pud)) {
771                 unsigned long i, npages, pfn;
772                 uint64_t *pfns, cpu_flags;
773                 bool fault, write_fault;
774
775                 if (!pud_present(pud))
776                         return hmm_vma_walk_hole(start, end, walk);
777
778                 i = (addr - range->start) >> PAGE_SHIFT;
779                 npages = (end - addr) >> PAGE_SHIFT;
780                 pfns = &range->pfns[i];
781
782                 cpu_flags = pud_to_hmm_pfn_flags(range, pud);
783                 hmm_range_need_fault(hmm_vma_walk, pfns, npages,
784                                      cpu_flags, &fault, &write_fault);
785                 if (fault || write_fault)
786                         return hmm_vma_walk_hole_(addr, end, fault,
787                                                 write_fault, walk);
788
789                 pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
790                 for (i = 0; i < npages; ++i, ++pfn) {
791                         hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
792                                               hmm_vma_walk->pgmap);
793                         if (unlikely(!hmm_vma_walk->pgmap))
794                                 return -EBUSY;
795                         pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
796                                   cpu_flags;
797                 }
798                 if (hmm_vma_walk->pgmap) {
799                         put_dev_pagemap(hmm_vma_walk->pgmap);
800                         hmm_vma_walk->pgmap = NULL;
801                 }
802                 hmm_vma_walk->last = end;
803                 return 0;
804         }
805
806         split_huge_pud(walk->vma, pudp, addr);
807         if (pud_none(*pudp))
808                 goto again;
809
810         pmdp = pmd_offset(pudp, addr);
811         do {
812                 next = pmd_addr_end(addr, end);
813                 ret = hmm_vma_walk_pmd(pmdp, addr, next, walk);
814                 if (ret)
815                         return ret;
816         } while (pmdp++, addr = next, addr != end);
817
818         return 0;
819 }
820
821 static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
822                                       unsigned long start, unsigned long end,
823                                       struct mm_walk *walk)
824 {
825 #ifdef CONFIG_HUGETLB_PAGE
826         unsigned long addr = start, i, pfn, mask, size, pfn_inc;
827         struct hmm_vma_walk *hmm_vma_walk = walk->private;
828         struct hmm_range *range = hmm_vma_walk->range;
829         struct vm_area_struct *vma = walk->vma;
830         struct hstate *h = hstate_vma(vma);
831         uint64_t orig_pfn, cpu_flags;
832         bool fault, write_fault;
833         spinlock_t *ptl;
834         pte_t entry;
835         int ret = 0;
836
837         size = 1UL << huge_page_shift(h);
838         mask = size - 1;
839         if (range->page_shift != PAGE_SHIFT) {
840                 /* Make sure we are looking at full page. */
841                 if (start & mask)
842                         return -EINVAL;
843                 if (end < (start + size))
844                         return -EINVAL;
845                 pfn_inc = size >> PAGE_SHIFT;
846         } else {
847                 pfn_inc = 1;
848                 size = PAGE_SIZE;
849         }
850
851
852         ptl = huge_pte_lock(hstate_vma(walk->vma), walk->mm, pte);
853         entry = huge_ptep_get(pte);
854
855         i = (start - range->start) >> range->page_shift;
856         orig_pfn = range->pfns[i];
857         range->pfns[i] = range->values[HMM_PFN_NONE];
858         cpu_flags = pte_to_hmm_pfn_flags(range, entry);
859         fault = write_fault = false;
860         hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
861                            &fault, &write_fault);
862         if (fault || write_fault) {
863                 ret = -ENOENT;
864                 goto unlock;
865         }
866
867         pfn = pte_pfn(entry) + ((start & mask) >> range->page_shift);
868         for (; addr < end; addr += size, i++, pfn += pfn_inc)
869                 range->pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
870                                  cpu_flags;
871         hmm_vma_walk->last = end;
872
873 unlock:
874         spin_unlock(ptl);
875
876         if (ret == -ENOENT)
877                 return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
878
879         return ret;
880 #else /* CONFIG_HUGETLB_PAGE */
881         return -EINVAL;
882 #endif
883 }
884
885 static void hmm_pfns_clear(struct hmm_range *range,
886                            uint64_t *pfns,
887                            unsigned long addr,
888                            unsigned long end)
889 {
890         for (; addr < end; addr += PAGE_SIZE, pfns++)
891                 *pfns = range->values[HMM_PFN_NONE];
892 }
893
894 /*
895  * hmm_range_register() - start tracking change to CPU page table over a range
896  * @range: range
897  * @mm: the mm struct for the range of virtual address
898  * @start: start virtual address (inclusive)
899  * @end: end virtual address (exclusive)
900  * @page_shift: expect page shift for the range
901  * Returns 0 on success, -EFAULT if the address space is no longer valid
902  *
903  * Track updates to the CPU page table see include/linux/hmm.h
904  */
905 int hmm_range_register(struct hmm_range *range,
906                        struct mm_struct *mm,
907                        unsigned long start,
908                        unsigned long end,
909                        unsigned page_shift)
910 {
911         unsigned long mask = ((1UL << page_shift) - 1UL);
912         struct hmm *hmm;
913
914         range->valid = false;
915         range->hmm = NULL;
916
917         if ((start & mask) || (end & mask))
918                 return -EINVAL;
919         if (start >= end)
920                 return -EINVAL;
921
922         range->page_shift = page_shift;
923         range->start = start;
924         range->end = end;
925
926         hmm = hmm_get_or_create(mm);
927         if (!hmm)
928                 return -EFAULT;
929
930         /* Check if hmm_mm_destroy() was call. */
931         if (hmm->mm == NULL || hmm->dead) {
932                 hmm_put(hmm);
933                 return -EFAULT;
934         }
935
936         /* Initialize range to track CPU page table updates. */
937         mutex_lock(&hmm->lock);
938
939         range->hmm = hmm;
940         list_add_rcu(&range->list, &hmm->ranges);
941
942         /*
943          * If there are any concurrent notifiers we have to wait for them for
944          * the range to be valid (see hmm_range_wait_until_valid()).
945          */
946         if (!hmm->notifiers)
947                 range->valid = true;
948         mutex_unlock(&hmm->lock);
949
950         return 0;
951 }
952 EXPORT_SYMBOL(hmm_range_register);
953
954 /*
955  * hmm_range_unregister() - stop tracking change to CPU page table over a range
956  * @range: range
957  *
958  * Range struct is used to track updates to the CPU page table after a call to
959  * hmm_range_register(). See include/linux/hmm.h for how to use it.
960  */
961 void hmm_range_unregister(struct hmm_range *range)
962 {
963         struct hmm *hmm = range->hmm;
964
965         /* Sanity check this really should not happen. */
966         if (hmm == NULL || range->end <= range->start)
967                 return;
968
969         mutex_lock(&hmm->lock);
970         list_del_rcu(&range->list);
971         mutex_unlock(&hmm->lock);
972
973         /* Drop reference taken by hmm_range_register() */
974         range->valid = false;
975         hmm_put(hmm);
976         range->hmm = NULL;
977 }
978 EXPORT_SYMBOL(hmm_range_unregister);
979
980 /*
981  * hmm_range_snapshot() - snapshot CPU page table for a range
982  * @range: range
983  * Return: -EINVAL if invalid argument, -ENOMEM out of memory, -EPERM invalid
984  *          permission (for instance asking for write and range is read only),
985  *          -EAGAIN if you need to retry, -EFAULT invalid (ie either no valid
986  *          vma or it is illegal to access that range), number of valid pages
987  *          in range->pfns[] (from range start address).
988  *
989  * This snapshots the CPU page table for a range of virtual addresses. Snapshot
990  * validity is tracked by range struct. See in include/linux/hmm.h for example
991  * on how to use.
992  */
993 long hmm_range_snapshot(struct hmm_range *range)
994 {
995         const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
996         unsigned long start = range->start, end;
997         struct hmm_vma_walk hmm_vma_walk;
998         struct hmm *hmm = range->hmm;
999         struct vm_area_struct *vma;
1000         struct mm_walk mm_walk;
1001
1002         /* Check if hmm_mm_destroy() was call. */
1003         if (hmm->mm == NULL || hmm->dead)
1004                 return -EFAULT;
1005
1006         do {
1007                 /* If range is no longer valid force retry. */
1008                 if (!range->valid)
1009                         return -EAGAIN;
1010
1011                 vma = find_vma(hmm->mm, start);
1012                 if (vma == NULL || (vma->vm_flags & device_vma))
1013                         return -EFAULT;
1014
1015                 if (is_vm_hugetlb_page(vma)) {
1016                         if (huge_page_shift(hstate_vma(vma)) !=
1017                                     range->page_shift &&
1018                             range->page_shift != PAGE_SHIFT)
1019                                 return -EINVAL;
1020                 } else {
1021                         if (range->page_shift != PAGE_SHIFT)
1022                                 return -EINVAL;
1023                 }
1024
1025                 if (!(vma->vm_flags & VM_READ)) {
1026                         /*
1027                          * If vma do not allow read access, then assume that it
1028                          * does not allow write access, either. HMM does not
1029                          * support architecture that allow write without read.
1030                          */
1031                         hmm_pfns_clear(range, range->pfns,
1032                                 range->start, range->end);
1033                         return -EPERM;
1034                 }
1035
1036                 range->vma = vma;
1037                 hmm_vma_walk.pgmap = NULL;
1038                 hmm_vma_walk.last = start;
1039                 hmm_vma_walk.fault = false;
1040                 hmm_vma_walk.range = range;
1041                 mm_walk.private = &hmm_vma_walk;
1042                 end = min(range->end, vma->vm_end);
1043
1044                 mm_walk.vma = vma;
1045                 mm_walk.mm = vma->vm_mm;
1046                 mm_walk.pte_entry = NULL;
1047                 mm_walk.test_walk = NULL;
1048                 mm_walk.hugetlb_entry = NULL;
1049                 mm_walk.pud_entry = hmm_vma_walk_pud;
1050                 mm_walk.pmd_entry = hmm_vma_walk_pmd;
1051                 mm_walk.pte_hole = hmm_vma_walk_hole;
1052                 mm_walk.hugetlb_entry = hmm_vma_walk_hugetlb_entry;
1053
1054                 walk_page_range(start, end, &mm_walk);
1055                 start = end;
1056         } while (start < range->end);
1057
1058         return (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
1059 }
1060 EXPORT_SYMBOL(hmm_range_snapshot);
1061
1062 /*
1063  * hmm_range_fault() - try to fault some address in a virtual address range
1064  * @range: range being faulted
1065  * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem)
1066  * Return: number of valid pages in range->pfns[] (from range start
1067  *          address). This may be zero. If the return value is negative,
1068  *          then one of the following values may be returned:
1069  *
1070  *           -EINVAL  invalid arguments or mm or virtual address are in an
1071  *                    invalid vma (for instance device file vma).
1072  *           -ENOMEM: Out of memory.
1073  *           -EPERM:  Invalid permission (for instance asking for write and
1074  *                    range is read only).
1075  *           -EAGAIN: If you need to retry and mmap_sem was drop. This can only
1076  *                    happens if block argument is false.
1077  *           -EBUSY:  If the the range is being invalidated and you should wait
1078  *                    for invalidation to finish.
1079  *           -EFAULT: Invalid (ie either no valid vma or it is illegal to access
1080  *                    that range), number of valid pages in range->pfns[] (from
1081  *                    range start address).
1082  *
1083  * This is similar to a regular CPU page fault except that it will not trigger
1084  * any memory migration if the memory being faulted is not accessible by CPUs
1085  * and caller does not ask for migration.
1086  *
1087  * On error, for one virtual address in the range, the function will mark the
1088  * corresponding HMM pfn entry with an error flag.
1089  */
1090 long hmm_range_fault(struct hmm_range *range, bool block)
1091 {
1092         const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
1093         unsigned long start = range->start, end;
1094         struct hmm_vma_walk hmm_vma_walk;
1095         struct hmm *hmm = range->hmm;
1096         struct vm_area_struct *vma;
1097         struct mm_walk mm_walk;
1098         int ret;
1099
1100         /* Check if hmm_mm_destroy() was call. */
1101         if (hmm->mm == NULL || hmm->dead)
1102                 return -EFAULT;
1103
1104         do {
1105                 /* If range is no longer valid force retry. */
1106                 if (!range->valid) {
1107                         up_read(&hmm->mm->mmap_sem);
1108                         return -EAGAIN;
1109                 }
1110
1111                 vma = find_vma(hmm->mm, start);
1112                 if (vma == NULL || (vma->vm_flags & device_vma))
1113                         return -EFAULT;
1114
1115                 if (is_vm_hugetlb_page(vma)) {
1116                         if (huge_page_shift(hstate_vma(vma)) !=
1117                             range->page_shift &&
1118                             range->page_shift != PAGE_SHIFT)
1119                                 return -EINVAL;
1120                 } else {
1121                         if (range->page_shift != PAGE_SHIFT)
1122                                 return -EINVAL;
1123                 }
1124
1125                 if (!(vma->vm_flags & VM_READ)) {
1126                         /*
1127                          * If vma do not allow read access, then assume that it
1128                          * does not allow write access, either. HMM does not
1129                          * support architecture that allow write without read.
1130                          */
1131                         hmm_pfns_clear(range, range->pfns,
1132                                 range->start, range->end);
1133                         return -EPERM;
1134                 }
1135
1136                 range->vma = vma;
1137                 hmm_vma_walk.pgmap = NULL;
1138                 hmm_vma_walk.last = start;
1139                 hmm_vma_walk.fault = true;
1140                 hmm_vma_walk.block = block;
1141                 hmm_vma_walk.range = range;
1142                 mm_walk.private = &hmm_vma_walk;
1143                 end = min(range->end, vma->vm_end);
1144
1145                 mm_walk.vma = vma;
1146                 mm_walk.mm = vma->vm_mm;
1147                 mm_walk.pte_entry = NULL;
1148                 mm_walk.test_walk = NULL;
1149                 mm_walk.hugetlb_entry = NULL;
1150                 mm_walk.pud_entry = hmm_vma_walk_pud;
1151                 mm_walk.pmd_entry = hmm_vma_walk_pmd;
1152                 mm_walk.pte_hole = hmm_vma_walk_hole;
1153                 mm_walk.hugetlb_entry = hmm_vma_walk_hugetlb_entry;
1154
1155                 do {
1156                         ret = walk_page_range(start, end, &mm_walk);
1157                         start = hmm_vma_walk.last;
1158
1159                         /* Keep trying while the range is valid. */
1160                 } while (ret == -EBUSY && range->valid);
1161
1162                 if (ret) {
1163                         unsigned long i;
1164
1165                         i = (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
1166                         hmm_pfns_clear(range, &range->pfns[i],
1167                                 hmm_vma_walk.last, range->end);
1168                         return ret;
1169                 }
1170                 start = end;
1171
1172         } while (start < range->end);
1173
1174         return (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
1175 }
1176 EXPORT_SYMBOL(hmm_range_fault);
1177
1178 /**
1179  * hmm_range_dma_map() - hmm_range_fault() and dma map page all in one.
1180  * @range: range being faulted
1181  * @device: device against to dma map page to
1182  * @daddrs: dma address of mapped pages
1183  * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem)
1184  * Return: number of pages mapped on success, -EAGAIN if mmap_sem have been
1185  *          drop and you need to try again, some other error value otherwise
1186  *
1187  * Note same usage pattern as hmm_range_fault().
1188  */
1189 long hmm_range_dma_map(struct hmm_range *range,
1190                        struct device *device,
1191                        dma_addr_t *daddrs,
1192                        bool block)
1193 {
1194         unsigned long i, npages, mapped;
1195         long ret;
1196
1197         ret = hmm_range_fault(range, block);
1198         if (ret <= 0)
1199                 return ret ? ret : -EBUSY;
1200
1201         npages = (range->end - range->start) >> PAGE_SHIFT;
1202         for (i = 0, mapped = 0; i < npages; ++i) {
1203                 enum dma_data_direction dir = DMA_TO_DEVICE;
1204                 struct page *page;
1205
1206                 /*
1207                  * FIXME need to update DMA API to provide invalid DMA address
1208                  * value instead of a function to test dma address value. This
1209                  * would remove lot of dumb code duplicated accross many arch.
1210                  *
1211                  * For now setting it to 0 here is good enough as the pfns[]
1212                  * value is what is use to check what is valid and what isn't.
1213                  */
1214                 daddrs[i] = 0;
1215
1216                 page = hmm_device_entry_to_page(range, range->pfns[i]);
1217                 if (page == NULL)
1218                         continue;
1219
1220                 /* Check if range is being invalidated */
1221                 if (!range->valid) {
1222                         ret = -EBUSY;
1223                         goto unmap;
1224                 }
1225
1226                 /* If it is read and write than map bi-directional. */
1227                 if (range->pfns[i] & range->flags[HMM_PFN_WRITE])
1228                         dir = DMA_BIDIRECTIONAL;
1229
1230                 daddrs[i] = dma_map_page(device, page, 0, PAGE_SIZE, dir);
1231                 if (dma_mapping_error(device, daddrs[i])) {
1232                         ret = -EFAULT;
1233                         goto unmap;
1234                 }
1235
1236                 mapped++;
1237         }
1238
1239         return mapped;
1240
1241 unmap:
1242         for (npages = i, i = 0; (i < npages) && mapped; ++i) {
1243                 enum dma_data_direction dir = DMA_TO_DEVICE;
1244                 struct page *page;
1245
1246                 page = hmm_device_entry_to_page(range, range->pfns[i]);
1247                 if (page == NULL)
1248                         continue;
1249
1250                 if (dma_mapping_error(device, daddrs[i]))
1251                         continue;
1252
1253                 /* If it is read and write than map bi-directional. */
1254                 if (range->pfns[i] & range->flags[HMM_PFN_WRITE])
1255                         dir = DMA_BIDIRECTIONAL;
1256
1257                 dma_unmap_page(device, daddrs[i], PAGE_SIZE, dir);
1258                 mapped--;
1259         }
1260
1261         return ret;
1262 }
1263 EXPORT_SYMBOL(hmm_range_dma_map);
1264
1265 /**
1266  * hmm_range_dma_unmap() - unmap range of that was map with hmm_range_dma_map()
1267  * @range: range being unmapped
1268  * @vma: the vma against which the range (optional)
1269  * @device: device against which dma map was done
1270  * @daddrs: dma address of mapped pages
1271  * @dirty: dirty page if it had the write flag set
1272  * Return: number of page unmapped on success, -EINVAL otherwise
1273  *
1274  * Note that caller MUST abide by mmu notifier or use HMM mirror and abide
1275  * to the sync_cpu_device_pagetables() callback so that it is safe here to
1276  * call set_page_dirty(). Caller must also take appropriate locks to avoid
1277  * concurrent mmu notifier or sync_cpu_device_pagetables() to make progress.
1278  */
1279 long hmm_range_dma_unmap(struct hmm_range *range,
1280                          struct vm_area_struct *vma,
1281                          struct device *device,
1282                          dma_addr_t *daddrs,
1283                          bool dirty)
1284 {
1285         unsigned long i, npages;
1286         long cpages = 0;
1287
1288         /* Sanity check. */
1289         if (range->end <= range->start)
1290                 return -EINVAL;
1291         if (!daddrs)
1292                 return -EINVAL;
1293         if (!range->pfns)
1294                 return -EINVAL;
1295
1296         npages = (range->end - range->start) >> PAGE_SHIFT;
1297         for (i = 0; i < npages; ++i) {
1298                 enum dma_data_direction dir = DMA_TO_DEVICE;
1299                 struct page *page;
1300
1301                 page = hmm_device_entry_to_page(range, range->pfns[i]);
1302                 if (page == NULL)
1303                         continue;
1304
1305                 /* If it is read and write than map bi-directional. */
1306                 if (range->pfns[i] & range->flags[HMM_PFN_WRITE]) {
1307                         dir = DMA_BIDIRECTIONAL;
1308
1309                         /*
1310                          * See comments in function description on why it is
1311                          * safe here to call set_page_dirty()
1312                          */
1313                         if (dirty)
1314                                 set_page_dirty(page);
1315                 }
1316
1317                 /* Unmap and clear pfns/dma address */
1318                 dma_unmap_page(device, daddrs[i], PAGE_SIZE, dir);
1319                 range->pfns[i] = range->values[HMM_PFN_NONE];
1320                 /* FIXME see comments in hmm_vma_dma_map() */
1321                 daddrs[i] = 0;
1322                 cpages++;
1323         }
1324
1325         return cpages;
1326 }
1327 EXPORT_SYMBOL(hmm_range_dma_unmap);
1328 #endif /* IS_ENABLED(CONFIG_HMM_MIRROR) */
1329
1330
1331 #if IS_ENABLED(CONFIG_DEVICE_PRIVATE) ||  IS_ENABLED(CONFIG_DEVICE_PUBLIC)
1332 struct page *hmm_vma_alloc_locked_page(struct vm_area_struct *vma,
1333                                        unsigned long addr)
1334 {
1335         struct page *page;
1336
1337         page = alloc_page_vma(GFP_HIGHUSER, vma, addr);
1338         if (!page)
1339                 return NULL;
1340         lock_page(page);
1341         return page;
1342 }
1343 EXPORT_SYMBOL(hmm_vma_alloc_locked_page);
1344
1345
1346 static void hmm_devmem_ref_release(struct percpu_ref *ref)
1347 {
1348         struct hmm_devmem *devmem;
1349
1350         devmem = container_of(ref, struct hmm_devmem, ref);
1351         complete(&devmem->completion);
1352 }
1353
1354 static void hmm_devmem_ref_exit(void *data)
1355 {
1356         struct percpu_ref *ref = data;
1357         struct hmm_devmem *devmem;
1358
1359         devmem = container_of(ref, struct hmm_devmem, ref);
1360         wait_for_completion(&devmem->completion);
1361         percpu_ref_exit(ref);
1362 }
1363
1364 static void hmm_devmem_ref_kill(struct percpu_ref *ref)
1365 {
1366         percpu_ref_kill(ref);
1367 }
1368
1369 static vm_fault_t hmm_devmem_fault(struct vm_area_struct *vma,
1370                             unsigned long addr,
1371                             const struct page *page,
1372                             unsigned int flags,
1373                             pmd_t *pmdp)
1374 {
1375         struct hmm_devmem *devmem = page->pgmap->data;
1376
1377         return devmem->ops->fault(devmem, vma, addr, page, flags, pmdp);
1378 }
1379
1380 static void hmm_devmem_free(struct page *page, void *data)
1381 {
1382         struct hmm_devmem *devmem = data;
1383
1384         page->mapping = NULL;
1385
1386         devmem->ops->free(devmem, page);
1387 }
1388
1389 /*
1390  * hmm_devmem_add() - hotplug ZONE_DEVICE memory for device memory
1391  *
1392  * @ops: memory event device driver callback (see struct hmm_devmem_ops)
1393  * @device: device struct to bind the resource too
1394  * @size: size in bytes of the device memory to add
1395  * Return: pointer to new hmm_devmem struct ERR_PTR otherwise
1396  *
1397  * This function first finds an empty range of physical address big enough to
1398  * contain the new resource, and then hotplugs it as ZONE_DEVICE memory, which
1399  * in turn allocates struct pages. It does not do anything beyond that; all
1400  * events affecting the memory will go through the various callbacks provided
1401  * by hmm_devmem_ops struct.
1402  *
1403  * Device driver should call this function during device initialization and
1404  * is then responsible of memory management. HMM only provides helpers.
1405  */
1406 struct hmm_devmem *hmm_devmem_add(const struct hmm_devmem_ops *ops,
1407                                   struct device *device,
1408                                   unsigned long size)
1409 {
1410         struct hmm_devmem *devmem;
1411         resource_size_t addr;
1412         void *result;
1413         int ret;
1414
1415         dev_pagemap_get_ops();
1416
1417         devmem = devm_kzalloc(device, sizeof(*devmem), GFP_KERNEL);
1418         if (!devmem)
1419                 return ERR_PTR(-ENOMEM);
1420
1421         init_completion(&devmem->completion);
1422         devmem->pfn_first = -1UL;
1423         devmem->pfn_last = -1UL;
1424         devmem->resource = NULL;
1425         devmem->device = device;
1426         devmem->ops = ops;
1427
1428         ret = percpu_ref_init(&devmem->ref, &hmm_devmem_ref_release,
1429                               0, GFP_KERNEL);
1430         if (ret)
1431                 return ERR_PTR(ret);
1432
1433         ret = devm_add_action_or_reset(device, hmm_devmem_ref_exit, &devmem->ref);
1434         if (ret)
1435                 return ERR_PTR(ret);
1436
1437         size = ALIGN(size, PA_SECTION_SIZE);
1438         addr = min((unsigned long)iomem_resource.end,
1439                    (1UL << MAX_PHYSMEM_BITS) - 1);
1440         addr = addr - size + 1UL;
1441
1442         /*
1443          * FIXME add a new helper to quickly walk resource tree and find free
1444          * range
1445          *
1446          * FIXME what about ioport_resource resource ?
1447          */
1448         for (; addr > size && addr >= iomem_resource.start; addr -= size) {
1449                 ret = region_intersects(addr, size, 0, IORES_DESC_NONE);
1450                 if (ret != REGION_DISJOINT)
1451                         continue;
1452
1453                 devmem->resource = devm_request_mem_region(device, addr, size,
1454                                                            dev_name(device));
1455                 if (!devmem->resource)
1456                         return ERR_PTR(-ENOMEM);
1457                 break;
1458         }
1459         if (!devmem->resource)
1460                 return ERR_PTR(-ERANGE);
1461
1462         devmem->resource->desc = IORES_DESC_DEVICE_PRIVATE_MEMORY;
1463         devmem->pfn_first = devmem->resource->start >> PAGE_SHIFT;
1464         devmem->pfn_last = devmem->pfn_first +
1465                            (resource_size(devmem->resource) >> PAGE_SHIFT);
1466         devmem->page_fault = hmm_devmem_fault;
1467
1468         devmem->pagemap.type = MEMORY_DEVICE_PRIVATE;
1469         devmem->pagemap.res = *devmem->resource;
1470         devmem->pagemap.page_free = hmm_devmem_free;
1471         devmem->pagemap.altmap_valid = false;
1472         devmem->pagemap.ref = &devmem->ref;
1473         devmem->pagemap.data = devmem;
1474         devmem->pagemap.kill = hmm_devmem_ref_kill;
1475
1476         result = devm_memremap_pages(devmem->device, &devmem->pagemap);
1477         if (IS_ERR(result))
1478                 return result;
1479         return devmem;
1480 }
1481 EXPORT_SYMBOL_GPL(hmm_devmem_add);
1482
1483 struct hmm_devmem *hmm_devmem_add_resource(const struct hmm_devmem_ops *ops,
1484                                            struct device *device,
1485                                            struct resource *res)
1486 {
1487         struct hmm_devmem *devmem;
1488         void *result;
1489         int ret;
1490
1491         if (res->desc != IORES_DESC_DEVICE_PUBLIC_MEMORY)
1492                 return ERR_PTR(-EINVAL);
1493
1494         dev_pagemap_get_ops();
1495
1496         devmem = devm_kzalloc(device, sizeof(*devmem), GFP_KERNEL);
1497         if (!devmem)
1498                 return ERR_PTR(-ENOMEM);
1499
1500         init_completion(&devmem->completion);
1501         devmem->pfn_first = -1UL;
1502         devmem->pfn_last = -1UL;
1503         devmem->resource = res;
1504         devmem->device = device;
1505         devmem->ops = ops;
1506
1507         ret = percpu_ref_init(&devmem->ref, &hmm_devmem_ref_release,
1508                               0, GFP_KERNEL);
1509         if (ret)
1510                 return ERR_PTR(ret);
1511
1512         ret = devm_add_action_or_reset(device, hmm_devmem_ref_exit,
1513                         &devmem->ref);
1514         if (ret)
1515                 return ERR_PTR(ret);
1516
1517         devmem->pfn_first = devmem->resource->start >> PAGE_SHIFT;
1518         devmem->pfn_last = devmem->pfn_first +
1519                            (resource_size(devmem->resource) >> PAGE_SHIFT);
1520         devmem->page_fault = hmm_devmem_fault;
1521
1522         devmem->pagemap.type = MEMORY_DEVICE_PUBLIC;
1523         devmem->pagemap.res = *devmem->resource;
1524         devmem->pagemap.page_free = hmm_devmem_free;
1525         devmem->pagemap.altmap_valid = false;
1526         devmem->pagemap.ref = &devmem->ref;
1527         devmem->pagemap.data = devmem;
1528         devmem->pagemap.kill = hmm_devmem_ref_kill;
1529
1530         result = devm_memremap_pages(devmem->device, &devmem->pagemap);
1531         if (IS_ERR(result))
1532                 return result;
1533         return devmem;
1534 }
1535 EXPORT_SYMBOL_GPL(hmm_devmem_add_resource);
1536
1537 /*
1538  * A device driver that wants to handle multiple devices memory through a
1539  * single fake device can use hmm_device to do so. This is purely a helper
1540  * and it is not needed to make use of any HMM functionality.
1541  */
1542 #define HMM_DEVICE_MAX 256
1543
1544 static DECLARE_BITMAP(hmm_device_mask, HMM_DEVICE_MAX);
1545 static DEFINE_SPINLOCK(hmm_device_lock);
1546 static struct class *hmm_device_class;
1547 static dev_t hmm_device_devt;
1548
1549 static void hmm_device_release(struct device *device)
1550 {
1551         struct hmm_device *hmm_device;
1552
1553         hmm_device = container_of(device, struct hmm_device, device);
1554         spin_lock(&hmm_device_lock);
1555         clear_bit(hmm_device->minor, hmm_device_mask);
1556         spin_unlock(&hmm_device_lock);
1557
1558         kfree(hmm_device);
1559 }
1560
1561 struct hmm_device *hmm_device_new(void *drvdata)
1562 {
1563         struct hmm_device *hmm_device;
1564
1565         hmm_device = kzalloc(sizeof(*hmm_device), GFP_KERNEL);
1566         if (!hmm_device)
1567                 return ERR_PTR(-ENOMEM);
1568
1569         spin_lock(&hmm_device_lock);
1570         hmm_device->minor = find_first_zero_bit(hmm_device_mask, HMM_DEVICE_MAX);
1571         if (hmm_device->minor >= HMM_DEVICE_MAX) {
1572                 spin_unlock(&hmm_device_lock);
1573                 kfree(hmm_device);
1574                 return ERR_PTR(-EBUSY);
1575         }
1576         set_bit(hmm_device->minor, hmm_device_mask);
1577         spin_unlock(&hmm_device_lock);
1578
1579         dev_set_name(&hmm_device->device, "hmm_device%d", hmm_device->minor);
1580         hmm_device->device.devt = MKDEV(MAJOR(hmm_device_devt),
1581                                         hmm_device->minor);
1582         hmm_device->device.release = hmm_device_release;
1583         dev_set_drvdata(&hmm_device->device, drvdata);
1584         hmm_device->device.class = hmm_device_class;
1585         device_initialize(&hmm_device->device);
1586
1587         return hmm_device;
1588 }
1589 EXPORT_SYMBOL(hmm_device_new);
1590
1591 void hmm_device_put(struct hmm_device *hmm_device)
1592 {
1593         put_device(&hmm_device->device);
1594 }
1595 EXPORT_SYMBOL(hmm_device_put);
1596
1597 static int __init hmm_init(void)
1598 {
1599         int ret;
1600
1601         ret = alloc_chrdev_region(&hmm_device_devt, 0,
1602                                   HMM_DEVICE_MAX,
1603                                   "hmm_device");
1604         if (ret)
1605                 return ret;
1606
1607         hmm_device_class = class_create(THIS_MODULE, "hmm_device");
1608         if (IS_ERR(hmm_device_class)) {
1609                 unregister_chrdev_region(hmm_device_devt, HMM_DEVICE_MAX);
1610                 return PTR_ERR(hmm_device_class);
1611         }
1612         return 0;
1613 }
1614
1615 device_initcall(hmm_init);
1616 #endif /* CONFIG_DEVICE_PRIVATE || CONFIG_DEVICE_PUBLIC */