augmented rbtree: add new RB_DECLARE_CALLBACKS_MAX macro
authorMichel Lespinasse <walken@google.com>
Wed, 25 Sep 2019 23:46:07 +0000 (16:46 -0700)
committerLinus Torvalds <torvalds@linux-foundation.org>
Thu, 26 Sep 2019 00:51:39 +0000 (17:51 -0700)
Add RB_DECLARE_CALLBACKS_MAX, which generates augmented rbtree callbacks
for the case where the augmented value is a scalar whose definition
follows a max(f(node)) pattern.  This actually covers all present uses of
RB_DECLARE_CALLBACKS, and saves some (source) code duplication in the
various RBCOMPUTE function definitions.

[walken@google.com: fix mm/vmalloc.c]
Link: http://lkml.kernel.org/r/CANN689FXgK13wDYNh1zKxdipeTuALG4eKvKpsdZqKFJ-rvtGiQ@mail.gmail.com
[walken@google.com: re-add check to check_augmented()]
Link: http://lkml.kernel.org/r/20190727022027.GA86863@google.com
Link: http://lkml.kernel.org/r/20190703040156.56953-3-walken@google.com
Signed-off-by: Michel Lespinasse <walken@google.com>
Acked-by: Peter Zijlstra (Intel) <peterz@infradead.org>
Cc: David Howells <dhowells@redhat.com>
Cc: Davidlohr Bueso <dbueso@suse.de>
Cc: Uladzislau Rezki <urezki@gmail.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
arch/x86/mm/pat_rbtree.c
drivers/block/drbd/drbd_interval.c
include/linux/interval_tree_generic.h
include/linux/rbtree_augmented.h
lib/rbtree_test.c
mm/mmap.c
mm/vmalloc.c
tools/include/linux/rbtree_augmented.h

index fa16036fa5929b44fa5c2d4df1a2c5df41fc2d84..65ebe4b88f7cbb634d64ba1fd810afef58b85f0e 100644 (file)
@@ -54,23 +54,10 @@ static u64 get_subtree_max_end(struct rb_node *node)
        return ret;
 }
 
-static u64 compute_subtree_max_end(struct memtype *data)
-{
-       u64 max_end = data->end, child_max_end;
-
-       child_max_end = get_subtree_max_end(data->rb.rb_right);
-       if (child_max_end > max_end)
-               max_end = child_max_end;
-
-       child_max_end = get_subtree_max_end(data->rb.rb_left);
-       if (child_max_end > max_end)
-               max_end = child_max_end;
-
-       return max_end;
-}
+#define NODE_END(node) ((node)->end)
 
-RB_DECLARE_CALLBACKS(static, memtype_rb_augment_cb, struct memtype, rb,
-                    u64, subtree_max_end, compute_subtree_max_end)
+RB_DECLARE_CALLBACKS_MAX(static, memtype_rb_augment_cb,
+                        struct memtype, rb, u64, subtree_max_end, NODE_END)
 
 /* Find the first (lowest start addr) overlapping range from rb tree */
 static struct memtype *memtype_rb_lowest_match(struct rb_root *root,
index c58986556161ff5a46e092bb8bb55c1fec87f0a6..651bd0236a996a0c1cd14aa6fdd97ade33e1aac2 100644 (file)
@@ -13,33 +13,10 @@ sector_t interval_end(struct rb_node *node)
        return this->end;
 }
 
-/**
- * compute_subtree_last  -  compute end of @node
- *
- * The end of an interval is the highest (start + (size >> 9)) value of this
- * node and of its children.  Called for @node and its parents whenever the end
- * may have changed.
- */
-static inline sector_t
-compute_subtree_last(struct drbd_interval *node)
-{
-       sector_t max = node->sector + (node->size >> 9);
-
-       if (node->rb.rb_left) {
-               sector_t left = interval_end(node->rb.rb_left);
-               if (left > max)
-                       max = left;
-       }
-       if (node->rb.rb_right) {
-               sector_t right = interval_end(node->rb.rb_right);
-               if (right > max)
-                       max = right;
-       }
-       return max;
-}
+#define NODE_END(node) ((node)->sector + ((node)->size >> 9))
 
-RB_DECLARE_CALLBACKS(static, augment_callbacks, struct drbd_interval, rb,
-                    sector_t, end, compute_subtree_last);
+RB_DECLARE_CALLBACKS_MAX(static, augment_callbacks,
+                        struct drbd_interval, rb, sector_t, end, NODE_END);
 
 /**
  * drbd_insert_interval  -  insert a new interval into a tree
index 855476145fe18e1a2bf9f9a0e6aa0b054d1c287b..aaa8a0767aa3a512c978d047556af3cee07af66f 100644 (file)
                                                                              \
 /* Callbacks for augmented rbtree insert and remove */                       \
                                                                              \
