Merge branch 'for-5.14' of git://git.kernel.org/pub/scm/linux/kernel/git/dennis/percpu
[linux-block.git] / mm / memcontrol.c
index 1fa9b00ec71d90c56e12901f169f4c13bd9c1016..b80aae448a49da5bb71ca17822861decae362c4e 100644 (file)
@@ -78,12 +78,13 @@ struct mem_cgroup *root_mem_cgroup __read_mostly;
 
 /* Active memory cgroup to use from an interrupt context */
 DEFINE_PER_CPU(struct mem_cgroup *, int_active_memcg);
+EXPORT_PER_CPU_SYMBOL_GPL(int_active_memcg);
 
 /* Socket memory accounting disabled? */
 static bool cgroup_memory_nosocket __ro_after_init;
 
 /* Kernel memory accounting disabled? */
-static bool cgroup_memory_nokmem __ro_after_init;
+bool cgroup_memory_nokmem __ro_after_init;
 
 /* Whether the swap controller is active */
 #ifdef CONFIG_MEMCG_SWAP
@@ -215,7 +216,7 @@ enum res_type {
 #define MEMFILE_PRIVATE(x, val)        ((x) << 16 | (val))
 #define MEMFILE_TYPE(val)      ((val) >> 16 & 0xffff)
 #define MEMFILE_ATTR(val)      ((val) & 0xffff)
-/* Used for OOM nofiier */
+/* Used for OOM notifier */
 #define OOM_CONTROL            (0)
 
 /*
@@ -260,15 +261,12 @@ bool mem_cgroup_kmem_disabled(void)
        return cgroup_memory_nokmem;
 }
 
-static int __memcg_kmem_charge(struct mem_cgroup *memcg, gfp_t gfp,
-                              unsigned int nr_pages);
-static void __memcg_kmem_uncharge(struct mem_cgroup *memcg,
-                                 unsigned int nr_pages);
+static void obj_cgroup_uncharge_pages(struct obj_cgroup *objcg,
+                                     unsigned int nr_pages);
 
 static void obj_cgroup_release(struct percpu_ref *ref)
 {
        struct obj_cgroup *objcg = container_of(ref, struct obj_cgroup, refcnt);
-       struct mem_cgroup *memcg;
        unsigned int nr_bytes;
        unsigned int nr_pages;
        unsigned long flags;
@@ -297,12 +295,11 @@ static void obj_cgroup_release(struct percpu_ref *ref)
        WARN_ON_ONCE(nr_bytes & (PAGE_SIZE - 1));
        nr_pages = nr_bytes >> PAGE_SHIFT;
 
-       spin_lock_irqsave(&css_set_lock, flags);
-       memcg = obj_cgroup_memcg(objcg);
        if (nr_pages)
-               __memcg_kmem_uncharge(memcg, nr_pages);
+               obj_cgroup_uncharge_pages(objcg, nr_pages);
+
+       spin_lock_irqsave(&css_set_lock, flags);
        list_del(&objcg->list);
-       mem_cgroup_put(memcg);
        spin_unlock_irqrestore(&css_set_lock, flags);
 
        percpu_ref_exit(ref);
@@ -337,17 +334,12 @@ static void memcg_reparent_objcgs(struct mem_cgroup *memcg,
 
        spin_lock_irq(&css_set_lock);
 
-       /* Move active objcg to the parent's list */
-       xchg(&objcg->memcg, parent);
-       css_get(&parent->css);
-       list_add(&objcg->list, &parent->objcg_list);
-
-       /* Move already reparented objcgs to the parent's list */
-       list_for_each_entry(iter, &memcg->objcg_list, list) {
-               css_get(&parent->css);
-               xchg(&iter->memcg, parent);
-               css_put(&memcg->css);
-       }
+       /* 1) Ready to reparent active objcg. */
+       list_add(&objcg->list, &memcg->objcg_list);
+       /* 2) Reparent active objcg and already reparented objcgs to parent. */
+       list_for_each_entry(iter, &memcg->objcg_list, list)
+               WRITE_ONCE(iter->memcg, parent);
+       /* 3) Move already reparented objcgs to the parent's list */
        list_splice(&memcg->objcg_list, &parent->objcg_list);
 
        spin_unlock_irq(&css_set_lock);
@@ -407,129 +399,6 @@ DEFINE_STATIC_KEY_FALSE(memcg_kmem_enabled_key);
 EXPORT_SYMBOL(memcg_kmem_enabled_key);
 #endif
 
-static int memcg_shrinker_map_size;
-static DEFINE_MUTEX(memcg_shrinker_map_mutex);
-
-static void memcg_free_shrinker_map_rcu(struct rcu_head *head)
-{
-       kvfree(container_of(head, struct memcg_shrinker_map, rcu));
-}
-
-static int memcg_expand_one_shrinker_map(struct mem_cgroup *memcg,
-                                        int size, int old_size)
-{
-       struct memcg_shrinker_map *new, *old;
-       int nid;
-
-       lockdep_assert_held(&memcg_shrinker_map_mutex);
-
-       for_each_node(nid) {
-               old = rcu_dereference_protected(
-                       mem_cgroup_nodeinfo(memcg, nid)->shrinker_map, true);
-               /* Not yet online memcg */
-               if (!old)
-                       return 0;
-
-               new = kvmalloc_node(sizeof(*new) + size, GFP_KERNEL, nid);
-               if (!new)
-                       return -ENOMEM;
-
-               /* Set all old bits, clear all new bits */
-               memset(new->map, (int)0xff, old_size);
-               memset((void *)new->map + old_size, 0, size - old_size);
-
-               rcu_assign_pointer(memcg->nodeinfo[nid]->shrinker_map, new);
-               call_rcu(&old->rcu, memcg_free_shrinker_map_rcu);
-       }
-
-       return 0;
-}
-
-static void memcg_free_shrinker_maps(struct mem_cgroup *memcg)
-{
-       struct mem_cgroup_per_node *pn;
-       struct memcg_shrinker_map *map;
-       int nid;
-
-       if (mem_cgroup_is_root(memcg))
-               return;
-
-       for_each_node(nid) {
-               pn = mem_cgroup_nodeinfo(memcg, nid);
-               map = rcu_dereference_protected(pn->shrinker_map, true);
-               kvfree(map);
-               rcu_assign_pointer(pn->shrinker_map, NULL);
-       }
-}
-
-static int memcg_alloc_shrinker_maps(struct mem_cgroup *memcg)
-{
-       struct memcg_shrinker_map *map;
-       int nid, size, ret = 0;
-
-       if (mem_cgroup_is_root(memcg))
-               return 0;
-
-       mutex_lock(&memcg_shrinker_map_mutex);
-       size = memcg_shrinker_map_size;
-       for_each_node(nid) {
-               map = kvzalloc_node(sizeof(*map) + size, GFP_KERNEL, nid);
-               if (!map) {
-                       memcg_free_shrinker_maps(memcg);
-                       ret = -ENOMEM;
-                       break;
-               }
-               rcu_assign_pointer(memcg->nodeinfo[nid]->shrinker_map, map);
-       }
-       mutex_unlock(&memcg_shrinker_map_mutex);
-
-       return ret;
-}
-
-int memcg_expand_shrinker_maps(int new_id)
-{
-       int size, old_size, ret = 0;
-       struct mem_cgroup *memcg;
-
-       size = DIV_ROUND_UP(new_id + 1, BITS_PER_LONG) * sizeof(unsigned long);
-       old_size = memcg_shrinker_map_size;
-       if (size <= old_size)
-               return 0;
-
-       mutex_lock(&memcg_shrinker_map_mutex);
-       if (!root_mem_cgroup)
-               goto unlock;
-
-       for_each_mem_cgroup(memcg) {
-               if (mem_cgroup_is_root(memcg))
-                       continue;
-               ret = memcg_expand_one_shrinker_map(memcg, size, old_size);
-               if (ret) {
-                       mem_cgroup_iter_break(NULL, memcg);
-                       goto unlock;
-               }
-       }
-unlock:
-       if (!ret)
-               memcg_shrinker_map_size = size;
-       mutex_unlock(&memcg_shrinker_map_mutex);
-       return ret;
-}
-
-void memcg_set_shrinker_bit(struct mem_cgroup *memcg, int nid, int shrinker_id)
-{
-       if (shrinker_id >= 0 && memcg && !mem_cgroup_is_root(memcg)) {
-               struct memcg_shrinker_map *map;
-
-               rcu_read_lock();
-               map = rcu_dereference(memcg->nodeinfo[nid]->shrinker_map);
-               /* Pairs with smp mb in shrink_slab() */
-               smp_mb__before_atomic();
-               set_bit(shrinker_id, map->map);
-               rcu_read_unlock();
-       }
-}
-
 /**
  * mem_cgroup_css_from_page - css of the memcg associated with a page
  * @page: page of interest
@@ -718,7 +587,7 @@ static void mem_cgroup_remove_from_trees(struct mem_cgroup *memcg)
        int nid;
 
        for_each_node(nid) {
-               mz = mem_cgroup_nodeinfo(memcg, nid);
+               mz = memcg->nodeinfo[nid];
                mctz = soft_limit_tree_node(nid);
                if (mctz)
                        mem_cgroup_remove_exceeded(mz, mctz);
@@ -769,28 +638,37 @@ mem_cgroup_largest_soft_limit_node(struct mem_cgroup_tree_per_node *mctz)
  */
 void __mod_memcg_state(struct mem_cgroup *memcg, int idx, int val)
 {
-       long x, threshold = MEMCG_CHARGE_BATCH;
-
        if (mem_cgroup_disabled())
                return;
 
-       if (memcg_stat_item_in_bytes(idx))
-               threshold <<= PAGE_SHIFT;
+       __this_cpu_add(memcg->vmstats_percpu->state[idx], val);
+       cgroup_rstat_updated(memcg->css.cgroup, smp_processor_id());
+}
 
-       x = val + __this_cpu_read(memcg->vmstats_percpu->stat[idx]);
-       if (unlikely(abs(x) > threshold)) {
-               struct mem_cgroup *mi;
+/* idx can be of type enum memcg_stat_item or node_stat_item. */
+static unsigned long memcg_page_state(struct mem_cgroup *memcg, int idx)
+{
+       long x = READ_ONCE(memcg->vmstats.state[idx]);
+#ifdef CONFIG_SMP
+       if (x < 0)
+               x = 0;
+#endif
+       return x;
+}
 
-               /*
-                * Batch local counters to keep them in sync with
-                * the hierarchical ones.
-                */
-               __this_cpu_add(memcg->vmstats_local->stat[idx], x);
-               for (mi = memcg; mi; mi = parent_mem_cgroup(mi))
-                       atomic_long_add(x, &mi->vmstats[idx]);
+/* idx can be of type enum memcg_stat_item or node_stat_item. */
+static unsigned long memcg_page_state_local(struct mem_cgroup *memcg, int idx)
+{
+       long x = 0;
+       int cpu;
+
+       for_each_possible_cpu(cpu)
+               x += per_cpu(memcg->vmstats_percpu->state[idx], cpu);
+#ifdef CONFIG_SMP
+       if (x < 0)
                x = 0;
-       }
-       __this_cpu_write(memcg->vmstats_percpu->stat[idx], x);
+#endif
+       return x;
 }
 
 static struct mem_cgroup_per_node *
@@ -801,7 +679,7 @@ parent_nodeinfo(struct mem_cgroup_per_node *pn, int nid)
        parent = parent_mem_cgroup(pn->memcg);
        if (!parent)
                return NULL;
-       return mem_cgroup_nodeinfo(parent, nid);
+       return parent->nodeinfo[nid];
 }
 
 void __mod_memcg_lruvec_state(struct lruvec *lruvec, enum node_stat_item idx,
@@ -860,18 +738,22 @@ void __mod_lruvec_page_state(struct page *page, enum node_stat_item idx,
                             int val)
 {
        struct page *head = compound_head(page); /* rmap on tail pages */
-       struct mem_cgroup *memcg = page_memcg(head);
+       struct mem_cgroup *memcg;
        pg_data_t *pgdat = page_pgdat(page);
        struct lruvec *lruvec;
 
+       rcu_read_lock();
+       memcg = page_memcg(head);
        /* Untracked pages have no memcg, no lruvec. Update only the node */
        if (!memcg) {
+               rcu_read_unlock();
                __mod_node_page_state(pgdat, idx, val);
                return;
        }
 
        lruvec = mem_cgroup_lruvec(memcg, pgdat);
        __mod_lruvec_state(lruvec, idx, val);
+       rcu_read_unlock();
 }
 EXPORT_SYMBOL(__mod_lruvec_page_state);
 
@@ -899,39 +781,43 @@ void __mod_lruvec_kmem_state(void *p, enum node_stat_item idx, int val)
        rcu_read_unlock();
 }
 
