Merge tag 'mm-stable-2023-04-27-15-30' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-block.git] / mm / mmap.c
index eefa6f0cda28e27bc5394a6aeb06b7405a7b204c..5522130ae6065d256c8cd80616378137c417132f 100644 (file)
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -46,6 +46,7 @@
 #include <linux/pkeys.h>
 #include <linux/oom.h>
 #include <linux/sched/mm.h>
+#include <linux/ksm.h>
 
 #include <linux/uaccess.h>
 #include <asm/cacheflush.h>
@@ -133,7 +134,7 @@ void unlink_file_vma(struct vm_area_struct *vma)
 /*
  * Close a vm structure and free it.
  */
-static void remove_vma(struct vm_area_struct *vma)
+static void remove_vma(struct vm_area_struct *vma, bool unreachable)
 {
        might_sleep();
        if (vma->vm_ops && vma->vm_ops->close)
@@ -141,7 +142,10 @@ static void remove_vma(struct vm_area_struct *vma)
        if (vma->vm_file)
                fput(vma->vm_file);
        mpol_put(vma_policy(vma));
-       vm_area_free(vma);
+       if (unreachable)
+               __vm_area_free(vma);
+       else
+               vm_area_free(vma);
 }
 
 static inline struct vm_area_struct *vma_prev_limit(struct vma_iterator *vmi,
@@ -502,6 +506,15 @@ static inline void init_vma_prep(struct vma_prepare *vp,
  */
 static inline void vma_prepare(struct vma_prepare *vp)
 {
+       vma_start_write(vp->vma);
+       if (vp->adj_next)
+               vma_start_write(vp->adj_next);
+       /* vp->insert is always a newly created VMA, no need for locking */
+       if (vp->remove)
+               vma_start_write(vp->remove);
+       if (vp->remove2)
+               vma_start_write(vp->remove2);
+
        if (vp->file) {
                uprobe_munmap(vp->vma, vp->vma->vm_start, vp->vma->vm_end);
 
@@ -590,6 +603,7 @@ static inline void vma_complete(struct vma_prepare *vp,
 
        if (vp->remove) {
 again:
+               vma_mark_detached(vp->remove, true);
                if (vp->file) {
                        uprobe_munmap(vp->remove, vp->remove->vm_start,
                                      vp->remove->vm_end);
@@ -605,7 +619,7 @@ again:
 
                /*
                 * In mprotect's case 6 (see comments on vma_merge),
-                * we must remove the one after next as well.
+                * we are removing both mid and next vmas
                 */
                if (vp->remove2) {
                        vp->remove = vp->remove2;
@@ -683,12 +697,12 @@ int vma_expand(struct vma_iterator *vmi, struct vm_area_struct *vma,
        if (vma_iter_prealloc(vmi))
                goto nomem;
 
+       vma_prepare(&vp);
        vma_adjust_trans_huge(vma, start, end, 0);
        /* VMA iterator points to previous, so set to start if necessary */
        if (vma_iter_addr(vmi) != start)
                vma_iter_set(vmi, start);
 
-       vma_prepare(&vp);
        vma->vm_start = start;
        vma->vm_end = end;
        vma->vm_pgoff = pgoff;
@@ -723,8 +737,8 @@ int vma_shrink(struct vma_iterator *vmi, struct vm_area_struct *vma,
                return -ENOMEM;
 
        init_vma_prep(&vp, vma);
-       vma_adjust_trans_huge(vma, start, end, 0);
        vma_prepare(&vp);
+       vma_adjust_trans_huge(vma, start, end, 0);
 
        if (vma->vm_start < start)
                vma_iter_clear(vmi, vma->vm_start, start);
@@ -742,12 +756,13 @@ int vma_shrink(struct vma_iterator *vmi, struct vm_area_struct *vma,
 
 /*
  * If the vma has a ->close operation then the driver probably needs to release
- * per-vma resources, so we don't attempt to merge those.
+ * per-vma resources, so we don't attempt to merge those if the caller indicates
+ * the current vma may be removed as part of the merge.
  */
-static inline int is_mergeable_vma(struct vm_area_struct *vma,
-                                  struct file *file, unsigned long vm_flags,
-                                  struct vm_userfaultfd_ctx vm_userfaultfd_ctx,
-                                  struct anon_vma_name *anon_name)
+static inline bool is_mergeable_vma(struct vm_area_struct *vma,
+               struct file *file, unsigned long vm_flags,
+               struct vm_userfaultfd_ctx vm_userfaultfd_ctx,
+               struct anon_vma_name *anon_name, bool may_remove_vma)
 {
        /*
         * VM_SOFTDIRTY should not prevent from VMA merging, if we
@@ -758,21 +773,20 @@ static inline int is_mergeable_vma(struct vm_area_struct *vma,
         * extended instead.
         */
        if ((vma->vm_flags ^ vm_flags) & ~VM_SOFTDIRTY)
-               return 0;
+               return false;
        if (vma->vm_file != file)
-               return 0;
-       if (vma->vm_ops && vma->vm_ops->close)
-               return 0;
+               return false;
+       if (may_remove_vma && vma->vm_ops && vma->vm_ops->close)
+               return false;
        if (!is_mergeable_vm_userfaultfd_ctx(vma, vm_userfaultfd_ctx))
-               return 0;
+               return false;
        if (!anon_vma_name_eq(anon_vma_name(vma), anon_name))
-               return 0;
-       return 1;
+               return false;
+       return true;
 }
 
-static inline int is_mergeable_anon_vma(struct anon_vma *anon_vma1,
-                                       struct anon_vma *anon_vma2,
-                                       struct vm_area_struct *vma)
+static inline bool is_mergeable_anon_vma(struct anon_vma *anon_vma1,
+                struct anon_vma *anon_vma2, struct vm_area_struct *vma)
 {
        /*
         * The list_is_singular() test is to avoid merging VMA cloned from
@@ -780,7 +794,7 @@ static inline int is_mergeable_anon_vma(struct anon_vma *anon_vma1,
         */
        if ((!anon_vma1 || !anon_vma2) && (!vma ||
                list_is_singular(&vma->anon_vma_chain)))
-               return 1;
+               return true;
        return anon_vma1 == anon_vma2;
 }
 
@@ -794,20 +808,21 @@ static inline int is_mergeable_anon_vma(struct anon_vma *anon_vma1,
  * We don't check here for the merged mmap wrapping around the end of pagecache
  * indices (16TB on ia32) because do_mmap() does not permit mmap's which
  * wrap, nor mmaps which cover the final page at index -1UL.
+ *
+ * We assume the vma may be removed as part of the merge.
  */
-static int
+static bool
 can_vma_merge_before(struct vm_area_struct *vma, unsigned long vm_flags,
-                    struct anon_vma *anon_vma, struct file *file,
-                    pgoff_t vm_pgoff,
-                    struct vm_userfaultfd_ctx vm_userfaultfd_ctx,
-                    struct anon_vma_name *anon_name)
+               struct anon_vma *anon_vma, struct file *file,
+               pgoff_t vm_pgoff, struct vm_userfaultfd_ctx vm_userfaultfd_ctx,
+               struct anon_vma_name *anon_name)
 {
-       if (is_mergeable_vma(vma, file, vm_flags, vm_userfaultfd_ctx, anon_name) &&
+       if (is_mergeable_vma(vma, file, vm_flags, vm_userfaultfd_ctx, anon_name, true) &&
            is_mergeable_anon_vma(anon_vma, vma->anon_vma, vma)) {
                if (vma->vm_pgoff == vm_pgoff)
-                       return 1;
+                       return true;
        }
-       return 0;
+       return false;
 }
 
 /*
@@ -816,22 +831,23 @@ can_vma_merge_before(struct vm_area_struct *vma, unsigned long vm_flags,
  *
  * We cannot merge two vmas if they have differently assigned (non-NULL)
  * anon_vmas, nor if same anon_vma is assigned but offsets incompatible.
+ *
+ * We assume that vma is not removed as part of the merge.
  */
-static int
+static bool
 can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
-                   struct anon_vma *anon_vma, struct file *file,
-                   pgoff_t vm_pgoff,
-                   struct vm_userfaultfd_ctx vm_userfaultfd_ctx,
-                   struct anon_vma_name *anon_name)
+               struct anon_vma *anon_vma, struct file *file,
+               pgoff_t vm_pgoff, struct vm_userfaultfd_ctx vm_userfaultfd_ctx,
+               struct anon_vma_name *anon_name)
 {
-       if (is_mergeable_vma(vma, file, vm_flags, vm_userfaultfd_ctx, anon_name) &&
+       if (is_mergeable_vma(vma, file, vm_flags, vm_userfaultfd_ctx, anon_name, false) &&
            is_mergeable_anon_vma(anon_vma, vma->anon_vma, vma)) {
                pgoff_t vm_pglen;
                vm_pglen = vma_pages(vma);
                if (vma->vm_pgoff + vm_pglen == vm_pgoff)
-                       return 1;
+                       return true;
        }
-       return 0;
+       return false;
 }
 
 /*
@@ -846,42 +862,45 @@ can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
  * this area are about to be changed to vm_flags - and the no-change
  * case has already been eliminated.
  *
- * The following mprotect cases have to be considered, where AAAA is
+ * The following mprotect cases have to be considered, where **** is
  * the area passed down from mprotect_fixup, never extending beyond one
- * vma, PPPPPP is the prev vma specified, and NNNNNN the next vma after:
+ * vma, PPPP is the previous vma, CCCC is a concurrent vma that starts
+ * at the same address as **** and is of the same or larger span, and
+ * NNNN the next vma after ****:
  *
- *     AAAA             AAAA                   AAAA
- *    PPPPPPNNNNNN    PPPPPPNNNNNN       PPPPPPNNNNNN
+ *     ****             ****                   ****
+ *    PPPPPPNNNNNN    PPPPPPNNNNNN       PPPPPPCCCCCC
  *    cannot merge    might become       might become
- *                    PPNNNNNNNNNN       PPPPPPPPPPNN
+ *                    PPNNNNNNNNNN       PPPPPPPPPPCC
  *    mmap, brk or    case 4 below       case 5 below
  *    mremap move:
- *                        AAAA               AAAA
- *                    PPPP    NNNN       PPPPNNNNXXXX
+ *                        ****               ****
+ *                    PPPP    NNNN       PPPPCCCCNNNN
  *                    might become       might become
  *                    PPPPPPPPPPPP 1 or  PPPPPPPPPPPP 6 or
- *                    PPPPPPPPNNNN 2 or  PPPPPPPPXXXX 7 or
- *                    PPPPNNNNNNNN 3     PPPPXXXXXXXX 8
+ *                    PPPPPPPPNNNN 2 or  PPPPPPPPNNNN 7 or
+ *                    PPPPNNNNNNNN 3     PPPPNNNNNNNN 8
  *
- * It is important for case 8 that the vma NNNN overlapping the
- * region AAAA is never going to extended over XXXX. Instead XXXX must
- * be extended in region AAAA and NNNN must be removed. This way in
+ * It is important for case 8 that the vma CCCC overlapping the
+ * region **** is never going to extended over NNNN. Instead NNNN must
+ * be extended in region **** and CCCC must be removed. This way in
  * all cases where vma_merge succeeds, the moment vma_merge drops the
  * rmap_locks, the properties of the merged vma will be already
  * correct for the whole merged range. Some of those properties like
  * vm_page_prot/vm_flags may be accessed by rmap_walks and they must
  * be correct for the whole merged range immediately after the
- * rmap_locks are released. Otherwise if XXXX would be removed and
- * NNNN would be extended over the XXXX range, remove_migration_ptes
+ * rmap_locks are released. Otherwise if NNNN would be removed and
+ * CCCC would be extended over the NNNN range, remove_migration_ptes
  * or other rmap walkers (if working on addresses beyond the "end"
- * parameter) may establish ptes with the wrong permissions of NNNN
- * instead of the right permissions of XXXX.
+ * parameter) may establish ptes with the wrong permissions of CCCC
+ * instead of the right permissions of NNNN.
  *
  * In the code below:
  * PPPP is represented by *prev
- * NNNN is represented by *mid (and possibly equal to *next)
- * XXXX is represented by *next or not represented at all.
- * AAAA is not represented - it will be merged or the function will return NULL
+ * CCCC is represented by *curr or not represented at all (NULL)
+ * NNNN is represented by *next or not represented at all (NULL)
+ * **** is not represented - it will be merged and the vma containing the
+ *      area is returned, or the function will return NULL
  */
 struct vm_area_struct *vma_merge(struct vma_iterator *vmi, struct mm_struct *mm,
                        struct vm_area_struct *prev, unsigned long addr,
@@ -891,18 +910,18 @@ struct vm_area_struct *vma_merge(struct vma_iterator *vmi, struct mm_struct *mm,
                        struct vm_userfaultfd_ctx vm_userfaultfd_ctx,
                        struct anon_vma_name *anon_name)
 {
-       pgoff_t pglen = (end - addr) >> PAGE_SHIFT;
-       pgoff_t vma_pgoff;
-       struct vm_area_struct *mid, *next, *res = NULL;
+       struct vm_area_struct *curr, *next, *res;
        struct vm_area_struct *vma, *adjust, *remove, *remove2;
-       int err = -1;
+       struct vma_prepare vp;
+       pgoff_t vma_pgoff;
+       int err = 0;
        bool merge_prev = false;
        bool merge_next = false;
        bool vma_expanded = false;
-       struct vma_prepare vp;
-       unsigned long vma_end = end;
-       long adj_next = 0;
        unsigned long vma_start = addr;
+       unsigned long vma_end = end;
+       pgoff_t pglen = (end - addr) >> PAGE_SHIFT;
+       long adj_start = 0;
 
        validate_mm(mm);
        /*
@@ -912,94 +931,105 @@ struct vm_area_struct *vma_merge(struct vma_iterator *vmi, struct mm_struct *mm,
        if (vm_flags & VM_SPECIAL)
                return NULL;
 
-       next = find_vma(mm, prev ? prev->vm_end : 0);
-       mid = next;
-       if (next && next->vm_end == end)                /* cases 6, 7, 8 */
-               next = find_vma(mm, next->vm_end);
+       /* Does the input range span an existing VMA? (cases 5 - 8) */
+       curr = find_vma_intersection(mm, prev ? prev->vm_end : 0, end);
 
-       /* verify some invariant that must be enforced by the caller */
-       VM_WARN_ON(prev && addr <= prev->vm_start);
-       VM_WARN_ON(mid && end > mid->vm_end);
-       VM_WARN_ON(addr >= end);
+       if (!curr ||                    /* cases 1 - 4 */
+           end == curr->vm_end)        /* cases 6 - 8, adjacent VMA */
+               next = vma_lookup(mm, end);
+       else
+               next = NULL;            /* case 5 */
 
        if (prev) {
-               res = prev;
-               vma = prev;
                vma_start = prev->vm_start;
                vma_pgoff = prev->vm_pgoff;
+
                /* Can we merge the predecessor? */
-               if (prev->vm_end == addr && mpol_equal(vma_policy(prev), policy)
+               if (addr == prev->vm_end && mpol_equal(vma_policy(prev), policy)
                    && can_vma_merge_after(prev, vm_flags, anon_vma, file,
-                                  pgoff, vm_userfaultfd_ctx, anon_name)) {
+                                          pgoff, vm_userfaultfd_ctx, anon_name)) {
                        merge_prev = true;
                        vma_prev(vmi);
                }
        }
+
        /* Can we merge the successor? */
-       if (next && end == next->vm_start &&
-                       mpol_equal(policy, vma_policy(next)) &&
-                       can_vma_merge_before(next, vm_flags,
-                                            anon_vma, file, pgoff+pglen,
-                                            vm_userfaultfd_ctx, anon_name)) {
+       if (next && mpol_equal(policy, vma_policy(next)) &&
+           can_vma_merge_before(next, vm_flags, anon_vma, file, pgoff+pglen,
+                                vm_userfaultfd_ctx, anon_name)) {
                merge_next = true;
        }
 
+       if (!merge_prev && !merge_next)
+               return NULL; /* Not mergeable. */
+
+       res = vma = prev;
        remove = remove2 = adjust = NULL;
+
+       /* Verify some invariant that must be enforced by the caller. */
+       VM_WARN_ON(prev && addr <= prev->vm_start);
+       VM_WARN_ON(curr && (addr != curr->vm_start || end > curr->vm_end));
+       VM_WARN_ON(addr >= end);
+
        /* Can we merge both the predecessor and the successor? */
        if (merge_prev && merge_next &&
            is_mergeable_anon_vma(prev->anon_vma, next->anon_vma, NULL)) {
-               remove = mid;                           /* case 1 */
+               remove = next;                          /* case 1 */
                vma_end = next->vm_end;
-               err = dup_anon_vma(res, remove);
-               if (mid != next) {                      /* case 6 */
+               err = dup_anon_vma(prev, next);
+               if (curr) {                             /* case 6 */
+                       remove = curr;
                        remove2 = next;
-                       if (!remove->anon_vma)
-                               err = dup_anon_vma(res, remove2);
+                       if (!next->anon_vma)
+                               err = dup_anon_vma(prev, curr);
                }
-       } else if (merge_prev) {
-               err = 0;                                /* case 2 */
-               if (mid && end > mid->vm_start) {
-                       err = dup_anon_vma(res, mid);
-                       if (end == mid->vm_end) {       /* case 7 */
-                               remove = mid;
+       } else if (merge_prev) {                        /* case 2 */
+               if (curr) {
+                       err = dup_anon_vma(prev, curr);
+                       if (end == curr->vm_end) {      /* case 7 */
+                               remove = curr;
                        } else {                        /* case 5 */
-                               adjust = mid;
-                               adj_next = (end - mid->vm_start);
+                               adjust = curr;
+                               adj_start = (end - curr->vm_start);
                        }
                }
-       } else if (merge_next) {
+       } else { /* merge_next */
                res = next;
                if (prev && addr < prev->vm_end) {      /* case 4 */
                        vma_end = addr;
-                       adjust = mid;
-                       adj_next = -(vma->vm_end - addr);
-                       err = dup_anon_vma(adjust, prev);
+                       adjust = next;
+                       adj_start = -(prev->vm_end - addr);
+                       err = dup_anon_vma(next, prev);
                } else {
+                       /*
+                        * Note that cases 3 and 8 are the ONLY ones where prev
+                        * is permitted to be (but is not necessarily) NULL.
+                        */
                        vma = next;                     /* case 3 */
                        vma_start = addr;
                        vma_end = next->vm_end;
                        vma_pgoff = next->vm_pgoff - pglen;
-                       err = 0;
-                       if (mid != next) {              /* case 8 */
-                               remove = mid;
-                               err = dup_anon_vma(res, remove);
+                       if (curr) {                     /* case 8 */
+                               vma_pgoff = curr->vm_pgoff;
+                               remove = curr;
+                               err = dup_anon_vma(next, curr);
                        }
                }
        }
 
-       /* Cannot merge or error in anon_vma clone */
+       /* Error in anon_vma clone. */
        if (err)
                return NULL;
 
        if (vma_iter_prealloc(vmi))
                return NULL;
 
-       vma_adjust_trans_huge(vma, vma_start, vma_end, adj_next);
        init_multi_vma_prep(&vp, vma, adjust, remove, remove2);
        VM_WARN_ON(vp.anon_vma && adjust && adjust->anon_vma &&
                   vp.anon_vma != adjust->anon_vma);
 
        vma_prepare(&vp);
+       vma_adjust_trans_huge(vma, vma_start, vma_end, adj_start);
        if (vma_start < vma->vm_start || vma_end > vma->vm_end)
                vma_expanded = true;
 
@@ -1010,10 +1040,10 @@ struct vm_area_struct *vma_merge(struct vma_iterator *vmi, struct mm_struct *mm,
        if (vma_expanded)
                vma_iter_store(vmi, vma);
 
-       if (adj_next) {
-               adjust->vm_start += adj_next;
-               adjust->vm_pgoff += adj_next >> PAGE_SHIFT;
-               if (adj_next < 0) {
+       if (adj_start) {
+               adjust->vm_start += adj_start;
+               adjust->vm_pgoff += adj_start >> PAGE_SHIFT;
+               if (adj_start < 0) {
                        WARN_ON(vma_expanded);
                        vma_iter_store(vmi, next);
                }
@@ -1518,7 +1548,8 @@ static inline int accountable_mapping(struct file *file, vm_flags_t vm_flags)
  */
 static unsigned long unmapped_area(struct vm_unmapped_area_info *info)
 {
-       unsigned long length, gap, low_limit;
+       unsigned long length, gap;
+       unsigned long low_limit, high_limit;
        struct vm_area_struct *tmp;
 
        MA_STATE(mas, &current->mm->mm_mt, 0, 0);
@@ -1529,8 +1560,11 @@ static unsigned long unmapped_area(struct vm_unmapped_area_info *info)
                return -ENOMEM;
 
        low_limit = info->low_limit;
+       if (low_limit < mmap_min_addr)
+               low_limit = mmap_min_addr;
+       high_limit = info->high_limit;
 retry:
-       if (mas_empty_area(&mas, low_limit, info->high_limit - 1, length))
+       if (mas_empty_area(&mas, low_limit, high_limit - 1, length))
                return -ENOMEM;
 
        gap = mas.index;
@@ -1566,7 +1600,8 @@ retry:
  */
 static unsigned long unmapped_area_topdown(struct vm_unmapped_area_info *info)
 {
-       unsigned long length, gap, high_limit, gap_end;
+       unsigned long length, gap, gap_end;
+       unsigned long low_limit, high_limit;
        struct vm_area_struct *tmp;
 
        MA_STATE(mas, &current->mm->mm_mt, 0, 0);
@@ -1575,10 +1610,12 @@ static unsigned long unmapped_area_topdown(struct vm_unmapped_area_info *info)
        if (length < info->length)
                return -ENOMEM;
 
+       low_limit = info->low_limit;
+       if (low_limit < mmap_min_addr)
+               low_limit = mmap_min_addr;
        high_limit = info->high_limit;
 retry:
-       if (mas_empty_area_rev(&mas, info->low_limit, high_limit - 1,
-                               length))
+       if (mas_empty_area_rev(&mas, low_limit, high_limit - 1, length))
                return -ENOMEM;
 
        gap = mas.last + 1 - info->length;
@@ -1713,7 +1750,7 @@ generic_get_unmapped_area_topdown(struct file *filp, unsigned long addr,
 
        info.flags = VM_UNMAPPED_AREA_TOPDOWN;
        info.length = len;
-       info.low_limit = max(PAGE_SIZE, mmap_min_addr);
+       info.low_limit = PAGE_SIZE;
        info.high_limit = arch_get_mmap_base(addr, mm->mmap_base);
        info.align_mask = 0;
        info.align_offset = 0;
@@ -2157,7 +2194,7 @@ static inline void remove_mt(struct mm_struct *mm, struct ma_state *mas)
                if (vma->vm_flags & VM_ACCOUNT)
                        nr_accounted += nrpages;
                vm_stat_account(mm, vma->vm_flags, -nrpages);
-               remove_vma(vma);
+               remove_vma(vma, false);
        }
        vm_unacct_memory(nr_accounted);
        validate_mm(mm);
@@ -2180,7 +2217,8 @@ static void unmap_region(struct mm_struct *mm, struct maple_tree *mt,
        update_hiwater_rss(mm);
        unmap_vmas(&tlb, mt, vma, start, end, mm_wr_locked);
        free_pgtables(&tlb, mt, vma, prev ? prev->vm_end : FIRST_USER_ADDRESS,
-                                next ? next->vm_start : USER_PGTABLES_CEILING);
+                                next ? next->vm_start : USER_PGTABLES_CEILING,
+                                mm_wr_locked);
        tlb_finish_mmu(&tlb);
 }
 
@@ -2236,10 +2274,10 @@ int __split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma,
        if (new->vm_ops && new->vm_ops->open)
                new->vm_ops->open(new);
 
-       vma_adjust_trans_huge(vma, vma->vm_start, addr, 0);
        init_vma_prep(&vp, vma);
        vp.insert = new;
        vma_prepare(&vp);
+       vma_adjust_trans_huge(vma, vma->vm_start, addr, 0);
 
        if (new_below) {
                vma->vm_start = addr;
@@ -2283,10 +2321,12 @@ int split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma,
 static inline int munmap_sidetree(struct vm_area_struct *vma,
                                   struct ma_state *mas_detach)
 {
+       vma_start_write(vma);
        mas_set_range(mas_detach, vma->vm_start, vma->vm_end - 1);
        if (mas_store_gfp(mas_detach, vma, GFP_KERNEL))
                return -ENOMEM;
 
+       vma_mark_detached(vma, true);
        if (vma->vm_flags & VM_LOCKED)
                vma->vm_mm->locked_vm -= vma_pages(vma);
 
@@ -2697,6 +2737,7 @@ unmap_writable:
        if (file && vm_flags & VM_SHARED)
                mapping_unmap_writable(file->f_mapping);
        file = vma->vm_file;
+       ksm_add_vma(vma);
 expanded:
        perf_event_mmap(vma);
 
@@ -2942,9 +2983,9 @@ static int do_brk_flags(struct vma_iterator *vmi, struct vm_area_struct *vma,
                if (vma_iter_prealloc(vmi))
                        goto unacct_fail;
 
-               vma_adjust_trans_huge(vma, vma->vm_start, addr + len, 0);
                init_vma_prep(&vp, vma);
                vma_prepare(&vp);
+               vma_adjust_trans_huge(vma, vma->vm_start, addr + len, 0);
                vma->vm_end = addr + len;
                vm_flags_set(vma, VM_SOFTDIRTY);
                vma_iter_store(vmi, vma);
@@ -2969,6 +3010,7 @@ static int do_brk_flags(struct vma_iterator *vmi, struct vm_area_struct *vma,
                goto mas_store_fail;
 
        mm->map_count++;
+       ksm_add_vma(vma);
 out:
        perf_event_mmap(vma);
        mm->total_vm += len >> PAGE_SHIFT;
@@ -3077,7 +3119,7 @@ void exit_mmap(struct mm_struct *mm)
        mmap_write_lock(mm);
        mt_clear_in_rcu(&mm->mm_mt);
        free_pgtables(&tlb, &mm->mm_mt, vma, FIRST_USER_ADDRESS,
-                     USER_PGTABLES_CEILING);
+                     USER_PGTABLES_CEILING, true);
        tlb_finish_mmu(&tlb);
 
        /*
@@ -3088,7 +3130,7 @@ void exit_mmap(struct mm_struct *mm)
        do {
                if (vma->vm_flags & VM_ACCOUNT)
                        nr_accounted += vma_pages(vma);
-               remove_vma(vma);
+               remove_vma(vma, true);
                count++;
                cond_resched();
        } while ((vma = mas_find(&mas, ULONG_MAX)) != NULL);
@@ -3211,6 +3253,7 @@ struct vm_area_struct *copy_vma(struct vm_area_struct **vmap,
                        get_file(new_vma->vm_file);
                if (new_vma->vm_ops && new_vma->vm_ops->open)
                        new_vma->vm_ops->open(new_vma);
+               vma_start_write(new_vma);
                if (vma_link(mm, new_vma))
                        goto out_vma_link;
                *need_rmap_locks = false;
@@ -3505,6 +3548,7 @@ static void vm_lock_mapping(struct mm_struct *mm, struct address_space *mapping)
  * of mm/rmap.c:
  *   - all hugetlbfs_i_mmap_rwsem_key locks (aka mapping->i_mmap_rwsem for
  *     hugetlb mapping);
+ *   - all vmas marked locked
  *   - all i_mmap_rwsem locks;
  *   - all anon_vma->rwseml
  *
@@ -3527,6 +3571,13 @@ int mm_take_all_locks(struct mm_struct *mm)
 
        mutex_lock(&mm_all_locks_mutex);
 
+       mas_for_each(&mas, vma, ULONG_MAX) {
+               if (signal_pending(current))
+                       goto out_unlock;
+               vma_start_write(vma);
+       }
+
+       mas_set(&mas, 0);
        mas_for_each(&mas, vma, ULONG_MAX) {
                if (signal_pending(current))
                        goto out_unlock;
@@ -3616,6 +3667,7 @@ void mm_drop_all_locks(struct mm_struct *mm)
                if (vma->vm_file && vma->vm_file->f_mapping)
                        vm_unlock_mapping(vma->vm_file->f_mapping);
        }
+       vma_end_write_all(mm);
 
        mutex_unlock(&mm_all_locks_mutex);
 }