mm/hmm: check the device private page owner in hmm_range_fault()
[linux-block.git] / mm / hmm.c
index 72e5a6d9a41756c93697cb3160634dbcaed25f5f..a491d9aaafe45d91f0138b8152f44cb2c50e429a 100644 (file)
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -33,38 +33,6 @@ struct hmm_vma_walk {
        unsigned int            flags;
 };
 
-static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
-                           bool write_fault, uint64_t *pfn)
-{
-       unsigned int flags = FAULT_FLAG_REMOTE;
-       struct hmm_vma_walk *hmm_vma_walk = walk->private;
-       struct hmm_range *range = hmm_vma_walk->range;
-       struct vm_area_struct *vma = walk->vma;
-       vm_fault_t ret;
-
-       if (!vma)
-               goto err;
-
-       if (hmm_vma_walk->flags & HMM_FAULT_ALLOW_RETRY)
-               flags |= FAULT_FLAG_ALLOW_RETRY;
-       if (write_fault)
-               flags |= FAULT_FLAG_WRITE;
-
-       ret = handle_mm_fault(vma, addr, flags);
-       if (ret & VM_FAULT_RETRY) {
-               /* Note, handle_mm_fault did up_read(&mm->mmap_sem)) */
-               return -EAGAIN;
-       }
-       if (ret & VM_FAULT_ERROR)
-               goto err;
-
-       return -EBUSY;
-
-err:
-       *pfn = range->values[HMM_PFN_ERROR];
-       return -EFAULT;
-}
-
 static int hmm_pfns_fill(unsigned long addr, unsigned long end,
                struct hmm_range *range, enum hmm_pfn_value_e value)
 {
@@ -79,45 +47,49 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long end,
 }
 
 /*
- * hmm_vma_walk_hole_() - handle a range lacking valid pmd or pte(s)
+ * hmm_vma_fault() - fault in a range lacking valid pmd or pte(s)
  * @addr: range virtual start address (inclusive)
  * @end: range virtual end address (exclusive)
  * @fault: should we fault or not ?
  * @write_fault: write fault ?
  * @walk: mm_walk structure
- * Return: 0 on success, -EBUSY after page fault, or page fault error
+ * Return: -EBUSY after page fault, or page fault error
  *
  * This function will be called whenever pmd_none() or pte_none() returns true,
  * or whenever there is no page directory covering the virtual address range.
  */
-static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end,
+static int hmm_vma_fault(unsigned long addr, unsigned long end,
                              bool fault, bool write_fault,
                              struct mm_walk *walk)
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
+       struct vm_area_struct *vma = walk->vma;
        uint64_t *pfns = range->pfns;
-       unsigned long i;
+       unsigned long i = (addr - range->start) >> PAGE_SHIFT;
+       unsigned int fault_flags = FAULT_FLAG_REMOTE;
 
+       WARN_ON_ONCE(!fault && !write_fault);
        hmm_vma_walk->last = addr;
-       i = (addr - range->start) >> PAGE_SHIFT;
 
-       if (write_fault && walk->vma && !(walk->vma->vm_flags & VM_WRITE))
-               return -EPERM;
-
-       for (; addr < end; addr += PAGE_SIZE, i++) {
-               pfns[i] = range->values[HMM_PFN_NONE];
-               if (fault || write_fault) {
-                       int ret;
+       if (!vma)
+               goto out_error;
 
-                       ret = hmm_vma_do_fault(walk, addr, write_fault,
-                                              &pfns[i]);
-                       if (ret != -EBUSY)
-                               return ret;
-               }
+       if (write_fault) {
+               if (!(vma->vm_flags & VM_WRITE))
+                       return -EPERM;
+               fault_flags |= FAULT_FLAG_WRITE;
        }
 
-       return (fault || write_fault) ? -EBUSY : 0;
+       for (; addr < end; addr += PAGE_SIZE, i++)
+               if (handle_mm_fault(vma, addr, fault_flags) & VM_FAULT_ERROR)
+                       goto out_error;
+
+       return -EBUSY;
+
+out_error:
+       pfns[i] = range->values[HMM_PFN_ERROR];
+       return -EFAULT;
 }
 
 static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
@@ -144,15 +116,6 @@ static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
        /* We aren't ask to do anything ... */
        if (!(pfns & range->flags[HMM_PFN_VALID]))
                return;