-static inline ITTYPE ITPREFIX ## _compute_subtree_last(ITSTRUCT *node)       \
-{                                                                            \
-       ITTYPE max = ITLAST(node), subtree_last;                              \
-       if (node->ITRB.rb_left) {                                             \
-               subtree_last = rb_entry(node->ITRB.rb_left,                   \
-                                       ITSTRUCT, ITRB)->ITSUBTREE;           \
-               if (max < subtree_last)                                       \
-                       max = subtree_last;                                   \
-       }                                                                     \
-       if (node->ITRB.rb_right) {                                            \
-               subtree_last = rb_entry(node->ITRB.rb_right,                  \
-                                       ITSTRUCT, ITRB)->ITSUBTREE;           \
-               if (max < subtree_last)                                       \
-                       max = subtree_last;                                   \
-       }                                                                     \
-       return max;                                                           \
-}                                                                            \
-                                                                             \
-RB_DECLARE_CALLBACKS(static, ITPREFIX ## _augment, ITSTRUCT, ITRB,           \
-                    ITTYPE, ITSUBTREE, ITPREFIX ## _compute_subtree_last)    \
+RB_DECLARE_CALLBACKS_MAX(static, ITPREFIX ## _augment,                       \
+                        ITSTRUCT, ITRB, ITTYPE, ITSUBTREE, ITLAST)           \
                                                                              \
 /* Insert / remove interval nodes from the tree */                           \
                                                                              \
index 97994160008223030d8581ba014338fd3b6a0e29..e5937e387e02ec20bf2107db23c8316e929117ae 100644 (file)
@@ -61,7 +61,7 @@ rb_insert_augmented_cached(struct rb_node *node,
 }
 
 /*
- * Template for declaring augmented rbtree callbacks
+ * Template for declaring augmented rbtree callbacks (generic case)
  *
  * RBSTATIC:    'static' or empty
  * RBNAME:      name of the rb_augment_callbacks structure
@@ -107,6 +107,40 @@ RBSTATIC const struct rb_augment_callbacks RBNAME = {                      \
        .rotate = RBNAME ## _rotate                                     \
 };
 
+/*
+ * Template for declaring augmented rbtree callbacks,
+ * computing RBAUGMENTED scalar as max(RBCOMPUTE(node)) for all subtree nodes.
+ *
+ * RBSTATIC:    'static' or empty
+ * RBNAME:      name of the rb_augment_callbacks structure
+ * RBSTRUCT:    struct type of the tree nodes
+ * RBFIELD:     name of struct rb_node field within RBSTRUCT
+ * RBTYPE:      type of the RBAUGMENTED field
+ * RBAUGMENTED: name of RBTYPE field within RBSTRUCT holding data for subtree
+ * RBCOMPUTE:   name of function that returns the per-node RBTYPE scalar
+ */
+
+#define RB_DECLARE_CALLBACKS_MAX(RBSTATIC, RBNAME, RBSTRUCT, RBFIELD,        \
+                                RBTYPE, RBAUGMENTED, RBCOMPUTE)              \
+static inline RBTYPE RBNAME ## _compute_max(RBSTRUCT *node)                  \
+{                                                                            \
+       RBSTRUCT *child;                                                      \
+       RBTYPE max = RBCOMPUTE(node);                                         \
+       if (node->RBFIELD.rb_left) {                                          \
+               child = rb_entry(node->RBFIELD.rb_left, RBSTRUCT, RBFIELD);   \
+               if (child->RBAUGMENTED > max)                                 \
+                       max = child->RBAUGMENTED;                             \
+       }                                                                     \
+       if (node->RBFIELD.rb_right) {                                         \
+               child = rb_entry(node->RBFIELD.rb_right, RBSTRUCT, RBFIELD);  \
+               if (child->RBAUGMENTED > max)                                 \
+                       max = child->RBAUGMENTED;                             \
+       }                                                                     \
+       return max;                                                           \
+}                                                                            \
+RB_DECLARE_CALLBACKS(RBSTATIC, RBNAME, RBSTRUCT, RBFIELD,                    \
+                    RBTYPE, RBAUGMENTED, RBNAME ## _compute_max)
+
 
 #define        RB_RED          0
 #define        RB_BLACK        1
index 62b8ee92643dc5de6f400f24e45e3ddbe475bc2a..41ae3c7570d3971478be9a396557cac8f1ab4dc3 100644 (file)
@@ -77,26 +77,10 @@ static inline void erase_cached(struct test_node *node, struct rb_root_cached *r
 }
 
 
-static inline u32 augment_recompute(struct test_node *node)
-{
-       u32 max = node->val, child_augmented;
-       if (node->rb.rb_left) {
-               child_augmented = rb_entry(node->rb.rb_left, struct test_node,
-                                          rb)->augmented;
-               if (max < child_augmented)
-                       max = child_augmented;
-       }
-       if (node->rb.rb_right) {
-               child_augmented = rb_entry(node->rb.rb_right, struct test_node,
-                                          rb)->augmented;
-               if (max < child_augmented)
-                       max = child_augmented;
-       }
-       return max;
-}
+#define NODE_VAL(node) ((node)->val)
 
-RB_DECLARE_CALLBACKS(static, augment_callbacks, struct test_node, rb,
-                    u32, augmented, augment_recompute)
+RB_DECLARE_CALLBACKS_MAX(static, augment_callbacks,
+                        struct test_node, rb, u32, augmented, NODE_VAL)
 
 static void insert_augmented(struct test_node *node,
                             struct rb_root_cached *root)
@@ -238,7 +222,20 @@ static void check_augmented(int nr_nodes)
        check(nr_nodes);
        for (rb = rb_first(&root.rb_root); rb; rb = rb_next(rb)) {
                struct test_node *node = rb_entry(rb, struct test_node, rb);
-               WARN_ON_ONCE(node->augmented != augment_recompute(node));
+               u32 subtree, max = node->val;
+               if (node->rb.rb_left) {
+                       subtree = rb_entry(node->rb.rb_left, struct test_node,
+                                          rb)->augmented;
+                       if (max < subtree)
+                               max = subtree;
+               }
+               if (node->rb.rb_right) {
+                       subtree = rb_entry(node->rb.rb_right, struct test_node,
+                                          rb)->augmented;
+                       if (max < subtree)
+                               max = subtree;
+               }
+               WARN_ON_ONCE(node->augmented != max);
        }
 }
 