+/*
+ * mod_objcg_mlstate() may be called with irq enabled, so
+ * mod_memcg_lruvec_state() should be used.
+ */
+static inline void mod_objcg_mlstate(struct obj_cgroup *objcg,
+                                    struct pglist_data *pgdat,
+                                    enum node_stat_item idx, int nr)
+{
+       struct mem_cgroup *memcg;
+       struct lruvec *lruvec;
+
+       rcu_read_lock();
+       memcg = obj_cgroup_memcg(objcg);
+       lruvec = mem_cgroup_lruvec(memcg, pgdat);
+       mod_memcg_lruvec_state(lruvec, idx, nr);
+       rcu_read_unlock();
+}
+
 /**
  * __count_memcg_events - account VM events in a cgroup
  * @memcg: the memory cgroup
  * @idx: the event item
- * @count: the number of events that occured
+ * @count: the number of events that occurred
  */
 void __count_memcg_events(struct mem_cgroup *memcg, enum vm_event_item idx,
                          unsigned long count)
 {
-       unsigned long x;
-
        if (mem_cgroup_disabled())
                return;
 
-       x = count + __this_cpu_read(memcg->vmstats_percpu->events[idx]);
-       if (unlikely(x > MEMCG_CHARGE_BATCH)) {
-               struct mem_cgroup *mi;
-
-               /*
-                * Batch local counters to keep them in sync with
-                * the hierarchical ones.
-                */
-               __this_cpu_add(memcg->vmstats_local->events[idx], x);
-               for (mi = memcg; mi; mi = parent_mem_cgroup(mi))
-                       atomic_long_add(x, &mi->vmevents[idx]);
-               x = 0;
-       }
-       __this_cpu_write(memcg->vmstats_percpu->events[idx], x);
+       __this_cpu_add(memcg->vmstats_percpu->events[idx], count);
+       cgroup_rstat_updated(memcg->css.cgroup, smp_processor_id());
 }
 
 static unsigned long memcg_events(struct mem_cgroup *memcg, int event)
 {
-       return atomic_long_read(&memcg->vmevents[event]);
+       return READ_ONCE(memcg->vmstats.events[event]);
 }
 
 static unsigned long memcg_events_local(struct mem_cgroup *memcg, int event)
@@ -940,7 +826,7 @@ static unsigned long memcg_events_local(struct mem_cgroup *memcg, int event)
        int cpu;
 
        for_each_possible_cpu(cpu)
-               x += per_cpu(memcg->vmstats_local->events[event], cpu);
+               x += per_cpu(memcg->vmstats_percpu->events[event], cpu);
        return x;
 }
 
@@ -1017,13 +903,24 @@ struct mem_cgroup *mem_cgroup_from_task(struct task_struct *p)
 }
 EXPORT_SYMBOL(mem_cgroup_from_task);
 
+static __always_inline struct mem_cgroup *active_memcg(void)
+{
+       if (in_interrupt())
+               return this_cpu_read(int_active_memcg);
+       else
+               return current->active_memcg;
+}
+
 /**
  * get_mem_cgroup_from_mm: Obtain a reference on given mm_struct's memcg.
  * @mm: mm from which memcg should be extracted. It can be NULL.
  *
- * Obtain a reference on mm->memcg and returns it if successful. Otherwise
- * root_mem_cgroup is returned. However if mem_cgroup is disabled, NULL is
- * returned.
+ * Obtain a reference on mm->memcg and returns it if successful. If mm
+ * is NULL, then the memcg is chosen as follows:
+ * 1) The active memcg, if set.
+ * 2) current->mm->memcg, if available
+ * 3) root memcg
+ * If mem_cgroup is disabled, NULL is returned.
  */
 struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm)
 {
@@ -1032,48 +929,38 @@ struct mem_cgroup *get_mem_cgroup_from_mm(struct mm_struct *mm)
        if (mem_cgroup_disabled())
                return NULL;
 
+       /*
+        * Page cache insertions can happen without an
+        * actual mm context, e.g. during disk probing
+        * on boot, loopback IO, acct() writes etc.
+        *
+        * No need to css_get on root memcg as the reference
+        * counting is disabled on the root level in the
+        * cgroup core. See CSS_NO_REF.
+        */
+       if (unlikely(!mm)) {
+               memcg = active_memcg();
+               if (unlikely(memcg)) {
+                       /* remote memcg must hold a ref */
+                       css_get(&memcg->css);
+                       return memcg;
+               }
+               mm = current->mm;
+               if (unlikely(!mm))
+                       return root_mem_cgroup;
+       }
+
        rcu_read_lock();
        do {
-               /*
-                * Page cache insertions can happen withou an
-                * actual mm context, e.g. during disk probing
-                * on boot, loopback IO, acct() writes etc.
-                */
-               if (unlikely(!mm))
+               memcg = mem_cgroup_from_task(rcu_dereference(mm->owner));
+               if (unlikely(!memcg))
                        memcg = root_mem_cgroup;
-               else {
-                       memcg = mem_cgroup_from_task(rcu_dereference(mm->owner));
-                       if (unlikely(!memcg))
-                               memcg = root_mem_cgroup;
-               }
        } while (!css_tryget(&memcg->css));
        rcu_read_unlock();
        return memcg;
 }
 EXPORT_SYMBOL(get_mem_cgroup_from_mm);
 
-static __always_inline struct mem_cgroup *active_memcg(void)
-{
-       if (in_interrupt())
-               return this_cpu_read(int_active_memcg);
-       else
-               return current->active_memcg;
-}
-
-static __always_inline struct mem_cgroup *get_active_memcg(void)
-{
-       struct mem_cgroup *memcg;
-
-       rcu_read_lock();
-       memcg = active_memcg();
-       /* remote memcg must hold a ref. */
-       if (memcg && WARN_ON_ONCE(!css_tryget(&memcg->css)))
-               memcg = root_mem_cgroup;
-       rcu_read_unlock();
-
-       return memcg;
-}
-
 static __always_inline bool memcg_kmem_bypass(void)
 {
        /* Allow remote memcg charging from any context. */
@@ -1087,20 +974,6 @@ static __always_inline bool memcg_kmem_bypass(void)
        return false;
 }
 
