Merge tag 'wireless-next-2023-04-21' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-2.6-block.git] / lib / maple_tree.c
index 9e2735cbc2b49441d03d66fcf04b6d07ebed35c6..1281a40d5735c423c47a451c53eec273dd8deda1 100644 (file)
@@ -185,7 +185,7 @@ static void mt_free_rcu(struct rcu_head *head)
  */
 static void ma_free_rcu(struct maple_node *node)
 {
-       node->parent = ma_parent_ptr(node);
+       WARN_ON(node->parent != ma_parent_ptr(node));
        call_rcu(&node->rcu, mt_free_rcu);
 }
 
@@ -539,11 +539,14 @@ static inline struct maple_node *mte_parent(const struct maple_enode *enode)
  */
 static inline bool ma_dead_node(const struct maple_node *node)
 {
-       struct maple_node *parent = (void *)((unsigned long)
-                                            node->parent & ~MAPLE_NODE_MASK);
+       struct maple_node *parent;
 
+       /* Do not reorder reads from the node prior to the parent check */
+       smp_rmb();
+       parent = (void *)((unsigned long) node->parent & ~MAPLE_NODE_MASK);
        return (parent == node);
 }
+
 /*
  * mte_dead_node() - check if the @enode is dead.
  * @enode: The encoded maple node
@@ -555,6 +558,8 @@ static inline bool mte_dead_node(const struct maple_enode *enode)
        struct maple_node *parent, *node;
 
        node = mte_to_node(enode);
+       /* Do not reorder reads from the node prior to the parent check */
+       smp_rmb();
        parent = mte_parent(enode);
        return (parent == node);
 }
@@ -625,6 +630,8 @@ static inline unsigned int mas_alloc_req(const struct ma_state *mas)
  * @node - the maple node
  * @type - the node type
  *