-       /* If this is device memory then only fault if explicitly requested */
-       if ((cpu_flags & range->flags[HMM_PFN_DEVICE_PRIVATE])) {
-               /* Do we fault on device memory ? */
-               if (pfns & range->flags[HMM_PFN_DEVICE_PRIVATE]) {
-                       *write_fault = pfns & range->flags[HMM_PFN_WRITE];
-                       *fault = true;
-               }
-               return;
-       }
 
        /* If CPU page table is not valid then we need to fault */
        *fault = !(cpu_flags & range->flags[HMM_PFN_VALID]);
@@ -199,7 +162,10 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
        pfns = &range->pfns[i];
        hmm_range_need_fault(hmm_vma_walk, pfns, npages,
                             0, &fault, &write_fault);
-       return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
+       if (fault || write_fault)
+               return hmm_vma_fault(addr, end, fault, write_fault, walk);
+       hmm_vma_walk->last = addr;
+       return hmm_pfns_fill(addr, end, range, HMM_PFN_NONE);
 }
 
 static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd)
@@ -226,8 +192,8 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr,
        hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags,
                             &fault, &write_fault);
 
-       if (pmd_protnone(pmd) || fault || write_fault)
-               return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
+       if (fault || write_fault)
+               return hmm_vma_fault(addr, end, fault, write_fault, walk);
 
        pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
        for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) {
@@ -252,6 +218,14 @@ int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr,
                unsigned long end, uint64_t *pfns, pmd_t pmd);
 #endif /* CONFIG_TRANSPARENT_HUGEPAGE */
 
+static inline bool hmm_is_device_private_entry(struct hmm_range *range,
+               swp_entry_t entry)
+{
+       return is_device_private_entry(entry) &&
+               device_private_entry_to_page(entry)->pgmap->owner ==
+               range->dev_private_owner;
+}
+
 static inline uint64_t pte_to_hmm_pfn_flags(struct hmm_range *range, pte_t pte)
 {
        if (pte_none(pte) || !pte_present(pte) || pte_protnone(pte))
@@ -286,70 +260,68 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
        if (!pte_present(pte)) {
                swp_entry_t entry = pte_to_swp_entry(pte);
 
-               if (!non_swap_entry(entry)) {
-                       cpu_flags = pte_to_hmm_pfn_flags(range, pte);
-                       hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
-                                          &fault, &write_fault);
-                       if (fault || write_fault)
-                               goto fault;
-                       return 0;
-               }
-
                /*
-                * This is a special swap entry, ignore migration, use
-                * device and report anything else as error.
+                * Never fault in device private pages pages, but just report
+                * the PFN even if not present.
                 */
-               if (is_device_private_entry(entry)) {
-                       cpu_flags = range->flags[HMM_PFN_VALID] |
-                               range->flags[HMM_PFN_DEVICE_PRIVATE];
-                       cpu_flags |= is_write_device_private_entry(entry) ?
-                               range->flags[HMM_PFN_WRITE] : 0;
-                       hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
-                                          &fault, &write_fault);
-                       if (fault || write_fault)
-                               goto fault;
+               if (hmm_is_device_private_entry(range, entry)) {
                        *pfn = hmm_device_entry_from_pfn(range,
                                            swp_offset(entry));
-                       *pfn |= cpu_flags;
+                       *pfn |= range->flags[HMM_PFN_VALID];
+                       if (is_write_device_private_entry(entry))
+                               *pfn |= range->flags[HMM_PFN_WRITE];
                        return 0;
                }
 
-               if (is_migration_entry(entry)) {
-                       if (fault || write_fault) {
-                               pte_unmap(ptep);
-                               hmm_vma_walk->last = addr;
-                               migration_entry_wait(walk->mm, pmdp, addr);
-                               return -EBUSY;
-                       }
+               hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0, &fault,
+                                  &write_fault);
+               if (!fault && !write_fault)
                        return 0;
+
+               if (!non_swap_entry(entry))
+                       goto fault;
+
+               if (is_migration_entry(entry)) {
+                       pte_unmap(ptep);
+                       hmm_vma_walk->last = addr;
+                       migration_entry_wait(walk->mm, pmdp, addr);
+                       return -EBUSY;
                }
 
                /* Report error for everything else */
+               pte_unmap(ptep);
                *pfn = range->values[HMM_PFN_ERROR];
                return -EFAULT;
-       } else {
-               cpu_flags = pte_to_hmm_pfn_flags(range, pte);
-               hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
-                                  &fault, &write_fault);
        }
 
