Merge tag 'loadpin-security-next' of https://git.kernel.org/pub/scm/linux/kernel...
[linux-2.6-block.git] / mm / memcontrol.c
index 6a921890739f7e1c0acad44d5916c5f43487a2a1..4ead5a4817de3ffbf1477cfe77b8f4f4802281c8 100644 (file)
@@ -1776,6 +1776,62 @@ cleanup:
        return true;
 }
 
+/**
+ * mem_cgroup_get_oom_group - get a memory cgroup to clean up after OOM
+ * @victim: task to be killed by the OOM killer
+ * @oom_domain: memcg in case of memcg OOM, NULL in case of system-wide OOM
+ *
+ * Returns a pointer to a memory cgroup, which has to be cleaned up
+ * by killing all belonging OOM-killable tasks.
+ *
+ * Caller has to call mem_cgroup_put() on the returned non-NULL memcg.
+ */
+struct mem_cgroup *mem_cgroup_get_oom_group(struct task_struct *victim,
+                                           struct mem_cgroup *oom_domain)
+{
+       struct mem_cgroup *oom_group = NULL;
+       struct mem_cgroup *memcg;
+
+       if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+               return NULL;
+
+       if (!oom_domain)
+               oom_domain = root_mem_cgroup;
+
+       rcu_read_lock();
+
+       memcg = mem_cgroup_from_task(victim);
+       if (memcg == root_mem_cgroup)
+               goto out;
+
+       /*
+        * Traverse the memory cgroup hierarchy from the victim task's
+        * cgroup up to the OOMing cgroup (or root) to find the
+        * highest-level memory cgroup with oom.group set.
+        */
+       for (; memcg; memcg = parent_mem_cgroup(memcg)) {
+               if (memcg->oom_group)
+                       oom_group = memcg;
+
+               if (memcg == oom_domain)
+                       break;
+       }
+
+       if (oom_group)
+               css_get(&oom_group->css);
+out:
+       rcu_read_unlock();
+
+       return oom_group;
+}
+
+void mem_cgroup_print_oom_group(struct mem_cgroup *memcg)
+{
+       pr_info("Tasks in ");
+       pr_cont_cgroup_path(memcg->css.cgroup);
+       pr_cont(" are going to be killed due to memory.oom.group set\n");
+}
+
 /**
  * lock_page_memcg - lock a page->mem_cgroup binding
  * @page: the page
@@ -2899,29 +2955,34 @@ static int mem_cgroup_hierarchy_write(struct cgroup_subsys_state *css,
        return retval;
 }
 
-static void tree_stat(struct mem_cgroup *memcg, unsigned long *stat)
-{
-       struct mem_cgroup *iter;
-       int i;
-
-       memset(stat, 0, sizeof(*stat) * MEMCG_NR_STAT);
-
-       for_each_mem_cgroup_tree(iter, memcg) {
-               for (i = 0; i < MEMCG_NR_STAT; i++)
-                       stat[i] += memcg_page_state(iter, i);
-       }
-}
+struct accumulated_stats {
+       unsigned long stat[MEMCG_NR_STAT];
+       unsigned long events[NR_VM_EVENT_ITEMS];
+       unsigned long lru_pages[NR_LRU_LISTS];
+       const unsigned int *stats_array;
+       const unsigned int *events_array;
+       int stats_size;
+       int events_size;
+};
 
-static void tree_events(struct mem_cgroup *memcg, unsigned long *events)
+static void accumulate_memcg_tree(struct mem_cgroup *memcg,
+                                 struct accumulated_stats *acc)
 {
-       struct mem_cgroup *iter;
+       struct mem_cgroup *mi;
        int i;
 
-       memset(events, 0, sizeof(*events) * NR_VM_EVENT_ITEMS);
+       for_each_mem_cgroup_tree(mi, memcg) {
+               for (i = 0; i < acc->stats_size; i++)
+                       acc->stat[i] += memcg_page_state(mi,
+                               acc->stats_array ? acc->stats_array[i] : i);
 
-       for_each_mem_cgroup_tree(iter, memcg) {
-               for (i = 0; i < NR_VM_EVENT_ITEMS; i++)
-                       events[i] += memcg_sum_events(iter, i);
+               for (i = 0; i < acc->events_size; i++)
+                       acc->events[i] += memcg_sum_events(mi,
+                               acc->events_array ? acc->events_array[i] : i);
+
+               for (i = 0; i < NR_LRU_LISTS; i++)
+                       acc->lru_pages[i] +=
+                               mem_cgroup_nr_lru_pages(mi, BIT(i));
        }
 }
 
@@ -3332,6 +3393,7 @@ static int memcg_stat_show(struct seq_file *m, void *v)
        unsigned long memory, memsw;
        struct mem_cgroup *mi;
        unsigned int i;
+       struct accumulated_stats acc;
 
        BUILD_BUG_ON(ARRAY_SIZE(memcg1_stat_names) != ARRAY_SIZE(memcg1_stats));
        BUILD_BUG_ON(ARRAY_SIZE(mem_cgroup_lru_names) != NR_LRU_LISTS);
@@ -3364,32 +3426,27 @@ static int memcg_stat_show(struct seq_file *m, void *v)
                seq_printf(m, "hierarchical_memsw_limit %llu\n",
                           (u64)memsw * PAGE_SIZE);
 
-       for (i = 0; i < ARRAY_SIZE(memcg1_stats); i++) {
-               unsigned long long val = 0;
+       memset(&acc, 0, sizeof(acc));
+       acc.stats_size = ARRAY_SIZE(memcg1_stats);
+       acc.stats_array = memcg1_stats;
+       acc.events_size = ARRAY_SIZE(memcg1_events);
+       acc.events_array = memcg1_events;
+       accumulate_memcg_tree(memcg, &acc);
 
+       for (i = 0; i < ARRAY_SIZE(memcg1_stats); i++) {
                if (memcg1_stats[i] == MEMCG_SWAP && !do_memsw_account())
                        continue;
-               for_each_mem_cgroup_tree(mi, memcg)
-                       val += memcg_page_state(mi, memcg1_stats[i]) *
-                       PAGE_SIZE;
-               seq_printf(m, "total_%s %llu\n", memcg1_stat_names[i], val);
+               seq_printf(m, "total_%s %llu\n", memcg1_stat_names[i],
+                          (u64)acc.stat[i] * PAGE_SIZE);
        }
 
-       for (i = 0; i < ARRAY_SIZE(memcg1_events); i++) {
-               unsigned long long val = 0;
-
-               for_each_mem_cgroup_tree(mi, memcg)
-                       val += memcg_sum_events(mi, memcg1_events[i]);
-               seq_printf(m, "total_%s %llu\n", memcg1_event_names[i], val);
-       }
-
-       for (i = 0; i < NR_LRU_LISTS; i++) {
-               unsigned long long val = 0;
+       for (i = 0; i < ARRAY_SIZE(memcg1_events); i++)
+               seq_printf(m, "total_%s %llu\n", memcg1_event_names[i],
+                          (u64)acc.events[i]);
 
-               for_each_mem_cgroup_tree(mi, memcg)
-                       val += mem_cgroup_nr_lru_pages(mi, BIT(i)) * PAGE_SIZE;
-               seq_printf(m, "total_%s %llu\n", mem_cgroup_lru_names[i], val);
-       }
+       for (i = 0; i < NR_LRU_LISTS; i++)
+               seq_printf(m, "total_%s %llu\n", mem_cgroup_lru_names[i],
+                          (u64)acc.lru_pages[i] * PAGE_SIZE);
 
 #ifdef CONFIG_DEBUG_VM
        {
@@ -5486,8 +5543,7 @@ static int memory_events_show(struct seq_file *m, void *v)
 static int memory_stat_show(struct seq_file *m, void *v)
 {
        struct mem_cgroup *memcg = mem_cgroup_from_css(seq_css(m));
-       unsigned long stat[MEMCG_NR_STAT];
-       unsigned long events[NR_VM_EVENT_ITEMS];
+       struct accumulated_stats acc;
        int i;
 
        /*
@@ -5501,70 +5557,97 @@ static int memory_stat_show(struct seq_file *m, void *v)
         * Current memory state:
         */
 