-/**
- * If active memcg is set, do not fallback to current->mm->memcg.
- */
-static __always_inline struct mem_cgroup *get_mem_cgroup_from_current(void)
-{
-       if (memcg_kmem_bypass())
-               return NULL;
-
-       if (unlikely(active_memcg()))
-               return get_active_memcg();
-
-       return get_mem_cgroup_from_mm(current->mm);
-}
-
 /**
  * mem_cgroup_iter - iterate over memory cgroup hierarchy
  * @root: hierarchy root
@@ -1141,7 +1014,7 @@ struct mem_cgroup *mem_cgroup_iter(struct mem_cgroup *root,
        if (reclaim) {
                struct mem_cgroup_per_node *mz;
 
-               mz = mem_cgroup_nodeinfo(root, reclaim->pgdat->node_id);
+               mz = root->nodeinfo[reclaim->pgdat->node_id];
                iter = &mz->iter;
 
                if (prev && reclaim->generation != iter->generation)
@@ -1243,7 +1116,7 @@ static void __invalidate_reclaim_iterators(struct mem_cgroup *from,
        int nid;
 
        for_each_node(nid) {
-               mz = mem_cgroup_nodeinfo(from, nid);
+               mz = from->nodeinfo[nid];
                iter = &mz->iter;
                cmpxchg(&iter->position, dead_memcg, NULL);
        }
@@ -1337,9 +1210,8 @@ void lruvec_memcg_debug(struct lruvec *lruvec, struct page *page)
 struct lruvec *lock_page_lruvec(struct page *page)
 {
        struct lruvec *lruvec;
-       struct pglist_data *pgdat = page_pgdat(page);
 
-       lruvec = mem_cgroup_page_lruvec(page, pgdat);
+       lruvec = mem_cgroup_page_lruvec(page);
        spin_lock(&lruvec->lru_lock);
 
        lruvec_memcg_debug(lruvec, page);
@@ -1350,9 +1222,8 @@ struct lruvec *lock_page_lruvec(struct page *page)
 struct lruvec *lock_page_lruvec_irq(struct page *page)
 {
        struct lruvec *lruvec;
-       struct pglist_data *pgdat = page_pgdat(page);
 
-       lruvec = mem_cgroup_page_lruvec(page, pgdat);
+       lruvec = mem_cgroup_page_lruvec(page);
        spin_lock_irq(&lruvec->lru_lock);
 
        lruvec_memcg_debug(lruvec, page);
@@ -1363,9 +1234,8 @@ struct lruvec *lock_page_lruvec_irq(struct page *page)
 struct lruvec *lock_page_lruvec_irqsave(struct page *page, unsigned long *flags)
 {
        struct lruvec *lruvec;
-       struct pglist_data *pgdat = page_pgdat(page);
 
-       lruvec = mem_cgroup_page_lruvec(page, pgdat);
+       lruvec = mem_cgroup_page_lruvec(page);
        spin_lock_irqsave(&lruvec->lru_lock, *flags);
 
        lruvec_memcg_debug(lruvec, page);
@@ -1576,6 +1446,7 @@ static char *memory_stat_format(struct mem_cgroup *memcg)
         *
         * Current memory state:
         */
+       cgroup_rstat_flush(memcg->css.cgroup);
 
        for (i = 0; i < ARRAY_SIZE(memory_stats); i++) {
                u64 size;
@@ -1870,7 +1741,7 @@ static void mem_cgroup_unmark_under_oom(struct mem_cgroup *memcg)
        struct mem_cgroup *iter;
 
        /*
-        * Be careful about under_oom underflows becase a child memcg
+        * Be careful about under_oom underflows because a child memcg
         * could have been added after mem_cgroup_mark_under_oom.
         */
        spin_lock(&memcg_oom_lock);
@@ -2042,7 +1913,7 @@ bool mem_cgroup_oom_synchronize(bool handle)
                /*
                 * There is no guarantee that an OOM-lock contender
                 * sees the wakeups triggered by the OOM kill
-                * uncharges.  Wake any sleepers explicitely.
+                * uncharges.  Wake any sleepers explicitly.
                 */
                memcg_oom_recover(memcg);
        }
@@ -2123,11 +1994,10 @@ void mem_cgroup_print_oom_group(struct mem_cgroup *memcg)
  * This function protects unlocked LRU pages from being moved to
  * another cgroup.
  *
- * It ensures lifetime of the returned memcg. Caller is responsible
- * for the lifetime of the page; __unlock_page_memcg() is available
- * when @page might get freed inside the locked section.
+ * It ensures lifetime of the locked memcg. Caller is responsible
+ * for the lifetime of the page.
  */
-struct mem_cgroup *lock_page_memcg(struct page *page)
+void lock_page_memcg(struct page *page)
 {
        struct page *head = compound_head(page); /* rmap on tail pages */
        struct mem_cgroup *memcg;
@@ -2137,21 +2007,15 @@ struct mem_cgroup *lock_page_memcg(struct page *page)
         * The RCU lock is held throughout the transaction.  The fast
         * path can get away without acquiring the memcg->move_lock
         * because page moving starts with an RCU grace period.
-        *
-        * The RCU lock also protects the memcg from being freed when
-        * the page state that is going to change is the only thing
-        * preventing the page itself from being freed. E.g. writeback
-        * doesn't hold a page reference and relies on PG_writeback to
-        * keep off truncation, migration and so forth.
          */
        rcu_read_lock();
 
        if (mem_cgroup_disabled())
-               return NULL;
+               return;
 again:
        memcg = page_memcg(head);
        if (unlikely(!memcg))
-               return NULL;
+               return;
 
 #ifdef CONFIG_PROVE_LOCKING
        local_irq_save(flags);
@@ -2160,7 +2024,7 @@ again:
 #endif
 
        if (atomic_read(&memcg->moving_account) <= 0)
-               return memcg;
+               return;
 
        spin_lock_irqsave(&memcg->move_lock, flags);
        if (memcg != page_memcg(head)) {
@@ -2169,24 +2033,17 @@ again:
        }
 
        /*
-        * When charge migration first begins, we can have locked and
-        * unlocked page stat updates happening concurrently.  Track
-        * the task who has the lock for unlock_page_memcg().
+        * When charge migration first begins, we can have multiple
+        * critical sections holding the fast-path RCU lock and one
+        * holding the slowpath move_lock. Track the task who has the
+        * move_lock for unlock_page_memcg().
         */
        memcg->move_lock_task = current;
        memcg->move_lock_flags = flags;
-
-       return memcg;
 }
 EXPORT_SYMBOL(lock_page_memcg);
 