+       cpu_flags = pte_to_hmm_pfn_flags(range, pte);
+       hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags, &fault,
+                          &write_fault);
        if (fault || write_fault)
                goto fault;
 
        if (pte_devmap(pte)) {
                hmm_vma_walk->pgmap = get_dev_pagemap(pte_pfn(pte),
                                              hmm_vma_walk->pgmap);
-               if (unlikely(!hmm_vma_walk->pgmap))
+               if (unlikely(!hmm_vma_walk->pgmap)) {
+                       pte_unmap(ptep);
                        return -EBUSY;
-       } else if (IS_ENABLED(CONFIG_ARCH_HAS_PTE_SPECIAL) && pte_special(pte)) {
-               if (!is_zero_pfn(pte_pfn(pte))) {
-                       *pfn = range->values[HMM_PFN_SPECIAL];
+               }
+       }
+
+       /*
+        * Since each architecture defines a struct page for the zero page, just
+        * fall through and treat it like a normal page.
+        */
+       if (pte_special(pte) && !is_zero_pfn(pte_pfn(pte))) {
+               hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0, &fault,
+                                  &write_fault);
+               if (fault || write_fault) {
+                       pte_unmap(ptep);
                        return -EFAULT;
                }
-               /*
-                * Since each architecture defines a struct page for the zero
-                * page, just fall through and treat it like a normal page.
-                */
+               *pfn = range->values[HMM_PFN_SPECIAL];
+               return 0;
        }
 
        *pfn = hmm_device_entry_from_pfn(range, pte_pfn(pte)) | cpu_flags;
@@ -362,7 +334,7 @@ fault:
        }
        pte_unmap(ptep);
        /* Fault any virtual address we were asked to fault */
-       return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
+       return hmm_vma_fault(addr, end, fault, write_fault, walk);
 }
 
 static int hmm_vma_walk_pmd(pmd_t *pmdp,
@@ -372,8 +344,10 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
 {
        struct hmm_vma_walk *hmm_vma_walk = walk->private;
        struct hmm_range *range = hmm_vma_walk->range;
-       uint64_t *pfns = range->pfns;
-       unsigned long addr = start, i;
+       uint64_t *pfns = &range->pfns[(start - range->start) >> PAGE_SHIFT];
+       unsigned long npages = (end - start) >> PAGE_SHIFT;
+       unsigned long addr = start;
+       bool fault, write_fault;
        pte_t *ptep;
        pmd_t pmd;
 
@@ -383,14 +357,6 @@ again:
                return hmm_vma_walk_hole(start, end, -1, walk);
 
        if (thp_migration_supported() && is_pmd_migration_entry(pmd)) {
-               bool fault, write_fault;
-               unsigned long npages;
-               uint64_t *pfns;
-
-               i = (addr - range->start) >> PAGE_SHIFT;
-               npages = (end - addr) >> PAGE_SHIFT;
-               pfns = &range->pfns[i];
-
                hmm_range_need_fault(hmm_vma_walk, pfns, npages,
                                     0, &fault, &write_fault);
                if (fault || write_fault) {
@@ -398,9 +364,16 @@ again:
                        pmd_migration_entry_wait(walk->mm, pmdp);
                        return -EBUSY;
                }
-               return 0;
-       } else if (!pmd_present(pmd))
+               return hmm_pfns_fill(start, end, range, HMM_PFN_NONE);
+       }
+
+       if (!pmd_present(pmd)) {
+               hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0, &fault,
+                                    &write_fault);
+               if (fault || write_fault)
+                       return -EFAULT;
                return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
+       }
 
        if (pmd_devmap(pmd) || pmd_trans_huge(pmd)) {
                /*
@@ -417,8 +390,7 @@ again:
                if (!pmd_devmap(pmd) && !pmd_trans_huge(pmd))
                        goto again;
 
-               i = (addr - range->start) >> PAGE_SHIFT;
-               return hmm_vma_handle_pmd(walk, addr, end, &pfns[i], pmd);
+               return hmm_vma_handle_pmd(walk, addr, end, pfns, pmd);
        }
 
        /*
@@ -427,17 +399,21 @@ again:
         * entry pointing to pte directory or it is a bad pmd that will not
         * recover.
         */
-       if (pmd_bad(pmd))
+       if (pmd_bad(pmd)) {
+               hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0, &fault,
+                                    &write_fault);
+               if (fault || write_fault)
+                       return -EFAULT;
                return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
+       }
 
        ptep = pte_offset_map(pmdp, addr);
