memcg: do obj_cgroup_put inside drain_obj_stock
authorShakeel Butt <shakeel.butt@linux.dev>
Fri, 4 Apr 2025 01:39:10 +0000 (18:39 -0700)
committerAndrew Morton <akpm@linux-foundation.org>
Mon, 12 May 2025 00:48:11 +0000 (17:48 -0700)
Previously we could not call obj_cgroup_put() inside the local lock
because on the put on the last reference, the release function
obj_cgroup_release() may try to re-acquire the local lock.  However that
chain has been broken.  Now simply do obj_cgroup_put() inside
drain_obj_stock() instead of returning the old objcg.

Link: https://lkml.kernel.org/r/20250404013913.1663035-7-shakeel.butt@linux.dev
Signed-off-by: Shakeel Butt <shakeel.butt@linux.dev>
Reviewed-by: Roman Gushchin <roman.gushchin@linux.dev>
Acked-by: Vlastimil Babka <vbabka@suse.cz>
Cc: Johannes Weiner <hannes@cmpxchg.org>
Cc: Michal Hocko <mhocko@kernel.org>
Cc: Muchun Song <muchun.song@linux.dev>
Cc: Sebastian Andrzej Siewior <bigeasy@linutronix.de>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
mm/memcontrol.c

index df52084e90f464df17f483d96e5078411af45200..7988a42b29bfcb3120cba93ec51754e53cbfaac3 100644 (file)
@@ -1785,7 +1785,7 @@ static DEFINE_PER_CPU(struct memcg_stock_pcp, memcg_stock) = {
 };
 static DEFINE_MUTEX(percpu_charge_mutex);
 
-static struct obj_cgroup *drain_obj_stock(struct memcg_stock_pcp *stock);
+static void drain_obj_stock(struct memcg_stock_pcp *stock);
 static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
                                     struct mem_cgroup *root_memcg);
 
@@ -1859,7 +1859,6 @@ static void drain_stock(struct memcg_stock_pcp *stock)
 static void drain_local_stock(struct work_struct *dummy)
 {
        struct memcg_stock_pcp *stock;
-       struct obj_cgroup *old = NULL;
        unsigned long flags;
 
        /*
@@ -1870,12 +1869,11 @@ static void drain_local_stock(struct work_struct *dummy)
        local_lock_irqsave(&memcg_stock.stock_lock, flags);
 
        stock = this_cpu_ptr(&memcg_stock);
-       old = drain_obj_stock(stock);
+       drain_obj_stock(stock);
        drain_stock(stock);
        clear_bit(FLUSHING_CACHED_CHARGE, &stock->flags);
 
        local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
-       obj_cgroup_put(old);
 }
 
 static void refill_stock(struct mem_cgroup *memcg, unsigned int nr_pages)
@@ -1958,18 +1956,16 @@ void drain_all_stock(struct mem_cgroup *root_memcg)
 static int memcg_hotplug_cpu_dead(unsigned int cpu)
 {
        struct memcg_stock_pcp *stock;
-       struct obj_cgroup *old;
        unsigned long flags;
 
        stock = &per_cpu(memcg_stock, cpu);
 
        /* drain_obj_stock requires stock_lock */
        local_lock_irqsave(&memcg_stock.stock_lock, flags);
-       old = drain_obj_stock(stock);
+       drain_obj_stock(stock);
        local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
 
        drain_stock(stock);
-       obj_cgroup_put(old);
 
        return 0;
 }
@@ -2766,24 +2762,20 @@ void __memcg_kmem_uncharge_page(struct page *page, int order)
 }
 
 /* Replace the stock objcg with objcg, return the old objcg */
-static struct obj_cgroup *replace_stock_objcg(struct memcg_stock_pcp *stock,
-                                            struct obj_cgroup *objcg)
+static void replace_stock_objcg(struct memcg_stock_pcp *stock,
+                               struct obj_cgroup *objcg)
 {
-       struct obj_cgroup *old = NULL;
-
-       old = drain_obj_stock(stock);
+       drain_obj_stock(stock);
        obj_cgroup_get(objcg);
        stock->nr_bytes = atomic_read(&objcg->nr_charged_bytes)
                        ? atomic_xchg(&objcg->nr_charged_bytes, 0) : 0;
        WRITE_ONCE(stock->cached_objcg, objcg);
-       return old;
 }
 
 static void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
                     enum node_stat_item idx, int nr)
 {
        struct memcg_stock_pcp *stock;
-       struct obj_cgroup *old = NULL;
        unsigned long flags;
        int *bytes;
 
@@ -2796,7 +2788,7 @@ static void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
         * changes.
         */
        if (READ_ONCE(stock->cached_objcg) != objcg) {
-               old = replace_stock_objcg(stock, objcg);
+               replace_stock_objcg(stock, objcg);
                stock->cached_pgdat = pgdat;
        } else if (stock->cached_pgdat != pgdat) {
                /* Flush the existing cached vmstat data */
@@ -2837,7 +2829,6 @@ static void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
                __mod_objcg_mlstate(objcg, pgdat, idx, nr);
 
        local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
-       obj_cgroup_put(old);
 }
 
 static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
@@ -2859,12 +2850,12 @@ static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
        return ret;
 }
 
-static struct obj_cgroup *drain_obj_stock(struct memcg_stock_pcp *stock)
+static void drain_obj_stock(struct memcg_stock_pcp *stock)
 {
        struct obj_cgroup *old = READ_ONCE(stock->cached_objcg);
 
        if (!old)
-               return NULL;
+               return;
 
        if (stock->nr_bytes) {
                unsigned int nr_pages = stock->nr_bytes >> PAGE_SHIFT;
@@ -2917,11 +2908,7 @@ static struct obj_cgroup *drain_obj_stock(struct memcg_stock_pcp *stock)
        }
 
        WRITE_ONCE(stock->cached_objcg, NULL);
-       /*
-        * The `old' objects needs to be released by the caller via
-        * obj_cgroup_put() outside of memcg_stock_pcp::stock_lock.
-        */
-       return old;
+       obj_cgroup_put(old);
 }
 
 static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
@@ -2943,7 +2930,6 @@ static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
                             bool allow_uncharge)
 {
        struct memcg_stock_pcp *stock;
-       struct obj_cgroup *old = NULL;
        unsigned long flags;
        unsigned int nr_pages = 0;
 
@@ -2951,7 +2937,7 @@ static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
 
        stock = this_cpu_ptr(&memcg_stock);
        if (READ_ONCE(stock->cached_objcg) != objcg) { /* reset if necessary */
-               old = replace_stock_objcg(stock, objcg);
+               replace_stock_objcg(stock, objcg);
                allow_uncharge = true;  /* Allow uncharge when objcg changes */
        }
        stock->nr_bytes += nr_bytes;
@@ -2962,7 +2948,6 @@ static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
        }
 
        local_unlock_irqrestore(&memcg_stock.stock_lock, flags);
-       obj_cgroup_put(old);
 
        if (nr_pages)
                obj_cgroup_uncharge_pages(objcg, nr_pages);