1eddda45cefae79029493ff34f041b335315b417
[linux-2.6-block.git] / mm / hmm.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Copyright 2013 Red Hat Inc.
4  *
5  * Authors: Jérôme Glisse <jglisse@redhat.com>
6  */
7 /*
8  * Refer to include/linux/hmm.h for information about heterogeneous memory
9  * management or HMM for short.
10  */
11 #include <linux/mm.h>
12 #include <linux/hmm.h>
13 #include <linux/init.h>
14 #include <linux/rmap.h>
15 #include <linux/swap.h>
16 #include <linux/slab.h>
17 #include <linux/sched.h>
18 #include <linux/mmzone.h>
19 #include <linux/pagemap.h>
20 #include <linux/swapops.h>
21 #include <linux/hugetlb.h>
22 #include <linux/memremap.h>
23 #include <linux/sched/mm.h>
24 #include <linux/jump_label.h>
25 #include <linux/dma-mapping.h>
26 #include <linux/mmu_notifier.h>
27 #include <linux/memory_hotplug.h>
28
29 #define PA_SECTION_SIZE (1UL << PA_SECTION_SHIFT)
30
31 #if IS_ENABLED(CONFIG_HMM_MIRROR)
32 static const struct mmu_notifier_ops hmm_mmu_notifier_ops;
33
34 /**
35  * hmm_get_or_create - register HMM against an mm (HMM internal)
36  *
37  * @mm: mm struct to attach to
38  * Returns: returns an HMM object, either by referencing the existing
39  *          (per-process) object, or by creating a new one.
40  *
41  * This is not intended to be used directly by device drivers. If mm already
42  * has an HMM struct then it get a reference on it and returns it. Otherwise
43  * it allocates an HMM struct, initializes it, associate it with the mm and
44  * returns it.
45  */
46 static struct hmm *hmm_get_or_create(struct mm_struct *mm)
47 {
48         struct hmm *hmm;
49
50         lockdep_assert_held_exclusive(&mm->mmap_sem);
51
52         /* Abuse the page_table_lock to also protect mm->hmm. */
53         spin_lock(&mm->page_table_lock);
54         hmm = mm->hmm;
55         if (mm->hmm && kref_get_unless_zero(&mm->hmm->kref))
56                 goto out_unlock;
57         spin_unlock(&mm->page_table_lock);
58
59         hmm = kmalloc(sizeof(*hmm), GFP_KERNEL);
60         if (!hmm)
61                 return NULL;
62         init_waitqueue_head(&hmm->wq);
63         INIT_LIST_HEAD(&hmm->mirrors);
64         init_rwsem(&hmm->mirrors_sem);
65         hmm->mmu_notifier.ops = NULL;
66         INIT_LIST_HEAD(&hmm->ranges);
67         mutex_init(&hmm->lock);
68         kref_init(&hmm->kref);
69         hmm->notifiers = 0;
70         hmm->mm = mm;
71
72         hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops;
73         if (__mmu_notifier_register(&hmm->mmu_notifier, mm)) {
74                 kfree(hmm);
75                 return NULL;
76         }
77
78         mmgrab(hmm->mm);
79
80         /*
81          * We hold the exclusive mmap_sem here so we know that mm->hmm is
82          * still NULL or 0 kref, and is safe to update.
83          */
84         spin_lock(&mm->page_table_lock);
85         mm->hmm = hmm;
86
87 out_unlock:
88         spin_unlock(&mm->page_table_lock);
89         return hmm;
90 }
91
92 static void hmm_free_rcu(struct rcu_head *rcu)
93 {
94         struct hmm *hmm = container_of(rcu, struct hmm, rcu);
95
96         mmdrop(hmm->mm);
97         kfree(hmm);
98 }
99
100 static void hmm_free(struct kref *kref)
101 {
102         struct hmm *hmm = container_of(kref, struct hmm, kref);
103
104         spin_lock(&hmm->mm->page_table_lock);
105         if (hmm->mm->hmm == hmm)
106                 hmm->mm->hmm = NULL;
107         spin_unlock(&hmm->mm->page_table_lock);
108
109         mmu_notifier_unregister_no_release(&hmm->mmu_notifier, hmm->mm);
110         mmu_notifier_call_srcu(&hmm->rcu, hmm_free_rcu);
111 }
112
113 static inline void hmm_put(struct hmm *hmm)
114 {
115         kref_put(&hmm->kref, hmm_free);
116 }
117
118 static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
119 {
120         struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
121         struct hmm_mirror *mirror;
122
123         /* Bail out if hmm is in the process of being freed */
124         if (!kref_get_unless_zero(&hmm->kref))
125                 return;
126
127         /*
128          * Since hmm_range_register() holds the mmget() lock hmm_release() is
129          * prevented as long as a range exists.
130          */
131         WARN_ON(!list_empty_careful(&hmm->ranges));
132
133         down_write(&hmm->mirrors_sem);
134         mirror = list_first_entry_or_null(&hmm->mirrors, struct hmm_mirror,
135                                           list);
136         while (mirror) {
137                 list_del_init(&mirror->list);
138                 if (mirror->ops->release) {
139                         /*
140                          * Drop mirrors_sem so the release callback can wait
141                          * on any pending work that might itself trigger a
142                          * mmu_notifier callback and thus would deadlock with
143                          * us.
144                          */
145                         up_write(&hmm->mirrors_sem);
146                         mirror->ops->release(mirror);
147                         down_write(&hmm->mirrors_sem);
148                 }
149                 mirror = list_first_entry_or_null(&hmm->mirrors,
150                                                   struct hmm_mirror, list);
151         }
152         up_write(&hmm->mirrors_sem);
153
154         hmm_put(hmm);
155 }
156
157 static int hmm_invalidate_range_start(struct mmu_notifier *mn,
158                         const struct mmu_notifier_range *nrange)
159 {
160         struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
161         struct hmm_mirror *mirror;
162         struct hmm_update update;
163         struct hmm_range *range;
164         int ret = 0;
165
166         if (!kref_get_unless_zero(&hmm->kref))
167                 return 0;
168
169         update.start = nrange->start;
170         update.end = nrange->end;
171         update.event = HMM_UPDATE_INVALIDATE;
172         update.blockable = mmu_notifier_range_blockable(nrange);
173
174         if (mmu_notifier_range_blockable(nrange))
175                 mutex_lock(&hmm->lock);
176         else if (!mutex_trylock(&hmm->lock)) {
177                 ret = -EAGAIN;
178                 goto out;
179         }
180         hmm->notifiers++;
181         list_for_each_entry(range, &hmm->ranges, list) {
182                 if (update.end < range->start || update.start >= range->end)
183                         continue;
184
185                 range->valid = false;
186         }
187         mutex_unlock(&hmm->lock);
188
189         if (mmu_notifier_range_blockable(nrange))
190                 down_read(&hmm->mirrors_sem);
191         else if (!down_read_trylock(&hmm->mirrors_sem)) {
192                 ret = -EAGAIN;
193                 goto out;
194         }
195         list_for_each_entry(mirror, &hmm->mirrors, list) {
196                 int ret;
197
198                 ret = mirror->ops->sync_cpu_device_pagetables(mirror, &update);
199                 if (!update.blockable && ret == -EAGAIN)
200                         break;
201         }
202         up_read(&hmm->mirrors_sem);
203
204 out:
205         hmm_put(hmm);
206         return ret;
207 }
208
209 static void hmm_invalidate_range_end(struct mmu_notifier *mn,
210                         const struct mmu_notifier_range *nrange)
211 {
212         struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
213
214         if (!kref_get_unless_zero(&hmm->kref))
215                 return;
216
217         mutex_lock(&hmm->lock);
218         hmm->notifiers--;
219         if (!hmm->notifiers) {
220                 struct hmm_range *range;
221
222                 list_for_each_entry(range, &hmm->ranges, list) {
223                         if (range->valid)
224                                 continue;
225                         range->valid = true;
226                 }
227                 wake_up_all(&hmm->wq);
228         }
229         mutex_unlock(&hmm->lock);
230
231         hmm_put(hmm);
232 }
233
234 static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
235         .release                = hmm_release,
236         .invalidate_range_start = hmm_invalidate_range_start,
237         .invalidate_range_end   = hmm_invalidate_range_end,
238 };
239
240 /*
241  * hmm_mirror_register() - register a mirror against an mm
242  *
243  * @mirror: new mirror struct to register
244  * @mm: mm to register against
245  * Return: 0 on success, -ENOMEM if no memory, -EINVAL if invalid arguments
246  *
247  * To start mirroring a process address space, the device driver must register
248  * an HMM mirror struct.
249  *
250  * THE mm->mmap_sem MUST BE HELD IN WRITE MODE !
251  */
252 int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm)
253 {
254         /* Sanity check */
255         if (!mm || !mirror || !mirror->ops)
256                 return -EINVAL;
257
258         mirror->hmm = hmm_get_or_create(mm);
259         if (!mirror->hmm)
260                 return -ENOMEM;
261
262         down_write(&mirror->hmm->mirrors_sem);
263         list_add(&mirror->list, &mirror->hmm->mirrors);
264         up_write(&mirror->hmm->mirrors_sem);
265
266         return 0;
267 }
268 EXPORT_SYMBOL(hmm_mirror_register);
269
270 /*
271  * hmm_mirror_unregister() - unregister a mirror
272  *
273  * @mirror: mirror struct to unregister
274  *
275  * Stop mirroring a process address space, and cleanup.
276  */
277 void hmm_mirror_unregister(struct hmm_mirror *mirror)
278 {
279         struct hmm *hmm = READ_ONCE(mirror->hmm);
280
281         if (hmm == NULL)
282                 return;
283
284         down_write(&hmm->mirrors_sem);
285         list_del_init(&mirror->list);
286         /* To protect us against double unregister ... */
287         mirror->hmm = NULL;
288         up_write(&hmm->mirrors_sem);
289
290         hmm_put(hmm);
291 }
292 EXPORT_SYMBOL(hmm_mirror_unregister);
293
294 struct hmm_vma_walk {
295         struct hmm_range        *range;
296         struct dev_pagemap      *pgmap;
297         unsigned long           last;
298         bool                    fault;
299         bool                    block;
300 };
301
302 static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
303                             bool write_fault, uint64_t *pfn)
304 {
305         unsigned int flags = FAULT_FLAG_REMOTE;
306         struct hmm_vma_walk *hmm_vma_walk = walk->private;
307         struct hmm_range *range = hmm_vma_walk->range;
308         struct vm_area_struct *vma = walk->vma;
309         vm_fault_t ret;
310
311         flags |= hmm_vma_walk->block ? 0 : FAULT_FLAG_ALLOW_RETRY;
312         flags |= write_fault ? FAULT_FLAG_WRITE : 0;
313         ret = handle_mm_fault(vma, addr, flags);
314         if (ret & VM_FAULT_RETRY)
315                 return -EAGAIN;
316         if (ret & VM_FAULT_ERROR) {
317                 *pfn = range->values[HMM_PFN_ERROR];
318                 return -EFAULT;
319         }
320
321         return -EBUSY;
322 }
323
324 static int hmm_pfns_bad(unsigned long addr,
325                         unsigned long end,
326                         struct mm_walk *walk)
327 {
328         struct hmm_vma_walk *hmm_vma_walk = walk->private;
329         struct hmm_range *range = hmm_vma_walk->range;
330         uint64_t *pfns = range->pfns;
331         unsigned long i;
332
333         i = (addr - range->start) >> PAGE_SHIFT;
334         for (; addr < end; addr += PAGE_SIZE, i++)
335                 pfns[i] = range->values[HMM_PFN_ERROR];
336
337         return 0;
338 }
339
340 /*
341  * hmm_vma_walk_hole() - handle a range lacking valid pmd or pte(s)
342  * @start: range virtual start address (inclusive)
343  * @end: range virtual end address (exclusive)
344  * @fault: should we fault or not ?
345  * @write_fault: write fault ?
346  * @walk: mm_walk structure
347  * Return: 0 on success, -EBUSY after page fault, or page fault error
348  *
349  * This function will be called whenever pmd_none() or pte_none() returns true,
350  * or whenever there is no page directory covering the virtual address range.
351  */
352 static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end,
353                               bool fault, bool write_fault,
354                               struct mm_walk *walk)
355 {
356         struct hmm_vma_walk *hmm_vma_walk = walk->private;
357         struct hmm_range *range = hmm_vma_walk->range;
358         uint64_t *pfns = range->pfns;
359         unsigned long i, page_size;
360
361         hmm_vma_walk->last = addr;
362         page_size = hmm_range_page_size(range);
363         i = (addr - range->start) >> range->page_shift;
364
365         for (; addr < end; addr += page_size, i++) {
366                 pfns[i] = range->values[HMM_PFN_NONE];
367                 if (fault || write_fault) {
368                         int ret;
369
370                         ret = hmm_vma_do_fault(walk, addr, write_fault,
371                                                &pfns[i]);
372                         if (ret != -EBUSY)
373                                 return ret;
374                 }
375         }
376
377         return (fault || write_fault) ? -EBUSY : 0;
378 }
379
380 static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
381                                       uint64_t pfns, uint64_t cpu_flags,
382                                       bool *fault, bool *write_fault)
383 {
384         struct hmm_range *range = hmm_vma_walk->range;
385
386         if (!hmm_vma_walk->fault)
387                 return;
388
389         /*
390          * So we not only consider the individual per page request we also
391          * consider the default flags requested for the range. The API can
392          * be use in 2 fashions. The first one where the HMM user coalesce
393          * multiple page fault into one request and set flags per pfns for
394          * of those faults. The second one where the HMM user want to pre-
395          * fault a range with specific flags. For the latter one it is a
396          * waste to have the user pre-fill the pfn arrays with a default
397          * flags value.
398          */
399         pfns = (pfns & range->pfn_flags_mask) | range->default_flags;
400
401         /* We aren't ask to do anything ... */
402         if (!(pfns & range->flags[HMM_PFN_VALID]))
403                 return;
404         /* If this is device memory than only fault if explicitly requested */
405         if ((cpu_flags & range->flags[HMM_PFN_DEVICE_PRIVATE])) {
406                 /* Do we fault on device memory ? */
407                 if (pfns & range->flags[HMM_PFN_DEVICE_PRIVATE]) {
408                         *write_fault = pfns & range->flags[HMM_PFN_WRITE];
409                         *fault = true;
410                 }
411                 return;
412         }
413
414         /* If CPU page table is not valid then we need to fault */
415         *fault = !(cpu_flags & range->flags[HMM_PFN_VALID]);
416         /* Need to write fault ? */
417         if ((pfns & range->flags[HMM_PFN_WRITE]) &&
418             !(cpu_flags & range->flags[HMM_PFN_WRITE])) {
419                 *write_fault = true;
420                 *fault = true;
421         }
422 }
423
424 static void hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
425                                  const uint64_t *pfns, unsigned long npages,
426                                  uint64_t cpu_flags, bool *fault,
427                                  bool *write_fault)
428 {
429         unsigned long i;
430
431         if (!hmm_vma_walk->fault) {
432                 *fault = *write_fault = false;
433                 return;
434         }
435
436         *fault = *write_fault = false;
437         for (i = 0; i < npages; ++i) {
438                 hmm_pte_need_fault(hmm_vma_walk, pfns[i], cpu_flags,
439                                    fault, write_fault);
440                 if ((*write_fault))
441                         return;
442         }
443 }
444
445 static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
446                              struct mm_walk *walk)
447 {
448         struct hmm_vma_walk *hmm_vma_walk = walk->private;
449         struct hmm_range *range = hmm_vma_walk->range;
450         bool fault, write_fault;
451         unsigned long i, npages;
452         uint64_t *pfns;
453
454         i = (addr - range->start) >> PAGE_SHIFT;
455         npages = (end - addr) >> PAGE_SHIFT;
456         pfns = &range->pfns[i];
457         hmm_range_need_fault(hmm_vma_walk, pfns, npages,
458                              0, &fault, &write_fault);
459         return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
460 }
461
462 static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd)
463 {
464         if (pmd_protnone(pmd))
465                 return 0;
466         return pmd_write(pmd) ? range->flags[HMM_PFN_VALID] |
467                                 range->flags[HMM_PFN_WRITE] :
468                                 range->flags[HMM_PFN_VALID];
469 }
470
471 static inline uint64_t pud_to_hmm_pfn_flags(struct hmm_range *range, pud_t pud)
472 {
473         if (!pud_present(pud))
474                 return 0;
475         return pud_write(pud) ? range->flags[HMM_PFN_VALID] |
476                                 range->flags[HMM_PFN_WRITE] :
477                                 range->flags[HMM_PFN_VALID];
478 }
479
480 static int hmm_vma_handle_pmd(struct mm_walk *walk,
481                               unsigned long addr,
482                               unsigned long end,
483                               uint64_t *pfns,
484                               pmd_t pmd)
485 {
486 #ifdef CONFIG_TRANSPARENT_HUGEPAGE
487         struct hmm_vma_walk *hmm_vma_walk = walk->private;
488         struct hmm_range *range = hmm_vma_walk->range;
489         unsigned long pfn, npages, i;
490         bool fault, write_fault;
491         uint64_t cpu_flags;
492
493         npages = (end - addr) >> PAGE_SHIFT;
494         cpu_flags = pmd_to_hmm_pfn_flags(range, pmd);
495         hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags,
496                              &fault, &write_fault);
497
498         if (pmd_protnone(pmd) || fault || write_fault)
499                 return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
500
501         pfn = pmd_pfn(pmd) + pte_index(addr);
502         for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) {
503                 if (pmd_devmap(pmd)) {
504                         hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
505                                               hmm_vma_walk->pgmap);
506                         if (unlikely(!hmm_vma_walk->pgmap))
507                                 return -EBUSY;
508                 }
509                 pfns[i] = hmm_device_entry_from_pfn(range, pfn) | cpu_flags;
510         }
511         if (hmm_vma_walk->pgmap) {
512                 put_dev_pagemap(hmm_vma_walk->pgmap);
513                 hmm_vma_walk->pgmap = NULL;
514         }
515         hmm_vma_walk->last = end;
516         return 0;
517 #else
518         /* If THP is not enabled then we should never reach that code ! */
519         return -EINVAL;
520 #endif
521 }
522
523 static inline uint64_t pte_to_hmm_pfn_flags(struct hmm_range *range, pte_t pte)
524 {
525         if (pte_none(pte) || !pte_present(pte) || pte_protnone(pte))
526                 return 0;
527         return pte_write(pte) ? range->flags[HMM_PFN_VALID] |
528                                 range->flags[HMM_PFN_WRITE] :
529                                 range->flags[HMM_PFN_VALID];
530 }
531
532 static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
533                               unsigned long end, pmd_t *pmdp, pte_t *ptep,
534                               uint64_t *pfn)
535 {
536         struct hmm_vma_walk *hmm_vma_walk = walk->private;
537         struct hmm_range *range = hmm_vma_walk->range;
538         struct vm_area_struct *vma = walk->vma;
539         bool fault, write_fault;
540         uint64_t cpu_flags;
541         pte_t pte = *ptep;
542         uint64_t orig_pfn = *pfn;
543
544         *pfn = range->values[HMM_PFN_NONE];
545         fault = write_fault = false;
546
547         if (pte_none(pte)) {
548                 hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0,
549                                    &fault, &write_fault);
550                 if (fault || write_fault)
551                         goto fault;
552                 return 0;
553         }
554
555         if (!pte_present(pte)) {
556                 swp_entry_t entry = pte_to_swp_entry(pte);
557
558                 if (!non_swap_entry(entry)) {
559                         if (fault || write_fault)
560                                 goto fault;
561                         return 0;
562                 }
563
564                 /*
565                  * This is a special swap entry, ignore migration, use
566                  * device and report anything else as error.
567                  */
568                 if (is_device_private_entry(entry)) {
569                         cpu_flags = range->flags[HMM_PFN_VALID] |
570                                 range->flags[HMM_PFN_DEVICE_PRIVATE];
571                         cpu_flags |= is_write_device_private_entry(entry) ?
572                                 range->flags[HMM_PFN_WRITE] : 0;
573                         hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
574                                            &fault, &write_fault);
575                         if (fault || write_fault)
576                                 goto fault;
577                         *pfn = hmm_device_entry_from_pfn(range,
578                                             swp_offset(entry));
579                         *pfn |= cpu_flags;
580                         return 0;
581                 }
582
583                 if (is_migration_entry(entry)) {
584                         if (fault || write_fault) {
585                                 pte_unmap(ptep);
586                                 hmm_vma_walk->last = addr;
587                                 migration_entry_wait(vma->vm_mm,
588                                                      pmdp, addr);
589                                 return -EBUSY;
590                         }
591                         return 0;
592                 }
593
594                 /* Report error for everything else */
595                 *pfn = range->values[HMM_PFN_ERROR];
596                 return -EFAULT;
597         } else {
598                 cpu_flags = pte_to_hmm_pfn_flags(range, pte);
599                 hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
600                                    &fault, &write_fault);
601         }
602
603         if (fault || write_fault)
604                 goto fault;
605
606         if (pte_devmap(pte)) {
607                 hmm_vma_walk->pgmap = get_dev_pagemap(pte_pfn(pte),
608                                               hmm_vma_walk->pgmap);
609                 if (unlikely(!hmm_vma_walk->pgmap))
610                         return -EBUSY;
611         } else if (IS_ENABLED(CONFIG_ARCH_HAS_PTE_SPECIAL) && pte_special(pte)) {
612                 *pfn = range->values[HMM_PFN_SPECIAL];
613                 return -EFAULT;
614         }
615
616         *pfn = hmm_device_entry_from_pfn(range, pte_pfn(pte)) | cpu_flags;
617         return 0;
618
619 fault:
620         if (hmm_vma_walk->pgmap) {
621                 put_dev_pagemap(hmm_vma_walk->pgmap);
622                 hmm_vma_walk->pgmap = NULL;
623         }
624         pte_unmap(ptep);
625         /* Fault any virtual address we were asked to fault */
626         return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
627 }
628
629 static int hmm_vma_walk_pmd(pmd_t *pmdp,
630                             unsigned long start,
631                             unsigned long end,
632                             struct mm_walk *walk)
633 {
634         struct hmm_vma_walk *hmm_vma_walk = walk->private;
635         struct hmm_range *range = hmm_vma_walk->range;
636         struct vm_area_struct *vma = walk->vma;
637         uint64_t *pfns = range->pfns;
638         unsigned long addr = start, i;
639         pte_t *ptep;
640         pmd_t pmd;
641
642
643 again:
644         pmd = READ_ONCE(*pmdp);
645         if (pmd_none(pmd))
646                 return hmm_vma_walk_hole(start, end, walk);
647
648         if (pmd_huge(pmd) && (range->vma->vm_flags & VM_HUGETLB))
649                 return hmm_pfns_bad(start, end, walk);
650
651         if (thp_migration_supported() && is_pmd_migration_entry(pmd)) {
652                 bool fault, write_fault;
653                 unsigned long npages;
654                 uint64_t *pfns;
655
656                 i = (addr - range->start) >> PAGE_SHIFT;
657                 npages = (end - addr) >> PAGE_SHIFT;
658                 pfns = &range->pfns[i];
659
660                 hmm_range_need_fault(hmm_vma_walk, pfns, npages,
661                                      0, &fault, &write_fault);
662                 if (fault || write_fault) {
663                         hmm_vma_walk->last = addr;
664                         pmd_migration_entry_wait(vma->vm_mm, pmdp);
665                         return -EBUSY;
666                 }
667                 return 0;
668         } else if (!pmd_present(pmd))
669                 return hmm_pfns_bad(start, end, walk);
670
671         if (pmd_devmap(pmd) || pmd_trans_huge(pmd)) {
672                 /*
673                  * No need to take pmd_lock here, even if some other threads
674                  * is splitting the huge pmd we will get that event through
675                  * mmu_notifier callback.
676                  *
677                  * So just read pmd value and check again its a transparent
678                  * huge or device mapping one and compute corresponding pfn
679                  * values.
680                  */
681                 pmd = pmd_read_atomic(pmdp);
682                 barrier();
683                 if (!pmd_devmap(pmd) && !pmd_trans_huge(pmd))
684                         goto again;
685
686                 i = (addr - range->start) >> PAGE_SHIFT;
687                 return hmm_vma_handle_pmd(walk, addr, end, &pfns[i], pmd);
688         }
689
690         /*
691          * We have handled all the valid case above ie either none, migration,
692          * huge or transparent huge. At this point either it is a valid pmd
693          * entry pointing to pte directory or it is a bad pmd that will not
694          * recover.
695          */
696         if (pmd_bad(pmd))
697                 return hmm_pfns_bad(start, end, walk);
698
699         ptep = pte_offset_map(pmdp, addr);
700         i = (addr - range->start) >> PAGE_SHIFT;
701         for (; addr < end; addr += PAGE_SIZE, ptep++, i++) {
702                 int r;
703
704                 r = hmm_vma_handle_pte(walk, addr, end, pmdp, ptep, &pfns[i]);
705                 if (r) {
706                         /* hmm_vma_handle_pte() did unmap pte directory */
707                         hmm_vma_walk->last = addr;
708                         return r;
709                 }
710         }
711         if (hmm_vma_walk->pgmap) {
712                 /*
713                  * We do put_dev_pagemap() here and not in hmm_vma_handle_pte()
714                  * so that we can leverage get_dev_pagemap() optimization which
715                  * will not re-take a reference on a pgmap if we already have
716                  * one.
717                  */
718                 put_dev_pagemap(hmm_vma_walk->pgmap);
719                 hmm_vma_walk->pgmap = NULL;
720         }
721         pte_unmap(ptep - 1);
722
723         hmm_vma_walk->last = addr;
724         return 0;
725 }
726
727 static int hmm_vma_walk_pud(pud_t *pudp,
728                             unsigned long start,
729                             unsigned long end,
730                             struct mm_walk *walk)
731 {
732         struct hmm_vma_walk *hmm_vma_walk = walk->private;
733         struct hmm_range *range = hmm_vma_walk->range;
734         unsigned long addr = start, next;
735         pmd_t *pmdp;
736         pud_t pud;
737         int ret;
738
739 again:
740         pud = READ_ONCE(*pudp);
741         if (pud_none(pud))
742                 return hmm_vma_walk_hole(start, end, walk);
743
744         if (pud_huge(pud) && pud_devmap(pud)) {
745                 unsigned long i, npages, pfn;
746                 uint64_t *pfns, cpu_flags;
747                 bool fault, write_fault;
748
749                 if (!pud_present(pud))
750                         return hmm_vma_walk_hole(start, end, walk);
751
752                 i = (addr - range->start) >> PAGE_SHIFT;
753                 npages = (end - addr) >> PAGE_SHIFT;
754                 pfns = &range->pfns[i];
755
756                 cpu_flags = pud_to_hmm_pfn_flags(range, pud);
757                 hmm_range_need_fault(hmm_vma_walk, pfns, npages,
758                                      cpu_flags, &fault, &write_fault);
759                 if (fault || write_fault)
760                         return hmm_vma_walk_hole_(addr, end, fault,
761                                                 write_fault, walk);
762
763                 pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
764                 for (i = 0; i < npages; ++i, ++pfn) {
765                         hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
766                                               hmm_vma_walk->pgmap);
767                         if (unlikely(!hmm_vma_walk->pgmap))
768                                 return -EBUSY;
769                         pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
770                                   cpu_flags;
771                 }
772                 if (hmm_vma_walk->pgmap) {
773                         put_dev_pagemap(hmm_vma_walk->pgmap);
774                         hmm_vma_walk->pgmap = NULL;
775                 }
776                 hmm_vma_walk->last = end;
777                 return 0;
778         }
779
780         split_huge_pud(walk->vma, pudp, addr);
781         if (pud_none(*pudp))
782                 goto again;
783
784         pmdp = pmd_offset(pudp, addr);
785         do {
786                 next = pmd_addr_end(addr, end);
787                 ret = hmm_vma_walk_pmd(pmdp, addr, next, walk);
788                 if (ret)
789                         return ret;
790         } while (pmdp++, addr = next, addr != end);
791
792         return 0;
793 }
794
795 static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
796                                       unsigned long start, unsigned long end,
797                                       struct mm_walk *walk)
798 {
799 #ifdef CONFIG_HUGETLB_PAGE
800         unsigned long addr = start, i, pfn, mask, size, pfn_inc;
801         struct hmm_vma_walk *hmm_vma_walk = walk->private;
802         struct hmm_range *range = hmm_vma_walk->range;
803         struct vm_area_struct *vma = walk->vma;
804         struct hstate *h = hstate_vma(vma);
805         uint64_t orig_pfn, cpu_flags;
806         bool fault, write_fault;
807         spinlock_t *ptl;
808         pte_t entry;
809         int ret = 0;
810
811         size = 1UL << huge_page_shift(h);
812         mask = size - 1;
813         if (range->page_shift != PAGE_SHIFT) {
814                 /* Make sure we are looking at full page. */
815                 if (start & mask)
816                         return -EINVAL;
817                 if (end < (start + size))
818                         return -EINVAL;
819                 pfn_inc = size >> PAGE_SHIFT;
820         } else {
821                 pfn_inc = 1;
822                 size = PAGE_SIZE;
823         }
824
825
826         ptl = huge_pte_lock(hstate_vma(walk->vma), walk->mm, pte);
827         entry = huge_ptep_get(pte);
828
829         i = (start - range->start) >> range->page_shift;
830         orig_pfn = range->pfns[i];
831         range->pfns[i] = range->values[HMM_PFN_NONE];
832         cpu_flags = pte_to_hmm_pfn_flags(range, entry);
833         fault = write_fault = false;
834         hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
835                            &fault, &write_fault);
836         if (fault || write_fault) {
837                 ret = -ENOENT;
838                 goto unlock;
839         }
840
841         pfn = pte_pfn(entry) + ((start & mask) >> range->page_shift);
842         for (; addr < end; addr += size, i++, pfn += pfn_inc)
843                 range->pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
844                                  cpu_flags;
845         hmm_vma_walk->last = end;
846
847 unlock:
848         spin_unlock(ptl);
849
850         if (ret == -ENOENT)
851                 return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
852
853         return ret;
854 #else /* CONFIG_HUGETLB_PAGE */
855         return -EINVAL;
856 #endif
857 }
858
859 static void hmm_pfns_clear(struct hmm_range *range,
860                            uint64_t *pfns,
861                            unsigned long addr,
862                            unsigned long end)
863 {
864         for (; addr < end; addr += PAGE_SIZE, pfns++)
865                 *pfns = range->values[HMM_PFN_NONE];
866 }
867
868 /*
869  * hmm_range_register() - start tracking change to CPU page table over a range
870  * @range: range
871  * @mm: the mm struct for the range of virtual address
872  * @start: start virtual address (inclusive)
873  * @end: end virtual address (exclusive)
874  * @page_shift: expect page shift for the range
875  * Returns 0 on success, -EFAULT if the address space is no longer valid
876  *
877  * Track updates to the CPU page table see include/linux/hmm.h
878  */
879 int hmm_range_register(struct hmm_range *range,
880                        struct hmm_mirror *mirror,
881                        unsigned long start,
882                        unsigned long end,
883                        unsigned page_shift)
884 {
885         unsigned long mask = ((1UL << page_shift) - 1UL);
886         struct hmm *hmm = mirror->hmm;
887
888         range->valid = false;
889         range->hmm = NULL;
890
891         if ((start & mask) || (end & mask))
892                 return -EINVAL;
893         if (start >= end)
894                 return -EINVAL;
895
896         range->page_shift = page_shift;
897         range->start = start;
898         range->end = end;
899
900         /* Prevent hmm_release() from running while the range is valid */
901         if (!mmget_not_zero(hmm->mm))
902                 return -EFAULT;
903
904         /* Initialize range to track CPU page table updates. */
905         mutex_lock(&hmm->lock);
906
907         range->hmm = hmm;
908         kref_get(&hmm->kref);
909         list_add(&range->list, &hmm->ranges);
910
911         /*
912          * If there are any concurrent notifiers we have to wait for them for
913          * the range to be valid (see hmm_range_wait_until_valid()).
914          */
915         if (!hmm->notifiers)
916                 range->valid = true;
917         mutex_unlock(&hmm->lock);
918
919         return 0;
920 }
921 EXPORT_SYMBOL(hmm_range_register);
922
923 /*
924  * hmm_range_unregister() - stop tracking change to CPU page table over a range
925  * @range: range
926  *
927  * Range struct is used to track updates to the CPU page table after a call to
928  * hmm_range_register(). See include/linux/hmm.h for how to use it.
929  */
930 void hmm_range_unregister(struct hmm_range *range)
931 {
932         struct hmm *hmm = range->hmm;
933
934         /* Sanity check this really should not happen. */
935         if (hmm == NULL || range->end <= range->start)
936                 return;
937
938         mutex_lock(&hmm->lock);
939         list_del_init(&range->list);
940         mutex_unlock(&hmm->lock);
941
942         /* Drop reference taken by hmm_range_register() */
943         range->valid = false;
944         mmput(hmm->mm);
945         hmm_put(hmm);
946         range->hmm = NULL;
947 }
948 EXPORT_SYMBOL(hmm_range_unregister);
949
950 /*
951  * hmm_range_snapshot() - snapshot CPU page table for a range
952  * @range: range
953  * Return: -EINVAL if invalid argument, -ENOMEM out of memory, -EPERM invalid
954  *          permission (for instance asking for write and range is read only),
955  *          -EAGAIN if you need to retry, -EFAULT invalid (ie either no valid
956  *          vma or it is illegal to access that range), number of valid pages
957  *          in range->pfns[] (from range start address).
958  *
959  * This snapshots the CPU page table for a range of virtual addresses. Snapshot
960  * validity is tracked by range struct. See in include/linux/hmm.h for example
961  * on how to use.
962  */
963 long hmm_range_snapshot(struct hmm_range *range)
964 {
965         const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
966         unsigned long start = range->start, end;
967         struct hmm_vma_walk hmm_vma_walk;
968         struct hmm *hmm = range->hmm;
969         struct vm_area_struct *vma;
970         struct mm_walk mm_walk;
971
972         lockdep_assert_held(&hmm->mm->mmap_sem);
973         do {
974                 /* If range is no longer valid force retry. */
975                 if (!range->valid)
976                         return -EAGAIN;
977
978                 vma = find_vma(hmm->mm, start);
979                 if (vma == NULL || (vma->vm_flags & device_vma))
980                         return -EFAULT;
981
982                 if (is_vm_hugetlb_page(vma)) {
983                         if (huge_page_shift(hstate_vma(vma)) !=
984                                     range->page_shift &&
985                             range->page_shift != PAGE_SHIFT)
986                                 return -EINVAL;
987                 } else {
988                         if (range->page_shift != PAGE_SHIFT)
989                                 return -EINVAL;
990                 }
991
992                 if (!(vma->vm_flags & VM_READ)) {
993                         /*
994                          * If vma do not allow read access, then assume that it
995                          * does not allow write access, either. HMM does not
996                          * support architecture that allow write without read.
997                          */
998                         hmm_pfns_clear(range, range->pfns,
999                                 range->start, range->end);
1000                         return -EPERM;
1001                 }
1002
1003                 range->vma = vma;
1004                 hmm_vma_walk.pgmap = NULL;
1005                 hmm_vma_walk.last = start;
1006                 hmm_vma_walk.fault = false;
1007                 hmm_vma_walk.range = range;
1008                 mm_walk.private = &hmm_vma_walk;
1009                 end = min(range->end, vma->vm_end);
1010
1011                 mm_walk.vma = vma;
1012                 mm_walk.mm = vma->vm_mm;
1013                 mm_walk.pte_entry = NULL;
1014                 mm_walk.test_walk = NULL;
1015                 mm_walk.hugetlb_entry = NULL;
1016                 mm_walk.pud_entry = hmm_vma_walk_pud;
1017                 mm_walk.pmd_entry = hmm_vma_walk_pmd;
1018                 mm_walk.pte_hole = hmm_vma_walk_hole;
1019                 mm_walk.hugetlb_entry = hmm_vma_walk_hugetlb_entry;
1020
1021                 walk_page_range(start, end, &mm_walk);
1022                 start = end;
1023         } while (start < range->end);
1024
1025         return (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
1026 }
1027 EXPORT_SYMBOL(hmm_range_snapshot);
1028
1029 /*
1030  * hmm_range_fault() - try to fault some address in a virtual address range
1031  * @range: range being faulted
1032  * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem)
1033  * Return: number of valid pages in range->pfns[] (from range start
1034  *          address). This may be zero. If the return value is negative,
1035  *          then one of the following values may be returned:
1036  *
1037  *           -EINVAL  invalid arguments or mm or virtual address are in an
1038  *                    invalid vma (for instance device file vma).
1039  *           -ENOMEM: Out of memory.
1040  *           -EPERM:  Invalid permission (for instance asking for write and
1041  *                    range is read only).
1042  *           -EAGAIN: If you need to retry and mmap_sem was drop. This can only
1043  *                    happens if block argument is false.
1044  *           -EBUSY:  If the the range is being invalidated and you should wait
1045  *                    for invalidation to finish.
1046  *           -EFAULT: Invalid (ie either no valid vma or it is illegal to access
1047  *                    that range), number of valid pages in range->pfns[] (from
1048  *                    range start address).
1049  *
1050  * This is similar to a regular CPU page fault except that it will not trigger
1051  * any memory migration if the memory being faulted is not accessible by CPUs
1052  * and caller does not ask for migration.
1053  *
1054  * On error, for one virtual address in the range, the function will mark the
1055  * corresponding HMM pfn entry with an error flag.
1056  */
1057 long hmm_range_fault(struct hmm_range *range, bool block)
1058 {
1059         const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
1060         unsigned long start = range->start, end;
1061         struct hmm_vma_walk hmm_vma_walk;
1062         struct hmm *hmm = range->hmm;
1063         struct vm_area_struct *vma;
1064         struct mm_walk mm_walk;
1065         int ret;
1066
1067         lockdep_assert_held(&hmm->mm->mmap_sem);
1068
1069         do {
1070                 /* If range is no longer valid force retry. */
1071                 if (!range->valid) {
1072                         up_read(&hmm->mm->mmap_sem);
1073                         return -EAGAIN;
1074                 }
1075
1076                 vma = find_vma(hmm->mm, start);
1077                 if (vma == NULL || (vma->vm_flags & device_vma))
1078                         return -EFAULT;
1079
1080                 if (is_vm_hugetlb_page(vma)) {
1081                         if (huge_page_shift(hstate_vma(vma)) !=
1082                             range->page_shift &&
1083                             range->page_shift != PAGE_SHIFT)
1084                                 return -EINVAL;
1085                 } else {
1086                         if (range->page_shift != PAGE_SHIFT)
1087                                 return -EINVAL;
1088                 }
1089
1090                 if (!(vma->vm_flags & VM_READ)) {
1091                         /*
1092                          * If vma do not allow read access, then assume that it
1093                          * does not allow write access, either. HMM does not
1094                          * support architecture that allow write without read.
1095                          */
1096                         hmm_pfns_clear(range, range->pfns,
1097                                 range->start, range->end);
1098                         return -EPERM;
1099                 }
1100
1101                 range->vma = vma;
1102                 hmm_vma_walk.pgmap = NULL;
1103                 hmm_vma_walk.last = start;
1104                 hmm_vma_walk.fault = true;
1105                 hmm_vma_walk.block = block;
1106                 hmm_vma_walk.range = range;
1107                 mm_walk.private = &hmm_vma_walk;
1108                 end = min(range->end, vma->vm_end);
1109
1110                 mm_walk.vma = vma;
1111                 mm_walk.mm = vma->vm_mm;
1112                 mm_walk.pte_entry = NULL;
1113                 mm_walk.test_walk = NULL;
1114                 mm_walk.hugetlb_entry = NULL;
1115                 mm_walk.pud_entry = hmm_vma_walk_pud;
1116                 mm_walk.pmd_entry = hmm_vma_walk_pmd;
1117                 mm_walk.pte_hole = hmm_vma_walk_hole;
1118                 mm_walk.hugetlb_entry = hmm_vma_walk_hugetlb_entry;
1119
1120                 do {
1121                         ret = walk_page_range(start, end, &mm_walk);
1122                         start = hmm_vma_walk.last;
1123
1124                         /* Keep trying while the range is valid. */
1125                 } while (ret == -EBUSY && range->valid);
1126
1127                 if (ret) {
1128                         unsigned long i;
1129
1130                         i = (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
1131                         hmm_pfns_clear(range, &range->pfns[i],
1132                                 hmm_vma_walk.last, range->end);
1133                         return ret;
1134                 }
1135                 start = end;
1136
1137         } while (start < range->end);
1138
1139         return (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
1140 }
1141 EXPORT_SYMBOL(hmm_range_fault);
1142
1143 /**
1144  * hmm_range_dma_map() - hmm_range_fault() and dma map page all in one.
1145  * @range: range being faulted
1146  * @device: device against to dma map page to
1147  * @daddrs: dma address of mapped pages
1148  * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem)
1149  * Return: number of pages mapped on success, -EAGAIN if mmap_sem have been
1150  *          drop and you need to try again, some other error value otherwise
1151  *
1152  * Note same usage pattern as hmm_range_fault().
1153  */
1154 long hmm_range_dma_map(struct hmm_range *range,
1155                        struct device *device,
1156                        dma_addr_t *daddrs,
1157                        bool block)
1158 {
1159         unsigned long i, npages, mapped;
1160         long ret;
1161
1162         ret = hmm_range_fault(range, block);
1163         if (ret <= 0)
1164                 return ret ? ret : -EBUSY;
1165
1166         npages = (range->end - range->start) >> PAGE_SHIFT;
1167         for (i = 0, mapped = 0; i < npages; ++i) {
1168                 enum dma_data_direction dir = DMA_TO_DEVICE;
1169                 struct page *page;
1170
1171                 /*
1172                  * FIXME need to update DMA API to provide invalid DMA address
1173                  * value instead of a function to test dma address value. This
1174                  * would remove lot of dumb code duplicated accross many arch.
1175                  *
1176                  * For now setting it to 0 here is good enough as the pfns[]
1177                  * value is what is use to check what is valid and what isn't.
1178                  */
1179                 daddrs[i] = 0;
1180
1181                 page = hmm_device_entry_to_page(range, range->pfns[i]);
1182                 if (page == NULL)
1183                         continue;
1184
1185                 /* Check if range is being invalidated */
1186                 if (!range->valid) {
1187                         ret = -EBUSY;
1188                         goto unmap;
1189                 }
1190
1191                 /* If it is read and write than map bi-directional. */
1192                 if (range->pfns[i] & range->flags[HMM_PFN_WRITE])
1193                         dir = DMA_BIDIRECTIONAL;
1194
1195                 daddrs[i] = dma_map_page(device, page, 0, PAGE_SIZE, dir);
1196                 if (dma_mapping_error(device, daddrs[i])) {
1197                         ret = -EFAULT;
1198                         goto unmap;
1199                 }
1200
1201                 mapped++;
1202         }
1203
1204         return mapped;
1205
1206 unmap:
1207         for (npages = i, i = 0; (i < npages) && mapped; ++i) {
1208                 enum dma_data_direction dir = DMA_TO_DEVICE;
1209                 struct page *page;
1210
1211                 page = hmm_device_entry_to_page(range, range->pfns[i]);
1212                 if (page == NULL)
1213                         continue;
1214
1215                 if (dma_mapping_error(device, daddrs[i]))
1216                         continue;
1217
1218                 /* If it is read and write than map bi-directional. */
1219                 if (range->pfns[i] & range->flags[HMM_PFN_WRITE])
1220                         dir = DMA_BIDIRECTIONAL;
1221
1222                 dma_unmap_page(device, daddrs[i], PAGE_SIZE, dir);
1223                 mapped--;
1224         }
1225
1226         return ret;
1227 }
1228 EXPORT_SYMBOL(hmm_range_dma_map);
1229
1230 /**
1231  * hmm_range_dma_unmap() - unmap range of that was map with hmm_range_dma_map()
1232  * @range: range being unmapped
1233  * @vma: the vma against which the range (optional)
1234  * @device: device against which dma map was done
1235  * @daddrs: dma address of mapped pages
1236  * @dirty: dirty page if it had the write flag set
1237  * Return: number of page unmapped on success, -EINVAL otherwise
1238  *
1239  * Note that caller MUST abide by mmu notifier or use HMM mirror and abide
1240  * to the sync_cpu_device_pagetables() callback so that it is safe here to
1241  * call set_page_dirty(). Caller must also take appropriate locks to avoid
1242  * concurrent mmu notifier or sync_cpu_device_pagetables() to make progress.
1243  */
1244 long hmm_range_dma_unmap(struct hmm_range *range,
1245                          struct vm_area_struct *vma,
1246                          struct device *device,
1247                          dma_addr_t *daddrs,
1248                          bool dirty)
1249 {
1250         unsigned long i, npages;
1251         long cpages = 0;
1252
1253         /* Sanity check. */
1254         if (range->end <= range->start)
1255                 return -EINVAL;
1256         if (!daddrs)
1257                 return -EINVAL;
1258         if (!range->pfns)
1259                 return -EINVAL;
1260
1261         npages = (range->end - range->start) >> PAGE_SHIFT;
1262         for (i = 0; i < npages; ++i) {
1263                 enum dma_data_direction dir = DMA_TO_DEVICE;
1264                 struct page *page;
1265
1266                 page = hmm_device_entry_to_page(range, range->pfns[i]);
1267                 if (page == NULL)
1268                         continue;
1269
1270                 /* If it is read and write than map bi-directional. */
1271                 if (range->pfns[i] & range->flags[HMM_PFN_WRITE]) {
1272                         dir = DMA_BIDIRECTIONAL;
1273
1274                         /*
1275                          * See comments in function description on why it is
1276                          * safe here to call set_page_dirty()
1277                          */
1278                         if (dirty)
1279                                 set_page_dirty(page);
1280                 }
1281
1282                 /* Unmap and clear pfns/dma address */
1283                 dma_unmap_page(device, daddrs[i], PAGE_SIZE, dir);
1284                 range->pfns[i] = range->values[HMM_PFN_NONE];
1285                 /* FIXME see comments in hmm_vma_dma_map() */
1286                 daddrs[i] = 0;
1287                 cpages++;
1288         }
1289
1290         return cpages;
1291 }
1292 EXPORT_SYMBOL(hmm_range_dma_unmap);
1293 #endif /* IS_ENABLED(CONFIG_HMM_MIRROR) */
1294
1295
1296 #if IS_ENABLED(CONFIG_DEVICE_PRIVATE) ||  IS_ENABLED(CONFIG_DEVICE_PUBLIC)
1297 struct page *hmm_vma_alloc_locked_page(struct vm_area_struct *vma,
1298                                        unsigned long addr)
1299 {
1300         struct page *page;
1301
1302         page = alloc_page_vma(GFP_HIGHUSER, vma, addr);
1303         if (!page)
1304                 return NULL;
1305         lock_page(page);
1306         return page;
1307 }
1308 EXPORT_SYMBOL(hmm_vma_alloc_locked_page);
1309
1310
1311 static void hmm_devmem_ref_release(struct percpu_ref *ref)
1312 {
1313         struct hmm_devmem *devmem;
1314
1315         devmem = container_of(ref, struct hmm_devmem, ref);
1316         complete(&devmem->completion);
1317 }
1318
1319 static void hmm_devmem_ref_exit(void *data)
1320 {
1321         struct percpu_ref *ref = data;
1322         struct hmm_devmem *devmem;
1323
1324         devmem = container_of(ref, struct hmm_devmem, ref);
1325         wait_for_completion(&devmem->completion);
1326         percpu_ref_exit(ref);
1327 }
1328
1329 static void hmm_devmem_ref_kill(struct percpu_ref *ref)
1330 {
1331         percpu_ref_kill(ref);
1332 }
1333
1334 static vm_fault_t hmm_devmem_fault(struct vm_area_struct *vma,
1335                             unsigned long addr,
1336                             const struct page *page,
1337                             unsigned int flags,
1338                             pmd_t *pmdp)
1339 {
1340         struct hmm_devmem *devmem = page->pgmap->data;
1341
1342         return devmem->ops->fault(devmem, vma, addr, page, flags, pmdp);
1343 }
1344
1345 static void hmm_devmem_free(struct page *page, void *data)
1346 {
1347         struct hmm_devmem *devmem = data;
1348
1349         page->mapping = NULL;
1350
1351         devmem->ops->free(devmem, page);
1352 }
1353
1354 /*
1355  * hmm_devmem_add() - hotplug ZONE_DEVICE memory for device memory
1356  *
1357  * @ops: memory event device driver callback (see struct hmm_devmem_ops)
1358  * @device: device struct to bind the resource too
1359  * @size: size in bytes of the device memory to add
1360  * Return: pointer to new hmm_devmem struct ERR_PTR otherwise
1361  *
1362  * This function first finds an empty range of physical address big enough to
1363  * contain the new resource, and then hotplugs it as ZONE_DEVICE memory, which
1364  * in turn allocates struct pages. It does not do anything beyond that; all
1365  * events affecting the memory will go through the various callbacks provided
1366  * by hmm_devmem_ops struct.
1367  *
1368  * Device driver should call this function during device initialization and
1369  * is then responsible of memory management. HMM only provides helpers.
1370  */
1371 struct hmm_devmem *hmm_devmem_add(const struct hmm_devmem_ops *ops,
1372                                   struct device *device,
1373                                   unsigned long size)
1374 {
1375         struct hmm_devmem *devmem;
1376         resource_size_t addr;
1377         void *result;
1378         int ret;
1379
1380         dev_pagemap_get_ops();
1381
1382         devmem = devm_kzalloc(device, sizeof(*devmem), GFP_KERNEL);
1383         if (!devmem)
1384                 return ERR_PTR(-ENOMEM);
1385
1386         init_completion(&devmem->completion);
1387         devmem->pfn_first = -1UL;
1388         devmem->pfn_last = -1UL;
1389         devmem->resource = NULL;
1390         devmem->device = device;
1391         devmem->ops = ops;
1392
1393         ret = percpu_ref_init(&devmem->ref, &hmm_devmem_ref_release,
1394                               0, GFP_KERNEL);
1395         if (ret)
1396                 return ERR_PTR(ret);
1397
1398         ret = devm_add_action_or_reset(device, hmm_devmem_ref_exit, &devmem->ref);
1399         if (ret)
1400                 return ERR_PTR(ret);
1401
1402         size = ALIGN(size, PA_SECTION_SIZE);
1403         addr = min((unsigned long)iomem_resource.end,
1404                    (1UL << MAX_PHYSMEM_BITS) - 1);
1405         addr = addr - size + 1UL;
1406
1407         /*
1408          * FIXME add a new helper to quickly walk resource tree and find free
1409          * range
1410          *
1411          * FIXME what about ioport_resource resource ?
1412          */
1413         for (; addr > size && addr >= iomem_resource.start; addr -= size) {
1414                 ret = region_intersects(addr, size, 0, IORES_DESC_NONE);
1415                 if (ret != REGION_DISJOINT)
1416                         continue;
1417
1418                 devmem->resource = devm_request_mem_region(device, addr, size,
1419                                                            dev_name(device));
1420                 if (!devmem->resource)
1421                         return ERR_PTR(-ENOMEM);
1422                 break;
1423         }
1424         if (!devmem->resource)
1425                 return ERR_PTR(-ERANGE);
1426
1427         devmem->resource->desc = IORES_DESC_DEVICE_PRIVATE_MEMORY;
1428         devmem->pfn_first = devmem->resource->start >> PAGE_SHIFT;
1429         devmem->pfn_last = devmem->pfn_first +
1430                            (resource_size(devmem->resource) >> PAGE_SHIFT);
1431         devmem->page_fault = hmm_devmem_fault;
1432
1433         devmem->pagemap.type = MEMORY_DEVICE_PRIVATE;
1434         devmem->pagemap.res = *devmem->resource;
1435         devmem->pagemap.page_free = hmm_devmem_free;
1436         devmem->pagemap.altmap_valid = false;
1437         devmem->pagemap.ref = &devmem->ref;
1438         devmem->pagemap.data = devmem;
1439         devmem->pagemap.kill = hmm_devmem_ref_kill;
1440
1441         result = devm_memremap_pages(devmem->device, &devmem->pagemap);
1442         if (IS_ERR(result))
1443                 return result;
1444         return devmem;
1445 }
1446 EXPORT_SYMBOL_GPL(hmm_devmem_add);
1447
1448 struct hmm_devmem *hmm_devmem_add_resource(const struct hmm_devmem_ops *ops,
1449                                            struct device *device,
1450                                            struct resource *res)
1451 {
1452         struct hmm_devmem *devmem;
1453         void *result;
1454         int ret;
1455
1456         if (res->desc != IORES_DESC_DEVICE_PUBLIC_MEMORY)
1457                 return ERR_PTR(-EINVAL);
1458
1459         dev_pagemap_get_ops();
1460
1461         devmem = devm_kzalloc(device, sizeof(*devmem), GFP_KERNEL);
1462         if (!devmem)
1463                 return ERR_PTR(-ENOMEM);
1464
1465         init_completion(&devmem->completion);
1466         devmem->pfn_first = -1UL;
1467         devmem->pfn_last = -1UL;
1468         devmem->resource = res;
1469         devmem->device = device;
1470         devmem->ops = ops;
1471
1472         ret = percpu_ref_init(&devmem->ref, &hmm_devmem_ref_release,
1473                               0, GFP_KERNEL);
1474         if (ret)
1475                 return ERR_PTR(ret);
1476
1477         ret = devm_add_action_or_reset(device, hmm_devmem_ref_exit,
1478                         &devmem->ref);
1479         if (ret)
1480                 return ERR_PTR(ret);
1481
1482         devmem->pfn_first = devmem->resource->start >> PAGE_SHIFT;
1483         devmem->pfn_last = devmem->pfn_first +
1484                            (resource_size(devmem->resource) >> PAGE_SHIFT);
1485         devmem->page_fault = hmm_devmem_fault;
1486
1487         devmem->pagemap.type = MEMORY_DEVICE_PUBLIC;
1488         devmem->pagemap.res = *devmem->resource;
1489         devmem->pagemap.page_free = hmm_devmem_free;
1490         devmem->pagemap.altmap_valid = false;
1491         devmem->pagemap.ref = &devmem->ref;
1492         devmem->pagemap.data = devmem;
1493         devmem->pagemap.kill = hmm_devmem_ref_kill;
1494
1495         result = devm_memremap_pages(devmem->device, &devmem->pagemap);
1496         if (IS_ERR(result))
1497                 return result;
1498         return devmem;
1499 }
1500 EXPORT_SYMBOL_GPL(hmm_devmem_add_resource);
1501
1502 /*
1503  * A device driver that wants to handle multiple devices memory through a
1504  * single fake device can use hmm_device to do so. This is purely a helper
1505  * and it is not needed to make use of any HMM functionality.
1506  */
1507 #define HMM_DEVICE_MAX 256
1508
1509 static DECLARE_BITMAP(hmm_device_mask, HMM_DEVICE_MAX);
1510 static DEFINE_SPINLOCK(hmm_device_lock);
1511 static struct class *hmm_device_class;
1512 static dev_t hmm_device_devt;
1513
1514 static void hmm_device_release(struct device *device)
1515 {
1516         struct hmm_device *hmm_device;
1517
1518         hmm_device = container_of(device, struct hmm_device, device);
1519         spin_lock(&hmm_device_lock);
1520         clear_bit(hmm_device->minor, hmm_device_mask);
1521         spin_unlock(&hmm_device_lock);
1522
1523         kfree(hmm_device);
1524 }
1525
1526 struct hmm_device *hmm_device_new(void *drvdata)
1527 {
1528         struct hmm_device *hmm_device;
1529
1530         hmm_device = kzalloc(sizeof(*hmm_device), GFP_KERNEL);
1531         if (!hmm_device)
1532                 return ERR_PTR(-ENOMEM);
1533
1534         spin_lock(&hmm_device_lock);
1535         hmm_device->minor = find_first_zero_bit(hmm_device_mask, HMM_DEVICE_MAX);
1536         if (hmm_device->minor >= HMM_DEVICE_MAX) {
1537                 spin_unlock(&hmm_device_lock);
1538                 kfree(hmm_device);
1539                 return ERR_PTR(-EBUSY);
1540         }
1541         set_bit(hmm_device->minor, hmm_device_mask);
1542         spin_unlock(&hmm_device_lock);
1543
1544         dev_set_name(&hmm_device->device, "hmm_device%d", hmm_device->minor);
1545         hmm_device->device.devt = MKDEV(MAJOR(hmm_device_devt),
1546                                         hmm_device->minor);
1547         hmm_device->device.release = hmm_device_release;
1548         dev_set_drvdata(&hmm_device->device, drvdata);
1549         hmm_device->device.class = hmm_device_class;
1550         device_initialize(&hmm_device->device);
1551
1552         return hmm_device;
1553 }
1554 EXPORT_SYMBOL(hmm_device_new);
1555
1556 void hmm_device_put(struct hmm_device *hmm_device)
1557 {
1558         put_device(&hmm_device->device);
1559 }
1560 EXPORT_SYMBOL(hmm_device_put);
1561
1562 static int __init hmm_init(void)
1563 {
1564         int ret;
1565
1566         ret = alloc_chrdev_region(&hmm_device_devt, 0,
1567                                   HMM_DEVICE_MAX,
1568                                   "hmm_device");
1569         if (ret)
1570                 return ret;
1571
1572         hmm_device_class = class_create(THIS_MODULE, "hmm_device");
1573         if (IS_ERR(hmm_device_class)) {
1574                 unregister_chrdev_region(hmm_device_devt, HMM_DEVICE_MAX);
1575                 return PTR_ERR(hmm_device_class);
1576         }
1577         return 0;
1578 }
1579
1580 device_initcall(hmm_init);
1581 #endif /* CONFIG_DEVICE_PRIVATE || CONFIG_DEVICE_PUBLIC */