-       i = (addr - range->start) >> PAGE_SHIFT;
-       for (; addr < end; addr += PAGE_SIZE, ptep++, i++) {
+       for (; addr < end; addr += PAGE_SIZE, ptep++, pfns++) {
                int r;
 
-               r = hmm_vma_handle_pte(walk, addr, end, pmdp, ptep, &pfns[i]);
+               r = hmm_vma_handle_pte(walk, addr, end, pmdp, ptep, pfns);
                if (r) {
-                       /* hmm_vma_handle_pte() did unmap pte directory */
+                       /* hmm_vma_handle_pte() did pte_unmap() */
                        hmm_vma_walk->last = addr;
                        return r;
                }
@@ -487,8 +463,8 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end,
 
        pud = READ_ONCE(*pudp);
        if (pud_none(pud)) {
-               ret = hmm_vma_walk_hole(start, end, -1, walk);
-               goto out_unlock;
+               spin_unlock(ptl);
+               return hmm_vma_walk_hole(start, end, -1, walk);
        }
 
        if (pud_huge(pud) && pud_devmap(pud)) {
@@ -497,8 +473,8 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end,
                bool fault, write_fault;
 
                if (!pud_present(pud)) {
-                       ret = hmm_vma_walk_hole(start, end, -1, walk);
-                       goto out_unlock;
+                       spin_unlock(ptl);
+                       return hmm_vma_walk_hole(start, end, -1, walk);
                }
 
                i = (addr - range->start) >> PAGE_SHIFT;
@@ -509,9 +485,9 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end,
                hmm_range_need_fault(hmm_vma_walk, pfns, npages,
                                     cpu_flags, &fault, &write_fault);
                if (fault || write_fault) {
-                       ret = hmm_vma_walk_hole_(addr, end, fault,
-                                                write_fault, walk);
-                       goto out_unlock;
+                       spin_unlock(ptl);
+                       return hmm_vma_fault(addr, end, fault, write_fault,
+                                                 walk);
                }
 
                pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
@@ -557,7 +533,6 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
        bool fault, write_fault;
        spinlock_t *ptl;
        pte_t entry;
-       int ret = 0;
 
        ptl = huge_pte_lock(hstate_vma(vma), walk->mm, pte);
        entry = huge_ptep_get(pte);
@@ -570,8 +545,8 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
        hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
                           &fault, &write_fault);
        if (fault || write_fault) {
-               ret = -ENOENT;
-               goto unlock;
+               spin_unlock(ptl);
+               return hmm_vma_fault(addr, end, fault, write_fault, walk);
        }
 
        pfn = pte_pfn(entry) + ((start & ~hmask) >> PAGE_SHIFT);
@@ -579,14 +554,8 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
                range->pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
                                 cpu_flags;
        hmm_vma_walk->last = end;
-
-unlock:
        spin_unlock(ptl);
-
-       if (ret == -ENOENT)
-               return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
-
-       return ret;
+       return 0;
 }
 #else
 #define hmm_vma_walk_hugetlb_entry NULL
@@ -600,18 +569,15 @@ static int hmm_vma_walk_test(unsigned long start, unsigned long end,
        struct vm_area_struct *vma = walk->vma;
 
        /*
-        * Skip vma ranges that don't have struct page backing them or
-        * map I/O devices directly.
-        */
-       if (vma->vm_flags & (VM_IO | VM_PFNMAP | VM_MIXEDMAP))
-               return -EFAULT;
-
-       /*
+        * Skip vma ranges that don't have struct page backing them or map I/O
+        * devices directly.
+        *
         * If the vma does not allow read access, then assume that it does not
-        * allow write access either. HMM does not support architectures
-        * that allow write without read.
+        * allow write access either. HMM does not support architectures that
+        * allow write without read.
         */
-       if (!(vma->vm_flags & VM_READ)) {
+       if ((vma->vm_flags & (VM_IO | VM_PFNMAP | VM_MIXEDMAP)) ||
+           !(vma->vm_flags & VM_READ)) {
                bool fault, write_fault;
 
                /*
@@ -625,7 +591,7 @@ static int hmm_vma_walk_test(unsigned long start, unsigned long end,
                if (fault || write_fault)
                        return -EFAULT;
 
-               hmm_pfns_fill(start, end, range, HMM_PFN_NONE);
+               hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
                hmm_vma_walk->last = end;
 
                /* Skip this vma and continue processing the next vma. */
@@ -657,7 +623,6 @@ static const struct mm_walk_ops hmm_walk_ops = {
  * -ENOMEM:    Out of memory.
  * -EPERM:     Invalid permission (e.g., asking for write and range is read
  *             only).
- * -EAGAIN:    A page fault needs to be retried and mmap_sem was dropped.
  * -EBUSY:     The range has been invalidated and the caller needs to wait for
  *             the invalidation to finish.
  * -EFAULT:    Invalid (i.e., either no valid vma or it is illegal to access