alloc_tag: allocate percpu counters for module tags dynamically
authorSuren Baghdasaryan <surenb@google.com>
Sat, 17 May 2025 00:07:39 +0000 (17:07 -0700)
committerAndrew Morton <akpm@linux-foundation.org>
Sun, 25 May 2025 07:53:48 +0000 (00:53 -0700)
When a module gets unloaded it checks whether any of its tags are still in
use and if so, we keep the memory containing module's allocation tags
alive until all tags are unused.  However percpu counters referenced by
the tags are freed by free_module().  This will lead to UAF if the memory
allocated by a module is accessed after module was unloaded.

To fix this we allocate percpu counters for module allocation tags
dynamically and we keep it alive for tags which are still in use after
module unloading.  This also removes the requirement of a larger
PERCPU_MODULE_RESERVE when memory allocation profiling is enabled because
percpu memory for counters does not need to be reserved anymore.

Link: https://lkml.kernel.org/r/20250517000739.5930-1-surenb@google.com
Fixes: 0db6f8d7820a ("alloc_tag: load module tags into separate contiguous memory")
Signed-off-by: Suren Baghdasaryan <surenb@google.com>
Reported-by: David Wang <00107082@163.com>
Closes: https://lore.kernel.org/all/20250516131246.6244-1-00107082@163.com/
Tested-by: David Wang <00107082@163.com>
Cc: Christoph Lameter (Ampere) <cl@gentwo.org>
Cc: Dennis Zhou <dennis@kernel.org>
Cc: Kent Overstreet <kent.overstreet@linux.dev>
Cc: Pasha Tatashin <pasha.tatashin@soleen.com>
Cc: Tejun Heo <tj@kernel.org>
Cc: <stable@vger.kernel.org>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
include/linux/alloc_tag.h
include/linux/codetag.h
include/linux/percpu.h
lib/alloc_tag.c
lib/codetag.c

index a946e0203e6d600d451542da0e6cd90d576fb9b6..8f7931eb7d164c13b175921a4e020deb8a7ce259 100644 (file)
@@ -104,6 +104,16 @@ DECLARE_PER_CPU(struct alloc_tag_counters, _shared_alloc_tag);
 
 #else /* ARCH_NEEDS_WEAK_PER_CPU */
 
+#ifdef MODULE
+
+#define DEFINE_ALLOC_TAG(_alloc_tag)                                           \
+       static struct alloc_tag _alloc_tag __used __aligned(8)                  \
+       __section(ALLOC_TAG_SECTION_NAME) = {                                   \
+               .ct = CODE_TAG_INIT,                                            \
+               .counters = NULL };
+
+#else  /* MODULE */
+
 #define DEFINE_ALLOC_TAG(_alloc_tag)                                           \
        static DEFINE_PER_CPU(struct alloc_tag_counters, _alloc_tag_cntr);      \
        static struct alloc_tag _alloc_tag __used __aligned(8)                  \
@@ -111,6 +121,8 @@ DECLARE_PER_CPU(struct alloc_tag_counters, _shared_alloc_tag);
                .ct = CODE_TAG_INIT,                                            \
                .counters = &_alloc_tag_cntr };
 