index f1e8c7f93e04c61f825dbb22b2a8e3e25370f031..14b7da317ec044500fe9e8c33c0758ecc6b9c575 100644 (file)
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -289,9 +289,9 @@ out:
        return retval;
 }
 
-static long vma_compute_subtree_gap(struct vm_area_struct *vma)
+static inline unsigned long vma_compute_gap(struct vm_area_struct *vma)
 {
-       unsigned long max, prev_end, subtree_gap;
+       unsigned long gap, prev_end;
 
        /*
         * Note: in the rare case of a VM_GROWSDOWN above a VM_GROWSUP, we
@@ -299,14 +299,21 @@ static long vma_compute_subtree_gap(struct vm_area_struct *vma)
         * an unmapped area; whereas when expanding we only require one.
         * That's a little inconsistent, but keeps the code here simpler.
         */
-       max = vm_start_gap(vma);
+       gap = vm_start_gap(vma);
        if (vma->vm_prev) {
                prev_end = vm_end_gap(vma->vm_prev);
-               if (max > prev_end)
-                       max -= prev_end;
+               if (gap > prev_end)
+                       gap -= prev_end;
                else
-                       max = 0;
+                       gap = 0;
        }
+       return gap;
+}
+
+#ifdef CONFIG_DEBUG_VM_RB
+static unsigned long vma_compute_subtree_gap(struct vm_area_struct *vma)
+{
+       unsigned long max = vma_compute_gap(vma), subtree_gap;
        if (vma->vm_rb.rb_left) {
                subtree_gap = rb_entry(vma->vm_rb.rb_left,
                                struct vm_area_struct, vm_rb)->rb_subtree_gap;
@@ -322,7 +329,6 @@ static long vma_compute_subtree_gap(struct vm_area_struct *vma)
        return max;
 }
 
-#ifdef CONFIG_DEBUG_VM_RB
 static int browse_rb(struct mm_struct *mm)
 {
        struct rb_root *root = &mm->mm_rb;
@@ -428,8 +434,9 @@ static void validate_mm(struct mm_struct *mm)
 #define validate_mm(mm) do { } while (0)
 #endif
 
-RB_DECLARE_CALLBACKS(static, vma_gap_callbacks, struct vm_area_struct, vm_rb,
-                    unsigned long, rb_subtree_gap, vma_compute_subtree_gap)
+RB_DECLARE_CALLBACKS_MAX(static, vma_gap_callbacks,
+                        struct vm_area_struct, vm_rb,
+                        unsigned long, rb_subtree_gap, vma_compute_gap)
 
 /*
  * Update augmented rbtree rb_subtree_gap values after vma->vm_start or
@@ -439,8 +446,8 @@ RB_DECLARE_CALLBACKS(static, vma_gap_callbacks, struct vm_area_struct, vm_rb,
 static void vma_gap_update(struct vm_area_struct *vma)
 {
        /*
-        * As it turns out, RB_DECLARE_CALLBACKS() already created a callback
-        * function that does exactly what we want.
+        * As it turns out, RB_DECLARE_CALLBACKS_MAX() already created
+        * a callback function that does exactly what we want.
         */
        vma_gap_callbacks_propagate(&vma->vm_rb, NULL);
 }
index fcadd3e25c0c8e1964506c4027c7e59ebc5e2c8d..a3c70e275f4e5331fd69eff970cf831021a1dced 100644 (file)
@@ -396,9 +396,8 @@ compute_subtree_max_size(struct vmap_area *va)
                get_subtree_max_size(va->rb_node.rb_right));
 }
 
