mm: kmem: scoped objcg protection
authorRoman Gushchin <roman.gushchin@linux.dev>
Thu, 19 Oct 2023 22:53:44 +0000 (15:53 -0700)
committerAndrew Morton <akpm@linux-foundation.org>
Wed, 25 Oct 2023 23:47:11 +0000 (16:47 -0700)
Switch to a scope-based protection of the objcg pointer on slab/kmem
allocation paths.  Instead of using the get_() semantics in the
pre-allocation hook and put the reference afterwards, let's rely on the
fact that objcg is pinned by the scope.

It's possible because:
1) if the objcg is received from the current task struct, the task is
   keeping a reference to the objcg.
2) if the objcg is received from an active memcg (remote charging),
   the memcg is pinned by the scope and has a reference to the
   corresponding objcg.

Link: https://lkml.kernel.org/r/20231019225346.1822282-5-roman.gushchin@linux.dev
Signed-off-by: Roman Gushchin (Cruise) <roman.gushchin@linux.dev>
Tested-by: Naresh Kamboju <naresh.kamboju@linaro.org>
Acked-by: Shakeel Butt <shakeelb@google.com>
Reviewed-by: Vlastimil Babka <vbabka@suse.cz>
Cc: David Rientjes <rientjes@google.com>
Cc: Dennis Zhou <dennis@kernel.org>
Cc: Johannes Weiner <hannes@cmpxchg.org>
Cc: Michal Hocko <mhocko@kernel.org>
Cc: Muchun Song <muchun.song@linux.dev>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
include/linux/memcontrol.h
include/linux/sched/mm.h
mm/memcontrol.c
mm/slab.h

index cc110cc8fdfc3d284cec61245fc34cb8cc533e74..8006bc3bd7bf0adffdfc702ce962e1dbc72a328f 100644 (file)
@@ -1796,6 +1796,15 @@ bool mem_cgroup_kmem_disabled(void);
 int __memcg_kmem_charge_page(struct page *page, gfp_t gfp, int order);
 void __memcg_kmem_uncharge_page(struct page *page, int order);
 
+/*
+ * The returned objcg pointer is safe to use without additional
+ * protection within a scope. The scope is defined either by
+ * the current task (similar to the "current" global variable)
+ * or by set_active_memcg() pair.
+ * Please, use obj_cgroup_get() to get a reference if the pointer
+ * needs to be used outside of the local scope.
+ */
+struct obj_cgroup *current_obj_cgroup(void);
 struct obj_cgroup *get_obj_cgroup_from_current(void);
 struct obj_cgroup *get_obj_cgroup_from_folio(struct folio *folio);
 
index 8d89c8c4fac1f2db1fc278478486aa71516b52bd..9a19f1b42f64129936dd763e1f9436536e60d47e 100644 (file)
@@ -403,6 +403,10 @@ DECLARE_PER_CPU(struct mem_cgroup *, int_active_memcg);
  * __GFP_ACCOUNT allocations till the end of the scope will be charged to the
  * given memcg.
  *
+ * Please, make sure that caller has a reference to the passed memcg structure,
+ * so its lifetime is guaranteed to exceed the scope between two
+ * set_active_memcg() calls.
+ *
  * NOTE: This function can nest. Users must save the return value and
  * reset the previous value after their own charging scope is over.
  */
index ff036d5d339d94e465fc8badfcef14e534dbdb09..a6457c8b5e16189cfa65939a54eca412aa1a9d02 100644 (file)
@@ -3170,6 +3170,49 @@ from_memcg:
        return objcg;
 }
 
+__always_inline struct obj_cgroup *current_obj_cgroup(void)
+{
+       struct mem_cgroup *memcg;
+       struct obj_cgroup *objcg;
+
+       if (in_task()) {
+               memcg = current->active_memcg;
+               if (unlikely(memcg))
+                       goto from_memcg;
+
+               objcg = READ_ONCE(current->objcg);
+               if (unlikely((unsigned long)objcg & CURRENT_OBJCG_UPDATE_FLAG))
+                       objcg = current_objcg_update();
+               /*
+                * Objcg reference is kept by the task, so it's safe
+                * to use the objcg by the current task.
+                */
+               return objcg;
+       }
+
+       memcg = this_cpu_read(int_active_memcg);
+       if (unlikely(memcg))
+               goto from_memcg;
+
+       return NULL;
+
+from_memcg:
+       for (; !mem_cgroup_is_root(memcg); memcg = parent_mem_cgroup(memcg)) {
+               /*
+                * Memcg pointer is protected by scope (see set_active_memcg())
+                * and is pinning the corresponding objcg, so objcg can't go
+                * away and can be used within the scope without any additional
+                * protection.
+                */
+               objcg = rcu_dereference_check(memcg->objcg, 1);
+               if (likely(objcg))
+                       break;
+               objcg = NULL;
+       }
+
+       return objcg;
+}
+
 struct obj_cgroup *get_obj_cgroup_from_folio(struct folio *folio)
 {
        struct obj_cgroup *objcg;
@@ -3264,15 +3307,15 @@ int __memcg_kmem_charge_page(struct page *page, gfp_t gfp, int order)
        struct obj_cgroup *objcg;
        int ret = 0;
 
-       objcg = get_obj_cgroup_from_current();
+       objcg = current_obj_cgroup();
        if (objcg) {
                ret = obj_cgroup_charge_pages(objcg, gfp, 1 << order);
                if (!ret) {
+                       obj_cgroup_get(objcg);
                        page->memcg_data = (unsigned long)objcg |
                                MEMCG_DATA_KMEM;
                        return 0;
                }
-               obj_cgroup_put(objcg);
        }
        return ret;
 }
index 799a315695c6791c860f4b464c54552053e8e05a..3d07fb428393fe14a3adf7cc75940bd0c58431ee 100644 (file)
--- a/mm/slab.h
+++ b/mm/slab.h
@@ -484,7 +484,12 @@ static inline bool memcg_slab_pre_alloc_hook(struct kmem_cache *s,
        if (!(flags & __GFP_ACCOUNT) && !(s->flags & SLAB_ACCOUNT))
                return true;
 
-       objcg = get_obj_cgroup_from_current();
+       /*
+        * The obtained objcg pointer is safe to use within the current scope,
+        * defined by current task or set_active_memcg() pair.
+        * obj_cgroup_get() is used to get a permanent reference.
+        */
+       objcg = current_obj_cgroup();
        if (!objcg)
                return true;
 
@@ -497,17 +502,14 @@ static inline bool memcg_slab_pre_alloc_hook(struct kmem_cache *s,
                css_put(&memcg->css);
 
                if (ret)
-                       goto out;
+                       return false;
        }
 
        if (obj_cgroup_charge(objcg, flags, objects * obj_full_size(s)))
-               goto out;
+               return false;
 
        *objcgp = objcg;
        return true;
-out:
-       obj_cgroup_put(objcg);
-       return false;
 }
 
 static inline void memcg_slab_post_alloc_hook(struct kmem_cache *s,
@@ -542,7 +544,6 @@ static inline void memcg_slab_post_alloc_hook(struct kmem_cache *s,
                        obj_cgroup_uncharge(objcg, obj_full_size(s));
                }
        }
-       obj_cgroup_put(objcg);
 }
 
 static inline void memcg_slab_free_hook(struct kmem_cache *s, struct slab *slab,