mm/gup: change GUP fast to use flags rather than a write 'bool'
[linux-2.6-block.git] / arch / sparc / mm / gup.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Lockless get_user_pages_fast for sparc, cribbed from powerpc
4  *
5  * Copyright (C) 2008 Nick Piggin
6  * Copyright (C) 2008 Novell Inc.
7  */
8
9 #include <linux/sched.h>
10 #include <linux/mm.h>
11 #include <linux/vmstat.h>
12 #include <linux/pagemap.h>
13 #include <linux/rwsem.h>
14 #include <asm/pgtable.h>
15 #include <asm/adi.h>
16
17 /*
18  * The performance critical leaf functions are made noinline otherwise gcc
19  * inlines everything into a single function which results in too much
20  * register pressure.
21  */
22 static noinline int gup_pte_range(pmd_t pmd, unsigned long addr,
23                 unsigned long end, int write, struct page **pages, int *nr)
24 {
25         unsigned long mask, result;
26         pte_t *ptep;
27
28         if (tlb_type == hypervisor) {
29                 result = _PAGE_PRESENT_4V|_PAGE_P_4V;
30                 if (write)
31                         result |= _PAGE_WRITE_4V;
32         } else {
33                 result = _PAGE_PRESENT_4U|_PAGE_P_4U;
34                 if (write)
35                         result |= _PAGE_WRITE_4U;
36         }
37         mask = result | _PAGE_SPECIAL;
38
39         ptep = pte_offset_kernel(&pmd, addr);
40         do {
41                 struct page *page, *head;
42                 pte_t pte = *ptep;
43
44                 if ((pte_val(pte) & mask) != result)
45                         return 0;
46                 VM_BUG_ON(!pfn_valid(pte_pfn(pte)));
47
48                 /* The hugepage case is simplified on sparc64 because
49                  * we encode the sub-page pfn offsets into the
50                  * hugepage PTEs.  We could optimize this in the future
51                  * use page_cache_add_speculative() for the hugepage case.
52                  */
53                 page = pte_page(pte);
54                 head = compound_head(page);
55                 if (!page_cache_get_speculative(head))
56                         return 0;
57                 if (unlikely(pte_val(pte) != pte_val(*ptep))) {
58                         put_page(head);
59                         return 0;
60                 }
61
62                 pages[*nr] = page;
63                 (*nr)++;
64         } while (ptep++, addr += PAGE_SIZE, addr != end);
65
66         return 1;
67 }
68
69 static int gup_huge_pmd(pmd_t *pmdp, pmd_t pmd, unsigned long addr,
70                         unsigned long end, int write, struct page **pages,
71                         int *nr)
72 {
73         struct page *head, *page;
74         int refs;
75
76         if (!(pmd_val(pmd) & _PAGE_VALID))
77                 return 0;
78
79         if (write && !pmd_write(pmd))
80                 return 0;
81
82         refs = 0;
83         page = pmd_page(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
84         head = compound_head(page);
85         do {
86                 VM_BUG_ON(compound_head(page) != head);
87                 pages[*nr] = page;
88                 (*nr)++;
89                 page++;
90                 refs++;
91         } while (addr += PAGE_SIZE, addr != end);
92
93         if (!page_cache_add_speculative(head, refs)) {
94                 *nr -= refs;
95                 return 0;
96         }
97
98         if (unlikely(pmd_val(pmd) != pmd_val(*pmdp))) {
99                 *nr -= refs;
100                 while (refs--)
101                         put_page(head);
102                 return 0;
103         }
104
105         return 1;
106 }
107
108 static int gup_huge_pud(pud_t *pudp, pud_t pud, unsigned long addr,
109                         unsigned long end, int write, struct page **pages,
110                         int *nr)
111 {
112         struct page *head, *page;
113         int refs;
114
115         if (!(pud_val(pud) & _PAGE_VALID))
116                 return 0;
117
118         if (write && !pud_write(pud))
119                 return 0;
120
121         refs = 0;
122         page = pud_page(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
123         head = compound_head(page);
124         do {
125                 VM_BUG_ON(compound_head(page) != head);
126                 pages[*nr] = page;
127                 (*nr)++;
128                 page++;
129                 refs++;
130         } while (addr += PAGE_SIZE, addr != end);
131
132         if (!page_cache_add_speculative(head, refs)) {
133                 *nr -= refs;
134                 return 0;
135         }
136
137         if (unlikely(pud_val(pud) != pud_val(*pudp))) {
138                 *nr -= refs;
139                 while (refs--)
140                         put_page(head);
141                 return 0;
142         }
143
144         return 1;
145 }
146
147 static int gup_pmd_range(pud_t pud, unsigned long addr, unsigned long end,
148                 int write, struct page **pages, int *nr)
149 {
150         unsigned long next;
151         pmd_t *pmdp;
152
153         pmdp = pmd_offset(&pud, addr);
154         do {
155                 pmd_t pmd = *pmdp;
156
157                 next = pmd_addr_end(addr, end);
158                 if (pmd_none(pmd))
159                         return 0;
160                 if (unlikely(pmd_large(pmd))) {
161                         if (!gup_huge_pmd(pmdp, pmd, addr, next,
162                                           write, pages, nr))
163                                 return 0;
164                 } else if (!gup_pte_range(pmd, addr, next, write,
165                                           pages, nr))
166                         return 0;
167         } while (pmdp++, addr = next, addr != end);
168
169         return 1;
170 }
171
172 static int gup_pud_range(pgd_t pgd, unsigned long addr, unsigned long end,
173                 int write, struct page **pages, int *nr)
174 {
175         unsigned long next;
176         pud_t *pudp;
177
178         pudp = pud_offset(&pgd, addr);
179         do {
180                 pud_t pud = *pudp;
181
182                 next = pud_addr_end(addr, end);
183                 if (pud_none(pud))
184                         return 0;
185                 if (unlikely(pud_large(pud))) {
186                         if (!gup_huge_pud(pudp, pud, addr, next,
187                                           write, pages, nr))
188                                 return 0;
189                 } else if (!gup_pmd_range(pud, addr, next, write, pages, nr))
190                         return 0;
191         } while (pudp++, addr = next, addr != end);
192
193         return 1;
194 }
195
196 /*
197  * Note a difference with get_user_pages_fast: this always returns the
198  * number of pages pinned, 0 if no pages were pinned.
199  */
200 int __get_user_pages_fast(unsigned long start, int nr_pages, int write,
201                           struct page **pages)
202 {
203         struct mm_struct *mm = current->mm;
204         unsigned long addr, len, end;
205         unsigned long next, flags;
206         pgd_t *pgdp;
207         int nr = 0;
208
209 #ifdef CONFIG_SPARC64
210         if (adi_capable()) {
211                 long addr = start;
212
213                 /* If userspace has passed a versioned address, kernel
214                  * will not find it in the VMAs since it does not store
215                  * the version tags in the list of VMAs. Storing version
216                  * tags in list of VMAs is impractical since they can be
217                  * changed any time from userspace without dropping into
218                  * kernel. Any address search in VMAs will be done with
219                  * non-versioned addresses. Ensure the ADI version bits
220                  * are dropped here by sign extending the last bit before
221                  * ADI bits. IOMMU does not implement version tags.
222                  */
223                 addr = (addr << (long)adi_nbits()) >> (long)adi_nbits();
224                 start = addr;
225         }
226 #endif
227         start &= PAGE_MASK;
228         addr = start;
229         len = (unsigned long) nr_pages << PAGE_SHIFT;
230         end = start + len;
231
232         local_irq_save(flags);
233         pgdp = pgd_offset(mm, addr);
234         do {
235                 pgd_t pgd = *pgdp;
236
237                 next = pgd_addr_end(addr, end);
238                 if (pgd_none(pgd))
239                         break;
240                 if (!gup_pud_range(pgd, addr, next, write, pages, &nr))
241                         break;
242         } while (pgdp++, addr = next, addr != end);
243         local_irq_restore(flags);
244
245         return nr;
246 }
247
248 int get_user_pages_fast(unsigned long start, int nr_pages,
249                         unsigned int gup_flags, struct page **pages)
250 {
251         struct mm_struct *mm = current->mm;
252         unsigned long addr, len, end;
253         unsigned long next;
254         pgd_t *pgdp;
255         int nr = 0;
256
257 #ifdef CONFIG_SPARC64
258         if (adi_capable()) {
259                 long addr = start;
260
261                 /* If userspace has passed a versioned address, kernel
262                  * will not find it in the VMAs since it does not store
263                  * the version tags in the list of VMAs. Storing version
264                  * tags in list of VMAs is impractical since they can be
265                  * changed any time from userspace without dropping into
266                  * kernel. Any address search in VMAs will be done with
267                  * non-versioned addresses. Ensure the ADI version bits
268                  * are dropped here by sign extending the last bit before
269                  * ADI bits. IOMMU does not implements version tags,
270                  */
271                 addr = (addr << (long)adi_nbits()) >> (long)adi_nbits();
272                 start = addr;
273         }
274 #endif
275         start &= PAGE_MASK;
276         addr = start;
277         len = (unsigned long) nr_pages << PAGE_SHIFT;
278         end = start + len;
279
280         /*
281          * XXX: batch / limit 'nr', to avoid large irq off latency
282          * needs some instrumenting to determine the common sizes used by
283          * important workloads (eg. DB2), and whether limiting the batch size
284          * will decrease performance.
285          *
286          * It seems like we're in the clear for the moment. Direct-IO is
287          * the main guy that batches up lots of get_user_pages, and even
288          * they are limited to 64-at-a-time which is not so many.
289          */
290         /*
291          * This doesn't prevent pagetable teardown, but does prevent
292          * the pagetables from being freed on sparc.
293          *
294          * So long as we atomically load page table pointers versus teardown,
295          * we can follow the address down to the the page and take a ref on it.
296          */
297         local_irq_disable();
298
299         pgdp = pgd_offset(mm, addr);
300         do {
301                 pgd_t pgd = *pgdp;
302
303                 next = pgd_addr_end(addr, end);
304                 if (pgd_none(pgd))
305                         goto slow;
306                 if (!gup_pud_range(pgd, addr, next, gup_flags & FOLL_WRITE,
307                                    pages, &nr))
308                         goto slow;
309         } while (pgdp++, addr = next, addr != end);
310
311         local_irq_enable();
312
313         VM_BUG_ON(nr != (end - start) >> PAGE_SHIFT);
314         return nr;
315
316         {
317                 int ret;
318
319 slow:
320                 local_irq_enable();
321
322                 /* Try to get the remaining pages with get_user_pages */
323                 start += nr << PAGE_SHIFT;
324                 pages += nr;
325
326                 ret = get_user_pages_unlocked(start,
327                         (end - start) >> PAGE_SHIFT, pages,
328                         gup_flags);
329
330                 /* Have to be a bit careful with return values */
331                 if (nr > 0) {
332                         if (ret < 0)
333                                 ret = nr;
334                         else
335                                 ret += nr;
336                 }
337
338                 return ret;
339         }
340 }