mm/memory.c: fix race when faulting a device private page
[linux-block.git] / drivers / gpu / drm / amd / amdkfd / kfd_migrate.c
index b059a77b6081d8c003127420bb08f72238e410f1..776448bd9fe4abfa6c6a4d16b667b9eb451199f0 100644 (file)
@@ -409,7 +409,7 @@ svm_migrate_vma_to_vram(struct amdgpu_device *adev, struct svm_range *prange,
        uint64_t npages = (end - start) >> PAGE_SHIFT;
        struct kfd_process_device *pdd;
        struct dma_fence *mfence = NULL;
-       struct migrate_vma migrate;
+       struct migrate_vma migrate = { 0 };
        unsigned long cpages = 0;
        dma_addr_t *scratch;
        void *buf;
@@ -668,7 +668,7 @@ out_oom:
 static long
 svm_migrate_vma_to_ram(struct amdgpu_device *adev, struct svm_range *prange,
                       struct vm_area_struct *vma, uint64_t start, uint64_t end,
-                      uint32_t trigger)
+                      uint32_t trigger, struct page *fault_page)
 {
        struct kfd_process *p = container_of(prange->svms, struct kfd_process, svms);
        uint64_t npages = (end - start) >> PAGE_SHIFT;
@@ -676,7 +676,7 @@ svm_migrate_vma_to_ram(struct amdgpu_device *adev, struct svm_range *prange,
        unsigned long cpages = 0;
        struct kfd_process_device *pdd;
        struct dma_fence *mfence = NULL;
-       struct migrate_vma migrate;
+       struct migrate_vma migrate = { 0 };
        dma_addr_t *scratch;
        void *buf;
        int r = -ENOMEM;
@@ -699,6 +699,7 @@ svm_migrate_vma_to_ram(struct amdgpu_device *adev, struct svm_range *prange,
 
        migrate.src = buf;
        migrate.dst = migrate.src + npages;
+       migrate.fault_page = fault_page;
        scratch = (dma_addr_t *)(migrate.dst + npages);
 
        kfd_smi_event_migration_start(adev->kfd.dev, p->lead_thread->pid,
@@ -766,7 +767,7 @@ out:
  * 0 - OK, otherwise error code
  */
 int svm_migrate_vram_to_ram(struct svm_range *prange, struct mm_struct *mm,
-                           uint32_t trigger)
+                           uint32_t trigger, struct page *fault_page)
 {
        struct amdgpu_device *adev;
        struct vm_area_struct *vma;
@@ -807,7 +808,8 @@ int svm_migrate_vram_to_ram(struct svm_range *prange, struct mm_struct *mm,
                }
 
                next = min(vma->vm_end, end);
-               r = svm_migrate_vma_to_ram(adev, prange, vma, addr, next, trigger);
+               r = svm_migrate_vma_to_ram(adev, prange, vma, addr, next, trigger,
+                       fault_page);
                if (r < 0) {
                        pr_debug("failed %ld to migrate prange %p\n", r, prange);
                        break;
@@ -851,7 +853,7 @@ svm_migrate_vram_to_vram(struct svm_range *prange, uint32_t best_loc,
        pr_debug("from gpu 0x%x to gpu 0x%x\n", prange->actual_loc, best_loc);
 
        do {
-               r = svm_migrate_vram_to_ram(prange, mm, trigger);
+               r = svm_migrate_vram_to_ram(prange, mm, trigger, NULL);
                if (r)
                        return r;
        } while (prange->actual_loc && --retries);
@@ -938,7 +940,8 @@ static vm_fault_t svm_migrate_to_ram(struct vm_fault *vmf)
                goto out_unlock_prange;
        }
 
-       r = svm_migrate_vram_to_ram(prange, mm, KFD_MIGRATE_TRIGGER_PAGEFAULT_CPU);
+       r = svm_migrate_vram_to_ram(prange, mm, KFD_MIGRATE_TRIGGER_PAGEFAULT_CPU,
+                               vmf->page);
        if (r)
                pr_debug("failed %d migrate 0x%p [0x%lx 0x%lx] to ram\n", r,
                         prange, prange->start, prange->last);