+#endif /* MODULE */
+
 #endif /* ARCH_NEEDS_WEAK_PER_CPU */
 
 DECLARE_STATIC_KEY_MAYBE(CONFIG_MEM_ALLOC_PROFILING_ENABLED_BY_DEFAULT,
index d14dbd26b37085787b456acfb5a875ea6651af16..0ee4c21c6dbc7cb838f79d8f62255a3824167a59 100644 (file)
@@ -36,10 +36,10 @@ union codetag_ref {
 struct codetag_type_desc {
        const char *section;
        size_t tag_size;
-       void (*module_load)(struct codetag_type *cttype,
-                           struct codetag_module *cmod);
-       void (*module_unload)(struct codetag_type *cttype,
-                             struct codetag_module *cmod);
+       void (*module_load)(struct module *mod,
+                           struct codetag *start, struct codetag *end);
+       void (*module_unload)(struct module *mod,
+                             struct codetag *start, struct codetag *end);
 #ifdef CONFIG_MODULES
        void (*module_replaced)(struct module *mod, struct module *new_mod);
        bool (*needs_section_mem)(struct module *mod, unsigned long size);
index 52b5ea663b9f092a9628dc71009e1e9ffd3724bc..85bf8dd9f08740cb4eb2dbf1c8c961af753d70aa 100644 (file)
 
 /* enough to cover all DEFINE_PER_CPUs in modules */
 #ifdef CONFIG_MODULES
-#ifdef CONFIG_MEM_ALLOC_PROFILING
-#define PERCPU_MODULE_RESERVE          (8 << 13)
-#else
 #define PERCPU_MODULE_RESERVE          (8 << 10)
-#endif
 #else
 #define PERCPU_MODULE_RESERVE          0
 #endif
index 25ecc1334b67ddd238b86d2ec785d1b6b3b7ce85..c7f602fa7b23fce782bf89b8b84446b3816d678c 100644 (file)
@@ -350,18 +350,28 @@ static bool needs_section_mem(struct module *mod, unsigned long size)
        return size >= sizeof(struct alloc_tag);
 }
 
-static struct alloc_tag *find_used_tag(struct alloc_tag *from, struct alloc_tag *to)
+static bool clean_unused_counters(struct alloc_tag *start_tag,
+                                 struct alloc_tag *end_tag)
 {
-       while (from <= to) {
+       struct alloc_tag *tag;
+       bool ret = true;
+
+       for (tag = start_tag; tag <= end_tag; tag++) {
                struct alloc_tag_counters counter;
 
-               counter = alloc_tag_read(from);
-               if (counter.bytes)
-                       return from;
-               from++;
+               if (!tag->counters)
+                       continue;
+
+               counter = alloc_tag_read(tag);
+               if (!counter.bytes) {
+                       free_percpu(tag->counters);
+                       tag->counters = NULL;
+               } else {
+                       ret = false;
+               }
        }
 
-       return NULL;
+       return ret;
 }
 
 /* Called with mod_area_mt locked */
@@ -371,12 +381,16 @@ static void clean_unused_module_areas_locked(void)
        struct module *val;
 
        mas_for_each(&mas, val, module_tags.size) {
+               struct alloc_tag *start_tag;
+               struct alloc_tag *end_tag;
+
                if (val != &unloaded_mod)
                        continue;
 
                /* Release area if all tags are unused */
-               if (!find_used_tag((struct alloc_tag *)(module_tags.start_addr + mas.index),
-                                  (struct alloc_tag *)(module_tags.start_addr + mas.last)))
+               start_tag = (struct alloc_tag *)(module_tags.start_addr + mas.index);
+               end_tag = (struct alloc_tag *)(module_tags.start_addr + mas.last);
+               if (clean_unused_counters(start_tag, end_tag))
                        mas_erase(&mas);
        }
 }
@@ -561,7 +575,8 @@ unlock:
 static void release_module_tags(struct module *mod, bool used)
 {
        MA_STATE(mas, &mod_area_mt, module_tags.size, module_tags.size);
-       struct alloc_tag *tag;
+       struct alloc_tag *start_tag;
+       struct alloc_tag *end_tag;
        struct module *val;
 
        mas_lock(&mas);
@@ -575,15 +590,22 @@ static void release_module_tags(struct module *mod, bool used)
        if (!used)
                goto release_area;
 
-       /* Find out if the area is used */
-       tag = find_used_tag((struct alloc_tag *)(module_tags.start_addr + mas.index),
-                           (struct alloc_tag *)(module_tags.start_addr + mas.last));
-       if (tag) {
-               struct alloc_tag_counters counter = alloc_tag_read(tag);
+       start_tag = (struct alloc_tag *)(module_tags.start_addr + mas.index);
+       end_tag = (struct alloc_tag *)(module_tags.start_addr + mas.last);
+       if (!clean_unused_counters(start_tag, end_tag)) {
+               struct alloc_tag *tag;
+
+               for (tag = start_tag; tag <= end_tag; tag++) {
+                       struct alloc_tag_counters counter;
+
+                       if (!tag->counters)
+                               continue;
 
-               pr_info("%s:%u module %s func:%s has %llu allocated at module unload\n",
-                       tag->ct.filename, tag->ct.lineno, tag->ct.modname,
-                       tag->ct.function, counter.bytes);
+                       counter = alloc_tag_read(tag);
+                       pr_info("%s:%u module %s func:%s has %llu allocated at module unload\n",
+                               tag->ct.filename, tag->ct.lineno, tag->ct.modname,
+                               tag->ct.function, counter.bytes);
+               }
        } else {
                used = false;
        }
@@ -596,6 +618,34 @@ out:
        mas_unlock(&mas);
 }
 
+static void load_module(struct module *mod, struct codetag *start, struct codetag *stop)
+{
+       /* Allocate module alloc_tag percpu counters */
+       struct alloc_tag *start_tag;
+       struct alloc_tag *stop_tag;
+       struct alloc_tag *tag;
+
+       if (!mod)
+               return;
+
+       start_tag = ct_to_alloc_tag(start);
+       stop_tag = ct_to_alloc_tag(stop);
+       for (tag = start_tag; tag < stop_tag; tag++) {
+               WARN_ON(tag->counters);
+               tag->counters = alloc_percpu(struct alloc_tag_counters);
+               if (!tag->counters) {
+                       while (--tag >= start_tag) {
+                               free_percpu(tag->counters);
+                               tag->counters = NULL;
+                       }
+                       shutdown_mem_profiling(true);
+                       pr_err("Failed to allocate memory for allocation tag percpu counters in the module %s. Memory allocation profiling is disabled!\n",
+                              mod->name);
+                       break;
+               }
+       }
+}
+
 static void replace_module(struct module *mod, struct module *new_mod)
 {
        MA_STATE(mas, &mod_area_mt, 0, module_tags.size);
@@ -757,6 +807,7 @@ static int __init alloc_tag_init(void)
                .needs_section_mem      = needs_section_mem,
                .alloc_section_mem      = reserve_module_tags,
                .free_section_mem       = release_module_tags,
+               .module_load            = load_module,
                .module_replaced        = replace_module,
 #endif
        };
index 42aadd6c14549987d51f9760f148d688dabfe928..de332e98d6f5b5e6be53939bff2a8fd1d7862ff2 100644 (file)
@@ -194,7 +194,7 @@ static int codetag_module_init(struct codetag_type *cttype, struct module *mod)
        if (err >= 0) {
                cttype->count += range_size(cttype, &range);
                if (cttype->desc.module_load)
-                       cttype->desc.module_load(cttype, cmod);
+                       cttype->desc.module_load(mod, range.start, range.stop);
        }
        up_write(&cttype->mod_lock);
 
@@ -333,7 +333,8 @@ void codetag_unload_module(struct module *mod)
                }
                if (found) {
                        if (cttype->desc.module_unload)
-                               cttype->desc.module_unload(cttype, cmod);
+                               cttype->desc.module_unload(cmod->mod,
+                                       cmod->range.start, cmod->range.stop);
 
                        cttype->count -= range_size(cttype, &cmod->range);
                        idr_remove(&cttype->mod_idr, mod_id);