-/**
- * __unlock_page_memcg - unlock and unpin a memcg
- * @memcg: the memcg
- *
- * Unlock and unpin a memcg returned by lock_page_memcg().
- */
-void __unlock_page_memcg(struct mem_cgroup *memcg)
+static void __unlock_page_memcg(struct mem_cgroup *memcg)
 {
        if (memcg && memcg->move_lock_task == current) {
                unsigned long flags = memcg->move_lock_flags;
@@ -2212,14 +2069,23 @@ void unlock_page_memcg(struct page *page)
 }
 EXPORT_SYMBOL(unlock_page_memcg);
 
-struct memcg_stock_pcp {
-       struct mem_cgroup *cached; /* this never be root cgroup */
-       unsigned int nr_pages;
-
+struct obj_stock {
 #ifdef CONFIG_MEMCG_KMEM
        struct obj_cgroup *cached_objcg;
+       struct pglist_data *cached_pgdat;
        unsigned int nr_bytes;
+       int nr_slab_reclaimable_b;
+       int nr_slab_unreclaimable_b;
+#else
+       int dummy[0];
 #endif
+};
+
+struct memcg_stock_pcp {
+       struct mem_cgroup *cached; /* this never be root cgroup */
+       unsigned int nr_pages;
+       struct obj_stock task_obj;
+       struct obj_stock irq_obj;
 
        struct work_struct work;
        unsigned long flags;
@@ -2229,12 +2095,12 @@ static DEFINE_PER_CPU(struct memcg_stock_pcp, memcg_stock);
 static DEFINE_MUTEX(percpu_charge_mutex);
 
 #ifdef CONFIG_MEMCG_KMEM
-static void drain_obj_stock(struct memcg_stock_pcp *stock);
+static void drain_obj_stock(struct obj_stock *stock);
 static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
                                     struct mem_cgroup *root_memcg);
 
 #else
-static inline void drain_obj_stock(struct memcg_stock_pcp *stock)
+static inline void drain_obj_stock(struct obj_stock *stock)
 {
 }
 static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
@@ -2244,6 +2110,41 @@ static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
 }
 #endif
 
+/*
+ * Most kmem_cache_alloc() calls are from user context. The irq disable/enable
+ * sequence used in this case to access content from object stock is slow.
+ * To optimize for user context access, there are now two object stocks for
+ * task context and interrupt context access respectively.
+ *
+ * The task context object stock can be accessed by disabling preemption only
+ * which is cheap in non-preempt kernel. The interrupt context object stock
+ * can only be accessed after disabling interrupt. User context code can
+ * access interrupt object stock, but not vice versa.
+ */
+static inline struct obj_stock *get_obj_stock(unsigned long *pflags)
+{
+       struct memcg_stock_pcp *stock;
+
+       if (likely(in_task())) {
+               *pflags = 0UL;
+               preempt_disable();
+               stock = this_cpu_ptr(&memcg_stock);
+               return &stock->task_obj;
+       }
+
+       local_irq_save(*pflags);
+       stock = this_cpu_ptr(&memcg_stock);
+       return &stock->irq_obj;
+}
+
+static inline void put_obj_stock(unsigned long flags)
+{
+       if (likely(in_task()))
+               preempt_enable();
+       else
+               local_irq_restore(flags);
+}
+
 /**
  * consume_stock: Try to consume stocked charge on this cpu.
  * @memcg: memcg to consume from.
@@ -2310,7 +2211,9 @@ static void drain_local_stock(struct work_struct *dummy)
        local_irq_save(flags);
 
        stock = this_cpu_ptr(&memcg_stock);
-       drain_obj_stock(stock);
+       drain_obj_stock(&stock->irq_obj);
+       if (in_task())
+               drain_obj_stock(&stock->task_obj);
        drain_stock(stock);
        clear_bit(FLUSHING_CACHED_CHARGE, &stock->flags);
 
@@ -2386,50 +2289,39 @@ static void drain_all_stock(struct mem_cgroup *root_memcg)
        mutex_unlock(&percpu_charge_mutex);
 }
 
-static int memcg_hotplug_cpu_dead(unsigned int cpu)
+static void memcg_flush_lruvec_page_state(struct mem_cgroup *memcg, int cpu)
 {
-       struct memcg_stock_pcp *stock;
-       struct mem_cgroup *memcg, *mi;
-
-       stock = &per_cpu(memcg_stock, cpu);
-       drain_stock(stock);
+       int nid;
 
-       for_each_mem_cgroup(memcg) {
+       for_each_node(nid) {
+               struct mem_cgroup_per_node *pn = memcg->nodeinfo[nid];
+               unsigned long stat[NR_VM_NODE_STAT_ITEMS];
+               struct batched_lruvec_stat *lstatc;
                int i;
 
-               for (i = 0; i < MEMCG_NR_STAT; i++) {
-                       int nid;
-                       long x;
-
-                       x = this_cpu_xchg(memcg->vmstats_percpu->stat[i], 0);
-                       if (x)
-                               for (mi = memcg; mi; mi = parent_mem_cgroup(mi))
-                                       atomic_long_add(x, &memcg->vmstats[i]);
-
-                       if (i >= NR_VM_NODE_STAT_ITEMS)
-                               continue;
+               lstatc = per_cpu_ptr(pn->lruvec_stat_cpu, cpu);
+               for (i = 0; i < NR_VM_NODE_STAT_ITEMS; i++) {
+                       stat[i] = lstatc->count[i];
+                       lstatc->count[i] = 0;
+               }
 
-                       for_each_node(nid) {
-                               struct mem_cgroup_per_node *pn;
+               do {
+                       for (i = 0; i < NR_VM_NODE_STAT_ITEMS; i++)
+                               atomic_long_add(stat[i], &pn->lruvec_stat[i]);
+               } while ((pn = parent_nodeinfo(pn, nid)));
+       }
+}
 
-                               pn = mem_cgroup_nodeinfo(memcg, nid);
-                               x = this_cpu_xchg(pn->lruvec_stat_cpu->count[i], 0);
-                               if (x)
-                                       do {
-                                               atomic_long_add(x, &pn->lruvec_stat[i]);
-                                       } while ((pn = parent_nodeinfo(pn, nid)));
-                       }
-               }
+static int memcg_hotplug_cpu_dead(unsigned int cpu)
+{
+       struct memcg_stock_pcp *stock;
+       struct mem_cgroup *memcg;
 
-               for (i = 0; i < NR_VM_EVENT_ITEMS; i++) {
-                       long x;
+       stock = &per_cpu(memcg_stock, cpu);
+       drain_stock(stock);
 
-                       x = this_cpu_xchg(memcg->vmstats_percpu->events[i], 0);
-                       if (x)
-                               for (mi = memcg; mi; mi = parent_mem_cgroup(mi))
-                                       atomic_long_add(x, &memcg->vmevents[i]);
-               }
-       }
+       for_each_mem_cgroup(memcg)
+               memcg_flush_lruvec_page_state(memcg, cpu);
 
        return 0;
 }
@@ -2687,8 +2579,8 @@ out:
        css_put(&memcg->css);
 }
 
-static int try_charge(struct mem_cgroup *memcg, gfp_t gfp_mask,
-                     unsigned int nr_pages)
+static int try_charge_memcg(struct mem_cgroup *memcg, gfp_t gfp_mask,
+                       unsigned int nr_pages)
 {
        unsigned int batch = max(MEMCG_CHARGE_BATCH, nr_pages);
        int nr_retries = MAX_RECLAIM_RETRIES;
@@ -2700,8 +2592,6 @@ static int try_charge(struct mem_cgroup *memcg, gfp_t gfp_mask,
        bool drained = false;
        unsigned long pflags;
 
-       if (mem_cgroup_is_root(memcg))
-               return 0;
 retry:
        if (consume_stock(memcg, nr_pages))
                return 0;
@@ -2798,9 +2688,6 @@ retry:
        if (gfp_mask & __GFP_RETRY_MAYFAIL)
                goto nomem;
 
-       if (gfp_mask & __GFP_NOFAIL)
-               goto force;
-
        if (fatal_signal_pending(current))
                goto force;
 
@@ -2884,6 +2771,15 @@ done_restock:
        return 0;
 }
 
+static inline int try_charge(struct mem_cgroup *memcg, gfp_t gfp_mask,
+                            unsigned int nr_pages)
+{
+       if (mem_cgroup_is_root(memcg))
+               return 0;
+
+       return try_charge_memcg(memcg, gfp_mask, nr_pages);
+}
+
 #if defined(CONFIG_MEMCG_KMEM) || defined(CONFIG_MMU)
 static void cancel_charge(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
@@ -2910,7 +2806,28 @@ static void commit_charge(struct page *page, struct mem_cgroup *memcg)
        page->memcg_data = (unsigned long)memcg;
 }
 
+static struct mem_cgroup *get_mem_cgroup_from_objcg(struct obj_cgroup *objcg)
+{
+       struct mem_cgroup *memcg;
+
+       rcu_read_lock();
+retry:
+       memcg = obj_cgroup_memcg(objcg);
+       if (unlikely(!css_tryget(&memcg->css)))
+               goto retry;
+       rcu_read_unlock();
+
+       return memcg;
+}
+
 #ifdef CONFIG_MEMCG_KMEM
+/*
+ * The allocated objcg pointers array is not accounted directly.
+ * Moreover, it should not come from DMA buffer and is not readily
+ * reclaimable. So those GFP bits should be masked off.
+ */
+#define OBJCGS_CLEAR_MASK      (__GFP_DMA | __GFP_RECLAIMABLE | __GFP_ACCOUNT)
+
 int memcg_alloc_page_obj_cgroups(struct page *page, struct kmem_cache *s,
                                 gfp_t gfp, bool new_page)
 {
@@ -2918,6 +2835,7 @@ int memcg_alloc_page_obj_cgroups(struct page *page, struct kmem_cache *s,
        unsigned long memcg_data;
        void *vec;
 
+       gfp &= ~OBJCGS_CLEAR_MASK;
        vec = kcalloc_node(objects, sizeof(struct obj_cgroup *), gfp,
                           page_to_nid(page));
        if (!vec)
@@ -3061,23 +2979,45 @@ static void memcg_free_cache_id(int id)
        ida_simple_remove(&memcg_cache_ida, id);
 }
 
-/**
- * __memcg_kmem_charge: charge a number of kernel pages to a memcg
- * @memcg: memory cgroup to charge
+/*
+ * obj_cgroup_uncharge_pages: uncharge a number of kernel pages from a objcg
+ * @objcg: object cgroup to uncharge
+ * @nr_pages: number of pages to uncharge
+ */
+static void obj_cgroup_uncharge_pages(struct obj_cgroup *objcg,
+                                     unsigned int nr_pages)
+{
+       struct mem_cgroup *memcg;
+
+       memcg = get_mem_cgroup_from_objcg(objcg);
+
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+               page_counter_uncharge(&memcg->kmem, nr_pages);
+       refill_stock(memcg, nr_pages);
+
+       css_put(&memcg->css);
+}
+
+/*
+ * obj_cgroup_charge_pages: charge a number of kernel pages to a objcg
+ * @objcg: object cgroup to charge
  * @gfp: reclaim mode
  * @nr_pages: number of pages to charge
  *
  * Returns 0 on success, an error code on failure.
  */
-static int __memcg_kmem_charge(struct mem_cgroup *memcg, gfp_t gfp,
-                              unsigned int nr_pages)
+static int obj_cgroup_charge_pages(struct obj_cgroup *objcg, gfp_t gfp,
+                                  unsigned int nr_pages)
 {
        struct page_counter *counter;
+       struct mem_cgroup *memcg;
        int ret;
 
-       ret = try_charge(memcg, gfp, nr_pages);
+       memcg = get_mem_cgroup_from_objcg(objcg);
+
+       ret = try_charge_memcg(memcg, gfp, nr_pages);
        if (ret)
-               return ret;
+               goto out;
 
        if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) &&
            !page_counter_try_charge(&memcg->kmem, nr_pages, &counter)) {
@@ -3089,25 +3029,15 @@ static int __memcg_kmem_charge(struct mem_cgroup *memcg, gfp_t gfp,
                 */
                if (gfp & __GFP_NOFAIL) {
                        page_counter_charge(&memcg->kmem, nr_pages);
-                       return 0;
+                       goto out;
                }
                cancel_charge(memcg, nr_pages);
-               return -ENOMEM;
+               ret = -ENOMEM;
        }
-       return 0;
-}
-
-/**
- * __memcg_kmem_uncharge: uncharge a number of kernel pages from a memcg
- * @memcg: memcg to uncharge
- * @nr_pages: number of pages to uncharge
- */
-static void __memcg_kmem_uncharge(struct mem_cgroup *memcg, unsigned int nr_pages)
-{
-       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
-               page_counter_uncharge(&memcg->kmem, nr_pages);
+out:
+       css_put(&memcg->css);
 
-       refill_stock(memcg, nr_pages);
+       return ret;
 }
 
 /**
@@ -3120,18 +3050,18 @@ static void __memcg_kmem_uncharge(struct mem_cgroup *memcg, unsigned int nr_page
  */
 int __memcg_kmem_charge_page(struct page *page, gfp_t gfp, int order)
 {
-       struct mem_cgroup *memcg;
+       struct obj_cgroup *objcg;
        int ret = 0;
 
-       memcg = get_mem_cgroup_from_current();
-       if (memcg && !mem_cgroup_is_root(memcg)) {
-               ret = __memcg_kmem_charge(memcg, gfp, 1 << order);
+       objcg = get_obj_cgroup_from_current();
+       if (objcg) {
+               ret = obj_cgroup_charge_pages(objcg, gfp, 1 << order);
                if (!ret) {
-                       page->memcg_data = (unsigned long)memcg |
+                       page->memcg_data = (unsigned long)objcg |
                                MEMCG_DATA_KMEM;
                        return 0;
                }
-               css_put(&memcg->css);
+               obj_cgroup_put(objcg);
        }
        return ret;
 }
@@ -3143,38 +3073,93 @@ int __memcg_kmem_charge_page(struct page *page, gfp_t gfp, int order)
  */
 void __memcg_kmem_uncharge_page(struct page *page, int order)
 {
-       struct mem_cgroup *memcg = page_memcg(page);
+       struct obj_cgroup *objcg;
        unsigned int nr_pages = 1 << order;
 
-       if (!memcg)
+       if (!PageMemcgKmem(page))
                return;
 
-       VM_BUG_ON_PAGE(mem_cgroup_is_root(memcg), page);
-       __memcg_kmem_uncharge(memcg, nr_pages);
+       objcg = __page_objcg(page);
+       obj_cgroup_uncharge_pages(objcg, nr_pages);
        page->memcg_data = 0;
-       css_put(&memcg->css);
+       obj_cgroup_put(objcg);
+}
+
+void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
+                    enum node_stat_item idx, int nr)
+{
+       unsigned long flags;
+       struct obj_stock *stock = get_obj_stock(&flags);
+       int *bytes;
+
+       /*
+        * Save vmstat data in stock and skip vmstat array update unless
+        * accumulating over a page of vmstat data or when pgdat or idx
+        * changes.
+        */
+       if (stock->cached_objcg != objcg) {
+               drain_obj_stock(stock);
+               obj_cgroup_get(objcg);
+               stock->nr_bytes = atomic_read(&objcg->nr_charged_bytes)
+                               ? atomic_xchg(&objcg->nr_charged_bytes, 0) : 0;
+               stock->cached_objcg = objcg;
+               stock->cached_pgdat = pgdat;
+       } else if (stock->cached_pgdat != pgdat) {
+               /* Flush the existing cached vmstat data */
+               if (stock->nr_slab_reclaimable_b) {
+                       mod_objcg_mlstate(objcg, pgdat, NR_SLAB_RECLAIMABLE_B,
+                                         stock->nr_slab_reclaimable_b);
+                       stock->nr_slab_reclaimable_b = 0;
+               }
+               if (stock->nr_slab_unreclaimable_b) {
+                       mod_objcg_mlstate(objcg, pgdat, NR_SLAB_UNRECLAIMABLE_B,
+                                         stock->nr_slab_unreclaimable_b);
+                       stock->nr_slab_unreclaimable_b = 0;
+               }
+               stock->cached_pgdat = pgdat;
+       }
+
+       bytes = (idx == NR_SLAB_RECLAIMABLE_B) ? &stock->nr_slab_reclaimable_b
+                                              : &stock->nr_slab_unreclaimable_b;
+       /*
+        * Even for large object >= PAGE_SIZE, the vmstat data will still be
+        * cached locally at least once before pushing it out.
+        */
+       if (!*bytes) {
+               *bytes = nr;
+               nr = 0;
+       } else {
+               *bytes += nr;
+               if (abs(*bytes) > PAGE_SIZE) {
+                       nr = *bytes;
+                       *bytes = 0;
+               } else {
+                       nr = 0;
+               }
+       }
+       if (nr)
+               mod_objcg_mlstate(objcg, pgdat, idx, nr);
+
+       put_obj_stock(flags);
 }
 
 static bool consume_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
 {
-       struct memcg_stock_pcp *stock;
        unsigned long flags;
+       struct obj_stock *stock = get_obj_stock(&flags);
        bool ret = false;
 
-       local_irq_save(flags);
-
-       stock = this_cpu_ptr(&memcg_stock);
        if (objcg == stock->cached_objcg && stock->nr_bytes >= nr_bytes) {
                stock->nr_bytes -= nr_bytes;
                ret = true;
        }
 
-       local_irq_restore(flags);
+       put_obj_stock(flags);
 
        return ret;
 }
 
-static void drain_obj_stock(struct memcg_stock_pcp *stock)
+static void drain_obj_stock(struct obj_stock *stock)
 {
        struct obj_cgroup *old = stock->cached_objcg;
 
@@ -3185,11 +3170,8 @@ static void drain_obj_stock(struct memcg_stock_pcp *stock)
                unsigned int nr_pages = stock->nr_bytes >> PAGE_SHIFT;
                unsigned int nr_bytes = stock->nr_bytes & (PAGE_SIZE - 1);
 
-               if (nr_pages) {
-                       rcu_read_lock();
-                       __memcg_kmem_uncharge(obj_cgroup_memcg(old), nr_pages);
-                       rcu_read_unlock();
-               }
+               if (nr_pages)
+                       obj_cgroup_uncharge_pages(old, nr_pages);
 
                /*
                 * The leftover is flushed to the centralized per-memcg value.
@@ -3205,6 +3187,25 @@ static void drain_obj_stock(struct memcg_stock_pcp *stock)
                stock->nr_bytes = 0;
        }
 
+       /*
+        * Flush the vmstat data in current stock
+        */
+       if (stock->nr_slab_reclaimable_b || stock->nr_slab_unreclaimable_b) {
+               if (stock->nr_slab_reclaimable_b) {
+                       mod_objcg_mlstate(old, stock->cached_pgdat,
+                                         NR_SLAB_RECLAIMABLE_B,
+                                         stock->nr_slab_reclaimable_b);
+                       stock->nr_slab_reclaimable_b = 0;
+               }
+               if (stock->nr_slab_unreclaimable_b) {
+                       mod_objcg_mlstate(old, stock->cached_pgdat,
+                                         NR_SLAB_UNRECLAIMABLE_B,
+                                         stock->nr_slab_unreclaimable_b);
+                       stock->nr_slab_unreclaimable_b = 0;
+               }
+               stock->cached_pgdat = NULL;
+       }
+
        obj_cgroup_put(old);
        stock->cached_objcg = NULL;
 }
@@ -3214,8 +3215,13 @@ static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
 {
        struct mem_cgroup *memcg;
 
-       if (stock->cached_objcg) {
-               memcg = obj_cgroup_memcg(stock->cached_objcg);
+       if (in_task() && stock->task_obj.cached_objcg) {
+               memcg = obj_cgroup_memcg(stock->task_obj.cached_objcg);
+               if (memcg && mem_cgroup_is_descendant(memcg, root_memcg))
+                       return true;
+       }
+       if (stock->irq_obj.cached_objcg) {
+               memcg = obj_cgroup_memcg(stock->irq_obj.cached_objcg);
                if (memcg && mem_cgroup_is_descendant(memcg, root_memcg))
                        return true;
        }
@@ -3223,31 +3229,36 @@ static bool obj_stock_flush_required(struct memcg_stock_pcp *stock,
        return false;
 }
 
-static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes)
+static void refill_obj_stock(struct obj_cgroup *objcg, unsigned int nr_bytes,
+                            bool allow_uncharge)
 {
-       struct memcg_stock_pcp *stock;
        unsigned long flags;
+       struct obj_stock *stock = get_obj_stock(&flags);
+       unsigned int nr_pages = 0;
 
-       local_irq_save(flags);
-
-       stock = this_cpu_ptr(&memcg_stock);
        if (stock->cached_objcg != objcg) { /* reset if necessary */
                drain_obj_stock(stock);
                obj_cgroup_get(objcg);
                stock->cached_objcg = objcg;
-               stock->nr_bytes = atomic_xchg(&objcg->nr_charged_bytes, 0);
+               stock->nr_bytes = atomic_read(&objcg->nr_charged_bytes)
+                               ? atomic_xchg(&objcg->nr_charged_bytes, 0) : 0;
+               allow_uncharge = true;  /* Allow uncharge when objcg changes */
        }
        stock->nr_bytes += nr_bytes;
 
-       if (stock->nr_bytes > PAGE_SIZE)
-               drain_obj_stock(stock);
+       if (allow_uncharge && (stock->nr_bytes > PAGE_SIZE)) {
+               nr_pages = stock->nr_bytes >> PAGE_SHIFT;
+               stock->nr_bytes &= (PAGE_SIZE - 1);
+       }
 
-       local_irq_restore(flags);
+       put_obj_stock(flags);
+
+       if (nr_pages)
+               obj_cgroup_uncharge_pages(objcg, nr_pages);
 }
 
 int obj_cgroup_charge(struct obj_cgroup *objcg, gfp_t gfp, size_t size)
 {
-       struct mem_cgroup *memcg;
        unsigned int nr_pages, nr_bytes;
        int ret;
 
@@ -3255,39 +3266,44 @@ int obj_cgroup_charge(struct obj_cgroup *objcg, gfp_t gfp, size_t size)
                return 0;
 
        /*
-        * In theory, memcg->nr_charged_bytes can have enough
+        * In theory, objcg->nr_charged_bytes can have enough
         * pre-charged bytes to satisfy the allocation. However,
-        * flushing memcg->nr_charged_bytes requires two atomic
-        * operations, and memcg->nr_charged_bytes can't be big,
-        * so it's better to ignore it and try grab some new pages.
-        * memcg->nr_charged_bytes will be flushed in
-        * refill_obj_stock(), called from this function or
-        * independently later.
+        * flushing objcg->nr_charged_bytes requires two atomic
+        * operations, and objcg->nr_charged_bytes can't be big.
+        * The shared objcg->nr_charged_bytes can also become a
+        * performance bottleneck if all tasks of the same memcg are
+        * trying to update it. So it's better to ignore it and try
+        * grab some new pages. The stock's nr_bytes will be flushed to
+        * objcg->nr_charged_bytes later on when objcg changes.
+        *
+        * The stock's nr_bytes may contain enough pre-charged bytes
+        * to allow one less page from being charged, but we can't rely
+        * on the pre-charged bytes not being changed outside of
+        * consume_obj_stock() or refill_obj_stock(). So ignore those
+        * pre-charged bytes as well when charging pages. To avoid a
+        * page uncharge right after a page charge, we set the
+        * allow_uncharge flag to false when calling refill_obj_stock()
+        * to temporarily allow the pre-charged bytes to exceed the page
+        * size limit. The maximum reachable value of the pre-charged
+        * bytes is (sizeof(object) + PAGE_SIZE - 2) if there is no data
+        * race.
         */
-       rcu_read_lock();
-retry:
-       memcg = obj_cgroup_memcg(objcg);
-       if (unlikely(!css_tryget(&memcg->css)))
-               goto retry;
-       rcu_read_unlock();
-
        nr_pages = size >> PAGE_SHIFT;
        nr_bytes = size & (PAGE_SIZE - 1);
 
        if (nr_bytes)
                nr_pages += 1;
 
-       ret = __memcg_kmem_charge(memcg, gfp, nr_pages);
+       ret = obj_cgroup_charge_pages(objcg, gfp, nr_pages);
        if (!ret && nr_bytes)
-               refill_obj_stock(objcg, PAGE_SIZE - nr_bytes);
+               refill_obj_stock(objcg, PAGE_SIZE - nr_bytes, false);
 
-       css_put(&memcg->css);
        return ret;
 }
 
 void obj_cgroup_uncharge(struct obj_cgroup *objcg, size_t size)
 {
-       refill_obj_stock(objcg, size);
+       refill_obj_stock(objcg, size, true);
 }
 
 #endif /* CONFIG_MEMCG_KMEM */
@@ -3305,7 +3321,11 @@ void split_page_memcg(struct page *head, unsigned int nr)
 
        for (i = 1; i < nr; i++)
                head[i].memcg_data = head->memcg_data;
-       css_get_many(&memcg->css, nr - 1);
+
+       if (PageMemcgKmem(head))
+               obj_cgroup_get_many(__page_objcg(head), nr - 1);
+       else
+               css_get_many(&memcg->css, nr - 1);
 }
 
 #ifdef CONFIG_MEMCG_SWAP
@@ -3554,6 +3574,7 @@ static unsigned long mem_cgroup_usage(struct mem_cgroup *memcg, bool swap)
        unsigned long val;
 
        if (mem_cgroup_is_root(memcg)) {
+               cgroup_rstat_flush(memcg->css.cgroup);
                val = memcg_page_state(memcg, NR_FILE_PAGES) +
                        memcg_page_state(memcg, NR_ANON_MAPPED);
                if (swap)
@@ -3618,57 +3639,6 @@ static u64 mem_cgroup_read_u64(struct cgroup_subsys_state *css,
        }
 }
 
-static void memcg_flush_percpu_vmstats(struct mem_cgroup *memcg)
-{
-       unsigned long stat[MEMCG_NR_STAT] = {0};
-       struct mem_cgroup *mi;
-       int node, cpu, i;
-
-       for_each_online_cpu(cpu)
-               for (i = 0; i < MEMCG_NR_STAT; i++)
-                       stat[i] += per_cpu(memcg->vmstats_percpu->stat[i], cpu);
-
-       for (mi = memcg; mi; mi = parent_mem_cgroup(mi))
-               for (i = 0; i < MEMCG_NR_STAT; i++)
-                       atomic_long_add(stat[i], &mi->vmstats[i]);
-
-       for_each_node(node) {
-               struct mem_cgroup_per_node *pn = memcg->nodeinfo[node];
-               struct mem_cgroup_per_node *pi;
-
-               for (i = 0; i < NR_VM_NODE_STAT_ITEMS; i++)
-                       stat[i] = 0;
-
-               for_each_online_cpu(cpu)
-                       for (i = 0; i < NR_VM_NODE_STAT_ITEMS; i++)
-                               stat[i] += per_cpu(
-                                       pn->lruvec_stat_cpu->count[i], cpu);
-
-               for (pi = pn; pi; pi = parent_nodeinfo(pi, node))
-                       for (i = 0; i < NR_VM_NODE_STAT_ITEMS; i++)
-                               atomic_long_add(stat[i], &pi->lruvec_stat[i]);
-       }
-}
-
-static void memcg_flush_percpu_vmevents(struct mem_cgroup *memcg)
-{
-       unsigned long events[NR_VM_EVENT_ITEMS];
-       struct mem_cgroup *mi;
-       int cpu, i;
-
-       for (i = 0; i < NR_VM_EVENT_ITEMS; i++)
-               events[i] = 0;
-
-       for_each_online_cpu(cpu)
-               for (i = 0; i < NR_VM_EVENT_ITEMS; i++)
-                       events[i] += per_cpu(memcg->vmstats_percpu->events[i],
-                                            cpu);
-
-       for (mi = memcg; mi; mi = parent_mem_cgroup(mi))
-               for (i = 0; i < NR_VM_EVENT_ITEMS; i++)
-                       atomic_long_add(events[i], &mi->vmevents[i]);
-}
-
 #ifdef CONFIG_MEMCG_KMEM
 static int memcg_online_kmem(struct mem_cgroup *memcg)
 {
@@ -3985,6 +3955,8 @@ static int memcg_numa_stat_show(struct seq_file *m, void *v)
        int nid;
        struct mem_cgroup *memcg = mem_cgroup_from_seq(m);
 
+       cgroup_rstat_flush(memcg->css.cgroup);
+
        for (stat = stats; stat < stats + ARRAY_SIZE(stats); stat++) {
                seq_printf(m, "%s=%lu", stat->name,
                           mem_cgroup_nr_lru_pages(memcg, stat->lru_mask,
@@ -4055,6 +4027,8 @@ static int memcg_stat_show(struct seq_file *m, void *v)
 
        BUILD_BUG_ON(ARRAY_SIZE(memcg1_stat_names) != ARRAY_SIZE(memcg1_stats));
 
+       cgroup_rstat_flush(memcg->css.cgroup);
+
        for (i = 0; i < ARRAY_SIZE(memcg1_stats); i++) {
                unsigned long nr;
 
@@ -4113,7 +4087,7 @@ static int memcg_stat_show(struct seq_file *m, void *v)
                unsigned long file_cost = 0;
 
                for_each_online_pgdat(pgdat) {
-                       mz = mem_cgroup_nodeinfo(memcg, pgdat->node_id);
+                       mz = memcg->nodeinfo[pgdat->node_id];
 
                        anon_cost += mz->lruvec.anon_cost;
                        file_cost += mz->lruvec.file_cost;
@@ -4142,7 +4116,7 @@ static int mem_cgroup_swappiness_write(struct cgroup_subsys_state *css,
        if (val > 100)
                return -EINVAL;
 
-       if (css->parent)
+       if (!mem_cgroup_is_root(memcg))
                memcg->swappiness = val;
        else
                vm_swappiness = val;
@@ -4492,7 +4466,7 @@ static int mem_cgroup_oom_control_write(struct cgroup_subsys_state *css,
        struct mem_cgroup *memcg = mem_cgroup_from_css(css);
 
        /* cannot set to root cgroup and only 0 and 1 are allowed */
-       if (!css->parent || !((val == 0) || (val == 1)))
+       if (mem_cgroup_is_root(memcg) || !((val == 0) || (val == 1)))
                return -EINVAL;
 
        memcg->oom_kill_disable = val;
@@ -4531,22 +4505,6 @@ struct wb_domain *mem_cgroup_wb_domain(struct bdi_writeback *wb)
        return &memcg->cgwb_domain;
 }
 
-/*
- * idx can be of type enum memcg_stat_item or node_stat_item.
- * Keep in sync with memcg_exact_page().
- */
-static unsigned long memcg_exact_page_state(struct mem_cgroup *memcg, int idx)
-{
-       long x = atomic_long_read(&memcg->vmstats[idx]);
-       int cpu;
-
-       for_each_online_cpu(cpu)
-               x += per_cpu_ptr(memcg->vmstats_percpu, cpu)->stat[idx];
-       if (x < 0)
-               x = 0;
-       return x;
-}
-
 /**
  * mem_cgroup_wb_stats - retrieve writeback related stats from its memcg
  * @wb: bdi_writeback in question
@@ -4572,13 +4530,14 @@ void mem_cgroup_wb_stats(struct bdi_writeback *wb, unsigned long *pfilepages,
        struct mem_cgroup *memcg = mem_cgroup_from_css(wb->memcg_css);
        struct mem_cgroup *parent;
 
-       *pdirty = memcg_exact_page_state(memcg, NR_FILE_DIRTY);
+       cgroup_rstat_flush_irqsafe(memcg->css.cgroup);
 
-       *pwriteback = memcg_exact_page_state(memcg, NR_WRITEBACK);
-       *pfilepages = memcg_exact_page_state(memcg, NR_INACTIVE_FILE) +
-                       memcg_exact_page_state(memcg, NR_ACTIVE_FILE);
-       *pheadroom = PAGE_COUNTER_MAX;
+       *pdirty = memcg_page_state(memcg, NR_FILE_DIRTY);
+       *pwriteback = memcg_page_state(memcg, NR_WRITEBACK);
+       *pfilepages = memcg_page_state(memcg, NR_INACTIVE_FILE) +
+                       memcg_page_state(memcg, NR_ACTIVE_FILE);
 
+       *pheadroom = PAGE_COUNTER_MAX;
        while ((parent = parent_mem_cgroup(memcg))) {
                unsigned long ceiling = min(READ_ONCE(memcg->memory.max),
                                            READ_ONCE(memcg->memory.high));
@@ -4593,7 +4552,7 @@ void mem_cgroup_wb_stats(struct bdi_writeback *wb, unsigned long *pfilepages,
  * Foreign dirty flushing
  *
  * There's an inherent mismatch between memcg and writeback.  The former
- * trackes ownership per-page while the latter per-inode.  This was a
+ * tracks ownership per-page while the latter per-inode.  This was a
  * deliberate design decision because honoring per-page ownership in the
  * writeback path is complicated, may lead to higher CPU and IO overheads
  * and deemed unnecessary given that write-sharing an inode across
@@ -4608,9 +4567,9 @@ void mem_cgroup_wb_stats(struct bdi_writeback *wb, unsigned long *pfilepages,
  * triggering background writeback.  A will be slowed down without a way to
  * make writeback of the dirty pages happen.
  *
- * Conditions like the above can lead to a cgroup getting repatedly and
+ * Conditions like the above can lead to a cgroup getting repeatedly and
  * severely throttled after making some progress after each
- * dirty_expire_interval while the underyling IO device is almost
+ * dirty_expire_interval while the underlying IO device is almost
  * completely idle.
  *
  * Solving this problem completely requires matching the ownership tracking
@@ -5210,19 +5169,20 @@ static void __mem_cgroup_free(struct mem_cgroup *memcg)
        for_each_node(node)
                free_mem_cgroup_per_node_info(memcg, node);
        free_percpu(memcg->vmstats_percpu);
-       free_percpu(memcg->vmstats_local);
        kfree(memcg);
 }
 
 static void mem_cgroup_free(struct mem_cgroup *memcg)
 {
+       int cpu;
+
        memcg_wb_domain_exit(memcg);
        /*
-        * Flush percpu vmstats and vmevents to guarantee the value correctness
-        * on parent's and all ancestor levels.
+        * Flush percpu lruvec stats to guarantee the value
+        * correctness on parent's and all ancestor levels.
         */
-       memcg_flush_percpu_vmstats(memcg);
-       memcg_flush_percpu_vmevents(memcg);
+       for_each_online_cpu(cpu)
+               memcg_flush_lruvec_page_state(memcg, cpu);
        __mem_cgroup_free(memcg);
 }
 
@@ -5249,11 +5209,6 @@ static struct mem_cgroup *mem_cgroup_alloc(void)
                goto fail;
        }
 
-       memcg->vmstats_local = alloc_percpu_gfp(struct memcg_vmstats_percpu,
-                                               GFP_KERNEL_ACCOUNT);
-       if (!memcg->vmstats_local)
-               goto fail;
-
        memcg->vmstats_percpu = alloc_percpu_gfp(struct memcg_vmstats_percpu,
                                                 GFP_KERNEL_ACCOUNT);
        if (!memcg->vmstats_percpu)
@@ -5351,11 +5306,11 @@ static int mem_cgroup_css_online(struct cgroup_subsys_state *css)
        struct mem_cgroup *memcg = mem_cgroup_from_css(css);
 
        /*
-        * A memcg must be visible for memcg_expand_shrinker_maps()
+        * A memcg must be visible for expand_shrinker_info()
         * by the time the maps are allocated. So, we allocate maps
         * here, when for_each_mem_cgroup() can't skip it.
         */
-       if (memcg_alloc_shrinker_maps(memcg)) {
+       if (alloc_shrinker_info(memcg)) {
                mem_cgroup_id_remove(memcg);
                return -ENOMEM;
        }
@@ -5387,6 +5342,7 @@ static void mem_cgroup_css_offline(struct cgroup_subsys_state *css)
        page_counter_set_low(&memcg->memory, 0);
 
        memcg_offline_kmem(memcg);
+       reparent_shrinker_deferred(memcg);
        wb_memcg_offline(memcg);
 
        drain_all_stock(memcg);
@@ -5419,7 +5375,7 @@ static void mem_cgroup_css_free(struct cgroup_subsys_state *css)
        vmpressure_cleanup(&memcg->vmpressure);
        cancel_work_sync(&memcg->high_work);
        mem_cgroup_remove_from_trees(memcg);
-       memcg_free_shrinker_maps(memcg);
+       free_shrinker_info(memcg);
        memcg_free_kmem(memcg);
        mem_cgroup_free(memcg);
 }
@@ -5453,6 +5409,62 @@ static void mem_cgroup_css_reset(struct cgroup_subsys_state *css)
        memcg_wb_domain_size_changed(memcg);
 }
 
+static void mem_cgroup_css_rstat_flush(struct cgroup_subsys_state *css, int cpu)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(css);
+       struct mem_cgroup *parent = parent_mem_cgroup(memcg);
+       struct memcg_vmstats_percpu *statc;
+       long delta, v;
+       int i;
+
+       statc = per_cpu_ptr(memcg->vmstats_percpu, cpu);
+
+       for (i = 0; i < MEMCG_NR_STAT; i++) {
+               /*
+                * Collect the aggregated propagation counts of groups
+                * below us. We're in a per-cpu loop here and this is
+                * a global counter, so the first cycle will get them.
+                */
+               delta = memcg->vmstats.state_pending[i];
+               if (delta)
+                       memcg->vmstats.state_pending[i] = 0;
+
+               /* Add CPU changes on this level since the last flush */
+               v = READ_ONCE(statc->state[i]);
+               if (v != statc->state_prev[i]) {
+                       delta += v - statc->state_prev[i];
+                       statc->state_prev[i] = v;
+               }
+
+               if (!delta)
+                       continue;
+
+               /* Aggregate counts on this level and propagate upwards */
+               memcg->vmstats.state[i] += delta;
+               if (parent)
+                       parent->vmstats.state_pending[i] += delta;
+       }
+
+       for (i = 0; i < NR_VM_EVENT_ITEMS; i++) {
+               delta = memcg->vmstats.events_pending[i];
+               if (delta)
+                       memcg->vmstats.events_pending[i] = 0;
+
+               v = READ_ONCE(statc->events[i]);
+               if (v != statc->events_prev[i]) {
+                       delta += v - statc->events_prev[i];
+                       statc->events_prev[i] = v;
+               }
+
+               if (!delta)
+                       continue;
+
+               memcg->vmstats.events[i] += delta;
+               if (parent)
+                       parent->vmstats.events_pending[i] += delta;
+       }
+}
+
 #ifdef CONFIG_MMU
 /* Handlers for move charge at task migration. */
 static int mem_cgroup_do_precharge(unsigned long count)
@@ -5950,7 +5962,7 @@ static int mem_cgroup_can_attach(struct cgroup_taskset *tset)
                return 0;
 
        /*
-        * We are now commited to this value whatever it is. Changes in this
+        * We are now committed to this value whatever it is. Changes in this
         * tunable will only affect upcoming migrations, not the current one.
         * So we need to save it, and keep it going.
         */
@@ -6506,6 +6518,7 @@ struct cgroup_subsys memory_cgrp_subsys = {
        .css_released = mem_cgroup_css_released,
        .css_free = mem_cgroup_css_free,
        .css_reset = mem_cgroup_css_reset,
+       .css_rstat_flush = mem_cgroup_css_rstat_flush,
        .can_attach = mem_cgroup_can_attach,
        .cancel_attach = mem_cgroup_cancel_attach,
        .post_attach = mem_cgroup_move_task,
@@ -6688,6 +6701,27 @@ void mem_cgroup_calculate_protection(struct mem_cgroup *root,
                        atomic_long_read(&parent->memory.children_low_usage)));
 }
 
+static int __mem_cgroup_charge(struct page *page, struct mem_cgroup *memcg,
+                              gfp_t gfp)
+{
+       unsigned int nr_pages = thp_nr_pages(page);
+       int ret;
+
+       ret = try_charge(memcg, gfp, nr_pages);
+       if (ret)
+               goto out;
+
+       css_get(&memcg->css);
+       commit_charge(page, memcg);
+
+       local_irq_disable();
+       mem_cgroup_charge_statistics(memcg, page, nr_pages);
+       memcg_check_events(memcg, page);
+       local_irq_enable();
+out:
+       return ret;
+}
+
 /**
  * mem_cgroup_charge - charge a newly allocated page to a cgroup
  * @page: page to charge
@@ -6695,57 +6729,74 @@ void mem_cgroup_calculate_protection(struct mem_cgroup *root,
  * @gfp_mask: reclaim mode
  *
  * Try to charge @page to the memcg that @mm belongs to, reclaiming
- * pages according to @gfp_mask if necessary.
+ * pages according to @gfp_mask if necessary. if @mm is NULL, try to
+ * charge to the active memcg.
+ *
+ * Do not use this for pages allocated for swapin.
  *
  * Returns 0 on success. Otherwise, an error code is returned.
  */
 int mem_cgroup_charge(struct page *page, struct mm_struct *mm, gfp_t gfp_mask)
 {
-       unsigned int nr_pages = thp_nr_pages(page);
-       struct mem_cgroup *memcg = NULL;
-       int ret = 0;
+       struct mem_cgroup *memcg;
+       int ret;
 
        if (mem_cgroup_disabled())
-               goto out;
+               return 0;
 
-       if (PageSwapCache(page)) {
-               swp_entry_t ent = { .val = page_private(page), };
-               unsigned short id;
+       memcg = get_mem_cgroup_from_mm(mm);
+       ret = __mem_cgroup_charge(page, memcg, gfp_mask);
+       css_put(&memcg->css);
 
-               /*
-                * Every swap fault against a single page tries to charge the
-                * page, bail as early as possible.  shmem_unuse() encounters
-                * already charged pages, too.  page and memcg binding is
-                * protected by the page lock, which serializes swap cache
-                * removal, which in turn serializes uncharging.
-                */
-               VM_BUG_ON_PAGE(!PageLocked(page), page);
-               if (page_memcg(compound_head(page)))
-                       goto out;
+       return ret;
+}
 
-               id = lookup_swap_cgroup_id(ent);
-               rcu_read_lock();
-               memcg = mem_cgroup_from_id(id);
-               if (memcg && !css_tryget_online(&memcg->css))
-                       memcg = NULL;
-               rcu_read_unlock();
-       }
+/**
+ * mem_cgroup_swapin_charge_page - charge a newly allocated page for swapin
+ * @page: page to charge
+ * @mm: mm context of the victim
+ * @gfp: reclaim mode
+ * @entry: swap entry for which the page is allocated
+ *
+ * This function charges a page allocated for swapin. Please call this before
+ * adding the page to the swapcache.
+ *
+ * Returns 0 on success. Otherwise, an error code is returned.
+ */
+int mem_cgroup_swapin_charge_page(struct page *page, struct mm_struct *mm,
+                                 gfp_t gfp, swp_entry_t entry)
+{
+       struct mem_cgroup *memcg;
+       unsigned short id;
+       int ret;
 
-       if (!memcg)
-               memcg = get_mem_cgroup_from_mm(mm);
+       if (mem_cgroup_disabled())
+               return 0;
 
-       ret = try_charge(memcg, gfp_mask, nr_pages);
-       if (ret)
-               goto out_put;
+       id = lookup_swap_cgroup_id(entry);
+       rcu_read_lock();
+       memcg = mem_cgroup_from_id(id);
+       if (!memcg || !css_tryget_online(&memcg->css))
+               memcg = get_mem_cgroup_from_mm(mm);
+       rcu_read_unlock();
 
-       css_get(&memcg->css);
-       commit_charge(page, memcg);
+       ret = __mem_cgroup_charge(page, memcg, gfp);
 
-       local_irq_disable();
-       mem_cgroup_charge_statistics(memcg, page, nr_pages);
-       memcg_check_events(memcg, page);
-       local_irq_enable();
+       css_put(&memcg->css);
+       return ret;
+}
 
+/*
+ * mem_cgroup_swapin_uncharge_swap - uncharge swap slot
+ * @entry: swap entry for which the page is charged
+ *
+ * Call this function after successfully adding the charged page to swapcache.
+ *
+ * Note: This function assumes the page for which swap slot is being uncharged
+ * is order 0 page.
+ */
+void mem_cgroup_swapin_uncharge_swap(swp_entry_t entry)
+{
        /*
         * Cgroup1's unified memory+swap counter has been charged with the
         * new swapcache page, finish the transfer by uncharging the swap
@@ -6758,25 +6809,19 @@ int mem_cgroup_charge(struct page *page, struct mm_struct *mm, gfp_t gfp_mask)
         * correspond 1:1 to page and swap slot lifetimes: we charge the
         * page to memory here, and uncharge swap when the slot is freed.
         */
-       if (do_memsw_account() && PageSwapCache(page)) {
-               swp_entry_t entry = { .val = page_private(page) };
+       if (!mem_cgroup_disabled() && do_memsw_account()) {
                /*
                 * The swap entry might not get freed for a long time,
                 * let's not wait for it.  The page already received a
                 * memory+swap charge, drop the swap entry duplicate.
                 */
-               mem_cgroup_uncharge_swap(entry, nr_pages);
+               mem_cgroup_uncharge_swap(entry, 1);
        }
-
-out_put:
-       css_put(&memcg->css);
-out:
-       return ret;
 }
 
 struct uncharge_gather {
        struct mem_cgroup *memcg;
-       unsigned long nr_pages;
+       unsigned long nr_memory;
        unsigned long pgpgout;
        unsigned long nr_kmem;
        struct page *dummy_page;
@@ -6791,10 +6836,10 @@ static void uncharge_batch(const struct uncharge_gather *ug)
 {
        unsigned long flags;
 
-       if (!mem_cgroup_is_root(ug->memcg)) {
-               page_counter_uncharge(&ug->memcg->memory, ug->nr_pages);
+       if (ug->nr_memory) {
+               page_counter_uncharge(&ug->memcg->memory, ug->nr_memory);
                if (do_memsw_account())
-                       page_counter_uncharge(&ug->memcg->memsw, ug->nr_pages);
+                       page_counter_uncharge(&ug->memcg->memsw, ug->nr_memory);
                if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) && ug->nr_kmem)
                        page_counter_uncharge(&ug->memcg->kmem, ug->nr_kmem);
                memcg_oom_recover(ug->memcg);
@@ -6802,7 +6847,7 @@ static void uncharge_batch(const struct uncharge_gather *ug)
 
        local_irq_save(flags);
        __count_memcg_events(ug->memcg, PGPGOUT, ug->pgpgout);
-       __this_cpu_add(ug->memcg->vmstats_percpu->nr_page_events, ug->nr_pages);
+       __this_cpu_add(ug->memcg->vmstats_percpu->nr_page_events, ug->nr_memory);
        memcg_check_events(ug->memcg, ug->dummy_page);
        local_irq_restore(flags);
 
@@ -6813,40 +6858,61 @@ static void uncharge_batch(const struct uncharge_gather *ug)
 static void uncharge_page(struct page *page, struct uncharge_gather *ug)
 {
        unsigned long nr_pages;
+       struct mem_cgroup *memcg;
+       struct obj_cgroup *objcg;
+       bool use_objcg = PageMemcgKmem(page);
 
        VM_BUG_ON_PAGE(PageLRU(page), page);
 
-       if (!page_memcg(page))
-               return;
-
        /*
         * Nobody should be changing or seriously looking at
-        * page_memcg(page) at this point, we have fully
+        * page memcg or objcg at this point, we have fully
         * exclusive access to the page.
         */
+       if (use_objcg) {
+               objcg = __page_objcg(page);
+               /*
+                * This get matches the put at the end of the function and
+                * kmem pages do not hold memcg references anymore.
+                */
+               memcg = get_mem_cgroup_from_objcg(objcg);
+       } else {
+               memcg = __page_memcg(page);
+       }
+
+       if (!memcg)
+               return;
 
-       if (ug->memcg != page_memcg(page)) {
+       if (ug->memcg != memcg) {
                if (ug->memcg) {
                        uncharge_batch(ug);
                        uncharge_gather_clear(ug);
                }
-               ug->memcg = page_memcg(page);
+               ug->memcg = memcg;
+               ug->dummy_page = page;
 
                /* pairs with css_put in uncharge_batch */
-               css_get(&ug->memcg->css);
+               css_get(&memcg->css);
        }
 
        nr_pages = compound_nr(page);
-       ug->nr_pages += nr_pages;
 
-       if (PageMemcgKmem(page))
+       if (use_objcg) {
+               ug->nr_memory += nr_pages;
                ug->nr_kmem += nr_pages;
-       else
+
+               page->memcg_data = 0;
+               obj_cgroup_put(objcg);
+       } else {
+               /* LRU pages aren't accounted at the root level */
+               if (!mem_cgroup_is_root(memcg))
+                       ug->nr_memory += nr_pages;
                ug->pgpgout++;
 
-       ug->dummy_page = page;
-       page->memcg_data = 0;
-       css_put(&ug->memcg->css);
+               page->memcg_data = 0;
+       }
+
+       css_put(&memcg->css);
 }
 
 /**
@@ -6930,9 +6996,11 @@ void mem_cgroup_migrate(struct page *oldpage, struct page *newpage)
        /* Force-charge the new page. The old one will be freed soon */
        nr_pages = thp_nr_pages(newpage);
 
-       page_counter_charge(&memcg->memory, nr_pages);
-       if (do_memsw_account())
-               page_counter_charge(&memcg->memsw, nr_pages);
+       if (!mem_cgroup_is_root(memcg)) {
+               page_counter_charge(&memcg->memory, nr_pages);
+               if (do_memsw_account())
+                       page_counter_charge(&memcg->memsw, nr_pages);
+       }
 
        css_get(&memcg->css);
        commit_charge(newpage, memcg);