-       tree_stat(memcg, stat);
-       tree_events(memcg, events);
+       memset(&acc, 0, sizeof(acc));
+       acc.stats_size = MEMCG_NR_STAT;
+       acc.events_size = NR_VM_EVENT_ITEMS;
+       accumulate_memcg_tree(memcg, &acc);
 
        seq_printf(m, "anon %llu\n",
-                  (u64)stat[MEMCG_RSS] * PAGE_SIZE);
+                  (u64)acc.stat[MEMCG_RSS] * PAGE_SIZE);
        seq_printf(m, "file %llu\n",
-                  (u64)stat[MEMCG_CACHE] * PAGE_SIZE);
+                  (u64)acc.stat[MEMCG_CACHE] * PAGE_SIZE);
        seq_printf(m, "kernel_stack %llu\n",
-                  (u64)stat[MEMCG_KERNEL_STACK_KB] * 1024);
+                  (u64)acc.stat[MEMCG_KERNEL_STACK_KB] * 1024);
        seq_printf(m, "slab %llu\n",
-                  (u64)(stat[NR_SLAB_RECLAIMABLE] +
-                        stat[NR_SLAB_UNRECLAIMABLE]) * PAGE_SIZE);
+                  (u64)(acc.stat[NR_SLAB_RECLAIMABLE] +
+                        acc.stat[NR_SLAB_UNRECLAIMABLE]) * PAGE_SIZE);
        seq_printf(m, "sock %llu\n",
-                  (u64)stat[MEMCG_SOCK] * PAGE_SIZE);
+                  (u64)acc.stat[MEMCG_SOCK] * PAGE_SIZE);
 
        seq_printf(m, "shmem %llu\n",
-                  (u64)stat[NR_SHMEM] * PAGE_SIZE);
+                  (u64)acc.stat[NR_SHMEM] * PAGE_SIZE);
        seq_printf(m, "file_mapped %llu\n",
-                  (u64)stat[NR_FILE_MAPPED] * PAGE_SIZE);
+                  (u64)acc.stat[NR_FILE_MAPPED] * PAGE_SIZE);
        seq_printf(m, "file_dirty %llu\n",
-                  (u64)stat[NR_FILE_DIRTY] * PAGE_SIZE);
+                  (u64)acc.stat[NR_FILE_DIRTY] * PAGE_SIZE);
        seq_printf(m, "file_writeback %llu\n",
-                  (u64)stat[NR_WRITEBACK] * PAGE_SIZE);
-
-       for (i = 0; i < NR_LRU_LISTS; i++) {
-               struct mem_cgroup *mi;
-               unsigned long val = 0;
+                  (u64)acc.stat[NR_WRITEBACK] * PAGE_SIZE);
 