+ * In the event of a dead node, this array may be %NULL
+ *
  * Return: A pointer to the maple node pivots
  */
 static inline unsigned long *ma_pivots(struct maple_node *node,
@@ -817,6 +824,11 @@ static inline void *mt_slot(const struct maple_tree *mt,
        return rcu_dereference_check(slots[offset], mt_locked(mt));
 }
 
+static inline void *mt_slot_locked(struct maple_tree *mt, void __rcu **slots,
+                                  unsigned char offset)
+{
+       return rcu_dereference_protected(slots[offset], mt_locked(mt));
+}
 /*
  * mas_slot_locked() - Get the slot value when holding the maple tree lock.
  * @mas: The maple state
@@ -828,7 +840,7 @@ static inline void *mt_slot(const struct maple_tree *mt,
 static inline void *mas_slot_locked(struct ma_state *mas, void __rcu **slots,
                                       unsigned char offset)
 {
-       return rcu_dereference_protected(slots[offset], mt_locked(mas->tree));
+       return mt_slot_locked(mas->tree, slots, offset);
 }
 
 /*
@@ -899,6 +911,45 @@ static inline void ma_set_meta(struct maple_node *mn, enum maple_type mt,
        meta->end = end;
 }
 
+/*
+ * mt_clear_meta() - clear the metadata information of a node, if it exists
+ * @mt: The maple tree
+ * @mn: The maple node
+ * @type: The maple node type
+ * @offset: The offset of the highest sub-gap in this node.
+ * @end: The end of the data in this node.
+ */
+static inline void mt_clear_meta(struct maple_tree *mt, struct maple_node *mn,
+                                 enum maple_type type)
+{
+       struct maple_metadata *meta;
+       unsigned long *pivots;
+       void __rcu **slots;
+       void *next;
+
+       switch (type) {
+       case maple_range_64:
+               pivots = mn->mr64.pivot;
+               if (unlikely(pivots[MAPLE_RANGE64_SLOTS - 2])) {
+                       slots = mn->mr64.slot;
+                       next = mt_slot_locked(mt, slots,
+                                             MAPLE_RANGE64_SLOTS - 1);
+                       if (unlikely((mte_to_node(next) &&
+                                     mte_node_type(next))))
+                               return; /* no metadata, could be node */
+               }
+               fallthrough;
+       case maple_arange_64:
+               meta = ma_meta(mn, type);
+               break;
+       default:
+               return;
+       }
+
+       meta->gap = 0;
+       meta->end = 0;
+}
+
 /*
  * ma_meta_end() - Get the data end of a node from the metadata
  * @mn: The maple node
@@ -1096,8 +1147,11 @@ static int mas_ascend(struct ma_state *mas)
                a_type = mas_parent_enum(mas, p_enode);
                a_node = mte_parent(p_enode);
                a_slot = mte_parent_slot(p_enode);
-               pivots = ma_pivots(a_node, a_type);
                a_enode = mt_mk_node(a_node, a_type);
+               pivots = ma_pivots(a_node, a_type);
+
+               if (unlikely(ma_dead_node(a_node)))
+                       return 1;
 
                if (!set_min && a_slot) {
                        set_min = true;
@@ -1249,26 +1303,21 @@ static inline void mas_alloc_nodes(struct ma_state *mas, gfp_t gfp)
        node = mas->alloc;
        node->request_count = 0;
        while (requested) {
-               max_req = MAPLE_ALLOC_SLOTS;
-               if (node->node_count) {
-                       unsigned int offset = node->node_count;
-
-                       slots = (void **)&node->slot[offset];
-                       max_req -= offset;
-               } else {
-                       slots = (void **)&node->slot;
-               }
-
+               max_req = MAPLE_ALLOC_SLOTS - node->node_count;
+               slots = (void **)&node->slot[node->node_count];
                max_req = min(requested, max_req);
                count = mt_alloc_bulk(gfp, max_req, slots);
                if (!count)
                        goto nomem_bulk;
 
+               if (node->node_count == 0) {
+                       node->slot[0]->node_count = 0;
+                       node->slot[0]->request_count = 0;
+               }
+
                node->node_count += count;
                allocated += count;
                node = node->slot[0];
-               node->node_count = 0;
-               node->request_count = 0;
                requested -= count;
        }
        mas->alloc->total = allocated;
@@ -1354,12 +1403,16 @@ static inline struct maple_enode *mas_start(struct ma_state *mas)
                mas->max = ULONG_MAX;
                mas->depth = 0;
 
+retry:
                root = mas_root(mas);
                /* Tree with nodes */
                if (likely(xa_is_node(root))) {
                        mas->depth = 1;
                        mas->node = mte_safe_root(root);
                        mas->offset = 0;
+                       if (mte_dead_node(mas->node))
+                               goto retry;
+
                        return NULL;
                }
 
@@ -1401,6 +1454,9 @@ static inline unsigned char ma_data_end(struct maple_node *node,
 {
        unsigned char offset;
 
+       if (!pivots)
+               return 0;
+
        if (type == maple_arange_64)
                return ma_meta_end(node, type);
 
@@ -1436,6 +1492,9 @@ static inline unsigned char mas_data_end(struct ma_state *mas)
                return ma_meta_end(node, type);
 
        pivots = ma_pivots(node, type);
+       if (unlikely(ma_dead_node(node)))
+               return 0;
+
        offset = mt_pivots[type] - 1;
        if (likely(!pivots[offset]))
                return ma_meta_end(node, type);
@@ -1724,8 +1783,10 @@ static inline void mas_replace(struct ma_state *mas, bool advanced)
                rcu_assign_pointer(slots[offset], mas->node);
        }
 
-       if (!advanced)
+       if (!advanced) {
+               mte_set_node_dead(old_enode);
                mas_free(mas, old_enode);
+       }
 }
 
 /*
@@ -3659,10 +3720,9 @@ static inline int mas_root_expand(struct ma_state *mas, void *entry)
                slot++;
        mas->depth = 1;
        mas_set_height(mas);
-
+       ma_set_meta(node, maple_leaf_64, 0, slot);
        /* swap the new root into the tree */
        rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node));
