Merge tag 'sched-core-2024-09-19' of git://git.kernel.org/pub/scm/linux/kernel/git...
[linux-2.6-block.git] / mm / mempolicy.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Simple NUMA memory policy for the Linux kernel.
4  *
5  * Copyright 2003,2004 Andi Kleen, SuSE Labs.
6  * (C) Copyright 2005 Christoph Lameter, Silicon Graphics, Inc.
7  *
8  * NUMA policy allows the user to give hints in which node(s) memory should
9  * be allocated.
10  *
11  * Support four policies per VMA and per process:
12  *
13  * The VMA policy has priority over the process policy for a page fault.
14  *
15  * interleave     Allocate memory interleaved over a set of nodes,
16  *                with normal fallback if it fails.
17  *                For VMA based allocations this interleaves based on the
18  *                offset into the backing object or offset into the mapping
19  *                for anonymous memory. For process policy an process counter
20  *                is used.
21  *
22  * weighted interleave
23  *                Allocate memory interleaved over a set of nodes based on
24  *                a set of weights (per-node), with normal fallback if it
25  *                fails.  Otherwise operates the same as interleave.
26  *                Example: nodeset(0,1) & weights (2,1) - 2 pages allocated
27  *                on node 0 for every 1 page allocated on node 1.
28  *
29  * bind           Only allocate memory on a specific set of nodes,
30  *                no fallback.
31  *                FIXME: memory is allocated starting with the first node
32  *                to the last. It would be better if bind would truly restrict
33  *                the allocation to memory nodes instead
34  *
35  * preferred      Try a specific node first before normal fallback.
36  *                As a special case NUMA_NO_NODE here means do the allocation
37  *                on the local CPU. This is normally identical to default,
38  *                but useful to set in a VMA when you have a non default
39  *                process policy.
40  *
41  * preferred many Try a set of nodes first before normal fallback. This is
42  *                similar to preferred without the special case.
43  *
44  * default        Allocate on the local node first, or when on a VMA
45  *                use the process policy. This is what Linux always did
46  *                in a NUMA aware kernel and still does by, ahem, default.
47  *
48  * The process policy is applied for most non interrupt memory allocations
49  * in that process' context. Interrupts ignore the policies and always
50  * try to allocate on the local CPU. The VMA policy is only applied for memory
51  * allocations for a VMA in the VM.
52  *
53  * Currently there are a few corner cases in swapping where the policy
54  * is not applied, but the majority should be handled. When process policy
55  * is used it is not remembered over swap outs/swap ins.
56  *
57  * Only the highest zone in the zone hierarchy gets policied. Allocations
58  * requesting a lower zone just use default policy. This implies that
59  * on systems with highmem kernel lowmem allocation don't get policied.
60  * Same with GFP_DMA allocations.
61  *
62  * For shmem/tmpfs shared memory the policy is shared between
63  * all users and remembered even when nobody has memory mapped.
64  */
65
66 /* Notebook:
67    fix mmap readahead to honour policy and enable policy for any page cache
68    object
69    statistics for bigpages
70    global policy for page cache? currently it uses process policy. Requires
71    first item above.
72    handle mremap for shared memory (currently ignored for the policy)
73    grows down?
74    make bind policy root only? It can trigger oom much faster and the
75    kernel is not always grateful with that.
76 */
77
78 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
79
80 #include <linux/mempolicy.h>
81 #include <linux/pagewalk.h>
82 #include <linux/highmem.h>
83 #include <linux/hugetlb.h>
84 #include <linux/kernel.h>
85 #include <linux/sched.h>
86 #include <linux/sched/mm.h>
87 #include <linux/sched/numa_balancing.h>
88 #include <linux/sched/task.h>
89 #include <linux/nodemask.h>
90 #include <linux/cpuset.h>
91 #include <linux/slab.h>
92 #include <linux/string.h>
93 #include <linux/export.h>
94 #include <linux/nsproxy.h>
95 #include <linux/interrupt.h>
96 #include <linux/init.h>
97 #include <linux/compat.h>
98 #include <linux/ptrace.h>
99 #include <linux/swap.h>
100 #include <linux/seq_file.h>
101 #include <linux/proc_fs.h>
102 #include <linux/migrate.h>
103 #include <linux/ksm.h>
104 #include <linux/rmap.h>
105 #include <linux/security.h>
106 #include <linux/syscalls.h>
107 #include <linux/ctype.h>
108 #include <linux/mm_inline.h>
109 #include <linux/mmu_notifier.h>
110 #include <linux/printk.h>
111 #include <linux/swapops.h>
112
113 #include <asm/tlbflush.h>
114 #include <asm/tlb.h>
115 #include <linux/uaccess.h>
116
117 #include "internal.h"
118
119 /* Internal flags */
120 #define MPOL_MF_DISCONTIG_OK (MPOL_MF_INTERNAL << 0)    /* Skip checks for continuous vmas */
121 #define MPOL_MF_INVERT       (MPOL_MF_INTERNAL << 1)    /* Invert check for nodemask */
122 #define MPOL_MF_WRLOCK       (MPOL_MF_INTERNAL << 2)    /* Write-lock walked vmas */
123
124 static struct kmem_cache *policy_cache;
125 static struct kmem_cache *sn_cache;
126
127 /* Highest zone. An specific allocation for a zone below that is not
128    policied. */
129 enum zone_type policy_zone = 0;
130
131 /*
132  * run-time system-wide default policy => local allocation
133  */
134 static struct mempolicy default_policy = {
135         .refcnt = ATOMIC_INIT(1), /* never free it */
136         .mode = MPOL_LOCAL,
137 };
138
139 static struct mempolicy preferred_node_policy[MAX_NUMNODES];
140
141 /*
142  * iw_table is the sysfs-set interleave weight table, a value of 0 denotes
143  * system-default value should be used. A NULL iw_table also denotes that
144  * system-default values should be used. Until the system-default table
145  * is implemented, the system-default is always 1.
146  *
147  * iw_table is RCU protected
148  */
149 static u8 __rcu *iw_table;
150 static DEFINE_MUTEX(iw_table_lock);
151
152 static u8 get_il_weight(int node)
153 {
154         u8 *table;
155         u8 weight;
156
157         rcu_read_lock();
158         table = rcu_dereference(iw_table);
159         /* if no iw_table, use system default */
160         weight = table ? table[node] : 1;
161         /* if value in iw_table is 0, use system default */
162         weight = weight ? weight : 1;
163         rcu_read_unlock();
164         return weight;
165 }
166
167 /**
168  * numa_nearest_node - Find nearest node by state
169  * @node: Node id to start the search
170  * @state: State to filter the search
171  *
172  * Lookup the closest node by distance if @nid is not in state.
173  *
174  * Return: this @node if it is in state, otherwise the closest node by distance
175  */
176 int numa_nearest_node(int node, unsigned int state)
177 {
178         int min_dist = INT_MAX, dist, n, min_node;
179
180         if (state >= NR_NODE_STATES)
181                 return -EINVAL;
182
183         if (node == NUMA_NO_NODE || node_state(node, state))
184                 return node;
185
186         min_node = node;
187         for_each_node_state(n, state) {
188                 dist = node_distance(node, n);
189                 if (dist < min_dist) {
190                         min_dist = dist;
191                         min_node = n;
192                 }
193         }
194
195         return min_node;
196 }
197 EXPORT_SYMBOL_GPL(numa_nearest_node);
198
199 struct mempolicy *get_task_policy(struct task_struct *p)
200 {
201         struct mempolicy *pol = p->mempolicy;
202         int node;
203
204         if (pol)
205                 return pol;
206
207         node = numa_node_id();
208         if (node != NUMA_NO_NODE) {
209                 pol = &preferred_node_policy[node];
210                 /* preferred_node_policy is not initialised early in boot */
211                 if (pol->mode)
212                         return pol;
213         }
214
215         return &default_policy;
216 }
217
218 static const struct mempolicy_operations {
219         int (*create)(struct mempolicy *pol, const nodemask_t *nodes);
220         void (*rebind)(struct mempolicy *pol, const nodemask_t *nodes);
221 } mpol_ops[MPOL_MAX];
222
223 static inline int mpol_store_user_nodemask(const struct mempolicy *pol)
224 {
225         return pol->flags & MPOL_MODE_FLAGS;
226 }
227
228 static void mpol_relative_nodemask(nodemask_t *ret, const nodemask_t *orig,
229                                    const nodemask_t *rel)
230 {
231         nodemask_t tmp;
232         nodes_fold(tmp, *orig, nodes_weight(*rel));
233         nodes_onto(*ret, tmp, *rel);
234 }
235
236 static int mpol_new_nodemask(struct mempolicy *pol, const nodemask_t *nodes)
237 {
238         if (nodes_empty(*nodes))
239                 return -EINVAL;
240         pol->nodes = *nodes;
241         return 0;
242 }
243
244 static int mpol_new_preferred(struct mempolicy *pol, const nodemask_t *nodes)
245 {
246         if (nodes_empty(*nodes))
247                 return -EINVAL;
248
249         nodes_clear(pol->nodes);
250         node_set(first_node(*nodes), pol->nodes);
251         return 0;
252 }
253
254 /*
255  * mpol_set_nodemask is called after mpol_new() to set up the nodemask, if
256  * any, for the new policy.  mpol_new() has already validated the nodes
257  * parameter with respect to the policy mode and flags.
258  *
259  * Must be called holding task's alloc_lock to protect task's mems_allowed
260  * and mempolicy.  May also be called holding the mmap_lock for write.
261  */
262 static int mpol_set_nodemask(struct mempolicy *pol,
263                      const nodemask_t *nodes, struct nodemask_scratch *nsc)
264 {
265         int ret;
266
267         /*
268          * Default (pol==NULL) resp. local memory policies are not a
269          * subject of any remapping. They also do not need any special
270          * constructor.
271          */
272         if (!pol || pol->mode == MPOL_LOCAL)
273                 return 0;
274
275         /* Check N_MEMORY */
276         nodes_and(nsc->mask1,
277                   cpuset_current_mems_allowed, node_states[N_MEMORY]);
278
279         VM_BUG_ON(!nodes);
280
281         if (pol->flags & MPOL_F_RELATIVE_NODES)
282                 mpol_relative_nodemask(&nsc->mask2, nodes, &nsc->mask1);
283         else
284                 nodes_and(nsc->mask2, *nodes, nsc->mask1);
285
286         if (mpol_store_user_nodemask(pol))
287                 pol->w.user_nodemask = *nodes;
288         else
289                 pol->w.cpuset_mems_allowed = cpuset_current_mems_allowed;
290
291         ret = mpol_ops[pol->mode].create(pol, &nsc->mask2);
292         return ret;
293 }
294
295 /*
296  * This function just creates a new policy, does some check and simple
297  * initialization. You must invoke mpol_set_nodemask() to set nodes.
298  */
299 static struct mempolicy *mpol_new(unsigned short mode, unsigned short flags,
300                                   nodemask_t *nodes)
301 {
302         struct mempolicy *policy;
303
304         if (mode == MPOL_DEFAULT) {
305                 if (nodes && !nodes_empty(*nodes))
306                         return ERR_PTR(-EINVAL);
307                 return NULL;
308         }
309         VM_BUG_ON(!nodes);
310
311         /*
312          * MPOL_PREFERRED cannot be used with MPOL_F_STATIC_NODES or
313          * MPOL_F_RELATIVE_NODES if the nodemask is empty (local allocation).
314          * All other modes require a valid pointer to a non-empty nodemask.
315          */
316         if (mode == MPOL_PREFERRED) {
317                 if (nodes_empty(*nodes)) {
318                         if (((flags & MPOL_F_STATIC_NODES) ||
319                              (flags & MPOL_F_RELATIVE_NODES)))
320                                 return ERR_PTR(-EINVAL);
321
322                         mode = MPOL_LOCAL;
323                 }
324         } else if (mode == MPOL_LOCAL) {
325                 if (!nodes_empty(*nodes) ||
326                     (flags & MPOL_F_STATIC_NODES) ||
327                     (flags & MPOL_F_RELATIVE_NODES))
328                         return ERR_PTR(-EINVAL);
329         } else if (nodes_empty(*nodes))
330                 return ERR_PTR(-EINVAL);
331
332         policy = kmem_cache_alloc(policy_cache, GFP_KERNEL);
333         if (!policy)
334                 return ERR_PTR(-ENOMEM);
335         atomic_set(&policy->refcnt, 1);
336         policy->mode = mode;
337         policy->flags = flags;
338         policy->home_node = NUMA_NO_NODE;
339
340         return policy;
341 }
342
343 /* Slow path of a mpol destructor. */
344 void __mpol_put(struct mempolicy *pol)
345 {
346         if (!atomic_dec_and_test(&pol->refcnt))
347                 return;
348         kmem_cache_free(policy_cache, pol);
349 }
350
351 static void mpol_rebind_default(struct mempolicy *pol, const nodemask_t *nodes)
352 {
353 }
354
355 static void mpol_rebind_nodemask(struct mempolicy *pol, const nodemask_t *nodes)
356 {
357         nodemask_t tmp;
358
359         if (pol->flags & MPOL_F_STATIC_NODES)
360                 nodes_and(tmp, pol->w.user_nodemask, *nodes);
361         else if (pol->flags & MPOL_F_RELATIVE_NODES)
362                 mpol_relative_nodemask(&tmp, &pol->w.user_nodemask, nodes);
363         else {
364                 nodes_remap(tmp, pol->nodes, pol->w.cpuset_mems_allowed,
365                                                                 *nodes);
366                 pol->w.cpuset_mems_allowed = *nodes;
367         }
368
369         if (nodes_empty(tmp))
370                 tmp = *nodes;
371
372         pol->nodes = tmp;
373 }
374
375 static void mpol_rebind_preferred(struct mempolicy *pol,
376                                                 const nodemask_t *nodes)
377 {
378         pol->w.cpuset_mems_allowed = *nodes;
379 }
380
381 /*
382  * mpol_rebind_policy - Migrate a policy to a different set of nodes
383  *
384  * Per-vma policies are protected by mmap_lock. Allocations using per-task
385  * policies are protected by task->mems_allowed_seq to prevent a premature
386  * OOM/allocation failure due to parallel nodemask modification.
387  */
388 static void mpol_rebind_policy(struct mempolicy *pol, const nodemask_t *newmask)
389 {
390         if (!pol || pol->mode == MPOL_LOCAL)
391                 return;
392         if (!mpol_store_user_nodemask(pol) &&
393             nodes_equal(pol->w.cpuset_mems_allowed, *newmask))
394                 return;
395
396         mpol_ops[pol->mode].rebind(pol, newmask);
397 }
398
399 /*
400  * Wrapper for mpol_rebind_policy() that just requires task
401  * pointer, and updates task mempolicy.
402  *
403  * Called with task's alloc_lock held.
404  */
405 void mpol_rebind_task(struct task_struct *tsk, const nodemask_t *new)
406 {
407         mpol_rebind_policy(tsk->mempolicy, new);
408 }
409
410 /*
411  * Rebind each vma in mm to new nodemask.
412  *
413  * Call holding a reference to mm.  Takes mm->mmap_lock during call.
414  */
415 void mpol_rebind_mm(struct mm_struct *mm, nodemask_t *new)
416 {
417         struct vm_area_struct *vma;
418         VMA_ITERATOR(vmi, mm, 0);
419
420         mmap_write_lock(mm);
421         for_each_vma(vmi, vma) {
422                 vma_start_write(vma);
423                 mpol_rebind_policy(vma->vm_policy, new);
424         }
425         mmap_write_unlock(mm);
426 }
427
428 static const struct mempolicy_operations mpol_ops[MPOL_MAX] = {
429         [MPOL_DEFAULT] = {
430                 .rebind = mpol_rebind_default,
431         },
432         [MPOL_INTERLEAVE] = {
433                 .create = mpol_new_nodemask,
434                 .rebind = mpol_rebind_nodemask,
435         },
436         [MPOL_PREFERRED] = {
437                 .create = mpol_new_preferred,
438                 .rebind = mpol_rebind_preferred,
439         },
440         [MPOL_BIND] = {
441                 .create = mpol_new_nodemask,
442                 .rebind = mpol_rebind_nodemask,
443         },
444         [MPOL_LOCAL] = {
445                 .rebind = mpol_rebind_default,
446         },
447         [MPOL_PREFERRED_MANY] = {
448                 .create = mpol_new_nodemask,
449                 .rebind = mpol_rebind_preferred,
450         },
451         [MPOL_WEIGHTED_INTERLEAVE] = {
452                 .create = mpol_new_nodemask,
453                 .rebind = mpol_rebind_nodemask,
454         },
455 };
456
457 static bool migrate_folio_add(struct folio *folio, struct list_head *foliolist,
458                                 unsigned long flags);
459 static nodemask_t *policy_nodemask(gfp_t gfp, struct mempolicy *pol,
460                                 pgoff_t ilx, int *nid);
461
462 static bool strictly_unmovable(unsigned long flags)
463 {
464         /*
465          * STRICT without MOVE flags lets do_mbind() fail immediately with -EIO
466          * if any misplaced page is found.
467          */
468         return (flags & (MPOL_MF_STRICT | MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) ==
469                          MPOL_MF_STRICT;
470 }
471
472 struct migration_mpol {         /* for alloc_migration_target_by_mpol() */
473         struct mempolicy *pol;
474         pgoff_t ilx;
475 };
476
477 struct queue_pages {
478         struct list_head *pagelist;
479         unsigned long flags;
480         nodemask_t *nmask;
481         unsigned long start;
482         unsigned long end;
483         struct vm_area_struct *first;
484         struct folio *large;            /* note last large folio encountered */
485         long nr_failed;                 /* could not be isolated at this time */
486 };
487
488 /*
489  * Check if the folio's nid is in qp->nmask.
490  *
491  * If MPOL_MF_INVERT is set in qp->flags, check if the nid is
492  * in the invert of qp->nmask.
493  */
494 static inline bool queue_folio_required(struct folio *folio,
495                                         struct queue_pages *qp)
496 {
497         int nid = folio_nid(folio);
498         unsigned long flags = qp->flags;
499
500         return node_isset(nid, *qp->nmask) == !(flags & MPOL_MF_INVERT);
501 }
502
503 static void queue_folios_pmd(pmd_t *pmd, struct mm_walk *walk)
504 {
505         struct folio *folio;
506         struct queue_pages *qp = walk->private;
507
508         if (unlikely(is_pmd_migration_entry(*pmd))) {
509                 qp->nr_failed++;
510                 return;
511         }
512         folio = pmd_folio(*pmd);
513         if (is_huge_zero_folio(folio)) {
514                 walk->action = ACTION_CONTINUE;
515                 return;
516         }
517         if (!queue_folio_required(folio, qp))
518                 return;
519         if (!(qp->flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) ||
520             !vma_migratable(walk->vma) ||
521             !migrate_folio_add(folio, qp->pagelist, qp->flags))
522                 qp->nr_failed++;
523 }
524
525 /*
526  * Scan through folios, checking if they satisfy the required conditions,
527  * moving them from LRU to local pagelist for migration if they do (or not).
528  *
529  * queue_folios_pte_range() has two possible return values:
530  * 0 - continue walking to scan for more, even if an existing folio on the
531  *     wrong node could not be isolated and queued for migration.
532  * -EIO - only MPOL_MF_STRICT was specified, without MPOL_MF_MOVE or ..._ALL,
533  *        and an existing folio was on a node that does not follow the policy.
534  */
535 static int queue_folios_pte_range(pmd_t *pmd, unsigned long addr,
536                         unsigned long end, struct mm_walk *walk)
537 {
538         struct vm_area_struct *vma = walk->vma;
539         struct folio *folio;
540         struct queue_pages *qp = walk->private;
541         unsigned long flags = qp->flags;
542         pte_t *pte, *mapped_pte;
543         pte_t ptent;
544         spinlock_t *ptl;
545
546         ptl = pmd_trans_huge_lock(pmd, vma);
547         if (ptl) {
548                 queue_folios_pmd(pmd, walk);
549                 spin_unlock(ptl);
550                 goto out;
551         }
552
553         mapped_pte = pte = pte_offset_map_lock(walk->mm, pmd, addr, &ptl);
554         if (!pte) {
555                 walk->action = ACTION_AGAIN;
556                 return 0;
557         }
558         for (; addr != end; pte++, addr += PAGE_SIZE) {
559                 ptent = ptep_get(pte);
560                 if (pte_none(ptent))
561                         continue;
562                 if (!pte_present(ptent)) {
563                         if (is_migration_entry(pte_to_swp_entry(ptent)))
564                                 qp->nr_failed++;
565                         continue;
566                 }
567                 folio = vm_normal_folio(vma, addr, ptent);
568                 if (!folio || folio_is_zone_device(folio))
569                         continue;
570                 /*
571                  * vm_normal_folio() filters out zero pages, but there might
572                  * still be reserved folios to skip, perhaps in a VDSO.
573                  */
574                 if (folio_test_reserved(folio))
575                         continue;
576                 if (!queue_folio_required(folio, qp))
577                         continue;
578                 if (folio_test_large(folio)) {
579                         /*
580                          * A large folio can only be isolated from LRU once,
581                          * but may be mapped by many PTEs (and Copy-On-Write may
582                          * intersperse PTEs of other, order 0, folios).  This is
583                          * a common case, so don't mistake it for failure (but
584                          * there can be other cases of multi-mapped pages which
585                          * this quick check does not help to filter out - and a
586                          * search of the pagelist might grow to be prohibitive).
587                          *
588                          * migrate_pages(&pagelist) returns nr_failed folios, so
589                          * check "large" now so that queue_pages_range() returns
590                          * a comparable nr_failed folios.  This does imply that
591                          * if folio could not be isolated for some racy reason
592                          * at its first PTE, later PTEs will not give it another
593                          * chance of isolation; but keeps the accounting simple.
594                          */
595                         if (folio == qp->large)
596                                 continue;
597                         qp->large = folio;
598                 }
599                 if (!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) ||
600                     !vma_migratable(vma) ||
601                     !migrate_folio_add(folio, qp->pagelist, flags)) {
602                         qp->nr_failed++;
603                         if (strictly_unmovable(flags))
604                                 break;
605                 }
606         }
607         pte_unmap_unlock(mapped_pte, ptl);
608         cond_resched();
609 out:
610         if (qp->nr_failed && strictly_unmovable(flags))
611                 return -EIO;
612         return 0;
613 }
614
615 static int queue_folios_hugetlb(pte_t *pte, unsigned long hmask,
616                                unsigned long addr, unsigned long end,
617                                struct mm_walk *walk)
618 {
619 #ifdef CONFIG_HUGETLB_PAGE
620         struct queue_pages *qp = walk->private;
621         unsigned long flags = qp->flags;
622         struct folio *folio;
623         spinlock_t *ptl;
624         pte_t entry;
625
626         ptl = huge_pte_lock(hstate_vma(walk->vma), walk->mm, pte);
627         entry = huge_ptep_get(walk->mm, addr, pte);
628         if (!pte_present(entry)) {
629                 if (unlikely(is_hugetlb_entry_migration(entry)))
630                         qp->nr_failed++;
631                 goto unlock;
632         }
633         folio = pfn_folio(pte_pfn(entry));
634         if (!queue_folio_required(folio, qp))
635                 goto unlock;
636         if (!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)) ||
637             !vma_migratable(walk->vma)) {
638                 qp->nr_failed++;
639                 goto unlock;
640         }
641         /*
642          * Unless MPOL_MF_MOVE_ALL, we try to avoid migrating a shared folio.
643          * Choosing not to migrate a shared folio is not counted as a failure.
644          *
645          * See folio_likely_mapped_shared() on possible imprecision when we
646          * cannot easily detect if a folio is shared.
647          */
648         if ((flags & MPOL_MF_MOVE_ALL) ||
649             (!folio_likely_mapped_shared(folio) && !hugetlb_pmd_shared(pte)))
650                 if (!isolate_hugetlb(folio, qp->pagelist))
651                         qp->nr_failed++;
652 unlock:
653         spin_unlock(ptl);
654         if (qp->nr_failed && strictly_unmovable(flags))
655                 return -EIO;
656 #endif
657         return 0;
658 }
659
660 #ifdef CONFIG_NUMA_BALANCING
661 /*
662  * This is used to mark a range of virtual addresses to be inaccessible.
663  * These are later cleared by a NUMA hinting fault. Depending on these
664  * faults, pages may be migrated for better NUMA placement.
665  *
666  * This is assuming that NUMA faults are handled using PROT_NONE. If
667  * an architecture makes a different choice, it will need further
668  * changes to the core.
669  */
670 unsigned long change_prot_numa(struct vm_area_struct *vma,
671                         unsigned long addr, unsigned long end)
672 {
673         struct mmu_gather tlb;
674         long nr_updated;
675
676         tlb_gather_mmu(&tlb, vma->vm_mm);
677
678         nr_updated = change_protection(&tlb, vma, addr, end, MM_CP_PROT_NUMA);
679         if (nr_updated > 0)
680                 count_vm_numa_events(NUMA_PTE_UPDATES, nr_updated);
681
682         tlb_finish_mmu(&tlb);
683
684         return nr_updated;
685 }
686 #endif /* CONFIG_NUMA_BALANCING */
687
688 static int queue_pages_test_walk(unsigned long start, unsigned long end,
689                                 struct mm_walk *walk)
690 {
691         struct vm_area_struct *next, *vma = walk->vma;
692         struct queue_pages *qp = walk->private;
693         unsigned long flags = qp->flags;
694
695         /* range check first */
696         VM_BUG_ON_VMA(!range_in_vma(vma, start, end), vma);
697
698         if (!qp->first) {
699                 qp->first = vma;
700                 if (!(flags & MPOL_MF_DISCONTIG_OK) &&
701                         (qp->start < vma->vm_start))
702                         /* hole at head side of range */
703                         return -EFAULT;
704         }
705         next = find_vma(vma->vm_mm, vma->vm_end);
706         if (!(flags & MPOL_MF_DISCONTIG_OK) &&
707                 ((vma->vm_end < qp->end) &&
708                 (!next || vma->vm_end < next->vm_start)))
709                 /* hole at middle or tail of range */
710                 return -EFAULT;
711
712         /*
713          * Need check MPOL_MF_STRICT to return -EIO if possible
714          * regardless of vma_migratable
715          */
716         if (!vma_migratable(vma) &&
717             !(flags & MPOL_MF_STRICT))
718                 return 1;
719
720         /*
721          * Check page nodes, and queue pages to move, in the current vma.
722          * But if no moving, and no strict checking, the scan can be skipped.
723          */
724         if (flags & (MPOL_MF_STRICT | MPOL_MF_MOVE | MPOL_MF_MOVE_ALL))
725                 return 0;
726         return 1;
727 }
728
729 static const struct mm_walk_ops queue_pages_walk_ops = {
730         .hugetlb_entry          = queue_folios_hugetlb,
731         .pmd_entry              = queue_folios_pte_range,
732         .test_walk              = queue_pages_test_walk,
733         .walk_lock              = PGWALK_RDLOCK,
734 };
735
736 static const struct mm_walk_ops queue_pages_lock_vma_walk_ops = {
737         .hugetlb_entry          = queue_folios_hugetlb,
738         .pmd_entry              = queue_folios_pte_range,
739         .test_walk              = queue_pages_test_walk,
740         .walk_lock              = PGWALK_WRLOCK,
741 };
742
743 /*
744  * Walk through page tables and collect pages to be migrated.
745  *
746  * If pages found in a given range are not on the required set of @nodes,
747  * and migration is allowed, they are isolated and queued to @pagelist.
748  *
749  * queue_pages_range() may return:
750  * 0 - all pages already on the right node, or successfully queued for moving
751  *     (or neither strict checking nor moving requested: only range checking).
752  * >0 - this number of misplaced folios could not be queued for moving
753  *      (a hugetlbfs page or a transparent huge page being counted as 1).
754  * -EIO - a misplaced page found, when MPOL_MF_STRICT specified without MOVEs.
755  * -EFAULT - a hole in the memory range, when MPOL_MF_DISCONTIG_OK unspecified.
756  */
757 static long
758 queue_pages_range(struct mm_struct *mm, unsigned long start, unsigned long end,
759                 nodemask_t *nodes, unsigned long flags,
760                 struct list_head *pagelist)
761 {
762         int err;
763         struct queue_pages qp = {
764                 .pagelist = pagelist,
765                 .flags = flags,
766                 .nmask = nodes,
767                 .start = start,
768                 .end = end,
769                 .first = NULL,
770         };
771         const struct mm_walk_ops *ops = (flags & MPOL_MF_WRLOCK) ?
772                         &queue_pages_lock_vma_walk_ops : &queue_pages_walk_ops;
773
774         err = walk_page_range(mm, start, end, ops, &qp);
775
776         if (!qp.first)
777                 /* whole range in hole */
778                 err = -EFAULT;
779
780         return err ? : qp.nr_failed;
781 }
782
783 /*
784  * Apply policy to a single VMA
785  * This must be called with the mmap_lock held for writing.
786  */
787 static int vma_replace_policy(struct vm_area_struct *vma,
788                                 struct mempolicy *pol)
789 {
790         int err;
791         struct mempolicy *old;
792         struct mempolicy *new;
793
794         vma_assert_write_locked(vma);
795
796         new = mpol_dup(pol);
797         if (IS_ERR(new))
798                 return PTR_ERR(new);
799
800         if (vma->vm_ops && vma->vm_ops->set_policy) {
801                 err = vma->vm_ops->set_policy(vma, new);
802                 if (err)
803                         goto err_out;
804         }
805
806         old = vma->vm_policy;
807         vma->vm_policy = new; /* protected by mmap_lock */
808         mpol_put(old);
809
810         return 0;
811  err_out:
812         mpol_put(new);
813         return err;
814 }
815
816 /* Split or merge the VMA (if required) and apply the new policy */
817 static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
818                 struct vm_area_struct **prev, unsigned long start,
819                 unsigned long end, struct mempolicy *new_pol)
820 {
821         unsigned long vmstart, vmend;
822
823         vmend = min(end, vma->vm_end);
824         if (start > vma->vm_start) {
825                 *prev = vma;
826                 vmstart = start;
827         } else {
828                 vmstart = vma->vm_start;
829         }
830
831         if (mpol_equal(vma->vm_policy, new_pol)) {
832                 *prev = vma;
833                 return 0;
834         }
835
836         vma =  vma_modify_policy(vmi, *prev, vma, vmstart, vmend, new_pol);
837         if (IS_ERR(vma))
838                 return PTR_ERR(vma);
839
840         *prev = vma;
841         return vma_replace_policy(vma, new_pol);
842 }
843
844 /* Set the process memory policy */
845 static long do_set_mempolicy(unsigned short mode, unsigned short flags,
846                              nodemask_t *nodes)
847 {
848         struct mempolicy *new, *old;
849         NODEMASK_SCRATCH(scratch);
850         int ret;
851
852         if (!scratch)
853                 return -ENOMEM;
854
855         new = mpol_new(mode, flags, nodes);
856         if (IS_ERR(new)) {
857                 ret = PTR_ERR(new);
858                 goto out;
859         }
860
861         task_lock(current);
862         ret = mpol_set_nodemask(new, nodes, scratch);
863         if (ret) {
864                 task_unlock(current);
865                 mpol_put(new);
866                 goto out;
867         }
868
869         old = current->mempolicy;
870         current->mempolicy = new;
871         if (new && (new->mode == MPOL_INTERLEAVE ||
872                     new->mode == MPOL_WEIGHTED_INTERLEAVE)) {
873                 current->il_prev = MAX_NUMNODES-1;
874                 current->il_weight = 0;
875         }
876         task_unlock(current);
877         mpol_put(old);
878         ret = 0;
879 out:
880         NODEMASK_SCRATCH_FREE(scratch);
881         return ret;
882 }
883
884 /*
885  * Return nodemask for policy for get_mempolicy() query
886  *
887  * Called with task's alloc_lock held
888  */
889 static void get_policy_nodemask(struct mempolicy *pol, nodemask_t *nodes)
890 {
891         nodes_clear(*nodes);
892         if (pol == &default_policy)
893                 return;
894
895         switch (pol->mode) {
896         case MPOL_BIND:
897         case MPOL_INTERLEAVE:
898         case MPOL_PREFERRED:
899         case MPOL_PREFERRED_MANY:
900         case MPOL_WEIGHTED_INTERLEAVE:
901                 *nodes = pol->nodes;
902                 break;
903         case MPOL_LOCAL:
904                 /* return empty node mask for local allocation */
905                 break;
906         default:
907                 BUG();
908         }
909 }
910
911 static int lookup_node(struct mm_struct *mm, unsigned long addr)
912 {
913         struct page *p = NULL;
914         int ret;
915
916         ret = get_user_pages_fast(addr & PAGE_MASK, 1, 0, &p);
917         if (ret > 0) {
918                 ret = page_to_nid(p);
919                 put_page(p);
920         }
921         return ret;
922 }
923
924 /* Retrieve NUMA policy */
925 static long do_get_mempolicy(int *policy, nodemask_t *nmask,
926                              unsigned long addr, unsigned long flags)
927 {
928         int err;
929         struct mm_struct *mm = current->mm;
930         struct vm_area_struct *vma = NULL;
931         struct mempolicy *pol = current->mempolicy, *pol_refcount = NULL;
932
933         if (flags &
934                 ~(unsigned long)(MPOL_F_NODE|MPOL_F_ADDR|MPOL_F_MEMS_ALLOWED))
935                 return -EINVAL;
936
937         if (flags & MPOL_F_MEMS_ALLOWED) {
938                 if (flags & (MPOL_F_NODE|MPOL_F_ADDR))
939                         return -EINVAL;
940                 *policy = 0;    /* just so it's initialized */
941                 task_lock(current);
942                 *nmask  = cpuset_current_mems_allowed;
943                 task_unlock(current);
944                 return 0;
945         }
946
947         if (flags & MPOL_F_ADDR) {
948                 pgoff_t ilx;            /* ignored here */
949                 /*
950                  * Do NOT fall back to task policy if the
951                  * vma/shared policy at addr is NULL.  We
952                  * want to return MPOL_DEFAULT in this case.
953                  */
954                 mmap_read_lock(mm);
955                 vma = vma_lookup(mm, addr);
956                 if (!vma) {
957                         mmap_read_unlock(mm);
958                         return -EFAULT;
959                 }
960                 pol = __get_vma_policy(vma, addr, &ilx);
961         } else if (addr)
962                 return -EINVAL;
963
964         if (!pol)
965                 pol = &default_policy;  /* indicates default behavior */
966
967         if (flags & MPOL_F_NODE) {
968                 if (flags & MPOL_F_ADDR) {
969                         /*
970                          * Take a refcount on the mpol, because we are about to
971                          * drop the mmap_lock, after which only "pol" remains
972                          * valid, "vma" is stale.
973                          */
974                         pol_refcount = pol;
975                         vma = NULL;
976                         mpol_get(pol);
977                         mmap_read_unlock(mm);
978                         err = lookup_node(mm, addr);
979                         if (err < 0)
980                                 goto out;
981                         *policy = err;
982                 } else if (pol == current->mempolicy &&
983                                 pol->mode == MPOL_INTERLEAVE) {
984                         *policy = next_node_in(current->il_prev, pol->nodes);
985                 } else if (pol == current->mempolicy &&
986                                 pol->mode == MPOL_WEIGHTED_INTERLEAVE) {
987                         if (current->il_weight)
988                                 *policy = current->il_prev;
989                         else
990                                 *policy = next_node_in(current->il_prev,
991                                                        pol->nodes);
992                 } else {
993                         err = -EINVAL;
994                         goto out;
995                 }
996         } else {
997                 *policy = pol == &default_policy ? MPOL_DEFAULT :
998                                                 pol->mode;
999                 /*
1000                  * Internal mempolicy flags must be masked off before exposing
1001                  * the policy to userspace.
1002                  */
1003                 *policy |= (pol->flags & MPOL_MODE_FLAGS);
1004         }
1005
1006         err = 0;
1007         if (nmask) {
1008                 if (mpol_store_user_nodemask(pol)) {
1009                         *nmask = pol->w.user_nodemask;
1010                 } else {
1011                         task_lock(current);
1012                         get_policy_nodemask(pol, nmask);
1013                         task_unlock(current);
1014                 }
1015         }
1016
1017  out:
1018         mpol_cond_put(pol);
1019         if (vma)
1020                 mmap_read_unlock(mm);
1021         if (pol_refcount)
1022                 mpol_put(pol_refcount);
1023         return err;
1024 }
1025
1026 #ifdef CONFIG_MIGRATION
1027 static bool migrate_folio_add(struct folio *folio, struct list_head *foliolist,
1028                                 unsigned long flags)
1029 {
1030         /*
1031          * Unless MPOL_MF_MOVE_ALL, we try to avoid migrating a shared folio.
1032          * Choosing not to migrate a shared folio is not counted as a failure.
1033          *
1034          * See folio_likely_mapped_shared() on possible imprecision when we
1035          * cannot easily detect if a folio is shared.
1036          */
1037         if ((flags & MPOL_MF_MOVE_ALL) || !folio_likely_mapped_shared(folio)) {
1038                 if (folio_isolate_lru(folio)) {
1039                         list_add_tail(&folio->lru, foliolist);
1040                         node_stat_mod_folio(folio,
1041                                 NR_ISOLATED_ANON + folio_is_file_lru(folio),
1042                                 folio_nr_pages(folio));
1043                 } else {
1044                         /*
1045                          * Non-movable folio may reach here.  And, there may be
1046                          * temporary off LRU folios or non-LRU movable folios.
1047                          * Treat them as unmovable folios since they can't be
1048                          * isolated, so they can't be moved at the moment.
1049                          */
1050                         return false;
1051                 }
1052         }
1053         return true;
1054 }
1055
1056 /*
1057  * Migrate pages from one node to a target node.
1058  * Returns error or the number of pages not migrated.
1059  */
1060 static long migrate_to_node(struct mm_struct *mm, int source, int dest,
1061                             int flags)
1062 {
1063         nodemask_t nmask;
1064         struct vm_area_struct *vma;
1065         LIST_HEAD(pagelist);
1066         long nr_failed;
1067         long err = 0;
1068         struct migration_target_control mtc = {
1069                 .nid = dest,
1070                 .gfp_mask = GFP_HIGHUSER_MOVABLE | __GFP_THISNODE,
1071                 .reason = MR_SYSCALL,
1072         };
1073
1074         nodes_clear(nmask);
1075         node_set(source, nmask);
1076
1077         VM_BUG_ON(!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)));
1078
1079         mmap_read_lock(mm);
1080         vma = find_vma(mm, 0);
1081
1082         /*
1083          * This does not migrate the range, but isolates all pages that
1084          * need migration.  Between passing in the full user address
1085          * space range and MPOL_MF_DISCONTIG_OK, this call cannot fail,
1086          * but passes back the count of pages which could not be isolated.
1087          */
1088         nr_failed = queue_pages_range(mm, vma->vm_start, mm->task_size, &nmask,
1089                                       flags | MPOL_MF_DISCONTIG_OK, &pagelist);
1090         mmap_read_unlock(mm);
1091
1092         if (!list_empty(&pagelist)) {
1093                 err = migrate_pages(&pagelist, alloc_migration_target, NULL,
1094                         (unsigned long)&mtc, MIGRATE_SYNC, MR_SYSCALL, NULL);
1095                 if (err)
1096                         putback_movable_pages(&pagelist);
1097         }
1098
1099         if (err >= 0)
1100                 err += nr_failed;
1101         return err;
1102 }
1103
1104 /*
1105  * Move pages between the two nodesets so as to preserve the physical
1106  * layout as much as possible.
1107  *
1108  * Returns the number of page that could not be moved.
1109  */
1110 int do_migrate_pages(struct mm_struct *mm, const nodemask_t *from,
1111                      const nodemask_t *to, int flags)
1112 {
1113         long nr_failed = 0;
1114         long err = 0;
1115         nodemask_t tmp;
1116
1117         lru_cache_disable();
1118
1119         /*
1120          * Find a 'source' bit set in 'tmp' whose corresponding 'dest'
1121          * bit in 'to' is not also set in 'tmp'.  Clear the found 'source'
1122          * bit in 'tmp', and return that <source, dest> pair for migration.
1123          * The pair of nodemasks 'to' and 'from' define the map.
1124          *
1125          * If no pair of bits is found that way, fallback to picking some
1126          * pair of 'source' and 'dest' bits that are not the same.  If the
1127          * 'source' and 'dest' bits are the same, this represents a node
1128          * that will be migrating to itself, so no pages need move.
1129          *
1130          * If no bits are left in 'tmp', or if all remaining bits left
1131          * in 'tmp' correspond to the same bit in 'to', return false
1132          * (nothing left to migrate).
1133          *
1134          * This lets us pick a pair of nodes to migrate between, such that
1135          * if possible the dest node is not already occupied by some other
1136          * source node, minimizing the risk of overloading the memory on a
1137          * node that would happen if we migrated incoming memory to a node
1138          * before migrating outgoing memory source that same node.
1139          *
1140          * A single scan of tmp is sufficient.  As we go, we remember the
1141          * most recent <s, d> pair that moved (s != d).  If we find a pair
1142          * that not only moved, but what's better, moved to an empty slot
1143          * (d is not set in tmp), then we break out then, with that pair.
1144          * Otherwise when we finish scanning from_tmp, we at least have the
1145          * most recent <s, d> pair that moved.  If we get all the way through
1146          * the scan of tmp without finding any node that moved, much less
1147          * moved to an empty node, then there is nothing left worth migrating.
1148          */
1149
1150         tmp = *from;
1151         while (!nodes_empty(tmp)) {
1152                 int s, d;
1153                 int source = NUMA_NO_NODE;
1154                 int dest = 0;
1155
1156                 for_each_node_mask(s, tmp) {
1157
1158                         /*
1159                          * do_migrate_pages() tries to maintain the relative
1160                          * node relationship of the pages established between
1161                          * threads and memory areas.
1162                          *
1163                          * However if the number of source nodes is not equal to
1164                          * the number of destination nodes we can not preserve
1165                          * this node relative relationship.  In that case, skip
1166                          * copying memory from a node that is in the destination
1167                          * mask.
1168                          *
1169                          * Example: [2,3,4] -> [3,4,5] moves everything.
1170                          *          [0-7] - > [3,4,5] moves only 0,1,2,6,7.
1171                          */
1172
1173                         if ((nodes_weight(*from) != nodes_weight(*to)) &&
1174                                                 (node_isset(s, *to)))
1175                                 continue;
1176
1177                         d = node_remap(s, *from, *to);
1178                         if (s == d)
1179                                 continue;
1180
1181                         source = s;     /* Node moved. Memorize */
1182                         dest = d;
1183
1184                         /* dest not in remaining from nodes? */
1185                         if (!node_isset(dest, tmp))
1186                                 break;
1187                 }
1188                 if (source == NUMA_NO_NODE)
1189                         break;
1190
1191                 node_clear(source, tmp);
1192                 err = migrate_to_node(mm, source, dest, flags);
1193                 if (err > 0)
1194                         nr_failed += err;
1195                 if (err < 0)
1196                         break;
1197         }
1198
1199         lru_cache_enable();
1200         if (err < 0)
1201                 return err;
1202         return (nr_failed < INT_MAX) ? nr_failed : INT_MAX;
1203 }
1204
1205 /*
1206  * Allocate a new folio for page migration, according to NUMA mempolicy.
1207  */
1208 static struct folio *alloc_migration_target_by_mpol(struct folio *src,
1209                                                     unsigned long private)
1210 {
1211         struct migration_mpol *mmpol = (struct migration_mpol *)private;
1212         struct mempolicy *pol = mmpol->pol;
1213         pgoff_t ilx = mmpol->ilx;
1214         unsigned int order;
1215         int nid = numa_node_id();
1216         gfp_t gfp;
1217
1218         order = folio_order(src);
1219         ilx += src->index >> order;
1220
1221         if (folio_test_hugetlb(src)) {
1222                 nodemask_t *nodemask;
1223                 struct hstate *h;
1224
1225                 h = folio_hstate(src);
1226                 gfp = htlb_alloc_mask(h);
1227                 nodemask = policy_nodemask(gfp, pol, ilx, &nid);
1228                 return alloc_hugetlb_folio_nodemask(h, nid, nodemask, gfp,
1229                                 htlb_allow_alloc_fallback(MR_MEMPOLICY_MBIND));
1230         }
1231
1232         if (folio_test_large(src))
1233                 gfp = GFP_TRANSHUGE;
1234         else
1235                 gfp = GFP_HIGHUSER_MOVABLE | __GFP_RETRY_MAYFAIL | __GFP_COMP;
1236
1237         return folio_alloc_mpol(gfp, order, pol, ilx, nid);
1238 }
1239 #else
1240
1241 static bool migrate_folio_add(struct folio *folio, struct list_head *foliolist,
1242                                 unsigned long flags)
1243 {
1244         return false;
1245 }
1246
1247 int do_migrate_pages(struct mm_struct *mm, const nodemask_t *from,
1248                      const nodemask_t *to, int flags)
1249 {
1250         return -ENOSYS;
1251 }
1252
1253 static struct folio *alloc_migration_target_by_mpol(struct folio *src,
1254                                                     unsigned long private)
1255 {
1256         return NULL;
1257 }
1258 #endif
1259
1260 static long do_mbind(unsigned long start, unsigned long len,
1261                      unsigned short mode, unsigned short mode_flags,
1262                      nodemask_t *nmask, unsigned long flags)
1263 {
1264         struct mm_struct *mm = current->mm;
1265         struct vm_area_struct *vma, *prev;
1266         struct vma_iterator vmi;
1267         struct migration_mpol mmpol;
1268         struct mempolicy *new;
1269         unsigned long end;
1270         long err;
1271         long nr_failed;
1272         LIST_HEAD(pagelist);
1273
1274         if (flags & ~(unsigned long)MPOL_MF_VALID)
1275                 return -EINVAL;
1276         if ((flags & MPOL_MF_MOVE_ALL) && !capable(CAP_SYS_NICE))
1277                 return -EPERM;
1278
1279         if (start & ~PAGE_MASK)
1280                 return -EINVAL;
1281
1282         if (mode == MPOL_DEFAULT)
1283                 flags &= ~MPOL_MF_STRICT;
1284
1285         len = PAGE_ALIGN(len);
1286         end = start + len;
1287
1288         if (end < start)
1289                 return -EINVAL;
1290         if (end == start)
1291                 return 0;
1292
1293         new = mpol_new(mode, mode_flags, nmask);
1294         if (IS_ERR(new))
1295                 return PTR_ERR(new);
1296
1297         /*
1298          * If we are using the default policy then operation
1299          * on discontinuous address spaces is okay after all
1300          */
1301         if (!new)
1302                 flags |= MPOL_MF_DISCONTIG_OK;
1303
1304         if (flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL))
1305                 lru_cache_disable();
1306         {
1307                 NODEMASK_SCRATCH(scratch);
1308                 if (scratch) {
1309                         mmap_write_lock(mm);
1310                         err = mpol_set_nodemask(new, nmask, scratch);
1311                         if (err)
1312                                 mmap_write_unlock(mm);
1313                 } else
1314                         err = -ENOMEM;
1315                 NODEMASK_SCRATCH_FREE(scratch);
1316         }
1317         if (err)
1318                 goto mpol_out;
1319
1320         /*
1321          * Lock the VMAs before scanning for pages to migrate,
1322          * to ensure we don't miss a concurrently inserted page.
1323          */
1324         nr_failed = queue_pages_range(mm, start, end, nmask,
1325                         flags | MPOL_MF_INVERT | MPOL_MF_WRLOCK, &pagelist);
1326
1327         if (nr_failed < 0) {
1328                 err = nr_failed;
1329                 nr_failed = 0;
1330         } else {
1331                 vma_iter_init(&vmi, mm, start);
1332                 prev = vma_prev(&vmi);
1333                 for_each_vma_range(vmi, vma, end) {
1334                         err = mbind_range(&vmi, vma, &prev, start, end, new);
1335                         if (err)
1336                                 break;
1337                 }
1338         }
1339
1340         if (!err && !list_empty(&pagelist)) {
1341                 /* Convert MPOL_DEFAULT's NULL to task or default policy */
1342                 if (!new) {
1343                         new = get_task_policy(current);
1344                         mpol_get(new);
1345                 }
1346                 mmpol.pol = new;
1347                 mmpol.ilx = 0;
1348
1349                 /*
1350                  * In the interleaved case, attempt to allocate on exactly the
1351                  * targeted nodes, for the first VMA to be migrated; for later
1352                  * VMAs, the nodes will still be interleaved from the targeted
1353                  * nodemask, but one by one may be selected differently.
1354                  */
1355                 if (new->mode == MPOL_INTERLEAVE ||
1356                     new->mode == MPOL_WEIGHTED_INTERLEAVE) {
1357                         struct folio *folio;
1358                         unsigned int order;
1359                         unsigned long addr = -EFAULT;
1360
1361                         list_for_each_entry(folio, &pagelist, lru) {
1362                                 if (!folio_test_ksm(folio))
1363                                         break;
1364                         }
1365                         if (!list_entry_is_head(folio, &pagelist, lru)) {
1366                                 vma_iter_init(&vmi, mm, start);
1367                                 for_each_vma_range(vmi, vma, end) {
1368                                         addr = page_address_in_vma(
1369                                                 folio_page(folio, 0), vma);
1370                                         if (addr != -EFAULT)
1371                                                 break;
1372                                 }
1373                         }
1374                         if (addr != -EFAULT) {
1375                                 order = folio_order(folio);
1376                                 /* We already know the pol, but not the ilx */
1377                                 mpol_cond_put(get_vma_policy(vma, addr, order,
1378                                                              &mmpol.ilx));
1379                                 /* Set base from which to increment by index */
1380                                 mmpol.ilx -= folio->index >> order;
1381                         }
1382                 }
1383         }
1384
1385         mmap_write_unlock(mm);
1386
1387         if (!err && !list_empty(&pagelist)) {
1388                 nr_failed |= migrate_pages(&pagelist,
1389                                 alloc_migration_target_by_mpol, NULL,
1390                                 (unsigned long)&mmpol, MIGRATE_SYNC,
1391                                 MR_MEMPOLICY_MBIND, NULL);
1392         }
1393
1394         if (nr_failed && (flags & MPOL_MF_STRICT))
1395                 err = -EIO;
1396         if (!list_empty(&pagelist))
1397                 putback_movable_pages(&pagelist);
1398 mpol_out:
1399         mpol_put(new);
1400         if (flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL))
1401                 lru_cache_enable();
1402         return err;
1403 }
1404
1405 /*
1406  * User space interface with variable sized bitmaps for nodelists.
1407  */
1408 static int get_bitmap(unsigned long *mask, const unsigned long __user *nmask,
1409                       unsigned long maxnode)
1410 {
1411         unsigned long nlongs = BITS_TO_LONGS(maxnode);
1412         int ret;
1413
1414         if (in_compat_syscall())
1415                 ret = compat_get_bitmap(mask,
1416                                         (const compat_ulong_t __user *)nmask,
1417                                         maxnode);
1418         else
1419                 ret = copy_from_user(mask, nmask,
1420                                      nlongs * sizeof(unsigned long));
1421
1422         if (ret)
1423                 return -EFAULT;
1424
1425         if (maxnode % BITS_PER_LONG)
1426                 mask[nlongs - 1] &= (1UL << (maxnode % BITS_PER_LONG)) - 1;
1427
1428         return 0;
1429 }
1430
1431 /* Copy a node mask from user space. */
1432 static int get_nodes(nodemask_t *nodes, const unsigned long __user *nmask,
1433                      unsigned long maxnode)
1434 {
1435         --maxnode;
1436         nodes_clear(*nodes);
1437         if (maxnode == 0 || !nmask)
1438                 return 0;
1439         if (maxnode > PAGE_SIZE*BITS_PER_BYTE)
1440                 return -EINVAL;
1441
1442         /*
1443          * When the user specified more nodes than supported just check
1444          * if the non supported part is all zero, one word at a time,
1445          * starting at the end.
1446          */
1447         while (maxnode > MAX_NUMNODES) {
1448                 unsigned long bits = min_t(unsigned long, maxnode, BITS_PER_LONG);
1449                 unsigned long t;
1450
1451                 if (get_bitmap(&t, &nmask[(maxnode - 1) / BITS_PER_LONG], bits))
1452                         return -EFAULT;
1453
1454                 if (maxnode - bits >= MAX_NUMNODES) {
1455                         maxnode -= bits;
1456                 } else {
1457                         maxnode = MAX_NUMNODES;
1458                         t &= ~((1UL << (MAX_NUMNODES % BITS_PER_LONG)) - 1);
1459                 }
1460                 if (t)
1461                         return -EINVAL;
1462         }
1463
1464         return get_bitmap(nodes_addr(*nodes), nmask, maxnode);
1465 }
1466
1467 /* Copy a kernel node mask to user space */
1468 static int copy_nodes_to_user(unsigned long __user *mask, unsigned long maxnode,
1469                               nodemask_t *nodes)
1470 {
1471         unsigned long copy = ALIGN(maxnode-1, 64) / 8;
1472         unsigned int nbytes = BITS_TO_LONGS(nr_node_ids) * sizeof(long);
1473         bool compat = in_compat_syscall();
1474
1475         if (compat)
1476                 nbytes = BITS_TO_COMPAT_LONGS(nr_node_ids) * sizeof(compat_long_t);
1477
1478         if (copy > nbytes) {
1479                 if (copy > PAGE_SIZE)
1480                         return -EINVAL;
1481                 if (clear_user((char __user *)mask + nbytes, copy - nbytes))
1482                         return -EFAULT;
1483                 copy = nbytes;
1484                 maxnode = nr_node_ids;
1485         }
1486
1487         if (compat)
1488                 return compat_put_bitmap((compat_ulong_t __user *)mask,
1489                                          nodes_addr(*nodes), maxnode);
1490
1491         return copy_to_user(mask, nodes_addr(*nodes), copy) ? -EFAULT : 0;
1492 }
1493
1494 /* Basic parameter sanity check used by both mbind() and set_mempolicy() */
1495 static inline int sanitize_mpol_flags(int *mode, unsigned short *flags)
1496 {
1497         *flags = *mode & MPOL_MODE_FLAGS;
1498         *mode &= ~MPOL_MODE_FLAGS;
1499
1500         if ((unsigned int)(*mode) >=  MPOL_MAX)
1501                 return -EINVAL;
1502         if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES))
1503                 return -EINVAL;
1504         if (*flags & MPOL_F_NUMA_BALANCING) {
1505                 if (*mode == MPOL_BIND || *mode == MPOL_PREFERRED_MANY)
1506                         *flags |= (MPOL_F_MOF | MPOL_F_MORON);
1507                 else
1508                         return -EINVAL;
1509         }
1510         return 0;
1511 }
1512
1513 static long kernel_mbind(unsigned long start, unsigned long len,
1514                          unsigned long mode, const unsigned long __user *nmask,
1515                          unsigned long maxnode, unsigned int flags)
1516 {
1517         unsigned short mode_flags;
1518         nodemask_t nodes;
1519         int lmode = mode;
1520         int err;
1521
1522         start = untagged_addr(start);
1523         err = sanitize_mpol_flags(&lmode, &mode_flags);
1524         if (err)
1525                 return err;
1526
1527         err = get_nodes(&nodes, nmask, maxnode);
1528         if (err)
1529                 return err;
1530
1531         return do_mbind(start, len, lmode, mode_flags, &nodes, flags);
1532 }
1533
1534 SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, len,
1535                 unsigned long, home_node, unsigned long, flags)
1536 {
1537         struct mm_struct *mm = current->mm;
1538         struct vm_area_struct *vma, *prev;
1539         struct mempolicy *new, *old;
1540         unsigned long end;
1541         int err = -ENOENT;
1542         VMA_ITERATOR(vmi, mm, start);
1543
1544         start = untagged_addr(start);
1545         if (start & ~PAGE_MASK)
1546                 return -EINVAL;
1547         /*
1548          * flags is used for future extension if any.
1549          */
1550         if (flags != 0)
1551                 return -EINVAL;
1552
1553         /*
1554          * Check home_node is online to avoid accessing uninitialized
1555          * NODE_DATA.
1556          */
1557         if (home_node >= MAX_NUMNODES || !node_online(home_node))
1558                 return -EINVAL;
1559
1560         len = PAGE_ALIGN(len);
1561         end = start + len;
1562
1563         if (end < start)
1564                 return -EINVAL;
1565         if (end == start)
1566                 return 0;
1567         mmap_write_lock(mm);
1568         prev = vma_prev(&vmi);
1569         for_each_vma_range(vmi, vma, end) {
1570                 /*
1571                  * If any vma in the range got policy other than MPOL_BIND
1572                  * or MPOL_PREFERRED_MANY we return error. We don't reset
1573                  * the home node for vmas we already updated before.
1574                  */
1575                 old = vma_policy(vma);
1576                 if (!old) {
1577                         prev = vma;
1578                         continue;
1579                 }
1580                 if (old->mode != MPOL_BIND && old->mode != MPOL_PREFERRED_MANY) {
1581                         err = -EOPNOTSUPP;
1582                         break;
1583                 }
1584                 new = mpol_dup(old);
1585                 if (IS_ERR(new)) {
1586                         err = PTR_ERR(new);
1587                         break;
1588                 }
1589
1590                 vma_start_write(vma);
1591                 new->home_node = home_node;
1592                 err = mbind_range(&vmi, vma, &prev, start, end, new);
1593                 mpol_put(new);
1594                 if (err)
1595                         break;
1596         }
1597         mmap_write_unlock(mm);
1598         return err;
1599 }
1600
1601 SYSCALL_DEFINE6(mbind, unsigned long, start, unsigned long, len,
1602                 unsigned long, mode, const unsigned long __user *, nmask,
1603                 unsigned long, maxnode, unsigned int, flags)
1604 {
1605         return kernel_mbind(start, len, mode, nmask, maxnode, flags);
1606 }
1607
1608 /* Set the process memory policy */
1609 static long kernel_set_mempolicy(int mode, const unsigned long __user *nmask,
1610                                  unsigned long maxnode)
1611 {
1612         unsigned short mode_flags;
1613         nodemask_t nodes;
1614         int lmode = mode;
1615         int err;
1616
1617         err = sanitize_mpol_flags(&lmode, &mode_flags);
1618         if (err)
1619                 return err;
1620
1621         err = get_nodes(&nodes, nmask, maxnode);
1622         if (err)
1623                 return err;
1624
1625         return do_set_mempolicy(lmode, mode_flags, &nodes);
1626 }
1627
1628 SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,
1629                 unsigned long, maxnode)
1630 {
1631         return kernel_set_mempolicy(mode, nmask, maxnode);
1632 }
1633
1634 static int kernel_migrate_pages(pid_t pid, unsigned long maxnode,
1635                                 const unsigned long __user *old_nodes,
1636                                 const unsigned long __user *new_nodes)
1637 {
1638         struct mm_struct *mm = NULL;
1639         struct task_struct *task;
1640         nodemask_t task_nodes;
1641         int err;
1642         nodemask_t *old;
1643         nodemask_t *new;
1644         NODEMASK_SCRATCH(scratch);
1645
1646         if (!scratch)
1647                 return -ENOMEM;
1648
1649         old = &scratch->mask1;
1650         new = &scratch->mask2;
1651
1652         err = get_nodes(old, old_nodes, maxnode);
1653         if (err)
1654                 goto out;
1655
1656         err = get_nodes(new, new_nodes, maxnode);
1657         if (err)
1658                 goto out;
1659
1660         /* Find the mm_struct */
1661         rcu_read_lock();
1662         task = pid ? find_task_by_vpid(pid) : current;
1663         if (!task) {
1664                 rcu_read_unlock();
1665                 err = -ESRCH;
1666                 goto out;
1667         }
1668         get_task_struct(task);
1669
1670         err = -EINVAL;
1671
1672         /*
1673          * Check if this process has the right to modify the specified process.
1674          * Use the regular "ptrace_may_access()" checks.
1675          */
1676         if (!ptrace_may_access(task, PTRACE_MODE_READ_REALCREDS)) {
1677                 rcu_read_unlock();
1678                 err = -EPERM;
1679                 goto out_put;
1680         }
1681         rcu_read_unlock();
1682
1683         task_nodes = cpuset_mems_allowed(task);
1684         /* Is the user allowed to access the target nodes? */
1685         if (!nodes_subset(*new, task_nodes) && !capable(CAP_SYS_NICE)) {
1686                 err = -EPERM;
1687                 goto out_put;
1688         }
1689
1690         task_nodes = cpuset_mems_allowed(current);
1691         nodes_and(*new, *new, task_nodes);
1692         if (nodes_empty(*new))
1693                 goto out_put;
1694
1695         err = security_task_movememory(task);
1696         if (err)
1697                 goto out_put;
1698
1699         mm = get_task_mm(task);
1700         put_task_struct(task);
1701
1702         if (!mm) {
1703                 err = -EINVAL;
1704                 goto out;
1705         }
1706
1707         err = do_migrate_pages(mm, old, new,
1708                 capable(CAP_SYS_NICE) ? MPOL_MF_MOVE_ALL : MPOL_MF_MOVE);
1709
1710         mmput(mm);
1711 out:
1712         NODEMASK_SCRATCH_FREE(scratch);
1713
1714         return err;
1715
1716 out_put:
1717         put_task_struct(task);
1718         goto out;
1719 }
1720
1721 SYSCALL_DEFINE4(migrate_pages, pid_t, pid, unsigned long, maxnode,
1722                 const unsigned long __user *, old_nodes,
1723                 const unsigned long __user *, new_nodes)
1724 {
1725         return kernel_migrate_pages(pid, maxnode, old_nodes, new_nodes);
1726 }
1727
1728 /* Retrieve NUMA policy */
1729 static int kernel_get_mempolicy(int __user *policy,
1730                                 unsigned long __user *nmask,
1731                                 unsigned long maxnode,
1732                                 unsigned long addr,
1733                                 unsigned long flags)
1734 {
1735         int err;
1736         int pval;
1737         nodemask_t nodes;
1738
1739         if (nmask != NULL && maxnode < nr_node_ids)
1740                 return -EINVAL;
1741
1742         addr = untagged_addr(addr);
1743
1744         err = do_get_mempolicy(&pval, &nodes, addr, flags);
1745
1746         if (err)
1747                 return err;
1748
1749         if (policy && put_user(pval, policy))
1750                 return -EFAULT;
1751
1752         if (nmask)
1753                 err = copy_nodes_to_user(nmask, maxnode, &nodes);
1754
1755         return err;
1756 }
1757
1758 SYSCALL_DEFINE5(get_mempolicy, int __user *, policy,
1759                 unsigned long __user *, nmask, unsigned long, maxnode,
1760                 unsigned long, addr, unsigned long, flags)
1761 {
1762         return kernel_get_mempolicy(policy, nmask, maxnode, addr, flags);
1763 }
1764
1765 bool vma_migratable(struct vm_area_struct *vma)
1766 {
1767         if (vma->vm_flags & (VM_IO | VM_PFNMAP))
1768                 return false;
1769
1770         /*
1771          * DAX device mappings require predictable access latency, so avoid
1772          * incurring periodic faults.
1773          */
1774         if (vma_is_dax(vma))
1775                 return false;
1776
1777         if (is_vm_hugetlb_page(vma) &&
1778                 !hugepage_migration_supported(hstate_vma(vma)))
1779                 return false;
1780
1781         /*
1782          * Migration allocates pages in the highest zone. If we cannot
1783          * do so then migration (at least from node to node) is not
1784          * possible.
1785          */
1786         if (vma->vm_file &&
1787                 gfp_zone(mapping_gfp_mask(vma->vm_file->f_mapping))
1788                         < policy_zone)
1789                 return false;
1790         return true;
1791 }
1792
1793 struct mempolicy *__get_vma_policy(struct vm_area_struct *vma,
1794                                    unsigned long addr, pgoff_t *ilx)
1795 {
1796         *ilx = 0;
1797         return (vma->vm_ops && vma->vm_ops->get_policy) ?
1798                 vma->vm_ops->get_policy(vma, addr, ilx) : vma->vm_policy;
1799 }
1800
1801 /*
1802  * get_vma_policy(@vma, @addr, @order, @ilx)
1803  * @vma: virtual memory area whose policy is sought
1804  * @addr: address in @vma for shared policy lookup
1805  * @order: 0, or appropriate huge_page_order for interleaving
1806  * @ilx: interleave index (output), for use only when MPOL_INTERLEAVE or
1807  *       MPOL_WEIGHTED_INTERLEAVE
1808  *
1809  * Returns effective policy for a VMA at specified address.
1810  * Falls back to current->mempolicy or system default policy, as necessary.
1811  * Shared policies [those marked as MPOL_F_SHARED] require an extra reference
1812  * count--added by the get_policy() vm_op, as appropriate--to protect against
1813  * freeing by another task.  It is the caller's responsibility to free the
1814  * extra reference for shared policies.
1815  */
1816 struct mempolicy *get_vma_policy(struct vm_area_struct *vma,
1817                                  unsigned long addr, int order, pgoff_t *ilx)
1818 {
1819         struct mempolicy *pol;
1820
1821         pol = __get_vma_policy(vma, addr, ilx);
1822         if (!pol)
1823                 pol = get_task_policy(current);
1824         if (pol->mode == MPOL_INTERLEAVE ||
1825             pol->mode == MPOL_WEIGHTED_INTERLEAVE) {
1826                 *ilx += vma->vm_pgoff >> order;
1827                 *ilx += (addr - vma->vm_start) >> (PAGE_SHIFT + order);
1828         }
1829         return pol;
1830 }
1831
1832 bool vma_policy_mof(struct vm_area_struct *vma)
1833 {
1834         struct mempolicy *pol;
1835
1836         if (vma->vm_ops && vma->vm_ops->get_policy) {
1837                 bool ret = false;
1838                 pgoff_t ilx;            /* ignored here */
1839
1840                 pol = vma->vm_ops->get_policy(vma, vma->vm_start, &ilx);
1841                 if (pol && (pol->flags & MPOL_F_MOF))
1842                         ret = true;
1843                 mpol_cond_put(pol);
1844
1845                 return ret;
1846         }
1847
1848         pol = vma->vm_policy;
1849         if (!pol)
1850                 pol = get_task_policy(current);
1851
1852         return pol->flags & MPOL_F_MOF;
1853 }
1854
1855 bool apply_policy_zone(struct mempolicy *policy, enum zone_type zone)
1856 {
1857         enum zone_type dynamic_policy_zone = policy_zone;
1858
1859         BUG_ON(dynamic_policy_zone == ZONE_MOVABLE);
1860
1861         /*
1862          * if policy->nodes has movable memory only,
1863          * we apply policy when gfp_zone(gfp) = ZONE_MOVABLE only.
1864          *
1865          * policy->nodes is intersect with node_states[N_MEMORY].
1866          * so if the following test fails, it implies
1867          * policy->nodes has movable memory only.
1868          */
1869         if (!nodes_intersects(policy->nodes, node_states[N_HIGH_MEMORY]))
1870                 dynamic_policy_zone = ZONE_MOVABLE;
1871
1872         return zone >= dynamic_policy_zone;
1873 }
1874
1875 static unsigned int weighted_interleave_nodes(struct mempolicy *policy)
1876 {
1877         unsigned int node;
1878         unsigned int cpuset_mems_cookie;
1879
1880 retry:
1881         /* to prevent miscount use tsk->mems_allowed_seq to detect rebind */
1882         cpuset_mems_cookie = read_mems_allowed_begin();
1883         node = current->il_prev;
1884         if (!current->il_weight || !node_isset(node, policy->nodes)) {
1885                 node = next_node_in(node, policy->nodes);
1886                 if (read_mems_allowed_retry(cpuset_mems_cookie))
1887                         goto retry;
1888                 if (node == MAX_NUMNODES)
1889                         return node;
1890                 current->il_prev = node;
1891                 current->il_weight = get_il_weight(node);
1892         }
1893         current->il_weight--;
1894         return node;
1895 }
1896
1897 /* Do dynamic interleaving for a process */
1898 static unsigned int interleave_nodes(struct mempolicy *policy)
1899 {
1900         unsigned int nid;
1901         unsigned int cpuset_mems_cookie;
1902
1903         /* to prevent miscount, use tsk->mems_allowed_seq to detect rebind */
1904         do {
1905                 cpuset_mems_cookie = read_mems_allowed_begin();
1906                 nid = next_node_in(current->il_prev, policy->nodes);
1907         } while (read_mems_allowed_retry(cpuset_mems_cookie));
1908
1909         if (nid < MAX_NUMNODES)
1910                 current->il_prev = nid;
1911         return nid;
1912 }
1913
1914 /*
1915  * Depending on the memory policy provide a node from which to allocate the
1916  * next slab entry.
1917  */
1918 unsigned int mempolicy_slab_node(void)
1919 {
1920         struct mempolicy *policy;
1921         int node = numa_mem_id();
1922
1923         if (!in_task())
1924                 return node;
1925
1926         policy = current->mempolicy;
1927         if (!policy)
1928                 return node;
1929
1930         switch (policy->mode) {
1931         case MPOL_PREFERRED:
1932                 return first_node(policy->nodes);
1933
1934         case MPOL_INTERLEAVE:
1935                 return interleave_nodes(policy);
1936
1937         case MPOL_WEIGHTED_INTERLEAVE:
1938                 return weighted_interleave_nodes(policy);
1939
1940         case MPOL_BIND:
1941         case MPOL_PREFERRED_MANY:
1942         {
1943                 struct zoneref *z;
1944
1945                 /*
1946                  * Follow bind policy behavior and start allocation at the
1947                  * first node.
1948                  */
1949                 struct zonelist *zonelist;
1950                 enum zone_type highest_zoneidx = gfp_zone(GFP_KERNEL);
1951                 zonelist = &NODE_DATA(node)->node_zonelists[ZONELIST_FALLBACK];
1952                 z = first_zones_zonelist(zonelist, highest_zoneidx,
1953                                                         &policy->nodes);
1954                 return z->zone ? zone_to_nid(z->zone) : node;
1955         }
1956         case MPOL_LOCAL:
1957                 return node;
1958
1959         default:
1960                 BUG();
1961         }
1962 }
1963
1964 static unsigned int read_once_policy_nodemask(struct mempolicy *pol,
1965                                               nodemask_t *mask)
1966 {
1967         /*
1968          * barrier stabilizes the nodemask locally so that it can be iterated
1969          * over safely without concern for changes. Allocators validate node
1970          * selection does not violate mems_allowed, so this is safe.
1971          */
1972         barrier();
1973         memcpy(mask, &pol->nodes, sizeof(nodemask_t));
1974         barrier();
1975         return nodes_weight(*mask);
1976 }
1977
1978 static unsigned int weighted_interleave_nid(struct mempolicy *pol, pgoff_t ilx)
1979 {
1980         nodemask_t nodemask;
1981         unsigned int target, nr_nodes;
1982         u8 *table;
1983         unsigned int weight_total = 0;
1984         u8 weight;
1985         int nid;
1986
1987         nr_nodes = read_once_policy_nodemask(pol, &nodemask);
1988         if (!nr_nodes)
1989                 return numa_node_id();
1990
1991         rcu_read_lock();
1992         table = rcu_dereference(iw_table);
1993         /* calculate the total weight */
1994         for_each_node_mask(nid, nodemask) {
1995                 /* detect system default usage */
1996                 weight = table ? table[nid] : 1;
1997                 weight = weight ? weight : 1;
1998                 weight_total += weight;
1999         }
2000
2001         /* Calculate the node offset based on totals */
2002         target = ilx % weight_total;
2003         nid = first_node(nodemask);
2004         while (target) {
2005                 /* detect system default usage */
2006                 weight = table ? table[nid] : 1;
2007                 weight = weight ? weight : 1;
2008                 if (target < weight)
2009                         break;
2010                 target -= weight;
2011                 nid = next_node_in(nid, nodemask);
2012         }
2013         rcu_read_unlock();
2014         return nid;
2015 }
2016
2017 /*
2018  * Do static interleaving for interleave index @ilx.  Returns the ilx'th
2019  * node in pol->nodes (starting from ilx=0), wrapping around if ilx
2020  * exceeds the number of present nodes.
2021  */
2022 static unsigned int interleave_nid(struct mempolicy *pol, pgoff_t ilx)
2023 {
2024         nodemask_t nodemask;
2025         unsigned int target, nnodes;
2026         int i;
2027         int nid;
2028
2029         nnodes = read_once_policy_nodemask(pol, &nodemask);
2030         if (!nnodes)
2031                 return numa_node_id();
2032         target = ilx % nnodes;
2033         nid = first_node(nodemask);
2034         for (i = 0; i < target; i++)
2035                 nid = next_node(nid, nodemask);
2036         return nid;
2037 }
2038
2039 /*
2040  * Return a nodemask representing a mempolicy for filtering nodes for
2041  * page allocation, together with preferred node id (or the input node id).
2042  */
2043 static nodemask_t *policy_nodemask(gfp_t gfp, struct mempolicy *pol,
2044                                    pgoff_t ilx, int *nid)
2045 {
2046         nodemask_t *nodemask = NULL;
2047
2048         switch (pol->mode) {
2049         case MPOL_PREFERRED:
2050                 /* Override input node id */
2051                 *nid = first_node(pol->nodes);
2052                 break;
2053         case MPOL_PREFERRED_MANY:
2054                 nodemask = &pol->nodes;
2055                 if (pol->home_node != NUMA_NO_NODE)
2056                         *nid = pol->home_node;
2057                 break;
2058         case MPOL_BIND:
2059                 /* Restrict to nodemask (but not on lower zones) */
2060                 if (apply_policy_zone(pol, gfp_zone(gfp)) &&
2061                     cpuset_nodemask_valid_mems_allowed(&pol->nodes))
2062                         nodemask = &pol->nodes;
2063                 if (pol->home_node != NUMA_NO_NODE)
2064                         *nid = pol->home_node;
2065                 /*
2066                  * __GFP_THISNODE shouldn't even be used with the bind policy
2067                  * because we might easily break the expectation to stay on the
2068                  * requested node and not break the policy.
2069                  */
2070                 WARN_ON_ONCE(gfp & __GFP_THISNODE);
2071                 break;
2072         case MPOL_INTERLEAVE:
2073                 /* Override input node id */
2074                 *nid = (ilx == NO_INTERLEAVE_INDEX) ?
2075                         interleave_nodes(pol) : interleave_nid(pol, ilx);
2076                 break;
2077         case MPOL_WEIGHTED_INTERLEAVE:
2078                 *nid = (ilx == NO_INTERLEAVE_INDEX) ?
2079                         weighted_interleave_nodes(pol) :
2080                         weighted_interleave_nid(pol, ilx);
2081                 break;
2082         }
2083
2084         return nodemask;
2085 }
2086
2087 #ifdef CONFIG_HUGETLBFS
2088 /*
2089  * huge_node(@vma, @addr, @gfp_flags, @mpol)
2090  * @vma: virtual memory area whose policy is sought
2091  * @addr: address in @vma for shared policy lookup and interleave policy
2092  * @gfp_flags: for requested zone
2093  * @mpol: pointer to mempolicy pointer for reference counted mempolicy
2094  * @nodemask: pointer to nodemask pointer for 'bind' and 'prefer-many' policy
2095  *
2096  * Returns a nid suitable for a huge page allocation and a pointer
2097  * to the struct mempolicy for conditional unref after allocation.
2098  * If the effective policy is 'bind' or 'prefer-many', returns a pointer
2099  * to the mempolicy's @nodemask for filtering the zonelist.
2100  */
2101 int huge_node(struct vm_area_struct *vma, unsigned long addr, gfp_t gfp_flags,
2102                 struct mempolicy **mpol, nodemask_t **nodemask)
2103 {
2104         pgoff_t ilx;
2105         int nid;
2106
2107         nid = numa_node_id();
2108         *mpol = get_vma_policy(vma, addr, hstate_vma(vma)->order, &ilx);
2109         *nodemask = policy_nodemask(gfp_flags, *mpol, ilx, &nid);
2110         return nid;
2111 }
2112
2113 /*
2114  * init_nodemask_of_mempolicy
2115  *
2116  * If the current task's mempolicy is "default" [NULL], return 'false'
2117  * to indicate default policy.  Otherwise, extract the policy nodemask
2118  * for 'bind' or 'interleave' policy into the argument nodemask, or
2119  * initialize the argument nodemask to contain the single node for
2120  * 'preferred' or 'local' policy and return 'true' to indicate presence
2121  * of non-default mempolicy.
2122  *
2123  * We don't bother with reference counting the mempolicy [mpol_get/put]
2124  * because the current task is examining it's own mempolicy and a task's
2125  * mempolicy is only ever changed by the task itself.
2126  *
2127  * N.B., it is the caller's responsibility to free a returned nodemask.
2128  */
2129 bool init_nodemask_of_mempolicy(nodemask_t *mask)
2130 {
2131         struct mempolicy *mempolicy;
2132
2133         if (!(mask && current->mempolicy))
2134                 return false;
2135
2136         task_lock(current);
2137         mempolicy = current->mempolicy;
2138         switch (mempolicy->mode) {
2139         case MPOL_PREFERRED:
2140         case MPOL_PREFERRED_MANY:
2141         case MPOL_BIND:
2142         case MPOL_INTERLEAVE:
2143         case MPOL_WEIGHTED_INTERLEAVE:
2144                 *mask = mempolicy->nodes;
2145                 break;
2146
2147         case MPOL_LOCAL:
2148                 init_nodemask_of_node(mask, numa_node_id());
2149                 break;
2150
2151         default:
2152                 BUG();
2153         }
2154         task_unlock(current);
2155
2156         return true;
2157 }
2158 #endif
2159
2160 /*
2161  * mempolicy_in_oom_domain
2162  *
2163  * If tsk's mempolicy is "bind", check for intersection between mask and
2164  * the policy nodemask. Otherwise, return true for all other policies
2165  * including "interleave", as a tsk with "interleave" policy may have
2166  * memory allocated from all nodes in system.
2167  *
2168  * Takes task_lock(tsk) to prevent freeing of its mempolicy.
2169  */
2170 bool mempolicy_in_oom_domain(struct task_struct *tsk,
2171                                         const nodemask_t *mask)
2172 {
2173         struct mempolicy *mempolicy;
2174         bool ret = true;
2175
2176         if (!mask)
2177                 return ret;
2178
2179         task_lock(tsk);
2180         mempolicy = tsk->mempolicy;
2181         if (mempolicy && mempolicy->mode == MPOL_BIND)
2182                 ret = nodes_intersects(mempolicy->nodes, *mask);
2183         task_unlock(tsk);
2184
2185         return ret;
2186 }
2187
2188 static struct page *alloc_pages_preferred_many(gfp_t gfp, unsigned int order,
2189                                                 int nid, nodemask_t *nodemask)
2190 {
2191         struct page *page;
2192         gfp_t preferred_gfp;
2193
2194         /*
2195          * This is a two pass approach. The first pass will only try the
2196          * preferred nodes but skip the direct reclaim and allow the
2197          * allocation to fail, while the second pass will try all the
2198          * nodes in system.
2199          */
2200         preferred_gfp = gfp | __GFP_NOWARN;
2201         preferred_gfp &= ~(__GFP_DIRECT_RECLAIM | __GFP_NOFAIL);
2202         page = __alloc_pages_noprof(preferred_gfp, order, nid, nodemask);
2203         if (!page)
2204                 page = __alloc_pages_noprof(gfp, order, nid, NULL);
2205
2206         return page;
2207 }
2208
2209 /**
2210  * alloc_pages_mpol - Allocate pages according to NUMA mempolicy.
2211  * @gfp: GFP flags.
2212  * @order: Order of the page allocation.
2213  * @pol: Pointer to the NUMA mempolicy.
2214  * @ilx: Index for interleave mempolicy (also distinguishes alloc_pages()).
2215  * @nid: Preferred node (usually numa_node_id() but @mpol may override it).
2216  *
2217  * Return: The page on success or NULL if allocation fails.
2218  */
2219 struct page *alloc_pages_mpol_noprof(gfp_t gfp, unsigned int order,
2220                 struct mempolicy *pol, pgoff_t ilx, int nid)
2221 {
2222         nodemask_t *nodemask;
2223         struct page *page;
2224
2225         nodemask = policy_nodemask(gfp, pol, ilx, &nid);
2226
2227         if (pol->mode == MPOL_PREFERRED_MANY)
2228                 return alloc_pages_preferred_many(gfp, order, nid, nodemask);
2229
2230         if (IS_ENABLED(CONFIG_TRANSPARENT_HUGEPAGE) &&
2231             /* filter "hugepage" allocation, unless from alloc_pages() */
2232             order == HPAGE_PMD_ORDER && ilx != NO_INTERLEAVE_INDEX) {
2233                 /*
2234                  * For hugepage allocation and non-interleave policy which
2235                  * allows the current node (or other explicitly preferred
2236                  * node) we only try to allocate from the current/preferred
2237                  * node and don't fall back to other nodes, as the cost of
2238                  * remote accesses would likely offset THP benefits.
2239                  *
2240                  * If the policy is interleave or does not allow the current
2241                  * node in its nodemask, we allocate the standard way.
2242                  */
2243                 if (pol->mode != MPOL_INTERLEAVE &&
2244                     pol->mode != MPOL_WEIGHTED_INTERLEAVE &&
2245                     (!nodemask || node_isset(nid, *nodemask))) {
2246                         /*
2247                          * First, try to allocate THP only on local node, but
2248                          * don't reclaim unnecessarily, just compact.
2249                          */
2250                         page = __alloc_pages_node_noprof(nid,
2251                                 gfp | __GFP_THISNODE | __GFP_NORETRY, order);
2252                         if (page || !(gfp & __GFP_DIRECT_RECLAIM))
2253                                 return page;
2254                         /*
2255                          * If hugepage allocations are configured to always
2256                          * synchronous compact or the vma has been madvised
2257                          * to prefer hugepage backing, retry allowing remote
2258                          * memory with both reclaim and compact as well.
2259                          */
2260                 }
2261         }
2262
2263         page = __alloc_pages_noprof(gfp, order, nid, nodemask);
2264
2265         if (unlikely(pol->mode == MPOL_INTERLEAVE) && page) {
2266                 /* skip NUMA_INTERLEAVE_HIT update if numa stats is disabled */
2267                 if (static_branch_likely(&vm_numa_stat_key) &&
2268                     page_to_nid(page) == nid) {
2269                         preempt_disable();
2270                         __count_numa_event(page_zone(page), NUMA_INTERLEAVE_HIT);
2271                         preempt_enable();
2272                 }
2273         }
2274
2275         return page;
2276 }
2277
2278 struct folio *folio_alloc_mpol_noprof(gfp_t gfp, unsigned int order,
2279                 struct mempolicy *pol, pgoff_t ilx, int nid)
2280 {
2281         return page_rmappable_folio(alloc_pages_mpol_noprof(gfp | __GFP_COMP,
2282                                                         order, pol, ilx, nid));
2283 }
2284
2285 /**
2286  * vma_alloc_folio - Allocate a folio for a VMA.
2287  * @gfp: GFP flags.
2288  * @order: Order of the folio.
2289  * @vma: Pointer to VMA.
2290  * @addr: Virtual address of the allocation.  Must be inside @vma.
2291  * @hugepage: Unused (was: For hugepages try only preferred node if possible).
2292  *
2293  * Allocate a folio for a specific address in @vma, using the appropriate
2294  * NUMA policy.  The caller must hold the mmap_lock of the mm_struct of the
2295  * VMA to prevent it from going away.  Should be used for all allocations
2296  * for folios that will be mapped into user space, excepting hugetlbfs, and
2297  * excepting where direct use of alloc_pages_mpol() is more appropriate.
2298  *
2299  * Return: The folio on success or NULL if allocation fails.
2300  */
2301 struct folio *vma_alloc_folio_noprof(gfp_t gfp, int order, struct vm_area_struct *vma,
2302                 unsigned long addr, bool hugepage)
2303 {
2304         struct mempolicy *pol;
2305         pgoff_t ilx;
2306         struct folio *folio;
2307
2308         if (vma->vm_flags & VM_DROPPABLE)
2309                 gfp |= __GFP_NOWARN;
2310
2311         pol = get_vma_policy(vma, addr, order, &ilx);
2312         folio = folio_alloc_mpol_noprof(gfp, order, pol, ilx, numa_node_id());
2313         mpol_cond_put(pol);
2314         return folio;
2315 }
2316 EXPORT_SYMBOL(vma_alloc_folio_noprof);
2317
2318 /**
2319  * alloc_pages - Allocate pages.
2320  * @gfp: GFP flags.
2321  * @order: Power of two of number of pages to allocate.
2322  *
2323  * Allocate 1 << @order contiguous pages.  The physical address of the
2324  * first page is naturally aligned (eg an order-3 allocation will be aligned
2325  * to a multiple of 8 * PAGE_SIZE bytes).  The NUMA policy of the current
2326  * process is honoured when in process context.
2327  *
2328  * Context: Can be called from any context, providing the appropriate GFP
2329  * flags are used.
2330  * Return: The page on success or NULL if allocation fails.
2331  */
2332 struct page *alloc_pages_noprof(gfp_t gfp, unsigned int order)
2333 {
2334         struct mempolicy *pol = &default_policy;
2335
2336         /*
2337          * No reference counting needed for current->mempolicy
2338          * nor system default_policy
2339          */
2340         if (!in_interrupt() && !(gfp & __GFP_THISNODE))
2341                 pol = get_task_policy(current);
2342
2343         return alloc_pages_mpol_noprof(gfp, order, pol, NO_INTERLEAVE_INDEX,
2344                                        numa_node_id());
2345 }
2346 EXPORT_SYMBOL(alloc_pages_noprof);
2347
2348 struct folio *folio_alloc_noprof(gfp_t gfp, unsigned int order)
2349 {
2350         return page_rmappable_folio(alloc_pages_noprof(gfp | __GFP_COMP, order));
2351 }
2352 EXPORT_SYMBOL(folio_alloc_noprof);
2353
2354 static unsigned long alloc_pages_bulk_array_interleave(gfp_t gfp,
2355                 struct mempolicy *pol, unsigned long nr_pages,
2356                 struct page **page_array)
2357 {
2358         int nodes;
2359         unsigned long nr_pages_per_node;
2360         int delta;
2361         int i;
2362         unsigned long nr_allocated;
2363         unsigned long total_allocated = 0;
2364
2365         nodes = nodes_weight(pol->nodes);
2366         nr_pages_per_node = nr_pages / nodes;
2367         delta = nr_pages - nodes * nr_pages_per_node;
2368
2369         for (i = 0; i < nodes; i++) {
2370                 if (delta) {
2371                         nr_allocated = alloc_pages_bulk_noprof(gfp,
2372                                         interleave_nodes(pol), NULL,
2373                                         nr_pages_per_node + 1, NULL,
2374                                         page_array);
2375                         delta--;
2376                 } else {
2377                         nr_allocated = alloc_pages_bulk_noprof(gfp,
2378                                         interleave_nodes(pol), NULL,
2379                                         nr_pages_per_node, NULL, page_array);
2380                 }
2381
2382                 page_array += nr_allocated;
2383                 total_allocated += nr_allocated;
2384         }
2385
2386         return total_allocated;
2387 }
2388
2389 static unsigned long alloc_pages_bulk_array_weighted_interleave(gfp_t gfp,
2390                 struct mempolicy *pol, unsigned long nr_pages,
2391                 struct page **page_array)
2392 {
2393         struct task_struct *me = current;
2394         unsigned int cpuset_mems_cookie;
2395         unsigned long total_allocated = 0;
2396         unsigned long nr_allocated = 0;
2397         unsigned long rounds;
2398         unsigned long node_pages, delta;
2399         u8 *table, *weights, weight;
2400         unsigned int weight_total = 0;
2401         unsigned long rem_pages = nr_pages;
2402         nodemask_t nodes;
2403         int nnodes, node;
2404         int resume_node = MAX_NUMNODES - 1;
2405         u8 resume_weight = 0;
2406         int prev_node;
2407         int i;
2408
2409         if (!nr_pages)
2410                 return 0;
2411
2412         /* read the nodes onto the stack, retry if done during rebind */
2413         do {
2414                 cpuset_mems_cookie = read_mems_allowed_begin();
2415                 nnodes = read_once_policy_nodemask(pol, &nodes);
2416         } while (read_mems_allowed_retry(cpuset_mems_cookie));
2417
2418         /* if the nodemask has become invalid, we cannot do anything */
2419         if (!nnodes)
2420                 return 0;
2421
2422         /* Continue allocating from most recent node and adjust the nr_pages */
2423         node = me->il_prev;
2424         weight = me->il_weight;
2425         if (weight && node_isset(node, nodes)) {
2426                 node_pages = min(rem_pages, weight);
2427                 nr_allocated = __alloc_pages_bulk(gfp, node, NULL, node_pages,
2428                                                   NULL, page_array);
2429                 page_array += nr_allocated;
2430                 total_allocated += nr_allocated;
2431                 /* if that's all the pages, no need to interleave */
2432                 if (rem_pages <= weight) {
2433                         me->il_weight -= rem_pages;
2434                         return total_allocated;
2435                 }
2436                 /* Otherwise we adjust remaining pages, continue from there */
2437                 rem_pages -= weight;
2438         }
2439         /* clear active weight in case of an allocation failure */
2440         me->il_weight = 0;
2441         prev_node = node;
2442
2443         /* create a local copy of node weights to operate on outside rcu */
2444         weights = kzalloc(nr_node_ids, GFP_KERNEL);
2445         if (!weights)
2446                 return total_allocated;
2447
2448         rcu_read_lock();
2449         table = rcu_dereference(iw_table);
2450         if (table)
2451                 memcpy(weights, table, nr_node_ids);
2452         rcu_read_unlock();
2453
2454         /* calculate total, detect system default usage */
2455         for_each_node_mask(node, nodes) {
2456                 if (!weights[node])
2457                         weights[node] = 1;
2458                 weight_total += weights[node];
2459         }
2460
2461         /*
2462          * Calculate rounds/partial rounds to minimize __alloc_pages_bulk calls.
2463          * Track which node weighted interleave should resume from.
2464          *
2465          * if (rounds > 0) and (delta == 0), resume_node will always be
2466          * the node following prev_node and its weight.
2467          */
2468         rounds = rem_pages / weight_total;
2469         delta = rem_pages % weight_total;
2470         resume_node = next_node_in(prev_node, nodes);
2471         resume_weight = weights[resume_node];
2472         for (i = 0; i < nnodes; i++) {
2473                 node = next_node_in(prev_node, nodes);
2474                 weight = weights[node];
2475                 node_pages = weight * rounds;
2476                 /* If a delta exists, add this node's portion of the delta */
2477                 if (delta > weight) {
2478                         node_pages += weight;
2479                         delta -= weight;
2480                 } else if (delta) {
2481                         /* when delta is depleted, resume from that node */
2482                         node_pages += delta;
2483                         resume_node = node;
2484                         resume_weight = weight - delta;
2485                         delta = 0;
2486                 }
2487                 /* node_pages can be 0 if an allocation fails and rounds == 0 */
2488                 if (!node_pages)
2489                         break;
2490                 nr_allocated = __alloc_pages_bulk(gfp, node, NULL, node_pages,
2491                                                   NULL, page_array);
2492                 page_array += nr_allocated;
2493                 total_allocated += nr_allocated;
2494                 if (total_allocated == nr_pages)
2495                         break;
2496                 prev_node = node;
2497         }
2498         me->il_prev = resume_node;
2499         me->il_weight = resume_weight;
2500         kfree(weights);
2501         return total_allocated;
2502 }
2503
2504 static unsigned long alloc_pages_bulk_array_preferred_many(gfp_t gfp, int nid,
2505                 struct mempolicy *pol, unsigned long nr_pages,
2506                 struct page **page_array)
2507 {
2508         gfp_t preferred_gfp;
2509         unsigned long nr_allocated = 0;
2510
2511         preferred_gfp = gfp | __GFP_NOWARN;
2512         preferred_gfp &= ~(__GFP_DIRECT_RECLAIM | __GFP_NOFAIL);
2513
2514         nr_allocated  = alloc_pages_bulk_noprof(preferred_gfp, nid, &pol->nodes,
2515                                            nr_pages, NULL, page_array);
2516
2517         if (nr_allocated < nr_pages)
2518                 nr_allocated += alloc_pages_bulk_noprof(gfp, numa_node_id(), NULL,
2519                                 nr_pages - nr_allocated, NULL,
2520                                 page_array + nr_allocated);
2521         return nr_allocated;
2522 }
2523
2524 /* alloc pages bulk and mempolicy should be considered at the
2525  * same time in some situation such as vmalloc.
2526  *
2527  * It can accelerate memory allocation especially interleaving
2528  * allocate memory.
2529  */
2530 unsigned long alloc_pages_bulk_array_mempolicy_noprof(gfp_t gfp,
2531                 unsigned long nr_pages, struct page **page_array)
2532 {
2533         struct mempolicy *pol = &default_policy;
2534         nodemask_t *nodemask;
2535         int nid;
2536
2537         if (!in_interrupt() && !(gfp & __GFP_THISNODE))
2538                 pol = get_task_policy(current);
2539
2540         if (pol->mode == MPOL_INTERLEAVE)
2541                 return alloc_pages_bulk_array_interleave(gfp, pol,
2542                                                          nr_pages, page_array);
2543
2544         if (pol->mode == MPOL_WEIGHTED_INTERLEAVE)
2545                 return alloc_pages_bulk_array_weighted_interleave(
2546                                   gfp, pol, nr_pages, page_array);
2547
2548         if (pol->mode == MPOL_PREFERRED_MANY)
2549                 return alloc_pages_bulk_array_preferred_many(gfp,
2550                                 numa_node_id(), pol, nr_pages, page_array);
2551
2552         nid = numa_node_id();
2553         nodemask = policy_nodemask(gfp, pol, NO_INTERLEAVE_INDEX, &nid);
2554         return alloc_pages_bulk_noprof(gfp, nid, nodemask,
2555                                        nr_pages, NULL, page_array);
2556 }
2557
2558 int vma_dup_policy(struct vm_area_struct *src, struct vm_area_struct *dst)
2559 {
2560         struct mempolicy *pol = mpol_dup(src->vm_policy);
2561
2562         if (IS_ERR(pol))
2563                 return PTR_ERR(pol);
2564         dst->vm_policy = pol;
2565         return 0;
2566 }
2567
2568 /*
2569  * If mpol_dup() sees current->cpuset == cpuset_being_rebound, then it
2570  * rebinds the mempolicy its copying by calling mpol_rebind_policy()
2571  * with the mems_allowed returned by cpuset_mems_allowed().  This
2572  * keeps mempolicies cpuset relative after its cpuset moves.  See
2573  * further kernel/cpuset.c update_nodemask().
2574  *
2575  * current's mempolicy may be rebinded by the other task(the task that changes
2576  * cpuset's mems), so we needn't do rebind work for current task.
2577  */
2578
2579 /* Slow path of a mempolicy duplicate */
2580 struct mempolicy *__mpol_dup(struct mempolicy *old)
2581 {
2582         struct mempolicy *new = kmem_cache_alloc(policy_cache, GFP_KERNEL);
2583
2584         if (!new)
2585                 return ERR_PTR(-ENOMEM);
2586
2587         /* task's mempolicy is protected by alloc_lock */
2588         if (old == current->mempolicy) {
2589                 task_lock(current);
2590                 *new = *old;
2591                 task_unlock(current);
2592         } else
2593                 *new = *old;
2594
2595         if (current_cpuset_is_being_rebound()) {
2596                 nodemask_t mems = cpuset_mems_allowed(current);
2597                 mpol_rebind_policy(new, &mems);
2598         }
2599         atomic_set(&new->refcnt, 1);
2600         return new;
2601 }
2602
2603 /* Slow path of a mempolicy comparison */
2604 bool __mpol_equal(struct mempolicy *a, struct mempolicy *b)
2605 {
2606         if (!a || !b)
2607                 return false;
2608         if (a->mode != b->mode)
2609                 return false;
2610         if (a->flags != b->flags)
2611                 return false;
2612         if (a->home_node != b->home_node)
2613                 return false;
2614         if (mpol_store_user_nodemask(a))
2615                 if (!nodes_equal(a->w.user_nodemask, b->w.user_nodemask))
2616                         return false;
2617
2618         switch (a->mode) {
2619         case MPOL_BIND:
2620         case MPOL_INTERLEAVE:
2621         case MPOL_PREFERRED:
2622         case MPOL_PREFERRED_MANY:
2623         case MPOL_WEIGHTED_INTERLEAVE:
2624                 return !!nodes_equal(a->nodes, b->nodes);
2625         case MPOL_LOCAL:
2626                 return true;
2627         default:
2628                 BUG();
2629                 return false;
2630         }
2631 }
2632
2633 /*
2634  * Shared memory backing store policy support.
2635  *
2636  * Remember policies even when nobody has shared memory mapped.
2637  * The policies are kept in Red-Black tree linked from the inode.
2638  * They are protected by the sp->lock rwlock, which should be held
2639  * for any accesses to the tree.
2640  */
2641
2642 /*
2643  * lookup first element intersecting start-end.  Caller holds sp->lock for
2644  * reading or for writing
2645  */
2646 static struct sp_node *sp_lookup(struct shared_policy *sp,
2647                                         pgoff_t start, pgoff_t end)
2648 {
2649         struct rb_node *n = sp->root.rb_node;
2650
2651         while (n) {
2652                 struct sp_node *p = rb_entry(n, struct sp_node, nd);
2653
2654                 if (start >= p->end)
2655                         n = n->rb_right;
2656                 else if (end <= p->start)
2657                         n = n->rb_left;
2658                 else
2659                         break;
2660         }
2661         if (!n)
2662                 return NULL;
2663         for (;;) {
2664                 struct sp_node *w = NULL;
2665                 struct rb_node *prev = rb_prev(n);
2666                 if (!prev)
2667                         break;
2668                 w = rb_entry(prev, struct sp_node, nd);
2669                 if (w->end <= start)
2670                         break;
2671                 n = prev;
2672         }
2673         return rb_entry(n, struct sp_node, nd);
2674 }
2675
2676 /*
2677  * Insert a new shared policy into the list.  Caller holds sp->lock for
2678  * writing.
2679  */
2680 static void sp_insert(struct shared_policy *sp, struct sp_node *new)
2681 {
2682         struct rb_node **p = &sp->root.rb_node;
2683         struct rb_node *parent = NULL;
2684         struct sp_node *nd;
2685
2686         while (*p) {
2687                 parent = *p;
2688                 nd = rb_entry(parent, struct sp_node, nd);
2689                 if (new->start < nd->start)
2690                         p = &(*p)->rb_left;
2691                 else if (new->end > nd->end)
2692                         p = &(*p)->rb_right;
2693                 else
2694                         BUG();
2695         }
2696         rb_link_node(&new->nd, parent, p);
2697         rb_insert_color(&new->nd, &sp->root);
2698 }
2699
2700 /* Find shared policy intersecting idx */
2701 struct mempolicy *mpol_shared_policy_lookup(struct shared_policy *sp,
2702                                                 pgoff_t idx)
2703 {
2704         struct mempolicy *pol = NULL;
2705         struct sp_node *sn;
2706
2707         if (!sp->root.rb_node)
2708                 return NULL;
2709         read_lock(&sp->lock);
2710         sn = sp_lookup(sp, idx, idx+1);
2711         if (sn) {
2712                 mpol_get(sn->policy);
2713                 pol = sn->policy;
2714         }
2715         read_unlock(&sp->lock);
2716         return pol;
2717 }
2718
2719 static void sp_free(struct sp_node *n)
2720 {
2721         mpol_put(n->policy);
2722         kmem_cache_free(sn_cache, n);
2723 }
2724
2725 /**
2726  * mpol_misplaced - check whether current folio node is valid in policy
2727  *
2728  * @folio: folio to be checked
2729  * @vmf: structure describing the fault
2730  * @addr: virtual address in @vma for shared policy lookup and interleave policy
2731  *
2732  * Lookup current policy node id for vma,addr and "compare to" folio's
2733  * node id.  Policy determination "mimics" alloc_page_vma().
2734  * Called from fault path where we know the vma and faulting address.
2735  *
2736  * Return: NUMA_NO_NODE if the page is in a node that is valid for this
2737  * policy, or a suitable node ID to allocate a replacement folio from.
2738  */
2739 int mpol_misplaced(struct folio *folio, struct vm_fault *vmf,
2740                    unsigned long addr)
2741 {
2742         struct mempolicy *pol;
2743         pgoff_t ilx;
2744         struct zoneref *z;
2745         int curnid = folio_nid(folio);
2746         struct vm_area_struct *vma = vmf->vma;
2747         int thiscpu = raw_smp_processor_id();
2748         int thisnid = numa_node_id();
2749         int polnid = NUMA_NO_NODE;
2750         int ret = NUMA_NO_NODE;
2751
2752         /*
2753          * Make sure ptl is held so that we don't preempt and we
2754          * have a stable smp processor id
2755          */
2756         lockdep_assert_held(vmf->ptl);
2757         pol = get_vma_policy(vma, addr, folio_order(folio), &ilx);
2758         if (!(pol->flags & MPOL_F_MOF))
2759                 goto out;
2760
2761         switch (pol->mode) {
2762         case MPOL_INTERLEAVE:
2763                 polnid = interleave_nid(pol, ilx);
2764                 break;
2765
2766         case MPOL_WEIGHTED_INTERLEAVE:
2767                 polnid = weighted_interleave_nid(pol, ilx);
2768                 break;
2769
2770         case MPOL_PREFERRED:
2771                 if (node_isset(curnid, pol->nodes))
2772                         goto out;
2773                 polnid = first_node(pol->nodes);
2774                 break;
2775
2776         case MPOL_LOCAL:
2777                 polnid = numa_node_id();
2778                 break;
2779
2780         case MPOL_BIND:
2781         case MPOL_PREFERRED_MANY:
2782                 /*
2783                  * Even though MPOL_PREFERRED_MANY can allocate pages outside
2784                  * policy nodemask we don't allow numa migration to nodes
2785                  * outside policy nodemask for now. This is done so that if we
2786                  * want demotion to slow memory to happen, before allocating
2787                  * from some DRAM node say 'x', we will end up using a
2788                  * MPOL_PREFERRED_MANY mask excluding node 'x'. In such scenario
2789                  * we should not promote to node 'x' from slow memory node.
2790                  */
2791                 if (pol->flags & MPOL_F_MORON) {
2792                         /*
2793                          * Optimize placement among multiple nodes
2794                          * via NUMA balancing
2795                          */
2796                         if (node_isset(thisnid, pol->nodes))
2797                                 break;
2798                         goto out;
2799                 }
2800
2801                 /*
2802                  * use current page if in policy nodemask,
2803                  * else select nearest allowed node, if any.
2804                  * If no allowed nodes, use current [!misplaced].
2805                  */
2806                 if (node_isset(curnid, pol->nodes))
2807                         goto out;
2808                 z = first_zones_zonelist(
2809                                 node_zonelist(thisnid, GFP_HIGHUSER),
2810                                 gfp_zone(GFP_HIGHUSER),
2811                                 &pol->nodes);
2812                 polnid = zone_to_nid(z->zone);
2813                 break;
2814
2815         default:
2816                 BUG();
2817         }
2818
2819         /* Migrate the folio towards the node whose CPU is referencing it */
2820         if (pol->flags & MPOL_F_MORON) {
2821                 polnid = thisnid;
2822
2823                 if (!should_numa_migrate_memory(current, folio, curnid,
2824                                                 thiscpu))
2825                         goto out;
2826         }
2827
2828         if (curnid != polnid)
2829                 ret = polnid;
2830 out:
2831         mpol_cond_put(pol);
2832
2833         return ret;
2834 }
2835
2836 /*
2837  * Drop the (possibly final) reference to task->mempolicy.  It needs to be
2838  * dropped after task->mempolicy is set to NULL so that any allocation done as
2839  * part of its kmem_cache_free(), such as by KASAN, doesn't reference a freed
2840  * policy.
2841  */
2842 void mpol_put_task_policy(struct task_struct *task)
2843 {
2844         struct mempolicy *pol;
2845
2846         task_lock(task);
2847         pol = task->mempolicy;
2848         task->mempolicy = NULL;
2849         task_unlock(task);
2850         mpol_put(pol);
2851 }
2852
2853 static void sp_delete(struct shared_policy *sp, struct sp_node *n)
2854 {
2855         rb_erase(&n->nd, &sp->root);
2856         sp_free(n);
2857 }
2858
2859 static void sp_node_init(struct sp_node *node, unsigned long start,
2860                         unsigned long end, struct mempolicy *pol)
2861 {
2862         node->start = start;
2863         node->end = end;
2864         node->policy = pol;
2865 }
2866
2867 static struct sp_node *sp_alloc(unsigned long start, unsigned long end,
2868                                 struct mempolicy *pol)
2869 {
2870         struct sp_node *n;
2871         struct mempolicy *newpol;
2872
2873         n = kmem_cache_alloc(sn_cache, GFP_KERNEL);
2874         if (!n)
2875                 return NULL;
2876
2877         newpol = mpol_dup(pol);
2878         if (IS_ERR(newpol)) {
2879                 kmem_cache_free(sn_cache, n);
2880                 return NULL;
2881         }
2882         newpol->flags |= MPOL_F_SHARED;
2883         sp_node_init(n, start, end, newpol);
2884
2885         return n;
2886 }
2887
2888 /* Replace a policy range. */
2889 static int shared_policy_replace(struct shared_policy *sp, pgoff_t start,
2890                                  pgoff_t end, struct sp_node *new)
2891 {
2892         struct sp_node *n;
2893         struct sp_node *n_new = NULL;
2894         struct mempolicy *mpol_new = NULL;
2895         int ret = 0;
2896
2897 restart:
2898         write_lock(&sp->lock);
2899         n = sp_lookup(sp, start, end);
2900         /* Take care of old policies in the same range. */
2901         while (n && n->start < end) {
2902                 struct rb_node *next = rb_next(&n->nd);
2903                 if (n->start >= start) {
2904                         if (n->end <= end)
2905                                 sp_delete(sp, n);
2906                         else
2907                                 n->start = end;
2908                 } else {
2909                         /* Old policy spanning whole new range. */
2910                         if (n->end > end) {
2911                                 if (!n_new)
2912                                         goto alloc_new;
2913
2914                                 *mpol_new = *n->policy;
2915                                 atomic_set(&mpol_new->refcnt, 1);
2916                                 sp_node_init(n_new, end, n->end, mpol_new);
2917                                 n->end = start;
2918                                 sp_insert(sp, n_new);
2919                                 n_new = NULL;
2920                                 mpol_new = NULL;
2921                                 break;
2922                         } else
2923                                 n->end = start;
2924                 }
2925                 if (!next)
2926                         break;
2927                 n = rb_entry(next, struct sp_node, nd);
2928         }
2929         if (new)
2930                 sp_insert(sp, new);
2931         write_unlock(&sp->lock);
2932         ret = 0;
2933
2934 err_out:
2935         if (mpol_new)
2936                 mpol_put(mpol_new);
2937         if (n_new)
2938                 kmem_cache_free(sn_cache, n_new);
2939
2940         return ret;
2941
2942 alloc_new:
2943         write_unlock(&sp->lock);
2944         ret = -ENOMEM;
2945         n_new = kmem_cache_alloc(sn_cache, GFP_KERNEL);
2946         if (!n_new)
2947                 goto err_out;
2948         mpol_new = kmem_cache_alloc(policy_cache, GFP_KERNEL);
2949         if (!mpol_new)
2950                 goto err_out;
2951         atomic_set(&mpol_new->refcnt, 1);
2952         goto restart;
2953 }
2954
2955 /**
2956  * mpol_shared_policy_init - initialize shared policy for inode
2957  * @sp: pointer to inode shared policy
2958  * @mpol:  struct mempolicy to install
2959  *
2960  * Install non-NULL @mpol in inode's shared policy rb-tree.
2961  * On entry, the current task has a reference on a non-NULL @mpol.
2962  * This must be released on exit.
2963  * This is called at get_inode() calls and we can use GFP_KERNEL.
2964  */
2965 void mpol_shared_policy_init(struct shared_policy *sp, struct mempolicy *mpol)
2966 {
2967         int ret;
2968
2969         sp->root = RB_ROOT;             /* empty tree == default mempolicy */
2970         rwlock_init(&sp->lock);
2971
2972         if (mpol) {
2973                 struct sp_node *sn;
2974                 struct mempolicy *npol;
2975                 NODEMASK_SCRATCH(scratch);
2976
2977                 if (!scratch)
2978                         goto put_mpol;
2979
2980                 /* contextualize the tmpfs mount point mempolicy to this file */
2981                 npol = mpol_new(mpol->mode, mpol->flags, &mpol->w.user_nodemask);
2982                 if (IS_ERR(npol))
2983                         goto free_scratch; /* no valid nodemask intersection */
2984
2985                 task_lock(current);
2986                 ret = mpol_set_nodemask(npol, &mpol->w.user_nodemask, scratch);
2987                 task_unlock(current);
2988                 if (ret)
2989                         goto put_npol;
2990
2991                 /* alloc node covering entire file; adds ref to file's npol */
2992                 sn = sp_alloc(0, MAX_LFS_FILESIZE >> PAGE_SHIFT, npol);
2993                 if (sn)
2994                         sp_insert(sp, sn);
2995 put_npol:
2996                 mpol_put(npol); /* drop initial ref on file's npol */
2997 free_scratch:
2998                 NODEMASK_SCRATCH_FREE(scratch);
2999 put_mpol:
3000                 mpol_put(mpol); /* drop our incoming ref on sb mpol */
3001         }
3002 }
3003
3004 int mpol_set_shared_policy(struct shared_policy *sp,
3005                         struct vm_area_struct *vma, struct mempolicy *pol)
3006 {
3007         int err;
3008         struct sp_node *new = NULL;
3009         unsigned long sz = vma_pages(vma);
3010
3011         if (pol) {
3012                 new = sp_alloc(vma->vm_pgoff, vma->vm_pgoff + sz, pol);
3013                 if (!new)
3014                         return -ENOMEM;
3015         }
3016         err = shared_policy_replace(sp, vma->vm_pgoff, vma->vm_pgoff + sz, new);
3017         if (err && new)
3018                 sp_free(new);
3019         return err;
3020 }
3021
3022 /* Free a backing policy store on inode delete. */
3023 void mpol_free_shared_policy(struct shared_policy *sp)
3024 {
3025         struct sp_node *n;
3026         struct rb_node *next;
3027
3028         if (!sp->root.rb_node)
3029                 return;
3030         write_lock(&sp->lock);
3031         next = rb_first(&sp->root);
3032         while (next) {
3033                 n = rb_entry(next, struct sp_node, nd);
3034                 next = rb_next(&n->nd);
3035                 sp_delete(sp, n);
3036         }
3037         write_unlock(&sp->lock);
3038 }
3039
3040 #ifdef CONFIG_NUMA_BALANCING
3041 static int __initdata numabalancing_override;
3042
3043 static void __init check_numabalancing_enable(void)
3044 {
3045         bool numabalancing_default = false;
3046
3047         if (IS_ENABLED(CONFIG_NUMA_BALANCING_DEFAULT_ENABLED))
3048                 numabalancing_default = true;
3049
3050         /* Parsed by setup_numabalancing. override == 1 enables, -1 disables */
3051         if (numabalancing_override)
3052                 set_numabalancing_state(numabalancing_override == 1);
3053
3054         if (num_online_nodes() > 1 && !numabalancing_override) {
3055                 pr_info("%s automatic NUMA balancing. Configure with numa_balancing= or the kernel.numa_balancing sysctl\n",
3056                         numabalancing_default ? "Enabling" : "Disabling");
3057                 set_numabalancing_state(numabalancing_default);
3058         }
3059 }
3060
3061 static int __init setup_numabalancing(char *str)
3062 {
3063         int ret = 0;
3064         if (!str)
3065                 goto out;
3066
3067         if (!strcmp(str, "enable")) {
3068                 numabalancing_override = 1;
3069                 ret = 1;
3070         } else if (!strcmp(str, "disable")) {
3071                 numabalancing_override = -1;
3072                 ret = 1;
3073         }
3074 out:
3075         if (!ret)
3076                 pr_warn("Unable to parse numa_balancing=\n");
3077
3078         return ret;
3079 }
3080 __setup("numa_balancing=", setup_numabalancing);
3081 #else
3082 static inline void __init check_numabalancing_enable(void)
3083 {
3084 }
3085 #endif /* CONFIG_NUMA_BALANCING */
3086
3087 void __init numa_policy_init(void)
3088 {
3089         nodemask_t interleave_nodes;
3090         unsigned long largest = 0;
3091         int nid, prefer = 0;
3092
3093         policy_cache = kmem_cache_create("numa_policy",
3094                                          sizeof(struct mempolicy),
3095                                          0, SLAB_PANIC, NULL);
3096
3097         sn_cache = kmem_cache_create("shared_policy_node",
3098                                      sizeof(struct sp_node),
3099                                      0, SLAB_PANIC, NULL);
3100
3101         for_each_node(nid) {
3102                 preferred_node_policy[nid] = (struct mempolicy) {
3103                         .refcnt = ATOMIC_INIT(1),
3104                         .mode = MPOL_PREFERRED,
3105                         .flags = MPOL_F_MOF | MPOL_F_MORON,
3106                         .nodes = nodemask_of_node(nid),
3107                 };
3108         }
3109
3110         /*
3111          * Set interleaving policy for system init. Interleaving is only
3112          * enabled across suitably sized nodes (default is >= 16MB), or
3113          * fall back to the largest node if they're all smaller.
3114          */
3115         nodes_clear(interleave_nodes);
3116         for_each_node_state(nid, N_MEMORY) {
3117                 unsigned long total_pages = node_present_pages(nid);
3118
3119                 /* Preserve the largest node */
3120                 if (largest < total_pages) {
3121                         largest = total_pages;
3122                         prefer = nid;
3123                 }
3124
3125                 /* Interleave this node? */
3126                 if ((total_pages << PAGE_SHIFT) >= (16 << 20))
3127                         node_set(nid, interleave_nodes);
3128         }
3129
3130         /* All too small, use the largest */
3131         if (unlikely(nodes_empty(interleave_nodes)))
3132                 node_set(prefer, interleave_nodes);
3133
3134         if (do_set_mempolicy(MPOL_INTERLEAVE, 0, &interleave_nodes))
3135                 pr_err("%s: interleaving failed\n", __func__);
3136
3137         check_numabalancing_enable();
3138 }
3139
3140 /* Reset policy of current process to default */
3141 void numa_default_policy(void)
3142 {
3143         do_set_mempolicy(MPOL_DEFAULT, 0, NULL);
3144 }
3145
3146 /*
3147  * Parse and format mempolicy from/to strings
3148  */
3149 static const char * const policy_modes[] =
3150 {
3151         [MPOL_DEFAULT]    = "default",
3152         [MPOL_PREFERRED]  = "prefer",
3153         [MPOL_BIND]       = "bind",
3154         [MPOL_INTERLEAVE] = "interleave",
3155         [MPOL_WEIGHTED_INTERLEAVE] = "weighted interleave",
3156         [MPOL_LOCAL]      = "local",
3157         [MPOL_PREFERRED_MANY]  = "prefer (many)",
3158 };
3159
3160 #ifdef CONFIG_TMPFS
3161 /**
3162  * mpol_parse_str - parse string to mempolicy, for tmpfs mpol mount option.
3163  * @str:  string containing mempolicy to parse
3164  * @mpol:  pointer to struct mempolicy pointer, returned on success.
3165  *
3166  * Format of input:
3167  *      <mode>[=<flags>][:<nodelist>]
3168  *
3169  * Return: %0 on success, else %1
3170  */
3171 int mpol_parse_str(char *str, struct mempolicy **mpol)
3172 {
3173         struct mempolicy *new = NULL;
3174         unsigned short mode_flags;
3175         nodemask_t nodes;
3176         char *nodelist = strchr(str, ':');
3177         char *flags = strchr(str, '=');
3178         int err = 1, mode;
3179
3180         if (flags)
3181                 *flags++ = '\0';        /* terminate mode string */
3182
3183         if (nodelist) {
3184                 /* NUL-terminate mode or flags string */
3185                 *nodelist++ = '\0';
3186                 if (nodelist_parse(nodelist, nodes))
3187                         goto out;
3188                 if (!nodes_subset(nodes, node_states[N_MEMORY]))
3189                         goto out;
3190         } else
3191                 nodes_clear(nodes);
3192
3193         mode = match_string(policy_modes, MPOL_MAX, str);
3194         if (mode < 0)
3195                 goto out;
3196
3197         switch (mode) {
3198         case MPOL_PREFERRED:
3199                 /*
3200                  * Insist on a nodelist of one node only, although later
3201                  * we use first_node(nodes) to grab a single node, so here
3202                  * nodelist (or nodes) cannot be empty.
3203                  */
3204                 if (nodelist) {
3205                         char *rest = nodelist;
3206                         while (isdigit(*rest))
3207                                 rest++;
3208                         if (*rest)
3209                                 goto out;
3210                         if (nodes_empty(nodes))
3211                                 goto out;
3212                 }
3213                 break;
3214         case MPOL_INTERLEAVE:
3215         case MPOL_WEIGHTED_INTERLEAVE:
3216                 /*
3217                  * Default to online nodes with memory if no nodelist
3218                  */
3219                 if (!nodelist)
3220                         nodes = node_states[N_MEMORY];
3221                 break;
3222         case MPOL_LOCAL:
3223                 /*
3224                  * Don't allow a nodelist;  mpol_new() checks flags
3225                  */
3226                 if (nodelist)
3227                         goto out;
3228                 break;
3229         case MPOL_DEFAULT:
3230                 /*
3231                  * Insist on a empty nodelist
3232                  */
3233                 if (!nodelist)
3234                         err = 0;
3235                 goto out;
3236         case MPOL_PREFERRED_MANY:
3237         case MPOL_BIND:
3238                 /*
3239                  * Insist on a nodelist
3240                  */
3241                 if (!nodelist)
3242                         goto out;
3243         }
3244
3245         mode_flags = 0;
3246         if (flags) {
3247                 /*
3248                  * Currently, we only support two mutually exclusive
3249                  * mode flags.
3250                  */
3251                 if (!strcmp(flags, "static"))
3252                         mode_flags |= MPOL_F_STATIC_NODES;
3253                 else if (!strcmp(flags, "relative"))
3254                         mode_flags |= MPOL_F_RELATIVE_NODES;
3255                 else
3256                         goto out;
3257         }
3258
3259         new = mpol_new(mode, mode_flags, &nodes);
3260         if (IS_ERR(new))
3261                 goto out;
3262
3263         /*
3264          * Save nodes for mpol_to_str() to show the tmpfs mount options
3265          * for /proc/mounts, /proc/pid/mounts and /proc/pid/mountinfo.
3266          */
3267         if (mode != MPOL_PREFERRED) {
3268                 new->nodes = nodes;
3269         } else if (nodelist) {
3270                 nodes_clear(new->nodes);
3271                 node_set(first_node(nodes), new->nodes);
3272         } else {
3273                 new->mode = MPOL_LOCAL;
3274         }
3275
3276         /*
3277          * Save nodes for contextualization: this will be used to "clone"
3278          * the mempolicy in a specific context [cpuset] at a later time.
3279          */
3280         new->w.user_nodemask = nodes;
3281
3282         err = 0;
3283
3284 out:
3285         /* Restore string for error message */
3286         if (nodelist)
3287                 *--nodelist = ':';
3288         if (flags)
3289                 *--flags = '=';
3290         if (!err)
3291                 *mpol = new;
3292         return err;
3293 }
3294 #endif /* CONFIG_TMPFS */
3295
3296 /**
3297  * mpol_to_str - format a mempolicy structure for printing
3298  * @buffer:  to contain formatted mempolicy string
3299  * @maxlen:  length of @buffer
3300  * @pol:  pointer to mempolicy to be formatted
3301  *
3302  * Convert @pol into a string.  If @buffer is too short, truncate the string.
3303  * Recommend a @maxlen of at least 51 for the longest mode, "weighted
3304  * interleave", plus the longest flag flags, "relative|balancing", and to
3305  * display at least a few node ids.
3306  */
3307 void mpol_to_str(char *buffer, int maxlen, struct mempolicy *pol)
3308 {
3309         char *p = buffer;
3310         nodemask_t nodes = NODE_MASK_NONE;
3311         unsigned short mode = MPOL_DEFAULT;
3312         unsigned short flags = 0;
3313
3314         if (pol &&
3315             pol != &default_policy &&
3316             !(pol >= &preferred_node_policy[0] &&
3317               pol <= &preferred_node_policy[ARRAY_SIZE(preferred_node_policy) - 1])) {
3318                 mode = pol->mode;
3319                 flags = pol->flags;
3320         }
3321
3322         switch (mode) {
3323         case MPOL_DEFAULT:
3324         case MPOL_LOCAL:
3325                 break;
3326         case MPOL_PREFERRED:
3327         case MPOL_PREFERRED_MANY:
3328         case MPOL_BIND:
3329         case MPOL_INTERLEAVE:
3330         case MPOL_WEIGHTED_INTERLEAVE:
3331                 nodes = pol->nodes;
3332                 break;
3333         default:
3334                 WARN_ON_ONCE(1);
3335                 snprintf(p, maxlen, "unknown");
3336                 return;
3337         }
3338
3339         p += snprintf(p, maxlen, "%s", policy_modes[mode]);
3340
3341         if (flags & MPOL_MODE_FLAGS) {
3342                 p += snprintf(p, buffer + maxlen - p, "=");
3343
3344                 /*
3345                  * Static and relative are mutually exclusive.
3346                  */
3347                 if (flags & MPOL_F_STATIC_NODES)
3348                         p += snprintf(p, buffer + maxlen - p, "static");
3349                 else if (flags & MPOL_F_RELATIVE_NODES)
3350                         p += snprintf(p, buffer + maxlen - p, "relative");
3351
3352                 if (flags & MPOL_F_NUMA_BALANCING) {
3353                         if (!is_power_of_2(flags & MPOL_MODE_FLAGS))
3354                                 p += snprintf(p, buffer + maxlen - p, "|");
3355                         p += snprintf(p, buffer + maxlen - p, "balancing");
3356                 }
3357         }
3358
3359         if (!nodes_empty(nodes))
3360                 p += scnprintf(p, buffer + maxlen - p, ":%*pbl",
3361                                nodemask_pr_args(&nodes));
3362 }
3363
3364 #ifdef CONFIG_SYSFS
3365 struct iw_node_attr {
3366         struct kobj_attribute kobj_attr;
3367         int nid;
3368 };
3369
3370 static ssize_t node_show(struct kobject *kobj, struct kobj_attribute *attr,
3371                          char *buf)
3372 {
3373         struct iw_node_attr *node_attr;
3374         u8 weight;
3375
3376         node_attr = container_of(attr, struct iw_node_attr, kobj_attr);
3377         weight = get_il_weight(node_attr->nid);
3378         return sysfs_emit(buf, "%d\n", weight);
3379 }
3380
3381 static ssize_t node_store(struct kobject *kobj, struct kobj_attribute *attr,
3382                           const char *buf, size_t count)
3383 {
3384         struct iw_node_attr *node_attr;
3385         u8 *new;
3386         u8 *old;
3387         u8 weight = 0;
3388
3389         node_attr = container_of(attr, struct iw_node_attr, kobj_attr);
3390         if (count == 0 || sysfs_streq(buf, ""))
3391                 weight = 0;
3392         else if (kstrtou8(buf, 0, &weight))
3393                 return -EINVAL;
3394
3395         new = kzalloc(nr_node_ids, GFP_KERNEL);
3396         if (!new)
3397                 return -ENOMEM;
3398
3399         mutex_lock(&iw_table_lock);
3400         old = rcu_dereference_protected(iw_table,
3401                                         lockdep_is_held(&iw_table_lock));
3402         if (old)
3403                 memcpy(new, old, nr_node_ids);
3404         new[node_attr->nid] = weight;
3405         rcu_assign_pointer(iw_table, new);
3406         mutex_unlock(&iw_table_lock);
3407         synchronize_rcu();
3408         kfree(old);
3409         return count;
3410 }
3411
3412 static struct iw_node_attr **node_attrs;
3413
3414 static void sysfs_wi_node_release(struct iw_node_attr *node_attr,
3415                                   struct kobject *parent)
3416 {
3417         if (!node_attr)
3418                 return;
3419         sysfs_remove_file(parent, &node_attr->kobj_attr.attr);
3420         kfree(node_attr->kobj_attr.attr.name);
3421         kfree(node_attr);
3422 }
3423
3424 static void sysfs_wi_release(struct kobject *wi_kobj)
3425 {
3426         int i;
3427
3428         for (i = 0; i < nr_node_ids; i++)
3429                 sysfs_wi_node_release(node_attrs[i], wi_kobj);
3430         kobject_put(wi_kobj);
3431 }
3432
3433 static const struct kobj_type wi_ktype = {
3434         .sysfs_ops = &kobj_sysfs_ops,
3435         .release = sysfs_wi_release,
3436 };
3437
3438 static int add_weight_node(int nid, struct kobject *wi_kobj)
3439 {
3440         struct iw_node_attr *node_attr;
3441         char *name;
3442
3443         node_attr = kzalloc(sizeof(*node_attr), GFP_KERNEL);
3444         if (!node_attr)
3445                 return -ENOMEM;
3446
3447         name = kasprintf(GFP_KERNEL, "node%d", nid);
3448         if (!name) {
3449                 kfree(node_attr);
3450                 return -ENOMEM;
3451         }
3452
3453         sysfs_attr_init(&node_attr->kobj_attr.attr);
3454         node_attr->kobj_attr.attr.name = name;
3455         node_attr->kobj_attr.attr.mode = 0644;
3456         node_attr->kobj_attr.show = node_show;
3457         node_attr->kobj_attr.store = node_store;
3458         node_attr->nid = nid;
3459
3460         if (sysfs_create_file(wi_kobj, &node_attr->kobj_attr.attr)) {
3461                 kfree(node_attr->kobj_attr.attr.name);
3462                 kfree(node_attr);
3463                 pr_err("failed to add attribute to weighted_interleave\n");
3464                 return -ENOMEM;
3465         }
3466
3467         node_attrs[nid] = node_attr;
3468         return 0;
3469 }
3470
3471 static int add_weighted_interleave_group(struct kobject *root_kobj)
3472 {
3473         struct kobject *wi_kobj;
3474         int nid, err;
3475
3476         wi_kobj = kzalloc(sizeof(struct kobject), GFP_KERNEL);
3477         if (!wi_kobj)
3478                 return -ENOMEM;
3479
3480         err = kobject_init_and_add(wi_kobj, &wi_ktype, root_kobj,
3481                                    "weighted_interleave");
3482         if (err) {
3483                 kfree(wi_kobj);
3484                 return err;
3485         }
3486
3487         for_each_node_state(nid, N_POSSIBLE) {
3488                 err = add_weight_node(nid, wi_kobj);
3489                 if (err) {
3490                         pr_err("failed to add sysfs [node%d]\n", nid);
3491                         break;
3492                 }
3493         }
3494         if (err)
3495                 kobject_put(wi_kobj);
3496         return 0;
3497 }
3498
3499 static void mempolicy_kobj_release(struct kobject *kobj)
3500 {
3501         u8 *old;
3502
3503         mutex_lock(&iw_table_lock);
3504         old = rcu_dereference_protected(iw_table,
3505                                         lockdep_is_held(&iw_table_lock));
3506         rcu_assign_pointer(iw_table, NULL);
3507         mutex_unlock(&iw_table_lock);
3508         synchronize_rcu();
3509         kfree(old);
3510         kfree(node_attrs);
3511         kfree(kobj);
3512 }
3513
3514 static const struct kobj_type mempolicy_ktype = {
3515         .release = mempolicy_kobj_release
3516 };
3517
3518 static int __init mempolicy_sysfs_init(void)
3519 {
3520         int err;
3521         static struct kobject *mempolicy_kobj;
3522
3523         mempolicy_kobj = kzalloc(sizeof(*mempolicy_kobj), GFP_KERNEL);
3524         if (!mempolicy_kobj) {
3525                 err = -ENOMEM;
3526                 goto err_out;
3527         }
3528
3529         node_attrs = kcalloc(nr_node_ids, sizeof(struct iw_node_attr *),
3530                              GFP_KERNEL);
3531         if (!node_attrs) {
3532                 err = -ENOMEM;
3533                 goto mempol_out;
3534         }
3535
3536         err = kobject_init_and_add(mempolicy_kobj, &mempolicy_ktype, mm_kobj,
3537                                    "mempolicy");
3538         if (err)
3539                 goto node_out;
3540
3541         err = add_weighted_interleave_group(mempolicy_kobj);
3542         if (err) {
3543                 pr_err("mempolicy sysfs structure failed to initialize\n");
3544                 kobject_put(mempolicy_kobj);
3545                 return err;
3546         }
3547
3548         return err;
3549 node_out:
3550         kfree(node_attrs);
3551 mempol_out:
3552         kfree(mempolicy_kobj);
3553 err_out:
3554         pr_err("failed to add mempolicy kobject to the system\n");
3555         return err;
3556 }
3557
3558 late_initcall(mempolicy_sysfs_init);
3559 #endif /* CONFIG_SYSFS */