-               for_each_mem_cgroup_tree(mi, memcg)
-                       val += mem_cgroup_nr_lru_pages(mi, BIT(i));
-               seq_printf(m, "%s %llu\n",
-                          mem_cgroup_lru_names[i], (u64)val * PAGE_SIZE);
-       }
+       for (i = 0; i < NR_LRU_LISTS; i++)
+               seq_printf(m, "%s %llu\n", mem_cgroup_lru_names[i],
+                          (u64)acc.lru_pages[i] * PAGE_SIZE);
 
        seq_printf(m, "slab_reclaimable %llu\n",
-                  (u64)stat[NR_SLAB_RECLAIMABLE] * PAGE_SIZE);
+                  (u64)acc.stat[NR_SLAB_RECLAIMABLE] * PAGE_SIZE);
        seq_printf(m, "slab_unreclaimable %llu\n",
-                  (u64)stat[NR_SLAB_UNRECLAIMABLE] * PAGE_SIZE);
+                  (u64)acc.stat[NR_SLAB_UNRECLAIMABLE] * PAGE_SIZE);
 
        /* Accumulated memory events */
 
-       seq_printf(m, "pgfault %lu\n", events[PGFAULT]);
-       seq_printf(m, "pgmajfault %lu\n", events[PGMAJFAULT]);
+       seq_printf(m, "pgfault %lu\n", acc.events[PGFAULT]);
+       seq_printf(m, "pgmajfault %lu\n", acc.events[PGMAJFAULT]);
 