-       ma_set_meta(node, maple_leaf_64, 0, slot);
        return slot;
 }
 
@@ -3875,18 +3935,13 @@ static inline void *mtree_lookup_walk(struct ma_state *mas)
                end = ma_data_end(node, type, pivots, max);
                if (unlikely(ma_dead_node(node)))
                        goto dead_node;
-
-               if (pivots[offset] >= mas->index)
-                       goto next;
-
                do {
-                       offset++;
-               } while ((offset < end) && (pivots[offset] < mas->index));
-
-               if (likely(offset > end))
-                       max = pivots[offset];
+                       if (pivots[offset] >= mas->index) {
+                               max = pivots[offset];
+                               break;
+                       }
+               } while (++offset < end);
 
-next:
                slots = ma_slots(node, type);
                next = mt_slot(mas->tree, slots, offset);
                if (unlikely(ma_dead_node(node)))
@@ -4164,6 +4219,7 @@ static inline bool mas_wr_node_store(struct ma_wr_state *wr_mas)
 done:
        mas_leaf_set_meta(mas, newnode, dst_pivots, maple_leaf_64, new_end);
        if (in_rcu) {
+               mte_set_node_dead(mas->node);
                mas->node = mt_mk_node(newnode, wr_mas->type);
                mas_replace(mas, false);
        } else {
@@ -4505,6 +4561,9 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
        node = mas_mn(mas);
        slots = ma_slots(node, mt);
        pivots = ma_pivots(node, mt);
+       if (unlikely(ma_dead_node(node)))
+               return 1;
+
        mas->max = pivots[offset];
        if (offset)
                mas->min = pivots[offset - 1] + 1;
@@ -4526,6 +4585,9 @@ static inline int mas_prev_node(struct ma_state *mas, unsigned long min)
                slots = ma_slots(node, mt);
                pivots = ma_pivots(node, mt);
                offset = ma_data_end(node, mt, pivots, mas->max);
+               if (unlikely(ma_dead_node(node)))
+                       return 1;
+
                if (offset)
                        mas->min = pivots[offset - 1] + 1;
 
@@ -4574,6 +4636,7 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
        struct maple_enode *enode;
        int level = 0;
        unsigned char offset;
+       unsigned char node_end;
        enum maple_type mt;
        void __rcu **slots;
 
@@ -4597,7 +4660,11 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
                node = mas_mn(mas);
                mt = mte_node_type(mas->node);
                pivots = ma_pivots(node, mt);
-       } while (unlikely(offset == ma_data_end(node, mt, pivots, mas->max)));
+               node_end = ma_data_end(node, mt, pivots, mas->max);
+               if (unlikely(ma_dead_node(node)))
+                       return 1;
+
+       } while (unlikely(offset == node_end));
 
        slots = ma_slots(node, mt);
        pivot = mas_safe_pivot(mas, pivots, ++offset, mt);
@@ -4613,6 +4680,9 @@ static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
                mt = mte_node_type(mas->node);
                slots = ma_slots(node, mt);
                pivots = ma_pivots(node, mt);
+               if (unlikely(ma_dead_node(node)))
+                       return 1;
+
                offset = 0;
                pivot = pivots[0];
        }
@@ -4659,11 +4729,14 @@ static inline void *mas_next_nentry(struct ma_state *mas,
                return NULL;
        }
 
-       pivots = ma_pivots(node, type);
        slots = ma_slots(node, type);
-       mas->index = mas_safe_min(mas, pivots, mas->offset);
+       pivots = ma_pivots(node, type);
        count = ma_data_end(node, type, pivots, mas->max);
-       if (ma_dead_node(node))
+       if (unlikely(ma_dead_node(node)))
+               return NULL;
+
+       mas->index = mas_safe_min(mas, pivots, mas->offset);
+       if (unlikely(ma_dead_node(node)))
                return NULL;
 
        if (mas->index > max)
