mm/mmap/vma_merge: set mid to NULL if not applicable
authorVlastimil Babka <vbabka@suse.cz>
Thu, 9 Mar 2023 11:12:54 +0000 (12:12 +0100)
committerAndrew Morton <akpm@linux-foundation.org>
Thu, 6 Apr 2023 02:42:49 +0000 (19:42 -0700)
There are several places where we test if 'mid' is really the area NNNN in
the diagram and the tests have two variants and are non-obvious to follow.
Instead, set 'mid' to NULL up-front if it's not the NNNN area, and
simplify the tests.

Also update the description in comment accordingly.

[vbabka@suse.cz: adjust/add comments as suggested by Lorenzo]
Link: https://lkml.kernel.org/r/def43190-53f7-a607-d1b0-b657565f4288@suse.cz
Link: https://lkml.kernel.org/r/20230309111258.24079-7-vbabka@suse.cz
Signed-off-by: Vlastimil Babka <vbabka@suse.cz>
Reviewed-by: Liam R. Howlett <Liam.Howlett@oracle.com>
Cc: Lorenzo Stoakes <lstoakes@gmail.com>
Cc: Matthew Wilcox (Oracle) <willy@infradead.org>
Cc: Suren Baghdasaryan <surenb@google.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
mm/mmap.c

index d3765dcd9a15519132ab17228ce563d9d2fea572..259b5e54baeb1a6adedec95d480b1ca00a0f55c4 100644 (file)
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -848,10 +848,12 @@ can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
  *
  * The following mprotect cases have to be considered, where AAAA is
  * the area passed down from mprotect_fixup, never extending beyond one
- * vma, PPPPPP is the prev vma specified, and NNNNNN the next vma after:
+ * vma, PPPP is the previous vma, NNNN is a vma that starts at the same
+ * address as AAAA and is of the same or larger span, and XXXX the next
+ * vma after AAAA:
  *
  *     AAAA             AAAA                   AAAA
- *    PPPPPPNNNNNN    PPPPPPXXXXXX       PPPPPPNNNNNN
+ *    PPPPPPXXXXXX    PPPPPPXXXXXX       PPPPPPNNNNNN
  *    cannot merge    might become       might become
  *                    PPXXXXXXXXXX       PPPPPPPPPPNN
  *    mmap, brk or    case 4 below       case 5 below
@@ -879,9 +881,10 @@ can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
  *
  * In the code below:
  * PPPP is represented by *prev
- * NNNN is represented by *mid (and possibly equal to *next)
- * XXXX is represented by *next or not represented at all.
- * AAAA is not represented - it will be merged or the function will return NULL
+ * NNNN is represented by *mid or not represented at all (NULL)
+ * XXXX is represented by *next or not represented at all (NULL)
+ * AAAA is not represented - it will be merged and the vma containing the
+ *      area is returned, or the function will return NULL
  */
 struct vm_area_struct *vma_merge(struct vma_iterator *vmi, struct mm_struct *mm,
                        struct vm_area_struct *prev, unsigned long addr,
@@ -918,6 +921,10 @@ struct vm_area_struct *vma_merge(struct vma_iterator *vmi, struct mm_struct *mm,
        else
                next = mid;
 
+       /* In cases 1 - 4 there's no NNNN vma */
+       if (mid && end <= mid->vm_start)
+               mid = NULL;
+
        /* verify some invariant that must be enforced by the caller */
        VM_WARN_ON(prev && addr <= prev->vm_start);
        VM_WARN_ON(mid && end > mid->vm_end);
@@ -952,7 +959,7 @@ struct vm_area_struct *vma_merge(struct vma_iterator *vmi, struct mm_struct *mm,
                remove = next;                          /* case 1 */
                vma_end = next->vm_end;
                err = dup_anon_vma(prev, next);
-               if (mid != next) {                      /* case 6 */
+               if (mid) {                              /* case 6 */
                        remove = mid;
                        remove2 = next;
                        if (!next->anon_vma)
@@ -960,7 +967,7 @@ struct vm_area_struct *vma_merge(struct vma_iterator *vmi, struct mm_struct *mm,
                }
        } else if (merge_prev) {
                err = 0;                                /* case 2 */
-               if (mid && end > mid->vm_start) {
+               if (mid) {
                        err = dup_anon_vma(prev, mid);
                        if (end == mid->vm_end) {       /* case 7 */
                                remove = mid;
@@ -982,7 +989,7 @@ struct vm_area_struct *vma_merge(struct vma_iterator *vmi, struct mm_struct *mm,
                        vma_end = next->vm_end;
                        vma_pgoff = next->vm_pgoff;
                        err = 0;
-                       if (mid != next) {              /* case 8 */
+                       if (mid) {                      /* case 8 */
                                vma_pgoff = mid->vm_pgoff;
                                remove = mid;
                                err = dup_anon_vma(next, mid);