-       seq_printf(m, "pgrefill %lu\n", events[PGREFILL]);
-       seq_printf(m, "pgscan %lu\n", events[PGSCAN_KSWAPD] +
-                  events[PGSCAN_DIRECT]);
-       seq_printf(m, "pgsteal %lu\n", events[PGSTEAL_KSWAPD] +
-                  events[PGSTEAL_DIRECT]);
-       seq_printf(m, "pgactivate %lu\n", events[PGACTIVATE]);
-       seq_printf(m, "pgdeactivate %lu\n", events[PGDEACTIVATE]);
-       seq_printf(m, "pglazyfree %lu\n", events[PGLAZYFREE]);
-       seq_printf(m, "pglazyfreed %lu\n", events[PGLAZYFREED]);
+       seq_printf(m, "pgrefill %lu\n", acc.events[PGREFILL]);
+       seq_printf(m, "pgscan %lu\n", acc.events[PGSCAN_KSWAPD] +
+                  acc.events[PGSCAN_DIRECT]);
+       seq_printf(m, "pgsteal %lu\n", acc.events[PGSTEAL_KSWAPD] +
+                  acc.events[PGSTEAL_DIRECT]);
+       seq_printf(m, "pgactivate %lu\n", acc.events[PGACTIVATE]);
+       seq_printf(m, "pgdeactivate %lu\n", acc.events[PGDEACTIVATE]);
+       seq_printf(m, "pglazyfree %lu\n", acc.events[PGLAZYFREE]);
+       seq_printf(m, "pglazyfreed %lu\n", acc.events[PGLAZYFREED]);
 
        seq_printf(m, "workingset_refault %lu\n",
-                  stat[WORKINGSET_REFAULT]);
+                  acc.stat[WORKINGSET_REFAULT]);
        seq_printf(m, "workingset_activate %lu\n",
-                  stat[WORKINGSET_ACTIVATE]);
+                  acc.stat[WORKINGSET_ACTIVATE]);
        seq_printf(m, "workingset_nodereclaim %lu\n",
-                  stat[WORKINGSET_NODERECLAIM]);
+                  acc.stat[WORKINGSET_NODERECLAIM]);
+
+       return 0;
+}
+
+static int memory_oom_group_show(struct seq_file *m, void *v)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(seq_css(m));
+
+       seq_printf(m, "%d\n", memcg->oom_group);
 
        return 0;
 }
 
+static ssize_t memory_oom_group_write(struct kernfs_open_file *of,
+                                     char *buf, size_t nbytes, loff_t off)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_css(of_css(of));
+       int ret, oom_group;
+
+       buf = strstrip(buf);
+       if (!buf)
+               return -EINVAL;
+
+       ret = kstrtoint(buf, 0, &oom_group);
+       if (ret)
+               return ret;
+
+       if (oom_group != 0 && oom_group != 1)
+               return -EINVAL;
+
+       memcg->oom_group = oom_group;
+
+       return nbytes;
+}
+
 static struct cftype memory_files[] = {
        {
                .name = "current",
@@ -5606,6 +5689,12 @@ static struct cftype memory_files[] = {
                .flags = CFTYPE_NOT_ON_ROOT,
                .seq_show = memory_stat_show,
        },
+       {
+               .name = "oom.group",
+               .flags = CFTYPE_NOT_ON_ROOT | CFTYPE_NS_DELEGATABLE,
+               .seq_show = memory_oom_group_show,
+               .write = memory_oom_group_write,
+       },
        { }     /* terminate */
 };