@@ -4817,6 +4890,11 @@ retry:
 
        slots = ma_slots(mn, mt);
        pivots = ma_pivots(mn, mt);
+       if (unlikely(ma_dead_node(mn))) {
+               mas_rewalk(mas, index);
+               goto retry;
+       }
+
        if (offset == mt_pivots[mt])
                pivot = mas->max;
        else
@@ -4887,7 +4965,8 @@ not_found:
  * Return: True if found in a leaf, false otherwise.
  *
  */
-static bool mas_rev_awalk(struct ma_state *mas, unsigned long size)
+static bool mas_rev_awalk(struct ma_state *mas, unsigned long size,
+               unsigned long *gap_min, unsigned long *gap_max)
 {
        enum maple_type type = mte_node_type(mas->node);
        struct maple_node *node = mas_mn(mas);
@@ -4952,8 +5031,8 @@ static bool mas_rev_awalk(struct ma_state *mas, unsigned long size)
 
        if (unlikely(ma_is_leaf(type))) {
                mas->offset = offset;
-               mas->min = min;
-               mas->max = min + gap - 1;
+               *gap_min = min;
+               *gap_max = min + gap - 1;
                return true;
        }
 
@@ -4977,10 +5056,10 @@ static inline bool mas_anode_descend(struct ma_state *mas, unsigned long size)
 {
        enum maple_type type = mte_node_type(mas->node);
        unsigned long pivot, min, gap = 0;
-       unsigned char offset;
-       unsigned long *gaps;
-       unsigned long *pivots = ma_pivots(mas_mn(mas), type);
-       void __rcu **slots = ma_slots(mas_mn(mas), type);
+       unsigned char offset, data_end;
+       unsigned long *gaps, *pivots;
+       void __rcu **slots;
+       struct maple_node *node;
        bool found = false;
 
        if (ma_is_dense(type)) {
@@ -4988,13 +5067,15 @@ static inline bool mas_anode_descend(struct ma_state *mas, unsigned long size)
                return true;
        }
 
-       gaps = ma_gaps(mte_to_node(mas->node), type);
+       node = mas_mn(mas);
+       pivots = ma_pivots(node, type);
+       slots = ma_slots(node, type);
+       gaps = ma_gaps(node, type);
        offset = mas->offset;
        min = mas_safe_min(mas, pivots, offset);
-       for (; offset < mt_slots[type]; offset++) {
-               pivot = mas_safe_pivot(mas, pivots, offset, type);
-               if (offset && !pivot)
-                       break;
+       data_end = ma_data_end(node, type, pivots, mas->max);
+       for (; offset <= data_end; offset++) {
+               pivot = mas_logical_pivot(mas, pivots, offset, type);
 
                /* Not within lower bounds */
                if (mas->index > pivot)
@@ -5229,6 +5310,9 @@ int mas_empty_area(struct ma_state *mas, unsigned long min,
        unsigned long *pivots;
        enum maple_type mt;
 
+       if (min >= max)
+               return -EINVAL;
+
        if (mas_is_start(mas))
                mas_start(mas);
        else if (mas->offset >= 2)
@@ -5283,6 +5367,9 @@ int mas_empty_area_rev(struct ma_state *mas, unsigned long min,
 {
        struct maple_enode *last = mas->node;
 
+       if (min >= max)
+               return -EINVAL;
+
        if (mas_is_start(mas)) {
                mas_start(mas);
                mas->offset = mas_data_end(mas);
@@ -5302,7 +5389,7 @@ int mas_empty_area_rev(struct ma_state *mas, unsigned long min,
        mas->index = min;
        mas->last = max;
 
-       while (!mas_rev_awalk(mas, size)) {
+       while (!mas_rev_awalk(mas, size, &min, &max)) {
                if (last == mas->node) {
                        if (!mas_rewind_node(mas))
                                return -EBUSY;
@@ -5317,17 +5404,9 @@ int mas_empty_area_rev(struct ma_state *mas, unsigned long min,
        if (unlikely(mas->offset == MAPLE_NODE_SLOTS))
                return -EBUSY;
 
-       /*
-        * mas_rev_awalk() has set mas->min and mas->max to the gap values.  If
-        * the maximum is outside the window we are searching, then use the last
-        * location in the search.
-        * mas->max and mas->min is the range of the gap.
-        * mas->index and mas->last are currently set to the search range.
-        */
-
        /* Trim the upper limit to the max. */
-       if (mas->max <= mas->last)
-               mas->last = mas->max;
+       if (max <= mas->last)
+               mas->last = max;
 
        mas->index = mas->last - size + 1;
        return 0;
@@ -5400,24 +5479,26 @@ no_gap:
 }
 
 /*
- * mas_dead_leaves() - Mark all leaves of a node as dead.
+ * mte_dead_leaves() - Mark all leaves of a node as dead.
  * @mas: The maple state
  * @slots: Pointer to the slot array
+ * @type: The maple node type
  *
  * Must hold the write lock.
  *
  * Return: The number of leaves marked as dead.
  */
 static inline
-unsigned char mas_dead_leaves(struct ma_state *mas, void __rcu **slots)
+unsigned char mte_dead_leaves(struct maple_enode *enode, struct maple_tree *mt,
+                             void __rcu **slots)
 {
        struct maple_node *node;
        enum maple_type type;
        void *entry;
        int offset;
 
-       for (offset = 0; offset < mt_slot_count(mas->node); offset++) {
-               entry = mas_slot_locked(mas, slots, offset);
+       for (offset = 0; offset < mt_slot_count(enode); offset++) {
+               entry = mt_slot(mt, slots, offset);
                type = mte_node_type(entry);
                node = mte_to_node(entry);
                /* Use both node and type to catch LE & BE metadata */
@@ -5425,7 +5506,6 @@ unsigned char mas_dead_leaves(struct ma_state *mas, void __rcu **slots)
                        break;
 
                mte_set_node_dead(entry);
-               smp_wmb(); /* Needed for RCU */
                node->type = type;
                rcu_assign_pointer(slots[offset], node);
        }
@@ -5433,151 +5513,160 @@ unsigned char mas_dead_leaves(struct ma_state *mas, void __rcu **slots)
        return offset;
 }
 
-static void __rcu **mas_dead_walk(struct ma_state *mas, unsigned char offset)
+/**
+ * mte_dead_walk() - Walk down a dead tree to just before the leaves
+ * @enode: The maple encoded node
+ * @offset: The starting offset
+ *
+ * Note: This can only be used from the RCU callback context.
+ */
+static void __rcu **mte_dead_walk(struct maple_enode **enode, unsigned char offset)
 {
        struct maple_node *node, *next;
        void __rcu **slots = NULL;
 
-       next = mas_mn(mas);
+       next = mte_to_node(*enode);
        do {
-               mas->node = ma_enode_ptr(next);
-               node = mas_mn(mas);
+               *enode = ma_enode_ptr(next);
+               node = mte_to_node(*enode);
                slots = ma_slots(node, node->type);
-               next = mas_slot_locked(mas, slots, offset);
+               next = rcu_dereference_protected(slots[offset],
+                                       lock_is_held(&rcu_callback_map));
                offset = 0;
        } while (!ma_is_leaf(next->type));
 
        return slots;
 }
 
+/**
+ * mt_free_walk() - Walk & free a tree in the RCU callback context
+ * @head: The RCU head that's within the node.
+ *
+ * Note: This can only be used from the RCU callback context.
+ */
 static void mt_free_walk(struct rcu_head *head)
 {
        void __rcu **slots;
        struct maple_node *node, *start;
-       struct maple_tree mt;
+       struct maple_enode *enode;
        unsigned char offset;
        enum maple_type type;
-       MA_STATE(mas, &mt, 0, 0);
 
        node = container_of(head, struct maple_node, rcu);
 
        if (ma_is_leaf(node->type))
                goto free_leaf;
 
-       mt_init_flags(&mt, node->ma_flags);
-       mas_lock(&mas);
        start = node;
-       mas.node = mt_mk_node(node, node->type);
-       slots = mas_dead_walk(&mas, 0);
-       node = mas_mn(&mas);
+       enode = mt_mk_node(node, node->type);
+       slots = mte_dead_walk(&enode, 0);
+       node = mte_to_node(enode);
        do {
                mt_free_bulk(node->slot_len, slots);
                offset = node->parent_slot + 1;
-               mas.node = node->piv_parent;
-               if (mas_mn(&mas) == node)
-                       goto start_slots_free;
-
-               type = mte_node_type(mas.node);
-               slots = ma_slots(mte_to_node(mas.node), type);
-               if ((offset < mt_slots[type]) && (slots[offset]))
-                       slots = mas_dead_walk(&mas, offset);
-
-               node = mas_mn(&mas);
+               enode = node->piv_parent;
+               if (mte_to_node(enode) == node)
+                       goto free_leaf;
+
+               type = mte_node_type(enode);
+               slots = ma_slots(mte_to_node(enode), type);
+               if ((offset < mt_slots[type]) &&
+                   rcu_dereference_protected(slots[offset],
+                                             lock_is_held(&rcu_callback_map)))
+                       slots = mte_dead_walk(&enode, offset);
+               node = mte_to_node(enode);
        } while ((node != start) || (node->slot_len < offset));
 
        slots = ma_slots(node, node->type);
        mt_free_bulk(node->slot_len, slots);
 
-start_slots_free:
-       mas_unlock(&mas);
 free_leaf:
        mt_free_rcu(&node->rcu);
 }
 
-static inline void __rcu **mas_destroy_descend(struct ma_state *mas,
-                       struct maple_enode *prev, unsigned char offset)
+static inline void __rcu **mte_destroy_descend(struct maple_enode **enode,
+       struct maple_tree *mt, struct maple_enode *prev, unsigned char offset)
 {
        struct maple_node *node;
-       struct maple_enode *next = mas->node;
+       struct maple_enode *next = *enode;
        void __rcu **slots = NULL;
+       enum maple_type type;
+       unsigned char next_offset = 0;
 
        do {
-               mas->node = next;
-               node = mas_mn(mas);
-               slots = ma_slots(node, mte_node_type(mas->node));
-               next = mas_slot_locked(mas, slots, 0);
+               *enode = next;
+               node = mte_to_node(*enode);
+               type = mte_node_type(*enode);
+               slots = ma_slots(node, type);
+               next = mt_slot_locked(mt, slots, next_offset);
                if ((mte_dead_node(next)))
-                       next = mas_slot_locked(mas, slots, 1);
+                       next = mt_slot_locked(mt, slots, ++next_offset);
 
-               mte_set_node_dead(mas->node);
-               node->type = mte_node_type(mas->node);
+               mte_set_node_dead(*enode);
+               node->type = type;
                node->piv_parent = prev;
                node->parent_slot = offset;
-               offset = 0;
-               prev = mas->node;
+               offset = next_offset;
+               next_offset = 0;
+               prev = *enode;
        } while (!mte_is_leaf(next));
 
        return slots;
 }
 
-static void mt_destroy_walk(struct maple_enode *enode, unsigned char ma_flags,
+static void mt_destroy_walk(struct maple_enode *enode, struct maple_tree *mt,
                            bool free)
 {
        void __rcu **slots;
        struct maple_node *node = mte_to_node(enode);
        struct maple_enode *start;
-       struct maple_tree mt;
 
-       MA_STATE(mas, &mt, 0, 0);
-
-       if (mte_is_leaf(enode))
+       if (mte_is_leaf(enode)) {
+               node->type = mte_node_type(enode);
                goto free_leaf;
+       }
 
-       mt_init_flags(&mt, ma_flags);
-       mas_lock(&mas);
-
-       mas.node = start = enode;
-       slots = mas_destroy_descend(&mas, start, 0);
-       node = mas_mn(&mas);
+       start = enode;
+       slots = mte_destroy_descend(&enode, mt, start, 0);
+       node = mte_to_node(enode); // Updated in the above call.
        do {
                enum maple_type type;
                unsigned char offset;
                struct maple_enode *parent, *tmp;
 
-               node->slot_len = mas_dead_leaves(&mas, slots);
+               node->slot_len = mte_dead_leaves(enode, mt, slots);
                if (free)
                        mt_free_bulk(node->slot_len, slots);
                offset = node->parent_slot + 1;
-               mas.node = node->piv_parent;
-               if (mas_mn(&mas) == node)
-                       goto start_slots_free;
+               enode = node->piv_parent;
+               if (mte_to_node(enode) == node)
+                       goto free_leaf;
 
-               type = mte_node_type(mas.node);
-               slots = ma_slots(mte_to_node(mas.node), type);
+               type = mte_node_type(enode);
+               slots = ma_slots(mte_to_node(enode), type);
                if (offset >= mt_slots[type])
                        goto next;
 
-               tmp = mas_slot_locked(&mas, slots, offset);
+               tmp = mt_slot_locked(mt, slots, offset);
                if (mte_node_type(tmp) && mte_to_node(tmp)) {
-                       parent = mas.node;
-                       mas.node = tmp;
-                       slots = mas_destroy_descend(&mas, parent, offset);
+                       parent = enode;
+                       enode = tmp;
+                       slots = mte_destroy_descend(&enode, mt, parent, offset);
                }
 next:
-               node = mas_mn(&mas);
-       } while (start != mas.node);
+               node = mte_to_node(enode);
+       } while (start != enode);
 
-       node = mas_mn(&mas);
-       node->slot_len = mas_dead_leaves(&mas, slots);
+       node = mte_to_node(enode);
+       node->slot_len = mte_dead_leaves(enode, mt, slots);
        if (free)
                mt_free_bulk(node->slot_len, slots);
 
-start_slots_free:
-       mas_unlock(&mas);
-
 free_leaf:
        if (free)
                mt_free_rcu(&node->rcu);
+       else
+               mt_clear_meta(mt, node, node->type);
 }
 
 /*
@@ -5593,10 +5682,10 @@ static inline void mte_destroy_walk(struct maple_enode *enode,
        struct maple_node *node = mte_to_node(enode);
 
        if (mt_in_rcu(mt)) {
-               mt_destroy_walk(enode, mt->ma_flags, false);
+               mt_destroy_walk(enode, mt, false);
                call_rcu(&node->rcu, mt_free_walk);
        } else {
-               mt_destroy_walk(enode, mt->ma_flags, true);
+               mt_destroy_walk(enode, mt, true);
        }
 }
 
@@ -6617,11 +6706,11 @@ static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn,
        while (likely(!ma_is_leaf(mt))) {
                MT_BUG_ON(mas->tree, mte_dead_node(mas->node));
                slots = ma_slots(mn, mt);
-               pivots = ma_pivots(mn, mt);
-               max = pivots[0];
                entry = mas_slot(mas, slots, 0);
+               pivots = ma_pivots(mn, mt);
                if (unlikely(ma_dead_node(mn)))
                        return NULL;
+               max = pivots[0];
                mas->node = entry;
                mn = mas_mn(mas);
                mt = mte_node_type(mas->node);
@@ -6641,13 +6730,13 @@ static inline void *mas_first_entry(struct ma_state *mas, struct maple_node *mn,
        if (likely(entry))
                return entry;
 
-       pivots = ma_pivots(mn, mt);
-       mas->index = pivots[0] + 1;
        mas->offset = 1;
        entry = mas_slot(mas, slots, 1);
+       pivots = ma_pivots(mn, mt);
        if (unlikely(ma_dead_node(mn)))
                return NULL;
 
+       mas->index = pivots[0] + 1;
        if (mas->index > limit)
                goto none;