mm/mempolicy: protect task interleave functions with tsk->mems_allowed_seq
[linux-2.6-block.git] / mm / mempolicy.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Simple NUMA memory policy for the Linux kernel.
4  *
5  * Copyright 2003,2004 Andi Kleen, SuSE Labs.
6  * (C) Copyright 2005 Christoph Lameter, Silicon Graphics, Inc.
7  *
8  * NUMA policy allows the user to give hints in which node(s) memory should
9  * be allocated.
10  *
11  * Support four policies per VMA and per process:
12  *
13  * The VMA policy has priority over the process policy for a page fault.
14  *
15  * interleave     Allocate memory interleaved over a set of nodes,
16  *                with normal fallback if it fails.
17  *                For VMA based allocations this interleaves based on the
18  *                offset into the backing object or offset into the mapping
19  *                for anonymous memory. For process policy an process counter
20  *                is used.
21  *
22  * weighted interleave
23  *                Allocate memory interleaved over a set of nodes based on
24  *                a set of weights (per-node), with normal fallback if it
25  *                fails.  Otherwise operates the same as interleave.
26  *                Example: nodeset(0,1) & weights (2,1) - 2 pages allocated
27  *                on node 0 for every 1 page allocated on node 1.
28  *
29  * bind           Only allocate memory on a specific set of nodes,
30  *                no fallback.
31  *                FIXME: memory is allocated starting with the first node
32  *                to the last. It would be better if bind would truly restrict
33  *                the allocation to memory nodes instead
34  *
35  * preferred      Try a specific node first before normal fallback.
36  *                As a special case NUMA_NO_NODE here means do the allocation
37  *                on the local CPU. This is normally identical to default,
38  *                but useful to set in a VMA when you have a non default
39  *                process policy.
40  *
41  * preferred many Try a set of nodes first before normal fallback. This is
42  *                similar to preferred without the special case.
43  *
44  * default        Allocate on the local node first, or when on a VMA
45  *                use the process policy. This is what Linux always did
46  *                in a NUMA aware kernel and still does by, ahem, default.
47  *
48  * The process policy is applied for most non interrupt memory allocations
49  * in that process' context. Interrupts ignore the policies and always
50  * try to allocate on the local CPU. The VMA policy is only applied for memory
51  * allocations for a VMA in the VM.
52  *
53  * Currently there are a few corner cases in swapping where the policy
54  * is not applied, but the majority should be handled. When process policy
55  * is used it is not remembered over swap outs/swap ins.
56  *
57  * Only the highest zone in the zone hierarchy gets policied. Allocations
58  * requesting a lower zone just use default policy. This implies that
59  * on systems with highmem kernel lowmem allocation don't get policied.
60  * Same with GFP_DMA allocations.
61  *
62  * For shmem/tmpfs shared memory the policy is shared between
63  * all users and remembered even when nobody has memory mapped.
64  */
65
66 /* Notebook:
67    fix mmap readahead to honour policy and enable policy for any page cache
68    object
69    statistics for bigpages
70    global policy for page cache? currently it uses process policy. Requires
71    first item above.
72    handle mremap for shared memory (currently ignored for the policy)
73    grows down?
74    make bind policy root only? It can trigger oom much faster and the
75    kernel is not always grateful with that.
76 */
77
78 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
79
80 #include <linux/mempolicy.h>
81 #include <linux/pagewalk.h>
82 #include <linux/highmem.h>
83 #include <linux/hugetlb.h>
84 #include <linux/kernel.h>
85 #include <linux/sched.h>
86 #include <linux/sched/mm.h>
87 #include <linux/sched/numa_balancing.h>
88 #include <linux/sched/task.h>
89 #include <linux/nodemask.h>
90 #include <linux/cpuset.h>
91 #include <linux/slab.h>
92 #include <linux/string.h>
93 #include <linux/export.h>
94 #include <linux/nsproxy.h>
95 #include <linux/interrupt.h>
96 #include <linux/init.h>
97 #include <linux/compat.h>
98 #include <linux/ptrace.h>
99 #include <linux/swap.h>
100 #include <linux/seq_file.h>
101 #include <linux/proc_fs.h>
102 #include <linux/migrate.h>
103 #include <linux/ksm.h>
104 #include <linux/rmap.h>
105 #include <linux/security.h>
106 #include <linux/syscalls.h>
107 #include <linux/ctype.h>
108 #include <linux/mm_inline.h>
109 #include <linux/mmu_notifier.h>
110 #include <linux/printk.h>
111 #include <linux/swapops.h>
112
113 #include <asm/tlbflush.h>
114 #include <asm/tlb.h>
115 #include <linux/uaccess.h>
116
117 #include "internal.h"
118
119 /* Internal flags */
120 #define MPOL_MF_DISCONTIG_OK (MPOL_MF_INTERNAL << 0)    /* Skip checks for continuous vmas */
121 #define MPOL_MF_INVERT       (MPOL_MF_INTERNAL << 1)    /* Invert check for nodemask */
122 #define MPOL_MF_WRLOCK       (MPOL_MF_INTERNAL << 2)    /* Write-lock walked vmas */
123
124 static struct kmem_cache *policy_cache;
125 static struct kmem_cache *sn_cache;
126
127 /* Highest zone. An specific allocation for a zone below that is not
128    policied. */
129 enum zone_type policy_zone = 0;
130
131 /*
132  * run-time system-wide default policy => local allocation
133  */
134 static struct mempolicy default_policy = {
135         .refcnt = ATOMIC_INIT(1), /* never free it */
136         .mode = MPOL_LOCAL,
137 };
138
139 static struct mempolicy preferred_node_policy[MAX_NUMNODES];
140
141 /*
142  * iw_table is the sysfs-set interleave weight table, a value of 0 denotes
143  * system-default value should be used. A NULL iw_table also denotes that
144  * system-default values should be used. Until the system-default table
145  * is implemented, the system-default is always 1.
146  *
147  * iw_table is RCU protected
148  */
149 static u8 __rcu *iw_table;
150 static DEFINE_MUTEX(iw_table_lock);
151
152 static u8 get_il_weight(int node)
153 {
154         u8 *table;
155         u8 weight;
156
157         rcu_read_lock();
158         table = rcu_dereference(iw_table);
159         /* if no iw_table, use system default */
160         weight = table ? table[node] : 1;
161         /* if value in iw_table is 0, use system default */
162         weight = weight ? weight : 1;
163         rcu_read_unlock();
164         return weight;
165 }
166
167 /**
168  * numa_nearest_node - Find nearest node by state
169  * @node: Node id to start the search
170  * @state: State to filter the search
171  *
172  * Lookup the closest node by distance if @nid is not in state.
173  *
174  * Return: this @node if it is in state, otherwise the closest node by distance
175  */
176 int numa_nearest_node(int node, unsigned int state)
177 {
178         int min_dist = INT_MAX, dist, n, min_node;
179
180         if (state >= NR_NODE_STATES)
181                 return -EINVAL;
182
183         if (node == NUMA_NO_NODE || node_state(node, state))
184                 return node;
185
186         min_node = node;
187         for_each_node_state(n, state) {
188                 dist = node_distance(node, n);
189                 if (dist < min_dist) {
190                         min_dist = dist;
191                         min_node = n;
192                 }
193         }
194
195         return min_node;
196 }
197 EXPORT_SYMBOL_GPL(numa_nearest_node);
198
199 struct mempolicy *get_task_policy(struct task_struct *p)
200 {
201         struct mempolicy *pol = p->mempolicy;
202         int node;
203
204         if (pol)
205                 return pol;
206
207         node = numa_node_id();
208         if (node != NUMA_NO_NODE) {
209                 pol = &preferred_node_policy[node];
210                 /* preferred_node_policy is not initialised early in boot */
211                 if (pol->mode)
212                         return pol;
213         }
214
215         return &default_policy;
216 }
217
218 static const struct mempolicy_operations {
219         int (*create)(struct mempolicy *pol, const nodemask_t *nodes);
220         void (*rebind)(struct mempolicy *pol, const nodemask_t *nodes);
221 } mpol_ops[MPOL_MAX];
222
223 static inline int mpol_store_user_nodemask(const struct mempolicy *pol)
224 {
225         return pol->flags & MPOL_MODE_FLAGS;
226 }
227
228 static void mpol_relative_nodemask(nodemask_t *ret, const nodemask_t *orig,
229                                    const nodemask_t *rel)
230 {
231         nodemask_t tmp;
232         nodes_fold(tmp, *orig, nodes_weight(*rel));
233         nodes_onto(*ret, tmp, *rel);
234 }
235
236 static int mpol_new_nodemask(struct mempolicy *pol, const nodemask_t *nodes)
237 {
238         if (nodes_empty(*nodes))
239                 return -EINVAL;
240         pol->nodes = *nodes;
241         return 0;
242 }
243
244 static int mpol_new_preferred(struct mempolicy *pol, const nodemask_t *nodes)
245 {
246         if (nodes_empty(*nodes))
247                 return -EINVAL;
248
249         nodes_clear(pol->nodes);
250         node_set(first_node(*nodes), pol->nodes);
251         return 0;
252 }
253
254 /*
255  * mpol_set_nodemask is called after mpol_new() to set up the nodemask, if
256  * any, for the new policy.  mpol_new() has already validated the nodes
257  * parameter with respect to the policy mode and flags.
258  *
259  * Must be called holding task's alloc_lock to protect task's mems_allowed
260  * and mempolicy.  May also be called holding the mmap_lock for write.
261  */
262 static int mpol_set_nodemask(struct mempolicy *pol,
263                      const nodemask_t *nodes, struct nodemask_scratch *nsc)
264 {
265         int ret;
266
267         /*
268          * Default (pol==NULL) resp. local memory policies are not a
269          * subject of any remapping. They also do not need any special
270          * constructor.
271          */
272         if (!pol || pol->mode == MPOL_LOCAL)
273                 return 0;
274
275         /* Check N_MEMORY */
276         nodes_and(nsc->mask1,
277                   cpuset_current_mems_allowed, node_states[N_MEMORY]);
278
279         VM_BUG_ON(!nodes);
280
281         if (pol->flags & MPOL_F_RELATIVE_NODES)
282                 mpol_relative_nodemask(&nsc->mask2, nodes, &nsc->mask1);
283         else
284                 nodes_and(nsc->mask2, *nodes, nsc->mask1);
285
286         if (mpol_store_user_nodemask(pol))
287                 pol->w.user_nodemask = *nodes;
288         else
289                 pol->w.cpuset_mems_allowed = cpuset_current_mems_allowed;
290
291         ret = mpol_ops[pol->mode].create(pol, &nsc->mask2);
292         return ret;
293 }
294
295 /*
296  * This function just creates a new policy, does some check and simple
297  * initialization. You must invoke mpol_set_nodemask() to set nodes.
298  */
299 static struct mempolicy *mpol_new(unsigned short mode, unsigned short flags,
300                                   nodemask_t *nodes)
301 {
302         struct mempolicy *policy;
303
304         if (mode == MPOL_DEFAULT) {
305                 if (nodes && !nodes_empty(*nodes))
306                         return ERR_PTR(-EINVAL);
307                 return NULL;
308         }
309         VM_BUG_ON(!nodes);
310
311         /*
312          * MPOL_PREFERRED cannot be used with MPOL_F_STATIC_NODES or
313          * MPOL_F_RELATIVE_NODES if the nodemask is empty (local allocation).
314          * All other modes require a valid pointer to a non-empty nodemask.
315          */
316         if (mode == MPOL_PREFERRED) {
317                 if (nodes_empty(*nodes)) {
318                         if (((flags & MPOL_F_STATIC_NODES) ||
319                              (flags & MPOL_F_RELATIVE_NODES)))
320                                 return ERR_PTR(-EINVAL);
321
322                         mode = MPOL_LOCAL;
323                 }
324         } else if (mode == MPOL_LOCAL) {
325                 if (!nodes_empty(*nodes) ||
326                     (flags & MPOL_F_STATIC_NODES) ||
327                     (flags & MPOL_F_RELATIVE_NODES))
328                         return ERR_PTR(-EINVAL);
329         } else if (nodes_empty(*nodes))
330                 return ERR_PTR(-EINVAL);
331
332         policy = kmem_cache_alloc(policy_cache, GFP_KERNEL);
333         if (!policy)
334                 return ERR_PTR(-ENOMEM);
335         atomic_set(&policy->refcnt, 1);
336         policy->mode = mode;
337         policy->flags = flags;
338         policy->home_node = NUMA_NO_NODE;
339
340         return policy;
341 }
342
343 /* Slow path of a mpol destructor. */
344 void __mpol_put(struct mempolicy *pol)
345 {
346         if (!atomic_dec_and_test(&pol->refcnt))
347                 return;
348         kmem_cache_free(policy_cache, pol);
349 }
350
351 static void mpol_rebind_default(struct mempolicy *pol, const nodemask_t *nodes)
352 {
353 }
354
355 static void mpol_rebind_nodemask(struct mempolicy *pol, const nodemask_t *nodes)
356 {
357         nodemask_t tmp;
358
359         if (pol->flags & MPOL_F_STATIC_NODES)
360                 nodes_and(tmp, pol->w.user_nodemask, *nodes);
361         else if (pol->flags & MPOL_F_RELATIVE_NODES)
362                 mpol_relative_nodemask(&tmp, &pol->w.user_nodemask, nodes);
363         else {
364                 nodes_remap(tmp, pol->nodes, pol->w.cpuset_mems_allowed,
365                                                                 *nodes);
366                 pol->w.cpuset_mems_allowed = *nodes;
367         }
368
369         if (nodes_empty(tmp))
370                 tmp = *nodes;
371
372         pol->nodes = tmp;
373 }
374
375 static void mpol_rebind_preferred(struct mempolicy *pol,
376                                                 const nodemask_t *nodes)
377 {
378         pol->w.cpuset_mems_allowed = *nodes;
379 }
380
381 /*
382  * mpol_rebind_policy - Migrate a policy to a different set of nodes
383  *
384  * Per-vma policies are protected by mmap_lock. Allocations using per-task
385  * policies are protected by task->mems_allowed_seq to prevent a premature
386  * OOM/allocation failure due to parallel nodemask modification.
387  */
388 static void mpol_rebind_policy(struct mempolicy *pol, const nodemask_t *newmask)
389 {
390         if (!pol || pol->mode == MPOL_LOCAL)
391                 return;
392         if (!mpol_store_user_nodemask(pol) &&
393             nodes_equal(pol->w.cpuset_mems_allowed, *newmask))
394                 return;
395
396         mpol_ops[pol->mode].rebind(pol, newmask);
397 }
398
399 /*
400  * Wrapper for mpol_rebind_policy() that just requires task
401  * pointer, and updates task mempolicy.
402  *
403  * Called with task's alloc_lock held.
404  */
405 void mpol_rebind_task(struct task_struct *tsk, const nodemask_t *new)
406 {
407         mpol_rebind_policy(tsk->mempolicy, new);
408 }
409
410 /*
411  * Rebind each vma in mm to new nodemask.
412  *
413  * Call holding a reference to mm.  Takes mm->mmap_lock during call.
414  */
415 void mpol_rebind_mm(struct mm_struct *mm, nodemask_t *new)
416 {
417         struct vm_area_struct *vma;
418         VMA_ITERATOR(vmi, mm, 0);
419
420         mmap_write_lock(mm);
421         for_each_vma(vmi, vma) {
422                 vma_start_write(vma);
423                 mpol_rebind_policy(vma->vm_policy, new);
424         }
425         mmap_write_unlock(mm);
426 }
427
428 static const struct mempolicy_operations mpol_ops[MPOL_MAX] = {
429         [MPOL_DEFAULT] = {
430                 .rebind = mpol_rebind_default,
431         },
432         [MPOL_INTERLEAVE] = {
433                 .create = mpol_new_nodemask,
434                 .rebind = mpol_rebind_nodemask,
435         },
436         [MPOL_PREFERRED] = {
437                 .create = mpol_new_preferred,
438                 .rebind = mpol_rebind_preferred,
439         },
440         [MPOL_BIND] = {
441                 .create = mpol_new_nodemask,
442                 .rebind = mpol_rebind_nodemask,
443         },
444         [MPOL_LOCAL] = {
445                 .rebind = mpol_rebind_default,
446         },
447         [MPOL_PREFERRED_MANY] = {
448                 .create = mpol_new_nodemask,
449                 .rebind = mpol_rebind_preferred,
450         },
451         [MPOL_WEIGHTED_INTERLEAVE] = {
452                 .create = mpol_new_nodemask,
453                 .rebind = mpol_rebind_nodemask,
454         },
455 };
456
457 static bool migrate_folio_add(struct folio *folio, struct list_head *foliolist,
458                                 unsigned long flags);
459 static nodemask_t *policy_nodemask(gfp_t gfp, struct mempolicy *pol,
460                                 pgoff_t ilx, int *nid);
461
462 static bool strictly_unmovable(unsigned long flags)
463 {
464         /*
465          * STRICT without MOVE flags lets do_mbind() fail immediately with -EIO
466          * if any misplaced page is found.
467          */
468         return (flags & (MPOL_MF_STRICT | MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) ==
469                          MPOL_MF_STRICT;
470 }
471
472 struct migration_mpol {         /* for alloc_migration_target_by_mpol() */
473         struct mempolicy *pol;
474         pgoff_t ilx;
475 };
476
477 struct queue_pages {
478         struct list_head *pagelist;
479         unsigned long flags;
480         nodemask_t *nmask;
481         unsigned long start;
482         unsigned long end;
483         struct vm_area_struct *first;
484         struct folio *large;            /* note last large folio encountered */
485         long nr_failed;                 /* could not be isolated at this time */
486 };
487
488 /*
489  * Check if the folio's nid is in qp->nmask.
490  *
491  * If MPOL_MF_INVERT is set in qp->flags, check if the nid is
492  * in the invert of qp->nmask.
493  */
494 static inline bool queue_folio_required(struct folio *folio,
495                                         struct queue_pages *qp)
496 {
497         int nid = folio_nid(folio);
498         unsigned long flags = qp->flags;
499
500         return node_isset(nid, *qp->nmask) == !(flags & MPOL_MF_INVERT);
501 }
502
503 static void queue_folios_pmd(pmd_t *pmd, struct mm_walk *walk)
504 {
505         struct folio *folio;
506         struct queue_pages *qp = walk->private;
507
508         if (unlikely(is_pmd_migration_entry(*pmd))) {
509                 qp->nr_failed++;
510                 return;
511         }
512         folio = pfn_folio(pmd_pfn(*pmd));
513         if (is_huge_zero_page(&folio->page)) {
514                 walk->action = ACTION_CONTINUE;
515                 return;
516         }
517         if (!queue_folio_required(folio, qp))
518                 return;
519         if (!(qp->flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) ||
520             !vma_migratable(walk->vma) ||
521             !migrate_folio_add(folio, qp->pagelist, qp->flags))
522                 qp->nr_failed++;
523 }
524
525 /*
526  * Scan through folios, checking if they satisfy the required conditions,
527  * moving them from LRU to local pagelist for migration if they do (or not).
528  *
529  * queue_folios_pte_range() has two possible return values:
530  * 0 - continue walking to scan for more, even if an existing folio on the
531  *     wrong node could not be isolated and queued for migration.
532  * -EIO - only MPOL_MF_STRICT was specified, without MPOL_MF_MOVE or ..._ALL,
533  *        and an existing folio was on a node that does not follow the policy.
534  */
535 static int queue_folios_pte_range(pmd_t *pmd, unsigned long addr,
536                         unsigned long end, struct mm_walk *walk)
537 {
538         struct vm_area_struct *vma = walk->vma;
539         struct folio *folio;
540         struct queue_pages *qp = walk->private;
541         unsigned long flags = qp->flags;
542         pte_t *pte, *mapped_pte;
543         pte_t ptent;
544         spinlock_t *ptl;
545
546         ptl = pmd_trans_huge_lock(pmd, vma);
547         if (ptl) {
548                 queue_folios_pmd(pmd, walk);
549                 spin_unlock(ptl);
550                 goto out;
551         }
552
553         mapped_pte = pte = pte_offset_map_lock(walk->mm, pmd, addr, &ptl);
554         if (!pte) {
555                 walk->action = ACTION_AGAIN;
556                 return 0;
557         }
558         for (; addr != end; pte++, addr += PAGE_SIZE) {
559                 ptent = ptep_get(pte);
560                 if (pte_none(ptent))
561                         continue;
562                 if (!pte_present(ptent)) {
563                         if (is_migration_entry(pte_to_swp_entry(ptent)))
564                                 qp->nr_failed++;
565                         continue;
566                 }
567                 folio = vm_normal_folio(vma, addr, ptent);
568                 if (!folio || folio_is_zone_device(folio))
569                         continue;
570                 /*
571                  * vm_normal_folio() filters out zero pages, but there might
572                  * still be reserved folios to skip, perhaps in a VDSO.
573                  */
574                 if (folio_test_reserved(folio))
575                         continue;
576                 if (!queue_folio_required(folio, qp))
577                         continue;
578                 if (folio_test_large(folio)) {
579                         /*
580                          * A large folio can only be isolated from LRU once,
581                          * but may be mapped by many PTEs (and Copy-On-Write may
582                          * intersperse PTEs of other, order 0, folios).  This is
583                          * a common case, so don't mistake it for failure (but
584                          * there can be other cases of multi-mapped pages which
585                          * this quick check does not help to filter out - and a
586                          * search of the pagelist might grow to be prohibitive).
587                          *
588                          * migrate_pages(&pagelist) returns nr_failed folios, so
589                          * check "large" now so that queue_pages_range() returns
590                          * a comparable nr_failed folios.  This does imply that
591                          * if folio could not be isolated for some racy reason
592                          * at its first PTE, later PTEs will not give it another
593                          * chance of isolation; but keeps the accounting simple.
594                          */
595                         if (folio == qp->large)
596                                 continue;
597                         qp->large = folio;
598                 }
599                 if (!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) ||
600                     !vma_migratable(vma) ||
601                     !migrate_folio_add(folio, qp->pagelist, flags)) {
602                         qp->nr_failed++;
603                         if (strictly_unmovable(flags))
604                                 break;
605                 }
606         }
607         pte_unmap_unlock(mapped_pte, ptl);
608         cond_resched();
609 out:
610         if (qp->nr_failed && strictly_unmovable(flags))
611                 return -EIO;
612         return 0;
613 }
614
615 static int queue_folios_hugetlb(pte_t *pte, unsigned long hmask,
616                                unsigned long addr, unsigned long end,
617                                struct mm_walk *walk)
618 {
619 #ifdef CONFIG_HUGETLB_PAGE
620         struct queue_pages *qp = walk->private;
621         unsigned long flags = qp->flags;
622         struct folio *folio;
623         spinlock_t *ptl;
624         pte_t entry;
625
626         ptl = huge_pte_lock(hstate_vma(walk->vma), walk->mm, pte);
627         entry = huge_ptep_get(pte);
628         if (!pte_present(entry)) {
629                 if (unlikely(is_hugetlb_entry_migration(entry)))
630                         qp->nr_failed++;
631                 goto unlock;
632         }
633         folio = pfn_folio(pte_pfn(entry));
634         if (!queue_folio_required(folio, qp))
635                 goto unlock;
636         if (!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) ||
637             !vma_migratable(walk->vma)) {
638                 qp->nr_failed++;
639                 goto unlock;
640         }
641         /*
642          * Unless MPOL_MF_MOVE_ALL, we try to avoid migrating a shared folio.
643          * Choosing not to migrate a shared folio is not counted as a failure.
644          *
645          * To check if the folio is shared, ideally we want to make sure
646          * every page is mapped to the same process. Doing that is very
647          * expensive, so check the estimated sharers of the folio instead.
648          */
649         if ((flags & MPOL_MF_MOVE_ALL) ||
650             (folio_estimated_sharers(folio) == 1 && !hugetlb_pmd_shared(pte)))
651                 if (!isolate_hugetlb(folio, qp->pagelist))
652                         qp->nr_failed++;
653 unlock:
654         spin_unlock(ptl);
655         if (qp->nr_failed && strictly_unmovable(flags))
656                 return -EIO;
657 #endif
658         return 0;
659 }
660
661 #ifdef CONFIG_NUMA_BALANCING
662 /*
663  * This is used to mark a range of virtual addresses to be inaccessible.
664  * These are later cleared by a NUMA hinting fault. Depending on these
665  * faults, pages may be migrated for better NUMA placement.
666  *
667  * This is assuming that NUMA faults are handled using PROT_NONE. If
668  * an architecture makes a different choice, it will need further
669  * changes to the core.
670  */
671 unsigned long change_prot_numa(struct vm_area_struct *vma,
672                         unsigned long addr, unsigned long end)
673 {
674         struct mmu_gather tlb;
675         long nr_updated;
676
677         tlb_gather_mmu(&tlb, vma->vm_mm);
678
679         nr_updated = change_protection(&tlb, vma, addr, end, MM_CP_PROT_NUMA);
680         if (nr_updated > 0)
681                 count_vm_numa_events(NUMA_PTE_UPDATES, nr_updated);
682
683         tlb_finish_mmu(&tlb);
684
685         return nr_updated;
686 }
687 #endif /* CONFIG_NUMA_BALANCING */
688
689 static int queue_pages_test_walk(unsigned long start, unsigned long end,
690                                 struct mm_walk *walk)
691 {
692         struct vm_area_struct *next, *vma = walk->vma;
693         struct queue_pages *qp = walk->private;
694         unsigned long flags = qp->flags;
695
696         /* range check first */
697         VM_BUG_ON_VMA(!range_in_vma(vma, start, end), vma);
698
699         if (!qp->first) {
700                 qp->first = vma;
701                 if (!(flags & MPOL_MF_DISCONTIG_OK) &&
702                         (qp->start < vma->vm_start))
703                         /* hole at head side of range */
704                         return -EFAULT;
705         }
706         next = find_vma(vma->vm_mm, vma->vm_end);
707         if (!(flags & MPOL_MF_DISCONTIG_OK) &&
708                 ((vma->vm_end < qp->end) &&
709                 (!next || vma->vm_end < next->vm_start)))
710                 /* hole at middle or tail of range */
711                 return -EFAULT;
712
713         /*
714          * Need check MPOL_MF_STRICT to return -EIO if possible
715          * regardless of vma_migratable
716          */
717         if (!vma_migratable(vma) &&
718             !(flags & MPOL_MF_STRICT))
719                 return 1;
720
721         /*
722          * Check page nodes, and queue pages to move, in the current vma.
723          * But if no moving, and no strict checking, the scan can be skipped.
724          */
725         if (flags & (MPOL_MF_STRICT | MPOL_MF_MOVE | MPOL_MF_MOVE_ALL))
726                 return 0;
727         return 1;
728 }
729
730 static const struct mm_walk_ops queue_pages_walk_ops = {
731         .hugetlb_entry          = queue_folios_hugetlb,
732         .pmd_entry              = queue_folios_pte_range,
733         .test_walk              = queue_pages_test_walk,
734         .walk_lock              = PGWALK_RDLOCK,
735 };
736
737 static const struct mm_walk_ops queue_pages_lock_vma_walk_ops = {
738         .hugetlb_entry          = queue_folios_hugetlb,
739         .pmd_entry              = queue_folios_pte_range,
740         .test_walk              = queue_pages_test_walk,
741         .walk_lock              = PGWALK_WRLOCK,
742 };
743
744 /*
745  * Walk through page tables and collect pages to be migrated.
746  *
747  * If pages found in a given range are not on the required set of @nodes,
748  * and migration is allowed, they are isolated and queued to @pagelist.
749  *
750  * queue_pages_range() may return:
751  * 0 - all pages already on the right node, or successfully queued for moving
752  *     (or neither strict checking nor moving requested: only range checking).
753  * >0 - this number of misplaced folios could not be queued for moving
754  *      (a hugetlbfs page or a transparent huge page being counted as 1).
755  * -EIO - a misplaced page found, when MPOL_MF_STRICT specified without MOVEs.
756  * -EFAULT - a hole in the memory range, when MPOL_MF_DISCONTIG_OK unspecified.
757  */
758 static long
759 queue_pages_range(struct mm_struct *mm, unsigned long start, unsigned long end,
760                 nodemask_t *nodes, unsigned long flags,
761                 struct list_head *pagelist)
762 {
763         int err;
764         struct queue_pages qp = {
765                 .pagelist = pagelist,
766                 .flags = flags,
767                 .nmask = nodes,
768                 .start = start,
769                 .end = end,
770                 .first = NULL,
771         };
772         const struct mm_walk_ops *ops = (flags & MPOL_MF_WRLOCK) ?
773                         &queue_pages_lock_vma_walk_ops : &queue_pages_walk_ops;
774
775         err = walk_page_range(mm, start, end, ops, &qp);
776
777         if (!qp.first)
778                 /* whole range in hole */
779                 err = -EFAULT;
780
781         return err ? : qp.nr_failed;
782 }
783
784 /*
785  * Apply policy to a single VMA
786  * This must be called with the mmap_lock held for writing.
787  */
788 static int vma_replace_policy(struct vm_area_struct *vma,
789                                 struct mempolicy *pol)
790 {
791         int err;
792         struct mempolicy *old;
793         struct mempolicy *new;
794
795         vma_assert_write_locked(vma);
796
797         new = mpol_dup(pol);
798         if (IS_ERR(new))
799                 return PTR_ERR(new);
800
801         if (vma->vm_ops && vma->vm_ops->set_policy) {
802                 err = vma->vm_ops->set_policy(vma, new);
803                 if (err)
804                         goto err_out;
805         }
806
807         old = vma->vm_policy;
808         vma->vm_policy = new; /* protected by mmap_lock */
809         mpol_put(old);
810
811         return 0;
812  err_out:
813         mpol_put(new);
814         return err;
815 }
816
817 /* Split or merge the VMA (if required) and apply the new policy */
818 static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
819                 struct vm_area_struct **prev, unsigned long start,
820                 unsigned long end, struct mempolicy *new_pol)
821 {
822         unsigned long vmstart, vmend;
823
824         vmend = min(end, vma->vm_end);
825         if (start > vma->vm_start) {
826                 *prev = vma;
827                 vmstart = start;
828         } else {
829                 vmstart = vma->vm_start;
830         }
831
832         if (mpol_equal(vma->vm_policy, new_pol)) {
833                 *prev = vma;
834                 return 0;
835         }
836
837         vma =  vma_modify_policy(vmi, *prev, vma, vmstart, vmend, new_pol);
838         if (IS_ERR(vma))
839                 return PTR_ERR(vma);
840
841         *prev = vma;
842         return vma_replace_policy(vma, new_pol);
843 }
844
845 /* Set the process memory policy */
846 static long do_set_mempolicy(unsigned short mode, unsigned short flags,
847                              nodemask_t *nodes)
848 {
849         struct mempolicy *new, *old;
850         NODEMASK_SCRATCH(scratch);
851         int ret;
852
853         if (!scratch)
854                 return -ENOMEM;
855
856         new = mpol_new(mode, flags, nodes);
857         if (IS_ERR(new)) {
858                 ret = PTR_ERR(new);
859                 goto out;
860         }
861
862         task_lock(current);
863         ret = mpol_set_nodemask(new, nodes, scratch);
864         if (ret) {
865                 task_unlock(current);
866                 mpol_put(new);
867                 goto out;
868         }
869
870         old = current->mempolicy;
871         current->mempolicy = new;
872         if (new && (new->mode == MPOL_INTERLEAVE ||
873                     new->mode == MPOL_WEIGHTED_INTERLEAVE)) {
874                 current->il_prev = MAX_NUMNODES-1;
875                 current->il_weight = 0;
876         }
877         task_unlock(current);
878         mpol_put(old);
879         ret = 0;
880 out:
881         NODEMASK_SCRATCH_FREE(scratch);
882         return ret;
883 }
884
885 /*
886  * Return nodemask for policy for get_mempolicy() query
887  *
888  * Called with task's alloc_lock held
889  */
890 static void get_policy_nodemask(struct mempolicy *pol, nodemask_t *nodes)
891 {
892         nodes_clear(*nodes);
893         if (pol == &default_policy)
894                 return;
895
896         switch (pol->mode) {
897         case MPOL_BIND:
898         case MPOL_INTERLEAVE:
899         case MPOL_PREFERRED:
900         case MPOL_PREFERRED_MANY:
901         case MPOL_WEIGHTED_INTERLEAVE:
902                 *nodes = pol->nodes;
903                 break;
904         case MPOL_LOCAL:
905                 /* return empty node mask for local allocation */
906                 break;
907         default:
908                 BUG();
909         }
910 }
911
912 static int lookup_node(struct mm_struct *mm, unsigned long addr)
913 {
914         struct page *p = NULL;
915         int ret;
916
917         ret = get_user_pages_fast(addr & PAGE_MASK, 1, 0, &p);
918         if (ret > 0) {
919                 ret = page_to_nid(p);
920                 put_page(p);
921         }
922         return ret;
923 }
924
925 /* Retrieve NUMA policy */
926 static long do_get_mempolicy(int *policy, nodemask_t *nmask,
927                              unsigned long addr, unsigned long flags)
928 {
929         int err;
930         struct mm_struct *mm = current->mm;
931         struct vm_area_struct *vma = NULL;
932         struct mempolicy *pol = current->mempolicy, *pol_refcount = NULL;
933
934         if (flags &
935                 ~(unsigned long)(MPOL_F_NODE|MPOL_F_ADDR|MPOL_F_MEMS_ALLOWED))
936                 return -EINVAL;
937
938         if (flags & MPOL_F_MEMS_ALLOWED) {
939                 if (flags & (MPOL_F_NODE|MPOL_F_ADDR))
940                         return -EINVAL;
941                 *policy = 0;    /* just so it's initialized */
942                 task_lock(current);
943                 *nmask  = cpuset_current_mems_allowed;
944                 task_unlock(current);
945                 return 0;
946         }
947
948         if (flags & MPOL_F_ADDR) {
949                 pgoff_t ilx;            /* ignored here */
950                 /*
951                  * Do NOT fall back to task policy if the
952                  * vma/shared policy at addr is NULL.  We
953                  * want to return MPOL_DEFAULT in this case.
954                  */
955                 mmap_read_lock(mm);
956                 vma = vma_lookup(mm, addr);
957                 if (!vma) {
958                         mmap_read_unlock(mm);
959                         return -EFAULT;
960                 }
961                 pol = __get_vma_policy(vma, addr, &ilx);
962         } else if (addr)
963                 return -EINVAL;
964
965         if (!pol)
966                 pol = &default_policy;  /* indicates default behavior */
967
968         if (flags & MPOL_F_NODE) {
969                 if (flags & MPOL_F_ADDR) {
970                         /*
971                          * Take a refcount on the mpol, because we are about to
972                          * drop the mmap_lock, after which only "pol" remains
973                          * valid, "vma" is stale.
974                          */
975                         pol_refcount = pol;
976                         vma = NULL;
977                         mpol_get(pol);
978                         mmap_read_unlock(mm);
979                         err = lookup_node(mm, addr);
980                         if (err < 0)
981                                 goto out;
982                         *policy = err;
983                 } else if (pol == current->mempolicy &&
984                                 pol->mode == MPOL_INTERLEAVE) {
985                         *policy = next_node_in(current->il_prev, pol->nodes);
986                 } else if (pol == current->mempolicy &&
987                                 pol->mode == MPOL_WEIGHTED_INTERLEAVE) {
988                         if (current->il_weight)
989                                 *policy = current->il_prev;
990                         else
991                                 *policy = next_node_in(current->il_prev,
992                                                        pol->nodes);
993                 } else {
994                         err = -EINVAL;
995                         goto out;
996                 }
997         } else {
998                 *policy = pol == &default_policy ? MPOL_DEFAULT :
999                                                 pol->mode;
1000                 /*
1001                  * Internal mempolicy flags must be masked off before exposing
1002                  * the policy to userspace.
1003                  */
1004                 *policy |= (pol->flags & MPOL_MODE_FLAGS);
1005         }
1006
1007         err = 0;
1008         if (nmask) {
1009                 if (mpol_store_user_nodemask(pol)) {
1010                         *nmask = pol->w.user_nodemask;
1011                 } else {
1012                         task_lock(current);
1013                         get_policy_nodemask(pol, nmask);
1014                         task_unlock(current);
1015                 }
1016         }
1017
1018  out:
1019         mpol_cond_put(pol);
1020         if (vma)
1021                 mmap_read_unlock(mm);
1022         if (pol_refcount)
1023                 mpol_put(pol_refcount);
1024         return err;
1025 }
1026
1027 #ifdef CONFIG_MIGRATION
1028 static bool migrate_folio_add(struct folio *folio, struct list_head *foliolist,
1029                                 unsigned long flags)
1030 {
1031         /*
1032          * Unless MPOL_MF_MOVE_ALL, we try to avoid migrating a shared folio.
1033          * Choosing not to migrate a shared folio is not counted as a failure.
1034          *
1035          * To check if the folio is shared, ideally we want to make sure
1036          * every page is mapped to the same process. Doing that is very
1037          * expensive, so check the estimated sharers of the folio instead.
1038          */
1039         if ((flags & MPOL_MF_MOVE_ALL) || folio_estimated_sharers(folio) == 1) {
1040                 if (folio_isolate_lru(folio)) {
1041                         list_add_tail(&folio->lru, foliolist);
1042                         node_stat_mod_folio(folio,
1043                                 NR_ISOLATED_ANON + folio_is_file_lru(folio),
1044                                 folio_nr_pages(folio));
1045                 } else {
1046                         /*
1047                          * Non-movable folio may reach here.  And, there may be
1048                          * temporary off LRU folios or non-LRU movable folios.
1049                          * Treat them as unmovable folios since they can't be
1050                          * isolated, so they can't be moved at the moment.
1051                          */
1052                         return false;
1053                 }
1054         }
1055         return true;
1056 }
1057
1058 /*
1059  * Migrate pages from one node to a target node.
1060  * Returns error or the number of pages not migrated.
1061  */
1062 static long migrate_to_node(struct mm_struct *mm, int source, int dest,
1063                             int flags)
1064 {
1065         nodemask_t nmask;
1066         struct vm_area_struct *vma;
1067         LIST_HEAD(pagelist);
1068         long nr_failed;
1069         long err = 0;
1070         struct migration_target_control mtc = {
1071                 .nid = dest,
1072                 .gfp_mask = GFP_HIGHUSER_MOVABLE | __GFP_THISNODE,
1073         };
1074
1075         nodes_clear(nmask);
1076         node_set(source, nmask);
1077
1078         VM_BUG_ON(!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)));
1079
1080         mmap_read_lock(mm);
1081         vma = find_vma(mm, 0);
1082
1083         /*
1084          * This does not migrate the range, but isolates all pages that
1085          * need migration.  Between passing in the full user address
1086          * space range and MPOL_MF_DISCONTIG_OK, this call cannot fail,
1087          * but passes back the count of pages which could not be isolated.
1088          */
1089         nr_failed = queue_pages_range(mm, vma->vm_start, mm->task_size, &nmask,
1090                                       flags | MPOL_MF_DISCONTIG_OK, &pagelist);
1091         mmap_read_unlock(mm);
1092
1093         if (!list_empty(&pagelist)) {
1094                 err = migrate_pages(&pagelist, alloc_migration_target, NULL,
1095                         (unsigned long)&mtc, MIGRATE_SYNC, MR_SYSCALL, NULL);
1096                 if (err)
1097                         putback_movable_pages(&pagelist);
1098         }
1099
1100         if (err >= 0)
1101                 err += nr_failed;
1102         return err;
1103 }
1104
1105 /*
1106  * Move pages between the two nodesets so as to preserve the physical
1107  * layout as much as possible.
1108  *
1109  * Returns the number of page that could not be moved.
1110  */
1111 int do_migrate_pages(struct mm_struct *mm, const nodemask_t *from,
1112                      const nodemask_t *to, int flags)
1113 {
1114         long nr_failed = 0;
1115         long err = 0;
1116         nodemask_t tmp;
1117
1118         lru_cache_disable();
1119
1120         /*
1121          * Find a 'source' bit set in 'tmp' whose corresponding 'dest'
1122          * bit in 'to' is not also set in 'tmp'.  Clear the found 'source'
1123          * bit in 'tmp', and return that <source, dest> pair for migration.
1124          * The pair of nodemasks 'to' and 'from' define the map.
1125          *
1126          * If no pair of bits is found that way, fallback to picking some
1127          * pair of 'source' and 'dest' bits that are not the same.  If the
1128          * 'source' and 'dest' bits are the same, this represents a node
1129          * that will be migrating to itself, so no pages need move.
1130          *
1131          * If no bits are left in 'tmp', or if all remaining bits left
1132          * in 'tmp' correspond to the same bit in 'to', return false
1133          * (nothing left to migrate).
1134          *
1135          * This lets us pick a pair of nodes to migrate between, such that
1136          * if possible the dest node is not already occupied by some other
1137          * source node, minimizing the risk of overloading the memory on a
1138          * node that would happen if we migrated incoming memory to a node
1139          * before migrating outgoing memory source that same node.
1140          *
1141          * A single scan of tmp is sufficient.  As we go, we remember the
1142          * most recent <s, d> pair that moved (s != d).  If we find a pair
1143          * that not only moved, but what's better, moved to an empty slot
1144          * (d is not set in tmp), then we break out then, with that pair.
1145          * Otherwise when we finish scanning from_tmp, we at least have the
1146          * most recent <s, d> pair that moved.  If we get all the way through
1147          * the scan of tmp without finding any node that moved, much less
1148          * moved to an empty node, then there is nothing left worth migrating.
1149          */
1150
1151         tmp = *from;
1152         while (!nodes_empty(tmp)) {
1153                 int s, d;
1154                 int source = NUMA_NO_NODE;
1155                 int dest = 0;
1156
1157                 for_each_node_mask(s, tmp) {
1158
1159                         /*
1160                          * do_migrate_pages() tries to maintain the relative
1161                          * node relationship of the pages established between
1162                          * threads and memory areas.
1163                          *
1164                          * However if the number of source nodes is not equal to
1165                          * the number of destination nodes we can not preserve
1166                          * this node relative relationship.  In that case, skip
1167                          * copying memory from a node that is in the destination
1168                          * mask.
1169                          *
1170                          * Example: [2,3,4] -> [3,4,5] moves everything.
1171                          *          [0-7] - > [3,4,5] moves only 0,1,2,6,7.
1172                          */
1173
1174                         if ((nodes_weight(*from) != nodes_weight(*to)) &&
1175                                                 (node_isset(s, *to)))
1176                                 continue;
1177
1178                         d = node_remap(s, *from, *to);
1179                         if (s == d)
1180                                 continue;
1181
1182                         source = s;     /* Node moved. Memorize */
1183                         dest = d;
1184
1185                         /* dest not in remaining from nodes? */
1186                         if (!node_isset(dest, tmp))
1187                                 break;
1188                 }
1189                 if (source == NUMA_NO_NODE)
1190                         break;
1191
1192                 node_clear(source, tmp);
1193                 err = migrate_to_node(mm, source, dest, flags);
1194                 if (err > 0)
1195                         nr_failed += err;
1196                 if (err < 0)
1197                         break;
1198         }
1199
1200         lru_cache_enable();
1201         if (err < 0)
1202                 return err;
1203         return (nr_failed < INT_MAX) ? nr_failed : INT_MAX;
1204 }
1205
1206 /*
1207  * Allocate a new folio for page migration, according to NUMA mempolicy.
1208  */
1209 static struct folio *alloc_migration_target_by_mpol(struct folio *src,
1210                                                     unsigned long private)
1211 {
1212         struct migration_mpol *mmpol = (struct migration_mpol *)private;
1213         struct mempolicy *pol = mmpol->pol;
1214         pgoff_t ilx = mmpol->ilx;
1215         struct page *page;
1216         unsigned int order;
1217         int nid = numa_node_id();
1218         gfp_t gfp;
1219
1220         order = folio_order(src);
1221         ilx += src->index >> order;
1222
1223         if (folio_test_hugetlb(src)) {
1224                 nodemask_t *nodemask;
1225                 struct hstate *h;
1226
1227                 h = folio_hstate(src);
1228                 gfp = htlb_alloc_mask(h);
1229                 nodemask = policy_nodemask(gfp, pol, ilx, &nid);
1230                 return alloc_hugetlb_folio_nodemask(h, nid, nodemask, gfp);
1231         }
1232
1233         if (folio_test_large(src))
1234                 gfp = GFP_TRANSHUGE;
1235         else
1236                 gfp = GFP_HIGHUSER_MOVABLE | __GFP_RETRY_MAYFAIL | __GFP_COMP;
1237
1238         page = alloc_pages_mpol(gfp, order, pol, ilx, nid);
1239         return page_rmappable_folio(page);
1240 }
1241 #else
1242
1243 static bool migrate_folio_add(struct folio *folio, struct list_head *foliolist,
1244                                 unsigned long flags)
1245 {
1246         return false;
1247 }
1248
1249 int do_migrate_pages(struct mm_struct *mm, const nodemask_t *from,
1250                      const nodemask_t *to, int flags)
1251 {
1252         return -ENOSYS;
1253 }
1254
1255 static struct folio *alloc_migration_target_by_mpol(struct folio *src,
1256                                                     unsigned long private)
1257 {
1258         return NULL;
1259 }
1260 #endif
1261
1262 static long do_mbind(unsigned long start, unsigned long len,
1263                      unsigned short mode, unsigned short mode_flags,
1264                      nodemask_t *nmask, unsigned long flags)
1265 {
1266         struct mm_struct *mm = current->mm;
1267         struct vm_area_struct *vma, *prev;
1268         struct vma_iterator vmi;
1269         struct migration_mpol mmpol;
1270         struct mempolicy *new;
1271         unsigned long end;
1272         long err;
1273         long nr_failed;
1274         LIST_HEAD(pagelist);
1275
1276         if (flags & ~(unsigned long)MPOL_MF_VALID)
1277                 return -EINVAL;
1278         if ((flags & MPOL_MF_MOVE_ALL) && !capable(CAP_SYS_NICE))
1279                 return -EPERM;
1280
1281         if (start & ~PAGE_MASK)
1282                 return -EINVAL;
1283
1284         if (mode == MPOL_DEFAULT)
1285                 flags &= ~MPOL_MF_STRICT;
1286
1287         len = PAGE_ALIGN(len);
1288         end = start + len;
1289
1290         if (end < start)
1291                 return -EINVAL;
1292         if (end == start)
1293                 return 0;
1294
1295         new = mpol_new(mode, mode_flags, nmask);
1296         if (IS_ERR(new))
1297                 return PTR_ERR(new);
1298
1299         /*
1300          * If we are using the default policy then operation
1301          * on discontinuous address spaces is okay after all
1302          */
1303         if (!new)
1304                 flags |= MPOL_MF_DISCONTIG_OK;
1305
1306         if (flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL))
1307                 lru_cache_disable();
1308         {
1309                 NODEMASK_SCRATCH(scratch);
1310                 if (scratch) {
1311                         mmap_write_lock(mm);
1312                         err = mpol_set_nodemask(new, nmask, scratch);
1313                         if (err)
1314                                 mmap_write_unlock(mm);
1315                 } else
1316                         err = -ENOMEM;
1317                 NODEMASK_SCRATCH_FREE(scratch);
1318         }
1319         if (err)
1320                 goto mpol_out;
1321
1322         /*
1323          * Lock the VMAs before scanning for pages to migrate,
1324          * to ensure we don't miss a concurrently inserted page.
1325          */
1326         nr_failed = queue_pages_range(mm, start, end, nmask,
1327                         flags | MPOL_MF_INVERT | MPOL_MF_WRLOCK, &pagelist);
1328
1329         if (nr_failed < 0) {
1330                 err = nr_failed;
1331                 nr_failed = 0;
1332         } else {
1333                 vma_iter_init(&vmi, mm, start);
1334                 prev = vma_prev(&vmi);
1335                 for_each_vma_range(vmi, vma, end) {
1336                         err = mbind_range(&vmi, vma, &prev, start, end, new);
1337                         if (err)
1338                                 break;
1339                 }
1340         }
1341
1342         if (!err && !list_empty(&pagelist)) {
1343                 /* Convert MPOL_DEFAULT's NULL to task or default policy */
1344                 if (!new) {
1345                         new = get_task_policy(current);
1346                         mpol_get(new);
1347                 }
1348                 mmpol.pol = new;
1349                 mmpol.ilx = 0;
1350
1351                 /*
1352                  * In the interleaved case, attempt to allocate on exactly the
1353                  * targeted nodes, for the first VMA to be migrated; for later
1354                  * VMAs, the nodes will still be interleaved from the targeted
1355                  * nodemask, but one by one may be selected differently.
1356                  */
1357                 if (new->mode == MPOL_INTERLEAVE ||
1358                     new->mode == MPOL_WEIGHTED_INTERLEAVE) {
1359                         struct page *page;
1360                         unsigned int order;
1361                         unsigned long addr = -EFAULT;
1362
1363                         list_for_each_entry(page, &pagelist, lru) {
1364                                 if (!PageKsm(page))
1365                                         break;
1366                         }
1367                         if (!list_entry_is_head(page, &pagelist, lru)) {
1368                                 vma_iter_init(&vmi, mm, start);
1369                                 for_each_vma_range(vmi, vma, end) {
1370                                         addr = page_address_in_vma(page, vma);
1371                                         if (addr != -EFAULT)
1372                                                 break;
1373                                 }
1374                         }
1375                         if (addr != -EFAULT) {
1376                                 order = compound_order(page);
1377                                 /* We already know the pol, but not the ilx */
1378                                 mpol_cond_put(get_vma_policy(vma, addr, order,
1379                                                              &mmpol.ilx));
1380                                 /* Set base from which to increment by index */
1381                                 mmpol.ilx -= page->index >> order;
1382                         }
1383                 }
1384         }
1385
1386         mmap_write_unlock(mm);
1387
1388         if (!err && !list_empty(&pagelist)) {
1389                 nr_failed |= migrate_pages(&pagelist,
1390                                 alloc_migration_target_by_mpol, NULL,
1391                                 (unsigned long)&mmpol, MIGRATE_SYNC,
1392                                 MR_MEMPOLICY_MBIND, NULL);
1393         }
1394
1395         if (nr_failed && (flags & MPOL_MF_STRICT))
1396                 err = -EIO;
1397         if (!list_empty(&pagelist))
1398                 putback_movable_pages(&pagelist);
1399 mpol_out:
1400         mpol_put(new);
1401         if (flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL))
1402                 lru_cache_enable();
1403         return err;
1404 }
1405
1406 /*
1407  * User space interface with variable sized bitmaps for nodelists.
1408  */
1409 static int get_bitmap(unsigned long *mask, const unsigned long __user *nmask,
1410                       unsigned long maxnode)
1411 {
1412         unsigned long nlongs = BITS_TO_LONGS(maxnode);
1413         int ret;
1414
1415         if (in_compat_syscall())
1416                 ret = compat_get_bitmap(mask,
1417                                         (const compat_ulong_t __user *)nmask,
1418                                         maxnode);
1419         else
1420                 ret = copy_from_user(mask, nmask,
1421                                      nlongs * sizeof(unsigned long));
1422
1423         if (ret)
1424                 return -EFAULT;
1425
1426         if (maxnode % BITS_PER_LONG)
1427                 mask[nlongs - 1] &= (1UL << (maxnode % BITS_PER_LONG)) - 1;
1428
1429         return 0;
1430 }
1431
1432 /* Copy a node mask from user space. */
1433 static int get_nodes(nodemask_t *nodes, const unsigned long __user *nmask,
1434                      unsigned long maxnode)
1435 {
1436         --maxnode;
1437         nodes_clear(*nodes);
1438         if (maxnode == 0 || !nmask)
1439                 return 0;
1440         if (maxnode > PAGE_SIZE*BITS_PER_BYTE)
1441                 return -EINVAL;
1442
1443         /*
1444          * When the user specified more nodes than supported just check
1445          * if the non supported part is all zero, one word at a time,
1446          * starting at the end.
1447          */
1448         while (maxnode > MAX_NUMNODES) {
1449                 unsigned long bits = min_t(unsigned long, maxnode, BITS_PER_LONG);
1450                 unsigned long t;
1451
1452                 if (get_bitmap(&t, &nmask[(maxnode - 1) / BITS_PER_LONG], bits))
1453                         return -EFAULT;
1454
1455                 if (maxnode - bits >= MAX_NUMNODES) {
1456                         maxnode -= bits;
1457                 } else {
1458                         maxnode = MAX_NUMNODES;
1459                         t &= ~((1UL << (MAX_NUMNODES % BITS_PER_LONG)) - 1);
1460                 }
1461                 if (t)
1462                         return -EINVAL;
1463         }
1464
1465         return get_bitmap(nodes_addr(*nodes), nmask, maxnode);
1466 }
1467
1468 /* Copy a kernel node mask to user space */
1469 static int copy_nodes_to_user(unsigned long __user *mask, unsigned long maxnode,
1470                               nodemask_t *nodes)
1471 {
1472         unsigned long copy = ALIGN(maxnode-1, 64) / 8;
1473         unsigned int nbytes = BITS_TO_LONGS(nr_node_ids) * sizeof(long);
1474         bool compat = in_compat_syscall();
1475
1476         if (compat)
1477                 nbytes = BITS_TO_COMPAT_LONGS(nr_node_ids) * sizeof(compat_long_t);
1478
1479         if (copy > nbytes) {
1480                 if (copy > PAGE_SIZE)
1481                         return -EINVAL;
1482                 if (clear_user((char __user *)mask + nbytes, copy - nbytes))
1483                         return -EFAULT;
1484                 copy = nbytes;
1485                 maxnode = nr_node_ids;
1486         }
1487
1488         if (compat)
1489                 return compat_put_bitmap((compat_ulong_t __user *)mask,
1490                                          nodes_addr(*nodes), maxnode);
1491
1492         return copy_to_user(mask, nodes_addr(*nodes), copy) ? -EFAULT : 0;
1493 }
1494
1495 /* Basic parameter sanity check used by both mbind() and set_mempolicy() */
1496 static inline int sanitize_mpol_flags(int *mode, unsigned short *flags)
1497 {
1498         *flags = *mode & MPOL_MODE_FLAGS;
1499         *mode &= ~MPOL_MODE_FLAGS;
1500
1501         if ((unsigned int)(*mode) >=  MPOL_MAX)
1502                 return -EINVAL;
1503         if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES))
1504                 return -EINVAL;
1505         if (*flags & MPOL_F_NUMA_BALANCING) {
1506                 if (*mode != MPOL_BIND)
1507                         return -EINVAL;
1508                 *flags |= (MPOL_F_MOF | MPOL_F_MORON);
1509         }
1510         return 0;
1511 }
1512
1513 static long kernel_mbind(unsigned long start, unsigned long len,
1514                          unsigned long mode, const unsigned long __user *nmask,
1515                          unsigned long maxnode, unsigned int flags)
1516 {
1517         unsigned short mode_flags;
1518         nodemask_t nodes;
1519         int lmode = mode;
1520         int err;
1521
1522         start = untagged_addr(start);
1523         err = sanitize_mpol_flags(&lmode, &mode_flags);
1524         if (err)
1525                 return err;
1526
1527         err = get_nodes(&nodes, nmask, maxnode);
1528         if (err)
1529                 return err;
1530
1531         return do_mbind(start, len, lmode, mode_flags, &nodes, flags);
1532 }
1533
1534 SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, len,
1535                 unsigned long, home_node, unsigned long, flags)
1536 {
1537         struct mm_struct *mm = current->mm;
1538         struct vm_area_struct *vma, *prev;
1539         struct mempolicy *new, *old;
1540         unsigned long end;
1541         int err = -ENOENT;
1542         VMA_ITERATOR(vmi, mm, start);
1543
1544         start = untagged_addr(start);
1545         if (start & ~PAGE_MASK)
1546                 return -EINVAL;
1547         /*
1548          * flags is used for future extension if any.
1549          */
1550         if (flags != 0)
1551                 return -EINVAL;
1552
1553         /*
1554          * Check home_node is online to avoid accessing uninitialized
1555          * NODE_DATA.
1556          */
1557         if (home_node >= MAX_NUMNODES || !node_online(home_node))
1558                 return -EINVAL;
1559
1560         len = PAGE_ALIGN(len);
1561         end = start + len;
1562
1563         if (end < start)
1564                 return -EINVAL;
1565         if (end == start)
1566                 return 0;
1567         mmap_write_lock(mm);
1568         prev = vma_prev(&vmi);
1569         for_each_vma_range(vmi, vma, end) {
1570                 /*
1571                  * If any vma in the range got policy other than MPOL_BIND
1572                  * or MPOL_PREFERRED_MANY we return error. We don't reset
1573                  * the home node for vmas we already updated before.
1574                  */
1575                 old = vma_policy(vma);
1576                 if (!old) {
1577                         prev = vma;
1578                         continue;
1579                 }
1580                 if (old->mode != MPOL_BIND && old->mode != MPOL_PREFERRED_MANY) {
1581                         err = -EOPNOTSUPP;
1582                         break;
1583                 }
1584                 new = mpol_dup(old);
1585                 if (IS_ERR(new)) {
1586                         err = PTR_ERR(new);
1587                         break;
1588                 }
1589
1590                 vma_start_write(vma);
1591                 new->home_node = home_node;
1592                 err = mbind_range(&vmi, vma, &prev, start, end, new);
1593                 mpol_put(new);
1594                 if (err)
1595                         break;
1596         }
1597         mmap_write_unlock(mm);
1598         return err;
1599 }
1600
1601 SYSCALL_DEFINE6(mbind, unsigned long, start, unsigned long, len,
1602                 unsigned long, mode, const unsigned long __user *, nmask,
1603                 unsigned long, maxnode, unsigned int, flags)
1604 {
1605         return kernel_mbind(start, len, mode, nmask, maxnode, flags);
1606 }
1607
1608 /* Set the process memory policy */
1609 static long kernel_set_mempolicy(int mode, const unsigned long __user *nmask,
1610                                  unsigned long maxnode)
1611 {
1612         unsigned short mode_flags;
1613         nodemask_t nodes;
1614         int lmode = mode;
1615         int err;
1616
1617         err = sanitize_mpol_flags(&lmode, &mode_flags);
1618         if (err)
1619                 return err;
1620
1621         err = get_nodes(&nodes, nmask, maxnode);
1622         if (err)
1623                 return err;
1624
1625         return do_set_mempolicy(lmode, mode_flags, &nodes);
1626 }
1627
1628 SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,
1629                 unsigned long, maxnode)
1630 {
1631         return kernel_set_mempolicy(mode, nmask, maxnode);
1632 }
1633
1634 static int kernel_migrate_pages(pid_t pid, unsigned long maxnode,
1635                                 const unsigned long __user *old_nodes,
1636                                 const unsigned long __user *new_nodes)
1637 {
1638         struct mm_struct *mm = NULL;
1639         struct task_struct *task;
1640         nodemask_t task_nodes;
1641         int err;
1642         nodemask_t *old;
1643         nodemask_t *new;
1644         NODEMASK_SCRATCH(scratch);
1645
1646         if (!scratch)
1647                 return -ENOMEM;
1648
1649         old = &scratch->mask1;
1650         new = &scratch->mask2;
1651
1652         err = get_nodes(old, old_nodes, maxnode);
1653         if (err)
1654                 goto out;
1655
1656         err = get_nodes(new, new_nodes, maxnode);
1657         if (err)
1658                 goto out;
1659
1660         /* Find the mm_struct */
1661         rcu_read_lock();
1662         task = pid ? find_task_by_vpid(pid) : current;
1663         if (!task) {
1664                 rcu_read_unlock();
1665                 err = -ESRCH;
1666                 goto out;
1667         }
1668         get_task_struct(task);
1669
1670         err = -EINVAL;
1671
1672         /*
1673          * Check if this process has the right to modify the specified process.
1674          * Use the regular "ptrace_may_access()" checks.
1675          */
1676         if (!ptrace_may_access(task, PTRACE_MODE_READ_REALCREDS)) {
1677                 rcu_read_unlock();
1678                 err = -EPERM;
1679                 goto out_put;
1680         }
1681         rcu_read_unlock();
1682
1683         task_nodes = cpuset_mems_allowed(task);
1684         /* Is the user allowed to access the target nodes? */
1685         if (!nodes_subset(*new, task_nodes) && !capable(CAP_SYS_NICE)) {
1686                 err = -EPERM;
1687                 goto out_put;
1688         }
1689
1690         task_nodes = cpuset_mems_allowed(current);
1691         nodes_and(*new, *new, task_nodes);
1692         if (nodes_empty(*new))
1693                 goto out_put;
1694
1695         err = security_task_movememory(task);
1696         if (err)
1697                 goto out_put;
1698
1699         mm = get_task_mm(task);
1700         put_task_struct(task);
1701
1702         if (!mm) {
1703                 err = -EINVAL;
1704                 goto out;
1705         }
1706
1707         err = do_migrate_pages(mm, old, new,
1708                 capable(CAP_SYS_NICE) ? MPOL_MF_MOVE_ALL : MPOL_MF_MOVE);
1709
1710         mmput(mm);
1711 out:
1712         NODEMASK_SCRATCH_FREE(scratch);
1713
1714         return err;
1715
1716 out_put:
1717         put_task_struct(task);
1718         goto out;
1719 }
1720
1721 SYSCALL_DEFINE4(migrate_pages, pid_t, pid, unsigned long, maxnode,
1722                 const unsigned long __user *, old_nodes,
1723                 const unsigned long __user *, new_nodes)
1724 {
1725         return kernel_migrate_pages(pid, maxnode, old_nodes, new_nodes);
1726 }
1727
1728 /* Retrieve NUMA policy */
1729 static int kernel_get_mempolicy(int __user *policy,
1730                                 unsigned long __user *nmask,
1731                                 unsigned long maxnode,
1732                                 unsigned long addr,
1733                                 unsigned long flags)
1734 {
1735         int err;
1736         int pval;
1737         nodemask_t nodes;
1738
1739         if (nmask != NULL && maxnode < nr_node_ids)
1740                 return -EINVAL;
1741
1742         addr = untagged_addr(addr);
1743
1744         err = do_get_mempolicy(&pval, &nodes, addr, flags);
1745
1746         if (err)
1747                 return err;
1748
1749         if (policy && put_user(pval, policy))
1750                 return -EFAULT;
1751
1752         if (nmask)
1753                 err = copy_nodes_to_user(nmask, maxnode, &nodes);
1754
1755         return err;
1756 }
1757
1758 SYSCALL_DEFINE5(get_mempolicy, int __user *, policy,
1759                 unsigned long __user *, nmask, unsigned long, maxnode,
1760                 unsigned long, addr, unsigned long, flags)
1761 {
1762         return kernel_get_mempolicy(policy, nmask, maxnode, addr, flags);
1763 }
1764
1765 bool vma_migratable(struct vm_area_struct *vma)
1766 {
1767         if (vma->vm_flags & (VM_IO | VM_PFNMAP))
1768                 return false;
1769
1770         /*
1771          * DAX device mappings require predictable access latency, so avoid
1772          * incurring periodic faults.
1773          */
1774         if (vma_is_dax(vma))
1775                 return false;
1776
1777         if (is_vm_hugetlb_page(vma) &&
1778                 !hugepage_migration_supported(hstate_vma(vma)))
1779                 return false;
1780
1781         /*
1782          * Migration allocates pages in the highest zone. If we cannot
1783          * do so then migration (at least from node to node) is not
1784          * possible.
1785          */
1786         if (vma->vm_file &&
1787                 gfp_zone(mapping_gfp_mask(vma->vm_file->f_mapping))
1788                         < policy_zone)
1789                 return false;
1790         return true;
1791 }
1792
1793 struct mempolicy *__get_vma_policy(struct vm_area_struct *vma,
1794                                    unsigned long addr, pgoff_t *ilx)
1795 {
1796         *ilx = 0;
1797         return (vma->vm_ops && vma->vm_ops->get_policy) ?
1798                 vma->vm_ops->get_policy(vma, addr, ilx) : vma->vm_policy;
1799 }
1800
1801 /*
1802  * get_vma_policy(@vma, @addr, @order, @ilx)
1803  * @vma: virtual memory area whose policy is sought
1804  * @addr: address in @vma for shared policy lookup
1805  * @order: 0, or appropriate huge_page_order for interleaving
1806  * @ilx: interleave index (output), for use only when MPOL_INTERLEAVE or
1807  *       MPOL_WEIGHTED_INTERLEAVE
1808  *
1809  * Returns effective policy for a VMA at specified address.
1810  * Falls back to current->mempolicy or system default policy, as necessary.
1811  * Shared policies [those marked as MPOL_F_SHARED] require an extra reference
1812  * count--added by the get_policy() vm_op, as appropriate--to protect against
1813  * freeing by another task.  It is the caller's responsibility to free the
1814  * extra reference for shared policies.
1815  */
1816 struct mempolicy *get_vma_policy(struct vm_area_struct *vma,
1817                                  unsigned long addr, int order, pgoff_t *ilx)
1818 {
1819         struct mempolicy *pol;
1820
1821         pol = __get_vma_policy(vma, addr, ilx);
1822         if (!pol)
1823                 pol = get_task_policy(current);
1824         if (pol->mode == MPOL_INTERLEAVE ||
1825             pol->mode == MPOL_WEIGHTED_INTERLEAVE) {
1826                 *ilx += vma->vm_pgoff >> order;
1827                 *ilx += (addr - vma->vm_start) >> (PAGE_SHIFT + order);
1828         }
1829         return pol;
1830 }
1831
1832 bool vma_policy_mof(struct vm_area_struct *vma)
1833 {
1834         struct mempolicy *pol;
1835
1836         if (vma->vm_ops && vma->vm_ops->get_policy) {
1837                 bool ret = false;
1838                 pgoff_t ilx;            /* ignored here */
1839
1840                 pol = vma->vm_ops->get_policy(vma, vma->vm_start, &ilx);
1841                 if (pol && (pol->flags & MPOL_F_MOF))
1842                         ret = true;
1843                 mpol_cond_put(pol);
1844
1845                 return ret;
1846         }
1847
1848         pol = vma->vm_policy;
1849         if (!pol)
1850                 pol = get_task_policy(current);
1851
1852         return pol->flags & MPOL_F_MOF;
1853 }
1854
1855 bool apply_policy_zone(struct mempolicy *policy, enum zone_type zone)
1856 {
1857         enum zone_type dynamic_policy_zone = policy_zone;
1858
1859         BUG_ON(dynamic_policy_zone == ZONE_MOVABLE);
1860
1861         /*
1862          * if policy->nodes has movable memory only,
1863          * we apply policy when gfp_zone(gfp) = ZONE_MOVABLE only.
1864          *
1865          * policy->nodes is intersect with node_states[N_MEMORY].
1866          * so if the following test fails, it implies
1867          * policy->nodes has movable memory only.
1868          */
1869         if (!nodes_intersects(policy->nodes, node_states[N_HIGH_MEMORY]))
1870                 dynamic_policy_zone = ZONE_MOVABLE;
1871
1872         return zone >= dynamic_policy_zone;
1873 }
1874
1875 static unsigned int weighted_interleave_nodes(struct mempolicy *policy)
1876 {
1877         unsigned int node;
1878         unsigned int cpuset_mems_cookie;
1879
1880 retry:
1881         /* to prevent miscount use tsk->mems_allowed_seq to detect rebind */
1882         cpuset_mems_cookie = read_mems_allowed_begin();
1883         node = current->il_prev;
1884         if (!current->il_weight || !node_isset(node, policy->nodes)) {
1885                 node = next_node_in(node, policy->nodes);
1886                 if (read_mems_allowed_retry(cpuset_mems_cookie))
1887                         goto retry;
1888                 if (node == MAX_NUMNODES)
1889                         return node;
1890                 current->il_prev = node;
1891                 current->il_weight = get_il_weight(node);
1892         }
1893         current->il_weight--;
1894         return node;
1895 }
1896
1897 /* Do dynamic interleaving for a process */
1898 static unsigned int interleave_nodes(struct mempolicy *policy)
1899 {
1900         unsigned int nid;
1901         unsigned int cpuset_mems_cookie;
1902
1903         /* to prevent miscount, use tsk->mems_allowed_seq to detect rebind */
1904         do {
1905                 cpuset_mems_cookie = read_mems_allowed_begin();
1906                 nid = next_node_in(current->il_prev, policy->nodes);
1907         } while (read_mems_allowed_retry(cpuset_mems_cookie));
1908
1909         if (nid < MAX_NUMNODES)
1910                 current->il_prev = nid;
1911         return nid;
1912 }
1913
1914 /*
1915  * Depending on the memory policy provide a node from which to allocate the
1916  * next slab entry.
1917  */
1918 unsigned int mempolicy_slab_node(void)
1919 {
1920         struct mempolicy *policy;
1921         int node = numa_mem_id();
1922
1923         if (!in_task())
1924                 return node;
1925
1926         policy = current->mempolicy;
1927         if (!policy)
1928                 return node;
1929
1930         switch (policy->mode) {
1931         case MPOL_PREFERRED:
1932                 return first_node(policy->nodes);
1933
1934         case MPOL_INTERLEAVE:
1935                 return interleave_nodes(policy);
1936
1937         case MPOL_WEIGHTED_INTERLEAVE:
1938                 return weighted_interleave_nodes(policy);
1939
1940         case MPOL_BIND:
1941         case MPOL_PREFERRED_MANY:
1942         {
1943                 struct zoneref *z;
1944
1945                 /*
1946                  * Follow bind policy behavior and start allocation at the
1947                  * first node.
1948                  */
1949                 struct zonelist *zonelist;
1950                 enum zone_type highest_zoneidx = gfp_zone(GFP_KERNEL);
1951                 zonelist = &NODE_DATA(node)->node_zonelists[ZONELIST_FALLBACK];
1952                 z = first_zones_zonelist(zonelist, highest_zoneidx,
1953                                                         &policy->nodes);
1954                 return z->zone ? zone_to_nid(z->zone) : node;
1955         }
1956         case MPOL_LOCAL:
1957                 return node;
1958
1959         default:
1960                 BUG();
1961         }
1962 }
1963
1964 static unsigned int read_once_policy_nodemask(struct mempolicy *pol,
1965                                               nodemask_t *mask)
1966 {
1967         /*
1968          * barrier stabilizes the nodemask locally so that it can be iterated
1969          * over safely without concern for changes. Allocators validate node
1970          * selection does not violate mems_allowed, so this is safe.
1971          */
1972         barrier();
1973         memcpy(mask, &pol->nodes, sizeof(nodemask_t));
1974         barrier();
1975         return nodes_weight(*mask);
1976 }
1977
1978 static unsigned int weighted_interleave_nid(struct mempolicy *pol, pgoff_t ilx)
1979 {
1980         nodemask_t nodemask;
1981         unsigned int target, nr_nodes;
1982         u8 *table;
1983         unsigned int weight_total = 0;
1984         u8 weight;
1985         int nid;
1986
1987         nr_nodes = read_once_policy_nodemask(pol, &nodemask);
1988         if (!nr_nodes)
1989                 return numa_node_id();
1990
1991         rcu_read_lock();
1992         table = rcu_dereference(iw_table);
1993         /* calculate the total weight */
1994         for_each_node_mask(nid, nodemask) {
1995                 /* detect system default usage */
1996                 weight = table ? table[nid] : 1;
1997                 weight = weight ? weight : 1;
1998                 weight_total += weight;
1999         }
2000
2001         /* Calculate the node offset based on totals */
2002         target = ilx % weight_total;
2003         nid = first_node(nodemask);
2004         while (target) {
2005                 /* detect system default usage */
2006                 weight = table ? table[nid] : 1;
2007                 weight = weight ? weight : 1;
2008                 if (target < weight)
2009                         break;
2010                 target -= weight;
2011                 nid = next_node_in(nid, nodemask);
2012         }
2013         rcu_read_unlock();
2014         return nid;
2015 }
2016
2017 /*
2018  * Do static interleaving for interleave index @ilx.  Returns the ilx'th
2019  * node in pol->nodes (starting from ilx=0), wrapping around if ilx
2020  * exceeds the number of present nodes.
2021  */
2022 static unsigned int interleave_nid(struct mempolicy *pol, pgoff_t ilx)
2023 {
2024         nodemask_t nodemask;
2025         unsigned int target, nnodes;
2026         int i;
2027         int nid;
2028
2029         nnodes = read_once_policy_nodemask(pol, &nodemask);
2030         if (!nnodes)
2031                 return numa_node_id();
2032         target = ilx % nnodes;
2033         nid = first_node(nodemask);
2034         for (i = 0; i < target; i++)
2035                 nid = next_node(nid, nodemask);
2036         return nid;
2037 }
2038
2039 /*
2040  * Return a nodemask representing a mempolicy for filtering nodes for
2041  * page allocation, together with preferred node id (or the input node id).
2042  */
2043 static nodemask_t *policy_nodemask(gfp_t gfp, struct mempolicy *pol,
2044                                    pgoff_t ilx, int *nid)
2045 {
2046         nodemask_t *nodemask = NULL;
2047
2048         switch (pol->mode) {
2049         case MPOL_PREFERRED:
2050                 /* Override input node id */
2051                 *nid = first_node(pol->nodes);
2052                 break;
2053         case MPOL_PREFERRED_MANY:
2054                 nodemask = &pol->nodes;
2055                 if (pol->home_node != NUMA_NO_NODE)
2056                         *nid = pol->home_node;
2057                 break;
2058         case MPOL_BIND:
2059                 /* Restrict to nodemask (but not on lower zones) */
2060                 if (apply_policy_zone(pol, gfp_zone(gfp)) &&
2061                     cpuset_nodemask_valid_mems_allowed(&pol->nodes))
2062                         nodemask = &pol->nodes;
2063                 if (pol->home_node != NUMA_NO_NODE)
2064                         *nid = pol->home_node;
2065                 /*
2066                  * __GFP_THISNODE shouldn't even be used with the bind policy
2067                  * because we might easily break the expectation to stay on the
2068                  * requested node and not break the policy.
2069                  */
2070                 WARN_ON_ONCE(gfp & __GFP_THISNODE);
2071                 break;
2072         case MPOL_INTERLEAVE:
2073                 /* Override input node id */
2074                 *nid = (ilx == NO_INTERLEAVE_INDEX) ?
2075                         interleave_nodes(pol) : interleave_nid(pol, ilx);
2076                 break;
2077         case MPOL_WEIGHTED_INTERLEAVE:
2078                 *nid = (ilx == NO_INTERLEAVE_INDEX) ?
2079                         weighted_interleave_nodes(pol) :
2080                         weighted_interleave_nid(pol, ilx);
2081                 break;
2082         }
2083
2084         return nodemask;
2085 }
2086
2087 #ifdef CONFIG_HUGETLBFS
2088 /*
2089  * huge_node(@vma, @addr, @gfp_flags, @mpol)
2090  * @vma: virtual memory area whose policy is sought
2091  * @addr: address in @vma for shared policy lookup and interleave policy
2092  * @gfp_flags: for requested zone
2093  * @mpol: pointer to mempolicy pointer for reference counted mempolicy
2094  * @nodemask: pointer to nodemask pointer for 'bind' and 'prefer-many' policy
2095  *
2096  * Returns a nid suitable for a huge page allocation and a pointer
2097  * to the struct mempolicy for conditional unref after allocation.
2098  * If the effective policy is 'bind' or 'prefer-many', returns a pointer
2099  * to the mempolicy's @nodemask for filtering the zonelist.
2100  */
2101 int huge_node(struct vm_area_struct *vma, unsigned long addr, gfp_t gfp_flags,
2102                 struct mempolicy **mpol, nodemask_t **nodemask)
2103 {
2104         pgoff_t ilx;
2105         int nid;
2106
2107         nid = numa_node_id();
2108         *mpol = get_vma_policy(vma, addr, hstate_vma(vma)->order, &ilx);
2109         *nodemask = policy_nodemask(gfp_flags, *mpol, ilx, &nid);
2110         return nid;
2111 }
2112
2113 /*
2114  * init_nodemask_of_mempolicy
2115  *
2116  * If the current task's mempolicy is "default" [NULL], return 'false'
2117  * to indicate default policy.  Otherwise, extract the policy nodemask
2118  * for 'bind' or 'interleave' policy into the argument nodemask, or
2119  * initialize the argument nodemask to contain the single node for
2120  * 'preferred' or 'local' policy and return 'true' to indicate presence
2121  * of non-default mempolicy.
2122  *
2123  * We don't bother with reference counting the mempolicy [mpol_get/put]
2124  * because the current task is examining it's own mempolicy and a task's
2125  * mempolicy is only ever changed by the task itself.
2126  *
2127  * N.B., it is the caller's responsibility to free a returned nodemask.
2128  */
2129 bool init_nodemask_of_mempolicy(nodemask_t *mask)
2130 {
2131         struct mempolicy *mempolicy;
2132
2133         if (!(mask && current->mempolicy))
2134                 return false;
2135
2136         task_lock(current);
2137         mempolicy = current->mempolicy;
2138         switch (mempolicy->mode) {
2139         case MPOL_PREFERRED:
2140         case MPOL_PREFERRED_MANY:
2141         case MPOL_BIND:
2142         case MPOL_INTERLEAVE:
2143         case MPOL_WEIGHTED_INTERLEAVE:
2144                 *mask = mempolicy->nodes;
2145                 break;
2146
2147         case MPOL_LOCAL:
2148                 init_nodemask_of_node(mask, numa_node_id());
2149                 break;
2150
2151         default:
2152                 BUG();
2153         }
2154         task_unlock(current);
2155
2156         return true;
2157 }
2158 #endif
2159
2160 /*
2161  * mempolicy_in_oom_domain
2162  *
2163  * If tsk's mempolicy is "bind", check for intersection between mask and
2164  * the policy nodemask. Otherwise, return true for all other policies
2165  * including "interleave", as a tsk with "interleave" policy may have
2166  * memory allocated from all nodes in system.
2167  *
2168  * Takes task_lock(tsk) to prevent freeing of its mempolicy.
2169  */
2170 bool mempolicy_in_oom_domain(struct task_struct *tsk,
2171                                         const nodemask_t *mask)
2172 {
2173         struct mempolicy *mempolicy;
2174         bool ret = true;
2175
2176         if (!mask)
2177                 return ret;
2178
2179         task_lock(tsk);
2180         mempolicy = tsk->mempolicy;
2181         if (mempolicy && mempolicy->mode == MPOL_BIND)
2182                 ret = nodes_intersects(mempolicy->nodes, *mask);
2183         task_unlock(tsk);
2184
2185         return ret;
2186 }
2187
2188 static struct page *alloc_pages_preferred_many(gfp_t gfp, unsigned int order,
2189                                                 int nid, nodemask_t *nodemask)
2190 {
2191         struct page *page;
2192         gfp_t preferred_gfp;
2193
2194         /*
2195          * This is a two pass approach. The first pass will only try the
2196          * preferred nodes but skip the direct reclaim and allow the
2197          * allocation to fail, while the second pass will try all the
2198          * nodes in system.
2199          */
2200         preferred_gfp = gfp | __GFP_NOWARN;
2201         preferred_gfp &= ~(__GFP_DIRECT_RECLAIM | __GFP_NOFAIL);
2202         page = __alloc_pages(preferred_gfp, order, nid, nodemask);
2203         if (!page)
2204                 page = __alloc_pages(gfp, order, nid, NULL);
2205
2206         return page;
2207 }
2208
2209 /**
2210  * alloc_pages_mpol - Allocate pages according to NUMA mempolicy.
2211  * @gfp: GFP flags.
2212  * @order: Order of the page allocation.
2213  * @pol: Pointer to the NUMA mempolicy.
2214  * @ilx: Index for interleave mempolicy (also distinguishes alloc_pages()).
2215  * @nid: Preferred node (usually numa_node_id() but @mpol may override it).
2216  *
2217  * Return: The page on success or NULL if allocation fails.
2218  */
2219 struct page *alloc_pages_mpol(gfp_t gfp, unsigned int order,
2220                 struct mempolicy *pol, pgoff_t ilx, int nid)
2221 {
2222         nodemask_t *nodemask;
2223         struct page *page;
2224
2225         nodemask = policy_nodemask(gfp, pol, ilx, &nid);
2226
2227         if (pol->mode == MPOL_PREFERRED_MANY)
2228                 return alloc_pages_preferred_many(gfp, order, nid, nodemask);
2229
2230         if (IS_ENABLED(CONFIG_TRANSPARENT_HUGEPAGE) &&
2231             /* filter "hugepage" allocation, unless from alloc_pages() */
2232             order == HPAGE_PMD_ORDER && ilx != NO_INTERLEAVE_INDEX) {
2233                 /*
2234                  * For hugepage allocation and non-interleave policy which
2235                  * allows the current node (or other explicitly preferred
2236                  * node) we only try to allocate from the current/preferred
2237                  * node and don't fall back to other nodes, as the cost of
2238                  * remote accesses would likely offset THP benefits.
2239                  *
2240                  * If the policy is interleave or does not allow the current
2241                  * node in its nodemask, we allocate the standard way.
2242                  */
2243                 if (pol->mode != MPOL_INTERLEAVE &&
2244                     pol->mode != MPOL_WEIGHTED_INTERLEAVE &&
2245                     (!nodemask || node_isset(nid, *nodemask))) {
2246                         /*
2247                          * First, try to allocate THP only on local node, but
2248                          * don't reclaim unnecessarily, just compact.
2249                          */
2250                         page = __alloc_pages_node(nid,
2251                                 gfp | __GFP_THISNODE | __GFP_NORETRY, order);
2252                         if (page || !(gfp & __GFP_DIRECT_RECLAIM))
2253                                 return page;
2254                         /*
2255                          * If hugepage allocations are configured to always
2256                          * synchronous compact or the vma has been madvised
2257                          * to prefer hugepage backing, retry allowing remote
2258                          * memory with both reclaim and compact as well.
2259                          */
2260                 }
2261         }
2262
2263         page = __alloc_pages(gfp, order, nid, nodemask);
2264
2265         if (unlikely(pol->mode == MPOL_INTERLEAVE) && page) {
2266                 /* skip NUMA_INTERLEAVE_HIT update if numa stats is disabled */
2267                 if (static_branch_likely(&vm_numa_stat_key) &&
2268                     page_to_nid(page) == nid) {
2269                         preempt_disable();
2270                         __count_numa_event(page_zone(page), NUMA_INTERLEAVE_HIT);
2271                         preempt_enable();
2272                 }
2273         }
2274
2275         return page;
2276 }
2277
2278 /**
2279  * vma_alloc_folio - Allocate a folio for a VMA.
2280  * @gfp: GFP flags.
2281  * @order: Order of the folio.
2282  * @vma: Pointer to VMA.
2283  * @addr: Virtual address of the allocation.  Must be inside @vma.
2284  * @hugepage: Unused (was: For hugepages try only preferred node if possible).
2285  *
2286  * Allocate a folio for a specific address in @vma, using the appropriate
2287  * NUMA policy.  The caller must hold the mmap_lock of the mm_struct of the
2288  * VMA to prevent it from going away.  Should be used for all allocations
2289  * for folios that will be mapped into user space, excepting hugetlbfs, and
2290  * excepting where direct use of alloc_pages_mpol() is more appropriate.
2291  *
2292  * Return: The folio on success or NULL if allocation fails.
2293  */
2294 struct folio *vma_alloc_folio(gfp_t gfp, int order, struct vm_area_struct *vma,
2295                 unsigned long addr, bool hugepage)
2296 {
2297         struct mempolicy *pol;
2298         pgoff_t ilx;
2299         struct page *page;
2300
2301         pol = get_vma_policy(vma, addr, order, &ilx);
2302         page = alloc_pages_mpol(gfp | __GFP_COMP, order,
2303                                 pol, ilx, numa_node_id());
2304         mpol_cond_put(pol);
2305         return page_rmappable_folio(page);
2306 }
2307 EXPORT_SYMBOL(vma_alloc_folio);
2308
2309 /**
2310  * alloc_pages - Allocate pages.
2311  * @gfp: GFP flags.
2312  * @order: Power of two of number of pages to allocate.
2313  *
2314  * Allocate 1 << @order contiguous pages.  The physical address of the
2315  * first page is naturally aligned (eg an order-3 allocation will be aligned
2316  * to a multiple of 8 * PAGE_SIZE bytes).  The NUMA policy of the current
2317  * process is honoured when in process context.
2318  *
2319  * Context: Can be called from any context, providing the appropriate GFP
2320  * flags are used.
2321  * Return: The page on success or NULL if allocation fails.
2322  */
2323 struct page *alloc_pages(gfp_t gfp, unsigned int order)
2324 {
2325         struct mempolicy *pol = &default_policy;
2326
2327         /*
2328          * No reference counting needed for current->mempolicy
2329          * nor system default_policy
2330          */
2331         if (!in_interrupt() && !(gfp & __GFP_THISNODE))
2332                 pol = get_task_policy(current);
2333
2334         return alloc_pages_mpol(gfp, order,
2335                                 pol, NO_INTERLEAVE_INDEX, numa_node_id());
2336 }
2337 EXPORT_SYMBOL(alloc_pages);
2338
2339 struct folio *folio_alloc(gfp_t gfp, unsigned int order)
2340 {
2341         return page_rmappable_folio(alloc_pages(gfp | __GFP_COMP, order));
2342 }
2343 EXPORT_SYMBOL(folio_alloc);
2344
2345 static unsigned long alloc_pages_bulk_array_interleave(gfp_t gfp,
2346                 struct mempolicy *pol, unsigned long nr_pages,
2347                 struct page **page_array)
2348 {
2349         int nodes;
2350         unsigned long nr_pages_per_node;
2351         int delta;
2352         int i;
2353         unsigned long nr_allocated;
2354         unsigned long total_allocated = 0;
2355
2356         nodes = nodes_weight(pol->nodes);
2357         nr_pages_per_node = nr_pages / nodes;
2358         delta = nr_pages - nodes * nr_pages_per_node;
2359
2360         for (i = 0; i < nodes; i++) {
2361                 if (delta) {
2362                         nr_allocated = __alloc_pages_bulk(gfp,
2363                                         interleave_nodes(pol), NULL,
2364                                         nr_pages_per_node + 1, NULL,
2365                                         page_array);
2366                         delta--;
2367                 } else {
2368                         nr_allocated = __alloc_pages_bulk(gfp,
2369                                         interleave_nodes(pol), NULL,
2370                                         nr_pages_per_node, NULL, page_array);
2371                 }
2372
2373                 page_array += nr_allocated;
2374                 total_allocated += nr_allocated;
2375         }
2376
2377         return total_allocated;
2378 }
2379
2380 static unsigned long alloc_pages_bulk_array_weighted_interleave(gfp_t gfp,
2381                 struct mempolicy *pol, unsigned long nr_pages,
2382                 struct page **page_array)
2383 {
2384         struct task_struct *me = current;
2385         unsigned int cpuset_mems_cookie;
2386         unsigned long total_allocated = 0;
2387         unsigned long nr_allocated = 0;
2388         unsigned long rounds;
2389         unsigned long node_pages, delta;
2390         u8 *table, *weights, weight;
2391         unsigned int weight_total = 0;
2392         unsigned long rem_pages = nr_pages;
2393         nodemask_t nodes;
2394         int nnodes, node;
2395         int resume_node = MAX_NUMNODES - 1;
2396         u8 resume_weight = 0;
2397         int prev_node;
2398         int i;
2399
2400         if (!nr_pages)
2401                 return 0;
2402
2403         /* read the nodes onto the stack, retry if done during rebind */
2404         do {
2405                 cpuset_mems_cookie = read_mems_allowed_begin();
2406                 nnodes = read_once_policy_nodemask(pol, &nodes);
2407         } while (read_mems_allowed_retry(cpuset_mems_cookie));
2408
2409         /* if the nodemask has become invalid, we cannot do anything */
2410         if (!nnodes)
2411                 return 0;
2412
2413         /* Continue allocating from most recent node and adjust the nr_pages */
2414         node = me->il_prev;
2415         weight = me->il_weight;
2416         if (weight && node_isset(node, nodes)) {
2417                 node_pages = min(rem_pages, weight);
2418                 nr_allocated = __alloc_pages_bulk(gfp, node, NULL, node_pages,
2419                                                   NULL, page_array);
2420                 page_array += nr_allocated;
2421                 total_allocated += nr_allocated;
2422                 /* if that's all the pages, no need to interleave */
2423                 if (rem_pages <= weight) {
2424                         me->il_weight -= rem_pages;
2425                         return total_allocated;
2426                 }
2427                 /* Otherwise we adjust remaining pages, continue from there */
2428                 rem_pages -= weight;
2429         }
2430         /* clear active weight in case of an allocation failure */
2431         me->il_weight = 0;
2432         prev_node = node;
2433
2434         /* create a local copy of node weights to operate on outside rcu */
2435         weights = kzalloc(nr_node_ids, GFP_KERNEL);
2436         if (!weights)
2437                 return total_allocated;
2438
2439         rcu_read_lock();
2440         table = rcu_dereference(iw_table);
2441         if (table)
2442                 memcpy(weights, table, nr_node_ids);
2443         rcu_read_unlock();
2444
2445         /* calculate total, detect system default usage */
2446         for_each_node_mask(node, nodes) {
2447                 if (!weights[node])
2448                         weights[node] = 1;
2449                 weight_total += weights[node];
2450         }
2451
2452         /*
2453          * Calculate rounds/partial rounds to minimize __alloc_pages_bulk calls.
2454          * Track which node weighted interleave should resume from.
2455          *
2456          * if (rounds > 0) and (delta == 0), resume_node will always be
2457          * the node following prev_node and its weight.
2458          */
2459         rounds = rem_pages / weight_total;
2460         delta = rem_pages % weight_total;
2461         resume_node = next_node_in(prev_node, nodes);
2462         resume_weight = weights[resume_node];
2463         for (i = 0; i < nnodes; i++) {
2464                 node = next_node_in(prev_node, nodes);
2465                 weight = weights[node];
2466                 node_pages = weight * rounds;
2467                 /* If a delta exists, add this node's portion of the delta */
2468                 if (delta > weight) {
2469                         node_pages += weight;
2470                         delta -= weight;
2471                 } else if (delta) {
2472                         /* when delta is depleted, resume from that node */
2473                         node_pages += delta;
2474                         resume_node = node;
2475                         resume_weight = weight - delta;
2476                         delta = 0;
2477                 }
2478                 /* node_pages can be 0 if an allocation fails and rounds == 0 */
2479                 if (!node_pages)
2480                         break;
2481                 nr_allocated = __alloc_pages_bulk(gfp, node, NULL, node_pages,
2482                                                   NULL, page_array);
2483                 page_array += nr_allocated;
2484                 total_allocated += nr_allocated;
2485                 if (total_allocated == nr_pages)
2486                         break;
2487                 prev_node = node;
2488         }
2489         me->il_prev = resume_node;
2490         me->il_weight = resume_weight;
2491         kfree(weights);
2492         return total_allocated;
2493 }
2494
2495 static unsigned long alloc_pages_bulk_array_preferred_many(gfp_t gfp, int nid,
2496                 struct mempolicy *pol, unsigned long nr_pages,
2497                 struct page **page_array)
2498 {
2499         gfp_t preferred_gfp;
2500         unsigned long nr_allocated = 0;
2501
2502         preferred_gfp = gfp | __GFP_NOWARN;
2503         preferred_gfp &= ~(__GFP_DIRECT_RECLAIM | __GFP_NOFAIL);
2504
2505         nr_allocated  = __alloc_pages_bulk(preferred_gfp, nid, &pol->nodes,
2506                                            nr_pages, NULL, page_array);
2507
2508         if (nr_allocated < nr_pages)
2509                 nr_allocated += __alloc_pages_bulk(gfp, numa_node_id(), NULL,
2510                                 nr_pages - nr_allocated, NULL,
2511                                 page_array + nr_allocated);
2512         return nr_allocated;
2513 }
2514
2515 /* alloc pages bulk and mempolicy should be considered at the
2516  * same time in some situation such as vmalloc.
2517  *
2518  * It can accelerate memory allocation especially interleaving
2519  * allocate memory.
2520  */
2521 unsigned long alloc_pages_bulk_array_mempolicy(gfp_t gfp,
2522                 unsigned long nr_pages, struct page **page_array)
2523 {
2524         struct mempolicy *pol = &default_policy;
2525         nodemask_t *nodemask;
2526         int nid;
2527
2528         if (!in_interrupt() && !(gfp & __GFP_THISNODE))
2529                 pol = get_task_policy(current);
2530
2531         if (pol->mode == MPOL_INTERLEAVE)
2532                 return alloc_pages_bulk_array_interleave(gfp, pol,
2533                                                          nr_pages, page_array);
2534
2535         if (pol->mode == MPOL_WEIGHTED_INTERLEAVE)
2536                 return alloc_pages_bulk_array_weighted_interleave(
2537                                   gfp, pol, nr_pages, page_array);
2538
2539         if (pol->mode == MPOL_PREFERRED_MANY)
2540                 return alloc_pages_bulk_array_preferred_many(gfp,
2541                                 numa_node_id(), pol, nr_pages, page_array);
2542
2543         nid = numa_node_id();
2544         nodemask = policy_nodemask(gfp, pol, NO_INTERLEAVE_INDEX, &nid);
2545         return __alloc_pages_bulk(gfp, nid, nodemask,
2546                                   nr_pages, NULL, page_array);
2547 }
2548
2549 int vma_dup_policy(struct vm_area_struct *src, struct vm_area_struct *dst)
2550 {
2551         struct mempolicy *pol = mpol_dup(src->vm_policy);
2552
2553         if (IS_ERR(pol))
2554                 return PTR_ERR(pol);
2555         dst->vm_policy = pol;
2556         return 0;
2557 }
2558
2559 /*
2560  * If mpol_dup() sees current->cpuset == cpuset_being_rebound, then it
2561  * rebinds the mempolicy its copying by calling mpol_rebind_policy()
2562  * with the mems_allowed returned by cpuset_mems_allowed().  This
2563  * keeps mempolicies cpuset relative after its cpuset moves.  See
2564  * further kernel/cpuset.c update_nodemask().
2565  *
2566  * current's mempolicy may be rebinded by the other task(the task that changes
2567  * cpuset's mems), so we needn't do rebind work for current task.
2568  */
2569
2570 /* Slow path of a mempolicy duplicate */
2571 struct mempolicy *__mpol_dup(struct mempolicy *old)
2572 {
2573         struct mempolicy *new = kmem_cache_alloc(policy_cache, GFP_KERNEL);
2574
2575         if (!new)
2576                 return ERR_PTR(-ENOMEM);
2577
2578         /* task's mempolicy is protected by alloc_lock */
2579         if (old == current->mempolicy) {
2580                 task_lock(current);
2581                 *new = *old;
2582                 task_unlock(current);
2583         } else
2584                 *new = *old;
2585
2586         if (current_cpuset_is_being_rebound()) {
2587                 nodemask_t mems = cpuset_mems_allowed(current);
2588                 mpol_rebind_policy(new, &mems);
2589         }
2590         atomic_set(&new->refcnt, 1);
2591         return new;
2592 }
2593
2594 /* Slow path of a mempolicy comparison */
2595 bool __mpol_equal(struct mempolicy *a, struct mempolicy *b)
2596 {
2597         if (!a || !b)
2598                 return false;
2599         if (a->mode != b->mode)
2600                 return false;
2601         if (a->flags != b->flags)
2602                 return false;
2603         if (a->home_node != b->home_node)
2604                 return false;
2605         if (mpol_store_user_nodemask(a))
2606                 if (!nodes_equal(a->w.user_nodemask, b->w.user_nodemask))
2607                         return false;
2608
2609         switch (a->mode) {
2610         case MPOL_BIND:
2611         case MPOL_INTERLEAVE:
2612         case MPOL_PREFERRED:
2613         case MPOL_PREFERRED_MANY:
2614         case MPOL_WEIGHTED_INTERLEAVE:
2615                 return !!nodes_equal(a->nodes, b->nodes);
2616         case MPOL_LOCAL:
2617                 return true;
2618         default:
2619                 BUG();
2620                 return false;
2621         }
2622 }
2623
2624 /*
2625  * Shared memory backing store policy support.
2626  *
2627  * Remember policies even when nobody has shared memory mapped.
2628  * The policies are kept in Red-Black tree linked from the inode.
2629  * They are protected by the sp->lock rwlock, which should be held
2630  * for any accesses to the tree.
2631  */
2632
2633 /*
2634  * lookup first element intersecting start-end.  Caller holds sp->lock for
2635  * reading or for writing
2636  */
2637 static struct sp_node *sp_lookup(struct shared_policy *sp,
2638                                         pgoff_t start, pgoff_t end)
2639 {
2640         struct rb_node *n = sp->root.rb_node;
2641
2642         while (n) {
2643                 struct sp_node *p = rb_entry(n, struct sp_node, nd);
2644
2645                 if (start >= p->end)
2646                         n = n->rb_right;
2647                 else if (end <= p->start)
2648                         n = n->rb_left;
2649                 else
2650                         break;
2651         }
2652         if (!n)
2653                 return NULL;
2654         for (;;) {
2655                 struct sp_node *w = NULL;
2656                 struct rb_node *prev = rb_prev(n);
2657                 if (!prev)
2658                         break;
2659                 w = rb_entry(prev, struct sp_node, nd);
2660                 if (w->end <= start)
2661                         break;
2662                 n = prev;
2663         }
2664         return rb_entry(n, struct sp_node, nd);
2665 }
2666
2667 /*
2668  * Insert a new shared policy into the list.  Caller holds sp->lock for
2669  * writing.
2670  */
2671 static void sp_insert(struct shared_policy *sp, struct sp_node *new)
2672 {
2673         struct rb_node **p = &sp->root.rb_node;
2674         struct rb_node *parent = NULL;
2675         struct sp_node *nd;
2676
2677         while (*p) {
2678                 parent = *p;
2679                 nd = rb_entry(parent, struct sp_node, nd);
2680                 if (new->start < nd->start)
2681                         p = &(*p)->rb_left;
2682                 else if (new->end > nd->end)
2683                         p = &(*p)->rb_right;
2684                 else
2685                         BUG();
2686         }
2687         rb_link_node(&new->nd, parent, p);
2688         rb_insert_color(&new->nd, &sp->root);
2689 }
2690
2691 /* Find shared policy intersecting idx */
2692 struct mempolicy *mpol_shared_policy_lookup(struct shared_policy *sp,
2693                                                 pgoff_t idx)
2694 {
2695         struct mempolicy *pol = NULL;
2696         struct sp_node *sn;
2697
2698         if (!sp->root.rb_node)
2699                 return NULL;
2700         read_lock(&sp->lock);
2701         sn = sp_lookup(sp, idx, idx+1);
2702         if (sn) {
2703                 mpol_get(sn->policy);
2704                 pol = sn->policy;
2705         }
2706         read_unlock(&sp->lock);
2707         return pol;
2708 }
2709
2710 static void sp_free(struct sp_node *n)
2711 {
2712         mpol_put(n->policy);
2713         kmem_cache_free(sn_cache, n);
2714 }
2715
2716 /**
2717  * mpol_misplaced - check whether current folio node is valid in policy
2718  *
2719  * @folio: folio to be checked
2720  * @vma: vm area where folio mapped
2721  * @addr: virtual address in @vma for shared policy lookup and interleave policy
2722  *
2723  * Lookup current policy node id for vma,addr and "compare to" folio's
2724  * node id.  Policy determination "mimics" alloc_page_vma().
2725  * Called from fault path where we know the vma and faulting address.
2726  *
2727  * Return: NUMA_NO_NODE if the page is in a node that is valid for this
2728  * policy, or a suitable node ID to allocate a replacement folio from.
2729  */
2730 int mpol_misplaced(struct folio *folio, struct vm_area_struct *vma,
2731                    unsigned long addr)
2732 {
2733         struct mempolicy *pol;
2734         pgoff_t ilx;
2735         struct zoneref *z;
2736         int curnid = folio_nid(folio);
2737         int thiscpu = raw_smp_processor_id();
2738         int thisnid = cpu_to_node(thiscpu);
2739         int polnid = NUMA_NO_NODE;
2740         int ret = NUMA_NO_NODE;
2741
2742         pol = get_vma_policy(vma, addr, folio_order(folio), &ilx);
2743         if (!(pol->flags & MPOL_F_MOF))
2744                 goto out;
2745
2746         switch (pol->mode) {
2747         case MPOL_INTERLEAVE:
2748                 polnid = interleave_nid(pol, ilx);
2749                 break;
2750
2751         case MPOL_WEIGHTED_INTERLEAVE:
2752                 polnid = weighted_interleave_nid(pol, ilx);
2753                 break;
2754
2755         case MPOL_PREFERRED:
2756                 if (node_isset(curnid, pol->nodes))
2757                         goto out;
2758                 polnid = first_node(pol->nodes);
2759                 break;
2760
2761         case MPOL_LOCAL:
2762                 polnid = numa_node_id();
2763                 break;
2764
2765         case MPOL_BIND:
2766                 /* Optimize placement among multiple nodes via NUMA balancing */
2767                 if (pol->flags & MPOL_F_MORON) {
2768                         if (node_isset(thisnid, pol->nodes))
2769                                 break;
2770                         goto out;
2771                 }
2772                 fallthrough;
2773
2774         case MPOL_PREFERRED_MANY:
2775                 /*
2776                  * use current page if in policy nodemask,
2777                  * else select nearest allowed node, if any.
2778                  * If no allowed nodes, use current [!misplaced].
2779                  */
2780                 if (node_isset(curnid, pol->nodes))
2781                         goto out;
2782                 z = first_zones_zonelist(
2783                                 node_zonelist(numa_node_id(), GFP_HIGHUSER),
2784                                 gfp_zone(GFP_HIGHUSER),
2785                                 &pol->nodes);
2786                 polnid = zone_to_nid(z->zone);
2787                 break;
2788
2789         default:
2790                 BUG();
2791         }
2792
2793         /* Migrate the folio towards the node whose CPU is referencing it */
2794         if (pol->flags & MPOL_F_MORON) {
2795                 polnid = thisnid;
2796
2797                 if (!should_numa_migrate_memory(current, folio, curnid,
2798                                                 thiscpu))
2799                         goto out;
2800         }
2801
2802         if (curnid != polnid)
2803                 ret = polnid;
2804 out:
2805         mpol_cond_put(pol);
2806
2807         return ret;
2808 }
2809
2810 /*
2811  * Drop the (possibly final) reference to task->mempolicy.  It needs to be
2812  * dropped after task->mempolicy is set to NULL so that any allocation done as
2813  * part of its kmem_cache_free(), such as by KASAN, doesn't reference a freed
2814  * policy.
2815  */
2816 void mpol_put_task_policy(struct task_struct *task)
2817 {
2818         struct mempolicy *pol;
2819
2820         task_lock(task);
2821         pol = task->mempolicy;
2822         task->mempolicy = NULL;
2823         task_unlock(task);
2824         mpol_put(pol);
2825 }
2826
2827 static void sp_delete(struct shared_policy *sp, struct sp_node *n)
2828 {
2829         rb_erase(&n->nd, &sp->root);
2830         sp_free(n);
2831 }
2832
2833 static void sp_node_init(struct sp_node *node, unsigned long start,
2834                         unsigned long end, struct mempolicy *pol)
2835 {
2836         node->start = start;
2837         node->end = end;
2838         node->policy = pol;
2839 }
2840
2841 static struct sp_node *sp_alloc(unsigned long start, unsigned long end,
2842                                 struct mempolicy *pol)
2843 {
2844         struct sp_node *n;
2845         struct mempolicy *newpol;
2846
2847         n = kmem_cache_alloc(sn_cache, GFP_KERNEL);
2848         if (!n)
2849                 return NULL;
2850
2851         newpol = mpol_dup(pol);
2852         if (IS_ERR(newpol)) {
2853                 kmem_cache_free(sn_cache, n);
2854                 return NULL;
2855         }
2856         newpol->flags |= MPOL_F_SHARED;
2857         sp_node_init(n, start, end, newpol);
2858
2859         return n;
2860 }
2861
2862 /* Replace a policy range. */
2863 static int shared_policy_replace(struct shared_policy *sp, pgoff_t start,
2864                                  pgoff_t end, struct sp_node *new)
2865 {
2866         struct sp_node *n;
2867         struct sp_node *n_new = NULL;
2868         struct mempolicy *mpol_new = NULL;
2869         int ret = 0;
2870
2871 restart:
2872         write_lock(&sp->lock);
2873         n = sp_lookup(sp, start, end);
2874         /* Take care of old policies in the same range. */
2875         while (n && n->start < end) {
2876                 struct rb_node *next = rb_next(&n->nd);
2877                 if (n->start >= start) {
2878                         if (n->end <= end)
2879                                 sp_delete(sp, n);
2880                         else
2881                                 n->start = end;
2882                 } else {
2883                         /* Old policy spanning whole new range. */
2884                         if (n->end > end) {
2885                                 if (!n_new)
2886                                         goto alloc_new;
2887
2888                                 *mpol_new = *n->policy;
2889                                 atomic_set(&mpol_new->refcnt, 1);
2890                                 sp_node_init(n_new, end, n->end, mpol_new);
2891                                 n->end = start;
2892                                 sp_insert(sp, n_new);
2893                                 n_new = NULL;
2894                                 mpol_new = NULL;
2895                                 break;
2896                         } else
2897                                 n->end = start;
2898                 }
2899                 if (!next)
2900                         break;
2901                 n = rb_entry(next, struct sp_node, nd);
2902         }
2903         if (new)
2904                 sp_insert(sp, new);
2905         write_unlock(&sp->lock);
2906         ret = 0;
2907
2908 err_out:
2909         if (mpol_new)
2910                 mpol_put(mpol_new);
2911         if (n_new)
2912                 kmem_cache_free(sn_cache, n_new);
2913
2914         return ret;
2915
2916 alloc_new:
2917         write_unlock(&sp->lock);
2918         ret = -ENOMEM;
2919         n_new = kmem_cache_alloc(sn_cache, GFP_KERNEL);
2920         if (!n_new)
2921                 goto err_out;
2922         mpol_new = kmem_cache_alloc(policy_cache, GFP_KERNEL);
2923         if (!mpol_new)
2924                 goto err_out;
2925         atomic_set(&mpol_new->refcnt, 1);
2926         goto restart;
2927 }
2928
2929 /**
2930  * mpol_shared_policy_init - initialize shared policy for inode
2931  * @sp: pointer to inode shared policy
2932  * @mpol:  struct mempolicy to install
2933  *
2934  * Install non-NULL @mpol in inode's shared policy rb-tree.
2935  * On entry, the current task has a reference on a non-NULL @mpol.
2936  * This must be released on exit.
2937  * This is called at get_inode() calls and we can use GFP_KERNEL.
2938  */
2939 void mpol_shared_policy_init(struct shared_policy *sp, struct mempolicy *mpol)
2940 {
2941         int ret;
2942
2943         sp->root = RB_ROOT;             /* empty tree == default mempolicy */
2944         rwlock_init(&sp->lock);
2945
2946         if (mpol) {
2947                 struct sp_node *sn;
2948                 struct mempolicy *npol;
2949                 NODEMASK_SCRATCH(scratch);
2950
2951                 if (!scratch)
2952                         goto put_mpol;
2953
2954                 /* contextualize the tmpfs mount point mempolicy to this file */
2955                 npol = mpol_new(mpol->mode, mpol->flags, &mpol->w.user_nodemask);
2956                 if (IS_ERR(npol))
2957                         goto free_scratch; /* no valid nodemask intersection */
2958
2959                 task_lock(current);
2960                 ret = mpol_set_nodemask(npol, &mpol->w.user_nodemask, scratch);
2961                 task_unlock(current);
2962                 if (ret)
2963                         goto put_npol;
2964
2965                 /* alloc node covering entire file; adds ref to file's npol */
2966                 sn = sp_alloc(0, MAX_LFS_FILESIZE >> PAGE_SHIFT, npol);
2967                 if (sn)
2968                         sp_insert(sp, sn);
2969 put_npol:
2970                 mpol_put(npol); /* drop initial ref on file's npol */
2971 free_scratch:
2972                 NODEMASK_SCRATCH_FREE(scratch);
2973 put_mpol:
2974                 mpol_put(mpol); /* drop our incoming ref on sb mpol */
2975         }
2976 }
2977
2978 int mpol_set_shared_policy(struct shared_policy *sp,
2979                         struct vm_area_struct *vma, struct mempolicy *pol)
2980 {
2981         int err;
2982         struct sp_node *new = NULL;
2983         unsigned long sz = vma_pages(vma);
2984
2985         if (pol) {
2986                 new = sp_alloc(vma->vm_pgoff, vma->vm_pgoff + sz, pol);
2987                 if (!new)
2988                         return -ENOMEM;
2989         }
2990         err = shared_policy_replace(sp, vma->vm_pgoff, vma->vm_pgoff + sz, new);
2991         if (err && new)
2992                 sp_free(new);
2993         return err;
2994 }
2995
2996 /* Free a backing policy store on inode delete. */
2997 void mpol_free_shared_policy(struct shared_policy *sp)
2998 {
2999         struct sp_node *n;
3000         struct rb_node *next;
3001
3002         if (!sp->root.rb_node)
3003                 return;
3004         write_lock(&sp->lock);
3005         next = rb_first(&sp->root);
3006         while (next) {
3007                 n = rb_entry(next, struct sp_node, nd);
3008                 next = rb_next(&n->nd);
3009                 sp_delete(sp, n);
3010         }
3011         write_unlock(&sp->lock);
3012 }
3013
3014 #ifdef CONFIG_NUMA_BALANCING
3015 static int __initdata numabalancing_override;
3016
3017 static void __init check_numabalancing_enable(void)
3018 {
3019         bool numabalancing_default = false;
3020
3021         if (IS_ENABLED(CONFIG_NUMA_BALANCING_DEFAULT_ENABLED))
3022                 numabalancing_default = true;
3023
3024         /* Parsed by setup_numabalancing. override == 1 enables, -1 disables */
3025         if (numabalancing_override)
3026                 set_numabalancing_state(numabalancing_override == 1);
3027
3028         if (num_online_nodes() > 1 && !numabalancing_override) {
3029                 pr_info("%s automatic NUMA balancing. Configure with numa_balancing= or the kernel.numa_balancing sysctl\n",
3030                         numabalancing_default ? "Enabling" : "Disabling");
3031                 set_numabalancing_state(numabalancing_default);
3032         }
3033 }
3034
3035 static int __init setup_numabalancing(char *str)
3036 {
3037         int ret = 0;
3038         if (!str)
3039                 goto out;
3040
3041         if (!strcmp(str, "enable")) {
3042                 numabalancing_override = 1;
3043                 ret = 1;
3044         } else if (!strcmp(str, "disable")) {
3045                 numabalancing_override = -1;
3046                 ret = 1;
3047         }
3048 out:
3049         if (!ret)
3050                 pr_warn("Unable to parse numa_balancing=\n");
3051
3052         return ret;
3053 }
3054 __setup("numa_balancing=", setup_numabalancing);
3055 #else
3056 static inline void __init check_numabalancing_enable(void)
3057 {
3058 }
3059 #endif /* CONFIG_NUMA_BALANCING */
3060
3061 void __init numa_policy_init(void)
3062 {
3063         nodemask_t interleave_nodes;
3064         unsigned long largest = 0;
3065         int nid, prefer = 0;
3066
3067         policy_cache = kmem_cache_create("numa_policy",
3068                                          sizeof(struct mempolicy),
3069                                          0, SLAB_PANIC, NULL);
3070
3071         sn_cache = kmem_cache_create("shared_policy_node",
3072                                      sizeof(struct sp_node),
3073                                      0, SLAB_PANIC, NULL);
3074
3075         for_each_node(nid) {
3076                 preferred_node_policy[nid] = (struct mempolicy) {
3077                         .refcnt = ATOMIC_INIT(1),
3078                         .mode = MPOL_PREFERRED,
3079                         .flags = MPOL_F_MOF | MPOL_F_MORON,
3080                         .nodes = nodemask_of_node(nid),
3081                 };
3082         }
3083
3084         /*
3085          * Set interleaving policy for system init. Interleaving is only
3086          * enabled across suitably sized nodes (default is >= 16MB), or
3087          * fall back to the largest node if they're all smaller.
3088          */
3089         nodes_clear(interleave_nodes);
3090         for_each_node_state(nid, N_MEMORY) {
3091                 unsigned long total_pages = node_present_pages(nid);
3092
3093                 /* Preserve the largest node */
3094                 if (largest < total_pages) {
3095                         largest = total_pages;
3096                         prefer = nid;
3097                 }
3098
3099                 /* Interleave this node? */
3100                 if ((total_pages << PAGE_SHIFT) >= (16 << 20))
3101                         node_set(nid, interleave_nodes);
3102         }
3103
3104         /* All too small, use the largest */
3105         if (unlikely(nodes_empty(interleave_nodes)))
3106                 node_set(prefer, interleave_nodes);
3107
3108         if (do_set_mempolicy(MPOL_INTERLEAVE, 0, &interleave_nodes))
3109                 pr_err("%s: interleaving failed\n", __func__);
3110
3111         check_numabalancing_enable();
3112 }
3113
3114 /* Reset policy of current process to default */
3115 void numa_default_policy(void)
3116 {
3117         do_set_mempolicy(MPOL_DEFAULT, 0, NULL);
3118 }
3119
3120 /*
3121  * Parse and format mempolicy from/to strings
3122  */
3123 static const char * const policy_modes[] =
3124 {
3125         [MPOL_DEFAULT]    = "default",
3126         [MPOL_PREFERRED]  = "prefer",
3127         [MPOL_BIND]       = "bind",
3128         [MPOL_INTERLEAVE] = "interleave",
3129         [MPOL_WEIGHTED_INTERLEAVE] = "weighted interleave",
3130         [MPOL_LOCAL]      = "local",
3131         [MPOL_PREFERRED_MANY]  = "prefer (many)",
3132 };
3133
3134 #ifdef CONFIG_TMPFS
3135 /**
3136  * mpol_parse_str - parse string to mempolicy, for tmpfs mpol mount option.
3137  * @str:  string containing mempolicy to parse
3138  * @mpol:  pointer to struct mempolicy pointer, returned on success.
3139  *
3140  * Format of input:
3141  *      <mode>[=<flags>][:<nodelist>]
3142  *
3143  * Return: %0 on success, else %1
3144  */
3145 int mpol_parse_str(char *str, struct mempolicy **mpol)
3146 {
3147         struct mempolicy *new = NULL;
3148         unsigned short mode_flags;
3149         nodemask_t nodes;
3150         char *nodelist = strchr(str, ':');
3151         char *flags = strchr(str, '=');
3152         int err = 1, mode;
3153
3154         if (flags)
3155                 *flags++ = '\0';        /* terminate mode string */
3156
3157         if (nodelist) {
3158                 /* NUL-terminate mode or flags string */
3159                 *nodelist++ = '\0';
3160                 if (nodelist_parse(nodelist, nodes))
3161                         goto out;
3162                 if (!nodes_subset(nodes, node_states[N_MEMORY]))
3163                         goto out;
3164         } else
3165                 nodes_clear(nodes);
3166
3167         mode = match_string(policy_modes, MPOL_MAX, str);
3168         if (mode < 0)
3169                 goto out;
3170
3171         switch (mode) {
3172         case MPOL_PREFERRED:
3173                 /*
3174                  * Insist on a nodelist of one node only, although later
3175                  * we use first_node(nodes) to grab a single node, so here
3176                  * nodelist (or nodes) cannot be empty.
3177                  */
3178                 if (nodelist) {
3179                         char *rest = nodelist;
3180                         while (isdigit(*rest))
3181                                 rest++;
3182                         if (*rest)
3183                                 goto out;
3184                         if (nodes_empty(nodes))
3185                                 goto out;
3186                 }
3187                 break;
3188         case MPOL_INTERLEAVE:
3189         case MPOL_WEIGHTED_INTERLEAVE:
3190                 /*
3191                  * Default to online nodes with memory if no nodelist
3192                  */
3193                 if (!nodelist)
3194                         nodes = node_states[N_MEMORY];
3195                 break;
3196         case MPOL_LOCAL:
3197                 /*
3198                  * Don't allow a nodelist;  mpol_new() checks flags
3199                  */
3200                 if (nodelist)
3201                         goto out;
3202                 break;
3203         case MPOL_DEFAULT:
3204                 /*
3205                  * Insist on a empty nodelist
3206                  */
3207                 if (!nodelist)
3208                         err = 0;
3209                 goto out;
3210         case MPOL_PREFERRED_MANY:
3211         case MPOL_BIND:
3212                 /*
3213                  * Insist on a nodelist
3214                  */
3215                 if (!nodelist)
3216                         goto out;
3217         }
3218
3219         mode_flags = 0;
3220         if (flags) {
3221                 /*
3222                  * Currently, we only support two mutually exclusive
3223                  * mode flags.
3224                  */
3225                 if (!strcmp(flags, "static"))
3226                         mode_flags |= MPOL_F_STATIC_NODES;
3227                 else if (!strcmp(flags, "relative"))
3228                         mode_flags |= MPOL_F_RELATIVE_NODES;
3229                 else
3230                         goto out;
3231         }
3232
3233         new = mpol_new(mode, mode_flags, &nodes);
3234         if (IS_ERR(new))
3235                 goto out;
3236
3237         /*
3238          * Save nodes for mpol_to_str() to show the tmpfs mount options
3239          * for /proc/mounts, /proc/pid/mounts and /proc/pid/mountinfo.
3240          */
3241         if (mode != MPOL_PREFERRED) {
3242                 new->nodes = nodes;
3243         } else if (nodelist) {
3244                 nodes_clear(new->nodes);
3245                 node_set(first_node(nodes), new->nodes);
3246         } else {
3247                 new->mode = MPOL_LOCAL;
3248         }
3249
3250         /*
3251          * Save nodes for contextualization: this will be used to "clone"
3252          * the mempolicy in a specific context [cpuset] at a later time.
3253          */
3254         new->w.user_nodemask = nodes;
3255
3256         err = 0;
3257
3258 out:
3259         /* Restore string for error message */
3260         if (nodelist)
3261                 *--nodelist = ':';
3262         if (flags)
3263                 *--flags = '=';
3264         if (!err)
3265                 *mpol = new;
3266         return err;
3267 }
3268 #endif /* CONFIG_TMPFS */
3269
3270 /**
3271  * mpol_to_str - format a mempolicy structure for printing
3272  * @buffer:  to contain formatted mempolicy string
3273  * @maxlen:  length of @buffer
3274  * @pol:  pointer to mempolicy to be formatted
3275  *
3276  * Convert @pol into a string.  If @buffer is too short, truncate the string.
3277  * Recommend a @maxlen of at least 32 for the longest mode, "interleave", the
3278  * longest flag, "relative", and to display at least a few node ids.
3279  */
3280 void mpol_to_str(char *buffer, int maxlen, struct mempolicy *pol)
3281 {
3282         char *p = buffer;
3283         nodemask_t nodes = NODE_MASK_NONE;
3284         unsigned short mode = MPOL_DEFAULT;
3285         unsigned short flags = 0;
3286
3287         if (pol && pol != &default_policy && !(pol->flags & MPOL_F_MORON)) {
3288                 mode = pol->mode;
3289                 flags = pol->flags;
3290         }
3291
3292         switch (mode) {
3293         case MPOL_DEFAULT:
3294         case MPOL_LOCAL:
3295                 break;
3296         case MPOL_PREFERRED:
3297         case MPOL_PREFERRED_MANY:
3298         case MPOL_BIND:
3299         case MPOL_INTERLEAVE:
3300         case MPOL_WEIGHTED_INTERLEAVE:
3301                 nodes = pol->nodes;
3302                 break;
3303         default:
3304                 WARN_ON_ONCE(1);
3305                 snprintf(p, maxlen, "unknown");
3306                 return;
3307         }
3308
3309         p += snprintf(p, maxlen, "%s", policy_modes[mode]);
3310
3311         if (flags & MPOL_MODE_FLAGS) {
3312                 p += snprintf(p, buffer + maxlen - p, "=");
3313
3314                 /*
3315                  * Currently, the only defined flags are mutually exclusive
3316                  */
3317                 if (flags & MPOL_F_STATIC_NODES)
3318                         p += snprintf(p, buffer + maxlen - p, "static");
3319                 else if (flags & MPOL_F_RELATIVE_NODES)
3320                         p += snprintf(p, buffer + maxlen - p, "relative");
3321         }
3322
3323         if (!nodes_empty(nodes))
3324                 p += scnprintf(p, buffer + maxlen - p, ":%*pbl",
3325                                nodemask_pr_args(&nodes));
3326 }
3327
3328 #ifdef CONFIG_SYSFS
3329 struct iw_node_attr {
3330         struct kobj_attribute kobj_attr;
3331         int nid;
3332 };
3333
3334 static ssize_t node_show(struct kobject *kobj, struct kobj_attribute *attr,
3335                          char *buf)
3336 {
3337         struct iw_node_attr *node_attr;
3338         u8 weight;
3339
3340         node_attr = container_of(attr, struct iw_node_attr, kobj_attr);
3341         weight = get_il_weight(node_attr->nid);
3342         return sysfs_emit(buf, "%d\n", weight);
3343 }
3344
3345 static ssize_t node_store(struct kobject *kobj, struct kobj_attribute *attr,
3346                           const char *buf, size_t count)
3347 {
3348         struct iw_node_attr *node_attr;
3349         u8 *new;
3350         u8 *old;
3351         u8 weight = 0;
3352
3353         node_attr = container_of(attr, struct iw_node_attr, kobj_attr);
3354         if (count == 0 || sysfs_streq(buf, ""))
3355                 weight = 0;
3356         else if (kstrtou8(buf, 0, &weight))
3357                 return -EINVAL;
3358
3359         new = kzalloc(nr_node_ids, GFP_KERNEL);
3360         if (!new)
3361                 return -ENOMEM;
3362
3363         mutex_lock(&iw_table_lock);
3364         old = rcu_dereference_protected(iw_table,
3365                                         lockdep_is_held(&iw_table_lock));
3366         if (old)
3367                 memcpy(new, old, nr_node_ids);
3368         new[node_attr->nid] = weight;
3369         rcu_assign_pointer(iw_table, new);
3370         mutex_unlock(&iw_table_lock);
3371         synchronize_rcu();
3372         kfree(old);
3373         return count;
3374 }
3375
3376 static struct iw_node_attr **node_attrs;
3377
3378 static void sysfs_wi_node_release(struct iw_node_attr *node_attr,
3379                                   struct kobject *parent)
3380 {
3381         if (!node_attr)
3382                 return;
3383         sysfs_remove_file(parent, &node_attr->kobj_attr.attr);
3384         kfree(node_attr->kobj_attr.attr.name);
3385         kfree(node_attr);
3386 }
3387
3388 static void sysfs_wi_release(struct kobject *wi_kobj)
3389 {
3390         int i;
3391
3392         for (i = 0; i < nr_node_ids; i++)
3393                 sysfs_wi_node_release(node_attrs[i], wi_kobj);
3394         kobject_put(wi_kobj);
3395 }
3396
3397 static const struct kobj_type wi_ktype = {
3398         .sysfs_ops = &kobj_sysfs_ops,
3399         .release = sysfs_wi_release,
3400 };
3401
3402 static int add_weight_node(int nid, struct kobject *wi_kobj)
3403 {
3404         struct iw_node_attr *node_attr;
3405         char *name;
3406
3407         node_attr = kzalloc(sizeof(*node_attr), GFP_KERNEL);
3408         if (!node_attr)
3409                 return -ENOMEM;
3410
3411         name = kasprintf(GFP_KERNEL, "node%d", nid);
3412         if (!name) {
3413                 kfree(node_attr);
3414                 return -ENOMEM;
3415         }
3416
3417         sysfs_attr_init(&node_attr->kobj_attr.attr);
3418         node_attr->kobj_attr.attr.name = name;
3419         node_attr->kobj_attr.attr.mode = 0644;
3420         node_attr->kobj_attr.show = node_show;
3421         node_attr->kobj_attr.store = node_store;
3422         node_attr->nid = nid;
3423
3424         if (sysfs_create_file(wi_kobj, &node_attr->kobj_attr.attr)) {
3425                 kfree(node_attr->kobj_attr.attr.name);
3426                 kfree(node_attr);
3427                 pr_err("failed to add attribute to weighted_interleave\n");
3428                 return -ENOMEM;
3429         }
3430
3431         node_attrs[nid] = node_attr;
3432         return 0;
3433 }
3434
3435 static int add_weighted_interleave_group(struct kobject *root_kobj)
3436 {
3437         struct kobject *wi_kobj;
3438         int nid, err;
3439
3440         wi_kobj = kzalloc(sizeof(struct kobject), GFP_KERNEL);
3441         if (!wi_kobj)
3442                 return -ENOMEM;
3443
3444         err = kobject_init_and_add(wi_kobj, &wi_ktype, root_kobj,
3445                                    "weighted_interleave");
3446         if (err) {
3447                 kfree(wi_kobj);
3448                 return err;
3449         }
3450
3451         for_each_node_state(nid, N_POSSIBLE) {
3452                 err = add_weight_node(nid, wi_kobj);
3453                 if (err) {
3454                         pr_err("failed to add sysfs [node%d]\n", nid);
3455                         break;
3456                 }
3457         }
3458         if (err)
3459                 kobject_put(wi_kobj);
3460         return 0;
3461 }
3462
3463 static void mempolicy_kobj_release(struct kobject *kobj)
3464 {
3465         u8 *old;
3466
3467         mutex_lock(&iw_table_lock);
3468         old = rcu_dereference_protected(iw_table,
3469                                         lockdep_is_held(&iw_table_lock));
3470         rcu_assign_pointer(iw_table, NULL);
3471         mutex_unlock(&iw_table_lock);
3472         synchronize_rcu();
3473         kfree(old);
3474         kfree(node_attrs);
3475         kfree(kobj);
3476 }
3477
3478 static const struct kobj_type mempolicy_ktype = {
3479         .release = mempolicy_kobj_release
3480 };
3481
3482 static int __init mempolicy_sysfs_init(void)
3483 {
3484         int err;
3485         static struct kobject *mempolicy_kobj;
3486
3487         mempolicy_kobj = kzalloc(sizeof(*mempolicy_kobj), GFP_KERNEL);
3488         if (!mempolicy_kobj) {
3489                 err = -ENOMEM;
3490                 goto err_out;
3491         }
3492
3493         node_attrs = kcalloc(nr_node_ids, sizeof(struct iw_node_attr *),
3494                              GFP_KERNEL);
3495         if (!node_attrs) {
3496                 err = -ENOMEM;
3497                 goto mempol_out;
3498         }
3499
3500         err = kobject_init_and_add(mempolicy_kobj, &mempolicy_ktype, mm_kobj,
3501                                    "mempolicy");
3502         if (err)
3503                 goto node_out;
3504
3505         err = add_weighted_interleave_group(mempolicy_kobj);
3506         if (err) {
3507                 pr_err("mempolicy sysfs structure failed to initialize\n");
3508                 kobject_put(mempolicy_kobj);
3509                 return err;
3510         }
3511
3512         return err;
3513 node_out:
3514         kfree(node_attrs);
3515 mempol_out:
3516         kfree(mempolicy_kobj);
3517 err_out:
3518         pr_err("failed to add mempolicy kobject to the system\n");
3519         return err;
3520 }
3521
3522 late_initcall(mempolicy_sysfs_init);
3523 #endif /* CONFIG_SYSFS */