-RB_DECLARE_CALLBACKS(static, free_vmap_area_rb_augment_cb,
-       struct vmap_area, rb_node, unsigned long, subtree_max_size,
-       compute_subtree_max_size)
+RB_DECLARE_CALLBACKS_MAX(static, free_vmap_area_rb_augment_cb,
+       struct vmap_area, rb_node, unsigned long, subtree_max_size, va_size)
 
 static void purge_vmap_area_lazy(void);
 static BLOCKING_NOTIFIER_HEAD(vmap_notify_list);
index de3a480204ba075946b5b0f7ec602a28e80b47d7..4e8c4c76e9a2628bde041e62b2ba220a3ba7a2e5 100644 (file)
@@ -63,7 +63,7 @@ rb_insert_augmented_cached(struct rb_node *node,
 }
 
 /*
- * Template for declaring augmented rbtree callbacks
+ * Template for declaring augmented rbtree callbacks (generic case)
  *
  * RBSTATIC:    'static' or empty
  * RBNAME:      name of the rb_augment_callbacks structure
@@ -109,6 +109,40 @@ RBSTATIC const struct rb_augment_callbacks RBNAME = {                      \
        .rotate = RBNAME ## _rotate                                     \
 };
 
+/*
+ * Template for declaring augmented rbtree callbacks,
+ * computing RBAUGMENTED scalar as max(RBCOMPUTE(node)) for all subtree nodes.
+ *
+ * RBSTATIC:    'static' or empty
+ * RBNAME:      name of the rb_augment_callbacks structure
+ * RBSTRUCT:    struct type of the tree nodes
+ * RBFIELD:     name of struct rb_node field within RBSTRUCT
+ * RBTYPE:      type of the RBAUGMENTED field
+ * RBAUGMENTED: name of RBTYPE field within RBSTRUCT holding data for subtree
+ * RBCOMPUTE:   name of function that returns the per-node RBTYPE scalar
+ */
+
+#define RB_DECLARE_CALLBACKS_MAX(RBSTATIC, RBNAME, RBSTRUCT, RBFIELD,        \
+                                RBTYPE, RBAUGMENTED, RBCOMPUTE)              \
+static inline RBTYPE RBNAME ## _compute_max(RBSTRUCT *node)                  \
+{                                                                            \
+       RBSTRUCT *child;                                                      \
+       RBTYPE max = RBCOMPUTE(node);                                         \
+       if (node->RBFIELD.rb_left) {                                          \
+               child = rb_entry(node->RBFIELD.rb_left, RBSTRUCT, RBFIELD);   \
+               if (child->RBAUGMENTED > max)                                 \
+                       max = child->RBAUGMENTED;                             \
+       }                                                                     \
+       if (node->RBFIELD.rb_right) {                                         \
+               child = rb_entry(node->RBFIELD.rb_right, RBSTRUCT, RBFIELD);  \
+               if (child->RBAUGMENTED > max)                                 \
+                       max = child->RBAUGMENTED;                             \
+       }                                                                     \
+       return max;                                                           \
+}                                                                            \
+RB_DECLARE_CALLBACKS(RBSTATIC, RBNAME, RBSTRUCT, RBFIELD,                    \
+                    RBTYPE, RBAUGMENTED, RBNAME ## _compute_max)
+
 
 #define        RB_RED          0
 #define        RB_BLACK        1