KVM: arm64: Convert pkvm_mappings to interval tree
authorQuentin Perret <qperret@google.com>
Wed, 21 May 2025 12:48:31 +0000 (13:48 +0100)
committerMarc Zyngier <maz@kernel.org>
Wed, 21 May 2025 13:33:51 +0000 (14:33 +0100)
In preparation for supporting stage-2 huge mappings for np-guest, let's
convert pgt.pkvm_mappings to an interval tree.

No functional change intended.

Suggested-by: Vincent Donnefort <vdonnefort@google.com>
Signed-off-by: Quentin Perret <qperret@google.com>
Signed-off-by: Vincent Donnefort <vdonnefort@google.com>
Link: https://lore.kernel.org/r/20250521124834.1070650-8-vdonnefort@google.com
Signed-off-by: Marc Zyngier <maz@kernel.org>
arch/arm64/include/asm/kvm_pgtable.h
arch/arm64/include/asm/kvm_pkvm.h
arch/arm64/kvm/pkvm.c

index 6b9d274052c7a777729da58934c5da684dad1a65..1b43bcd2a679af29295ee07090216387003e9868 100644 (file)
@@ -413,7 +413,7 @@ static inline bool kvm_pgtable_walk_lock_held(void)
  */
 struct kvm_pgtable {
        union {
-               struct rb_root                                  pkvm_mappings;
+               struct rb_root_cached                           pkvm_mappings;
                struct {
                        u32                                     ia_bits;
                        s8                                      start_level;
index d91bfcf2db56db32d37aed1497aa84d1a22e7808..da75d41c948c1a86d962a09ca857ef7bf0dd0017 100644 (file)
@@ -173,6 +173,7 @@ struct pkvm_mapping {
        struct rb_node node;
        u64 gfn;
        u64 pfn;
+       u64 __subtree_last;     /* Internal member for interval tree */
 };
 
 int pkvm_pgtable_stage2_init(struct kvm_pgtable *pgt, struct kvm_s2_mmu *mmu,
index f2c1d4c4e27ed19d83ab14bfa73fc612dea775af..0562da0249c331842a062e205de0fd063e103bb9 100644 (file)
@@ -5,6 +5,7 @@
  */
 
 #include <linux/init.h>
+#include <linux/interval_tree_generic.h>
 #include <linux/kmemleak.h>
 #include <linux/kvm_host.h>
 #include <asm/kvm_mmu.h>
@@ -275,80 +276,67 @@ static int __init finalize_pkvm(void)
 }
 device_initcall_sync(finalize_pkvm);
 
-static int cmp_mappings(struct rb_node *node, const struct rb_node *parent)
+static u64 __pkvm_mapping_start(struct pkvm_mapping *m)
 {
-       struct pkvm_mapping *a = rb_entry(node, struct pkvm_mapping, node);
-       struct pkvm_mapping *b = rb_entry(parent, struct pkvm_mapping, node);
-
-       if (a->gfn < b->gfn)
-               return -1;
-       if (a->gfn > b->gfn)
-               return 1;
-       return 0;
+       return m->gfn * PAGE_SIZE;
 }
 
-static struct rb_node *find_first_mapping_node(struct rb_root *root, u64 gfn)
+static u64 __pkvm_mapping_end(struct pkvm_mapping *m)
 {
-       struct rb_node *node = root->rb_node, *prev = NULL;
-       struct pkvm_mapping *mapping;
-
-       while (node) {
-               mapping = rb_entry(node, struct pkvm_mapping, node);
-               if (mapping->gfn == gfn)
-                       return node;
-               prev = node;
-               node = (gfn < mapping->gfn) ? node->rb_left : node->rb_right;
-       }
-
-       return prev;
+       return (m->gfn + 1) * PAGE_SIZE - 1;
 }
 
+INTERVAL_TREE_DEFINE(struct pkvm_mapping, node, u64, __subtree_last,
+                    __pkvm_mapping_start, __pkvm_mapping_end, static,
+                    pkvm_mapping);
+
 /*
- * __tmp is updated to rb_next(__tmp) *before* entering the body of the loop to allow freeing
- * of __map inline.
+ * __tmp is updated to iter_first(pkvm_mappings) *before* entering the body of the loop to allow
+ * freeing of __map inline.
  */
 #define for_each_mapping_in_range_safe(__pgt, __start, __end, __map)                           \
-       for (struct rb_node *__tmp = find_first_mapping_node(&(__pgt)->pkvm_mappings,           \
-                                                            ((__start) >> PAGE_SHIFT));        \
+       for (struct pkvm_mapping *__tmp = pkvm_mapping_iter_first(&(__pgt)->pkvm_mappings,      \
+                                                                 __start, __end - 1);          \
             __tmp && ({                                                                        \
-                               __map = rb_entry(__tmp, struct pkvm_mapping, node);             \
-                               __tmp = rb_next(__tmp);                                         \
+                               __map = __tmp;                                                  \
+                               __tmp = pkvm_mapping_iter_next(__map, __start, __end - 1);      \
                                true;                                                           \
                       });                                                                      \
-           )                                                                                   \
-               if (__map->gfn < ((__start) >> PAGE_SHIFT))                                     \
-                       continue;                                                               \
-               else if (__map->gfn >= ((__end) >> PAGE_SHIFT))                                 \
-                       break;                                                                  \
-               else
+           )
 
 int pkvm_pgtable_stage2_init(struct kvm_pgtable *pgt, struct kvm_s2_mmu *mmu,
                             struct kvm_pgtable_mm_ops *mm_ops)
 {
-       pgt->pkvm_mappings      = RB_ROOT;
+       pgt->pkvm_mappings      = RB_ROOT_CACHED;
        pgt->mmu                = mmu;
 
        return 0;
 }
 
-void pkvm_pgtable_stage2_destroy(struct kvm_pgtable *pgt)
+static int __pkvm_pgtable_stage2_unmap(struct kvm_pgtable *pgt, u64 start, u64 end)
 {
        struct kvm *kvm = kvm_s2_mmu_to_kvm(pgt->mmu);
        pkvm_handle_t handle = kvm->arch.pkvm.handle;
        struct pkvm_mapping *mapping;
-       struct rb_node *node;
+       int ret;
 
        if (!handle)
-               return;
+               return 0;
 
-       node = rb_first(&pgt->pkvm_mappings);
-       while (node) {
-               mapping = rb_entry(node, struct pkvm_mapping, node);
-               kvm_call_hyp_nvhe(__pkvm_host_unshare_guest, handle, mapping->gfn);
-               node = rb_next(node);
-               rb_erase(&mapping->node, &pgt->pkvm_mappings);
+       for_each_mapping_in_range_safe(pgt, start, end, mapping) {
+               ret = kvm_call_hyp_nvhe(__pkvm_host_unshare_guest, handle, mapping->gfn, 1);
+               if (WARN_ON(ret))
+                       return ret;
+               pkvm_mapping_remove(mapping, &pgt->pkvm_mappings);
                kfree(mapping);
        }
+
+       return 0;
+}
+
+void pkvm_pgtable_stage2_destroy(struct kvm_pgtable *pgt)
+{
+       __pkvm_pgtable_stage2_unmap(pgt, 0, ~(0ULL));
 }
 
 int pkvm_pgtable_stage2_map(struct kvm_pgtable *pgt, u64 addr, u64 size,
@@ -376,28 +364,16 @@ int pkvm_pgtable_stage2_map(struct kvm_pgtable *pgt, u64 addr, u64 size,
        swap(mapping, cache->mapping);
        mapping->gfn = gfn;
        mapping->pfn = pfn;
-       WARN_ON(rb_find_add(&mapping->node, &pgt->pkvm_mappings, cmp_mappings));
+       pkvm_mapping_insert(mapping, &pgt->pkvm_mappings);
 
        return ret;
 }
 
 int pkvm_pgtable_stage2_unmap(struct kvm_pgtable *pgt, u64 addr, u64 size)
 {
-       struct kvm *kvm = kvm_s2_mmu_to_kvm(pgt->mmu);
-       pkvm_handle_t handle = kvm->arch.pkvm.handle;
-       struct pkvm_mapping *mapping;
-       int ret = 0;
-
-       lockdep_assert_held_write(&kvm->mmu_lock);
-       for_each_mapping_in_range_safe(pgt, addr, addr + size, mapping) {
-               ret = kvm_call_hyp_nvhe(__pkvm_host_unshare_guest, handle, mapping->gfn, 1);
-               if (WARN_ON(ret))
-                       break;
-               rb_erase(&mapping->node, &pgt->pkvm_mappings);
-               kfree(mapping);
-       }
+       lockdep_assert_held_write(&kvm_s2_mmu_to_kvm(pgt->mmu)->mmu_lock);
 
-       return ret;
+       return __pkvm_pgtable_stage2_unmap(pgt, addr, addr + size);
 }
 
 int pkvm_pgtable_stage2_wrprotect(struct kvm_pgtable *pgt, u64 addr, u64 size)