mm/memory.c: fix race when faulting a device private page
[linux-block.git] / drivers / gpu / drm / amd / amdkfd / kfd_svm.c
index a67ba8879a56730226cc0ebbdd21a35cce8d68ba..9139e5a0b2a07cc46029424f29f7842d7e192826 100644 (file)
@@ -541,7 +541,6 @@ svm_range_vram_node_new(struct amdgpu_device *adev, struct svm_range *prange,
                kfree(svm_bo);
                return -ESRCH;
        }
-       svm_bo->svms = prange->svms;
        svm_bo->eviction_fence =
                amdgpu_amdkfd_fence_create(dma_fence_context_alloc(1),
                                           mm,
@@ -2914,13 +2913,15 @@ retry_write_locked:
                                 */
                                if (prange->actual_loc)
                                        r = svm_migrate_vram_to_ram(prange, mm,
-                                          KFD_MIGRATE_TRIGGER_PAGEFAULT_GPU);
+                                          KFD_MIGRATE_TRIGGER_PAGEFAULT_GPU,
+                                          NULL);
                                else
                                        r = 0;
                        }
                } else {
                        r = svm_migrate_vram_to_ram(prange, mm,
-                                       KFD_MIGRATE_TRIGGER_PAGEFAULT_GPU);
+                                       KFD_MIGRATE_TRIGGER_PAGEFAULT_GPU,
+                                       NULL);
                }
                if (r) {
                        pr_debug("failed %d to migrate svms %p [0x%lx 0x%lx]\n",
@@ -3243,7 +3244,8 @@ svm_range_trigger_migration(struct mm_struct *mm, struct svm_range *prange,
                return 0;
 
        if (!best_loc) {
-               r = svm_migrate_vram_to_ram(prange, mm, KFD_MIGRATE_TRIGGER_PREFETCH);
+               r = svm_migrate_vram_to_ram(prange, mm,
+                                       KFD_MIGRATE_TRIGGER_PREFETCH, NULL);
                *migrated = !r;
                return r;
        }
@@ -3273,7 +3275,6 @@ int svm_range_schedule_evict_svm_bo(struct amdgpu_amdkfd_fence *fence)
 static void svm_range_evict_svm_bo_worker(struct work_struct *work)
 {
        struct svm_range_bo *svm_bo;
-       struct kfd_process *p;
        struct mm_struct *mm;
        int r = 0;
 
@@ -3281,13 +3282,12 @@ static void svm_range_evict_svm_bo_worker(struct work_struct *work)
        if (!svm_bo_ref_unless_zero(svm_bo))
                return; /* svm_bo was freed while eviction was pending */
 
-       /* svm_range_bo_release destroys this worker thread. So during
-        * the lifetime of this thread, kfd_process and mm will be valid.
-        */
-       p = container_of(svm_bo->svms, struct kfd_process, svms);
-       mm = p->mm;
-       if (!mm)
+       if (mmget_not_zero(svm_bo->eviction_fence->mm)) {
+               mm = svm_bo->eviction_fence->mm;
+       } else {
+               svm_range_bo_unref(svm_bo);
                return;
+       }
 
        mmap_read_lock(mm);
        spin_lock(&svm_bo->list_lock);
@@ -3305,9 +3305,8 @@ static void svm_range_evict_svm_bo_worker(struct work_struct *work)
 
                mutex_lock(&prange->migrate_mutex);
                do {
-                       r = svm_migrate_vram_to_ram(prange,
-                                               svm_bo->eviction_fence->mm,
-                                               KFD_MIGRATE_TRIGGER_TTM_EVICTION);
+                       r = svm_migrate_vram_to_ram(prange, mm,
+                                       KFD_MIGRATE_TRIGGER_TTM_EVICTION, NULL);
                } while (!r && prange->actual_loc && --retries);
 
                if (!r && prange->actual_loc)
@@ -3324,6 +3323,7 @@ static void svm_range_evict_svm_bo_worker(struct work_struct *work)
        }
        spin_unlock(&svm_bo->list_lock);
        mmap_read_unlock(mm);
+       mmput(mm);
 
        dma_fence_signal(&svm_bo->eviction_fence->base);