mm/hugetlb: fix huge_pmd_unshare() vs GUP-fast race
authorJann Horn <jannh@google.com>
Tue, 27 May 2025 21:23:54 +0000 (23:23 +0200)
committerAndrew Morton <akpm@linux-foundation.org>
Fri, 6 Jun 2025 05:02:24 +0000 (22:02 -0700)
huge_pmd_unshare() drops a reference on a page table that may have
previously been shared across processes, potentially turning it into a
normal page table used in another process in which unrelated VMAs can
afterwards be installed.

If this happens in the middle of a concurrent gup_fast(), gup_fast() could
end up walking the page tables of another process.  While I don't see any
way in which that immediately leads to kernel memory corruption, it is
really weird and unexpected.

Fix it with an explicit broadcast IPI through tlb_remove_table_sync_one(),
just like we do in khugepaged when removing page tables for a THP
collapse.

Link: https://lkml.kernel.org/r/20250528-hugetlb-fixes-splitrace-v2-2-1329349bad1a@google.com
Link: https://lkml.kernel.org/r/20250527-hugetlb-fixes-splitrace-v1-2-f4136f5ec58a@google.com
Fixes: 39dde65c9940 ("[PATCH] shared page table for hugetlb page")
Signed-off-by: Jann Horn <jannh@google.com>
Reviewed-by: Lorenzo Stoakes <lorenzo.stoakes@oracle.com>
Cc: Liam Howlett <liam.howlett@oracle.com>
Cc: Muchun Song <muchun.song@linux.dev>
Cc: Oscar Salvador <osalvador@suse.de>
Cc: Vlastimil Babka <vbabka@suse.cz>
Cc: <stable@vger.kernel.org>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
mm/hugetlb.c

index 7ba020d489d45643a7eb0807de6b35e2d9bfae9b..8746ed2fec135b646ba28536fbb4c816461a058a 100644 (file)
@@ -7629,6 +7629,13 @@ int huge_pmd_unshare(struct mm_struct *mm, struct vm_area_struct *vma,
                return 0;
 
        pud_clear(pud);
+       /*
+        * Once our caller drops the rmap lock, some other process might be
+        * using this page table as a normal, non-hugetlb page table.
+        * Wait for pending gup_fast() in other threads to finish before letting
+        * that happen.
+        */
+       tlb_remove_table_sync_one();
        ptdesc_pmd_pts_dec(virt_to_ptdesc(ptep));
        mm_dec_nr_pmds(mm);
        return 1;