Merge tag 'spi-fix-v6.9-rc6' of git://git.kernel.org/pub/scm/linux/kernel/git/broonie/spi
[linux-block.git] / tools / testing / selftests / mm / protection_keys.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Tests Memory Protection Keys (see Documentation/core-api/protection-keys.rst)
4  *
5  * There are examples in here of:
6  *  * how to set protection keys on memory
7  *  * how to set/clear bits in pkey registers (the rights register)
8  *  * how to handle SEGV_PKUERR signals and extract pkey-relevant
9  *    information from the siginfo
10  *
11  * Things to add:
12  *      make sure KSM and KSM COW breaking works
13  *      prefault pages in at malloc, or not
14  *      protect MPX bounds tables with protection keys?
15  *      make sure VMA splitting/merging is working correctly
16  *      OOMs can destroy mm->mmap (see exit_mmap()), so make sure it is immune to pkeys
17  *      look for pkey "leaks" where it is still set on a VMA but "freed" back to the kernel
18  *      do a plain mprotect() to a mprotect_pkey() area and make sure the pkey sticks
19  *
20  * Compile like this:
21  *      gcc -mxsave      -o protection_keys    -O2 -g -std=gnu99 -pthread -Wall protection_keys.c -lrt -ldl -lm
22  *      gcc -mxsave -m32 -o protection_keys_32 -O2 -g -std=gnu99 -pthread -Wall protection_keys.c -lrt -ldl -lm
23  */
24 #define _GNU_SOURCE
25 #define __SANE_USERSPACE_TYPES__
26 #include <errno.h>
27 #include <linux/elf.h>
28 #include <linux/futex.h>
29 #include <time.h>
30 #include <sys/time.h>
31 #include <sys/syscall.h>
32 #include <string.h>
33 #include <stdio.h>
34 #include <stdint.h>
35 #include <stdbool.h>
36 #include <signal.h>
37 #include <assert.h>
38 #include <stdlib.h>
39 #include <ucontext.h>
40 #include <sys/mman.h>
41 #include <sys/types.h>
42 #include <sys/wait.h>
43 #include <sys/stat.h>
44 #include <fcntl.h>
45 #include <unistd.h>
46 #include <sys/ptrace.h>
47 #include <setjmp.h>
48
49 #include "pkey-helpers.h"
50
51 int iteration_nr = 1;
52 int test_nr;
53
54 u64 shadow_pkey_reg;
55 int dprint_in_signal;
56 char dprint_in_signal_buffer[DPRINT_IN_SIGNAL_BUF_SIZE];
57
58 void cat_into_file(char *str, char *file)
59 {
60         int fd = open(file, O_RDWR);
61         int ret;
62
63         dprintf2("%s(): writing '%s' to '%s'\n", __func__, str, file);
64         /*
65          * these need to be raw because they are called under
66          * pkey_assert()
67          */
68         if (fd < 0) {
69                 fprintf(stderr, "error opening '%s'\n", str);
70                 perror("error: ");
71                 exit(__LINE__);
72         }
73
74         ret = write(fd, str, strlen(str));
75         if (ret != strlen(str)) {
76                 perror("write to file failed");
77                 fprintf(stderr, "filename: '%s' str: '%s'\n", file, str);
78                 exit(__LINE__);
79         }
80         close(fd);
81 }
82
83 #if CONTROL_TRACING > 0
84 static int warned_tracing;
85 int tracing_root_ok(void)
86 {
87         if (geteuid() != 0) {
88                 if (!warned_tracing)
89                         fprintf(stderr, "WARNING: not run as root, "
90                                         "can not do tracing control\n");
91                 warned_tracing = 1;
92                 return 0;
93         }
94         return 1;
95 }
96 #endif
97
98 void tracing_on(void)
99 {
100 #if CONTROL_TRACING > 0
101 #define TRACEDIR "/sys/kernel/tracing"
102         char pidstr[32];
103
104         if (!tracing_root_ok())
105                 return;
106
107         sprintf(pidstr, "%d", getpid());
108         cat_into_file("0", TRACEDIR "/tracing_on");
109         cat_into_file("\n", TRACEDIR "/trace");
110         if (1) {
111                 cat_into_file("function_graph", TRACEDIR "/current_tracer");
112                 cat_into_file("1", TRACEDIR "/options/funcgraph-proc");
113         } else {
114                 cat_into_file("nop", TRACEDIR "/current_tracer");
115         }
116         cat_into_file(pidstr, TRACEDIR "/set_ftrace_pid");
117         cat_into_file("1", TRACEDIR "/tracing_on");
118         dprintf1("enabled tracing\n");
119 #endif
120 }
121
122 void tracing_off(void)
123 {
124 #if CONTROL_TRACING > 0
125         if (!tracing_root_ok())
126                 return;
127         cat_into_file("0", "/sys/kernel/tracing/tracing_on");
128 #endif
129 }
130
131 void abort_hooks(void)
132 {
133         fprintf(stderr, "running %s()...\n", __func__);
134         tracing_off();
135 #ifdef SLEEP_ON_ABORT
136         sleep(SLEEP_ON_ABORT);
137 #endif
138 }
139
140 /*
141  * This attempts to have roughly a page of instructions followed by a few
142  * instructions that do a write, and another page of instructions.  That
143  * way, we are pretty sure that the write is in the second page of
144  * instructions and has at least a page of padding behind it.
145  *
146  * *That* lets us be sure to madvise() away the write instruction, which
147  * will then fault, which makes sure that the fault code handles
148  * execute-only memory properly.
149  */
150 #ifdef __powerpc64__
151 /* This way, both 4K and 64K alignment are maintained */
152 __attribute__((__aligned__(65536)))
153 #else
154 __attribute__((__aligned__(PAGE_SIZE)))
155 #endif
156 void lots_o_noops_around_write(int *write_to_me)
157 {
158         dprintf3("running %s()\n", __func__);
159         __page_o_noops();
160         /* Assume this happens in the second page of instructions: */
161         *write_to_me = __LINE__;
162         /* pad out by another page: */
163         __page_o_noops();
164         dprintf3("%s() done\n", __func__);
165 }
166
167 void dump_mem(void *dumpme, int len_bytes)
168 {
169         char *c = (void *)dumpme;
170         int i;
171
172         for (i = 0; i < len_bytes; i += sizeof(u64)) {
173                 u64 *ptr = (u64 *)(c + i);
174                 dprintf1("dump[%03d][@%p]: %016llx\n", i, ptr, *ptr);
175         }
176 }
177
178 static u32 hw_pkey_get(int pkey, unsigned long flags)
179 {
180         u64 pkey_reg = __read_pkey_reg();
181
182         dprintf1("%s(pkey=%d, flags=%lx) = %x / %d\n",
183                         __func__, pkey, flags, 0, 0);
184         dprintf2("%s() raw pkey_reg: %016llx\n", __func__, pkey_reg);
185
186         return (u32) get_pkey_bits(pkey_reg, pkey);
187 }
188
189 static int hw_pkey_set(int pkey, unsigned long rights, unsigned long flags)
190 {
191         u32 mask = (PKEY_DISABLE_ACCESS|PKEY_DISABLE_WRITE);
192         u64 old_pkey_reg = __read_pkey_reg();
193         u64 new_pkey_reg;
194
195         /* make sure that 'rights' only contains the bits we expect: */
196         assert(!(rights & ~mask));
197
198         /* modify bits accordingly in old pkey_reg and assign it */
199         new_pkey_reg = set_pkey_bits(old_pkey_reg, pkey, rights);
200
201         __write_pkey_reg(new_pkey_reg);
202
203         dprintf3("%s(pkey=%d, rights=%lx, flags=%lx) = %x"
204                 " pkey_reg now: %016llx old_pkey_reg: %016llx\n",
205                 __func__, pkey, rights, flags, 0, __read_pkey_reg(),
206                 old_pkey_reg);
207         return 0;
208 }
209
210 void pkey_disable_set(int pkey, int flags)
211 {
212         unsigned long syscall_flags = 0;
213         int ret;
214         int pkey_rights;
215         u64 orig_pkey_reg = read_pkey_reg();
216
217         dprintf1("START->%s(%d, 0x%x)\n", __func__,
218                 pkey, flags);
219         pkey_assert(flags & (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE));
220
221         pkey_rights = hw_pkey_get(pkey, syscall_flags);
222
223         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
224                         pkey, pkey, pkey_rights);
225
226         pkey_assert(pkey_rights >= 0);
227
228         pkey_rights |= flags;
229
230         ret = hw_pkey_set(pkey, pkey_rights, syscall_flags);
231         assert(!ret);
232         /* pkey_reg and flags have the same format */
233         shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, pkey, pkey_rights);
234         dprintf1("%s(%d) shadow: 0x%016llx\n",
235                 __func__, pkey, shadow_pkey_reg);
236
237         pkey_assert(ret >= 0);
238
239         pkey_rights = hw_pkey_get(pkey, syscall_flags);
240         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
241                         pkey, pkey, pkey_rights);
242
243         dprintf1("%s(%d) pkey_reg: 0x%016llx\n",
244                 __func__, pkey, read_pkey_reg());
245         if (flags)
246                 pkey_assert(read_pkey_reg() >= orig_pkey_reg);
247         dprintf1("END<---%s(%d, 0x%x)\n", __func__,
248                 pkey, flags);
249 }
250
251 void pkey_disable_clear(int pkey, int flags)
252 {
253         unsigned long syscall_flags = 0;
254         int ret;
255         int pkey_rights = hw_pkey_get(pkey, syscall_flags);
256         u64 orig_pkey_reg = read_pkey_reg();
257
258         pkey_assert(flags & (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE));
259
260         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
261                         pkey, pkey, pkey_rights);
262         pkey_assert(pkey_rights >= 0);
263
264         pkey_rights &= ~flags;
265
266         ret = hw_pkey_set(pkey, pkey_rights, 0);
267         shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, pkey, pkey_rights);
268         pkey_assert(ret >= 0);
269
270         pkey_rights = hw_pkey_get(pkey, syscall_flags);
271         dprintf1("%s(%d) hw_pkey_get(%d): %x\n", __func__,
272                         pkey, pkey, pkey_rights);
273
274         dprintf1("%s(%d) pkey_reg: 0x%016llx\n", __func__,
275                         pkey, read_pkey_reg());
276         if (flags)
277                 assert(read_pkey_reg() <= orig_pkey_reg);
278 }
279
280 void pkey_write_allow(int pkey)
281 {
282         pkey_disable_clear(pkey, PKEY_DISABLE_WRITE);
283 }
284 void pkey_write_deny(int pkey)
285 {
286         pkey_disable_set(pkey, PKEY_DISABLE_WRITE);
287 }
288 void pkey_access_allow(int pkey)
289 {
290         pkey_disable_clear(pkey, PKEY_DISABLE_ACCESS);
291 }
292 void pkey_access_deny(int pkey)
293 {
294         pkey_disable_set(pkey, PKEY_DISABLE_ACCESS);
295 }
296
297 static char *si_code_str(int si_code)
298 {
299         if (si_code == SEGV_MAPERR)
300                 return "SEGV_MAPERR";
301         if (si_code == SEGV_ACCERR)
302                 return "SEGV_ACCERR";
303         if (si_code == SEGV_BNDERR)
304                 return "SEGV_BNDERR";
305         if (si_code == SEGV_PKUERR)
306                 return "SEGV_PKUERR";
307         return "UNKNOWN";
308 }
309
310 int pkey_faults;
311 int last_si_pkey = -1;
312 void signal_handler(int signum, siginfo_t *si, void *vucontext)
313 {
314         ucontext_t *uctxt = vucontext;
315         int trapno;
316         unsigned long ip;
317         char *fpregs;
318 #if defined(__i386__) || defined(__x86_64__) /* arch */
319         u32 *pkey_reg_ptr;
320         int pkey_reg_offset;
321 #endif /* arch */
322         u64 siginfo_pkey;
323         u32 *si_pkey_ptr;
324
325         dprint_in_signal = 1;
326         dprintf1(">>>>===============SIGSEGV============================\n");
327         dprintf1("%s()::%d, pkey_reg: 0x%016llx shadow: %016llx\n",
328                         __func__, __LINE__,
329                         __read_pkey_reg(), shadow_pkey_reg);
330
331         trapno = uctxt->uc_mcontext.gregs[REG_TRAPNO];
332         ip = uctxt->uc_mcontext.gregs[REG_IP_IDX];
333         fpregs = (char *) uctxt->uc_mcontext.fpregs;
334
335         dprintf2("%s() trapno: %d ip: 0x%016lx info->si_code: %s/%d\n",
336                         __func__, trapno, ip, si_code_str(si->si_code),
337                         si->si_code);
338
339 #if defined(__i386__) || defined(__x86_64__) /* arch */
340 #ifdef __i386__
341         /*
342          * 32-bit has some extra padding so that userspace can tell whether
343          * the XSTATE header is present in addition to the "legacy" FPU
344          * state.  We just assume that it is here.
345          */
346         fpregs += 0x70;
347 #endif /* i386 */
348         pkey_reg_offset = pkey_reg_xstate_offset();
349         pkey_reg_ptr = (void *)(&fpregs[pkey_reg_offset]);
350
351         /*
352          * If we got a PKEY fault, we *HAVE* to have at least one bit set in
353          * here.
354          */
355         dprintf1("pkey_reg_xstate_offset: %d\n", pkey_reg_xstate_offset());
356         if (DEBUG_LEVEL > 4)
357                 dump_mem(pkey_reg_ptr - 128, 256);
358         pkey_assert(*pkey_reg_ptr);
359 #endif /* arch */
360
361         dprintf1("siginfo: %p\n", si);
362         dprintf1(" fpregs: %p\n", fpregs);
363
364         if ((si->si_code == SEGV_MAPERR) ||
365             (si->si_code == SEGV_ACCERR) ||
366             (si->si_code == SEGV_BNDERR)) {
367                 printf("non-PK si_code, exiting...\n");
368                 exit(4);
369         }
370
371         si_pkey_ptr = siginfo_get_pkey_ptr(si);
372         dprintf1("si_pkey_ptr: %p\n", si_pkey_ptr);
373         dump_mem((u8 *)si_pkey_ptr - 8, 24);
374         siginfo_pkey = *si_pkey_ptr;
375         pkey_assert(siginfo_pkey < NR_PKEYS);
376         last_si_pkey = siginfo_pkey;
377
378         /*
379          * need __read_pkey_reg() version so we do not do shadow_pkey_reg
380          * checking
381          */
382         dprintf1("signal pkey_reg from  pkey_reg: %016llx\n",
383                         __read_pkey_reg());
384         dprintf1("pkey from siginfo: %016llx\n", siginfo_pkey);
385 #if defined(__i386__) || defined(__x86_64__) /* arch */
386         dprintf1("signal pkey_reg from xsave: %08x\n", *pkey_reg_ptr);
387         *(u64 *)pkey_reg_ptr = 0x00000000;
388         dprintf1("WARNING: set PKEY_REG=0 to allow faulting instruction to continue\n");
389 #elif defined(__powerpc64__) /* arch */
390         /* restore access and let the faulting instruction continue */
391         pkey_access_allow(siginfo_pkey);
392 #endif /* arch */
393         pkey_faults++;
394         dprintf1("<<<<==================================================\n");
395         dprint_in_signal = 0;
396 }
397
398 int wait_all_children(void)
399 {
400         int status;
401         return waitpid(-1, &status, 0);
402 }
403
404 void sig_chld(int x)
405 {
406         dprint_in_signal = 1;
407         dprintf2("[%d] SIGCHLD: %d\n", getpid(), x);
408         dprint_in_signal = 0;
409 }
410
411 void setup_sigsegv_handler(void)
412 {
413         int r, rs;
414         struct sigaction newact;
415         struct sigaction oldact;
416
417         /* #PF is mapped to sigsegv */
418         int signum  = SIGSEGV;
419
420         newact.sa_handler = 0;
421         newact.sa_sigaction = signal_handler;
422
423         /*sigset_t - signals to block while in the handler */
424         /* get the old signal mask. */
425         rs = sigprocmask(SIG_SETMASK, 0, &newact.sa_mask);
426         pkey_assert(rs == 0);
427
428         /* call sa_sigaction, not sa_handler*/
429         newact.sa_flags = SA_SIGINFO;
430
431         newact.sa_restorer = 0;  /* void(*)(), obsolete */
432         r = sigaction(signum, &newact, &oldact);
433         r = sigaction(SIGALRM, &newact, &oldact);
434         pkey_assert(r == 0);
435 }
436
437 void setup_handlers(void)
438 {
439         signal(SIGCHLD, &sig_chld);
440         setup_sigsegv_handler();
441 }
442
443 pid_t fork_lazy_child(void)
444 {
445         pid_t forkret;
446
447         forkret = fork();
448         pkey_assert(forkret >= 0);
449         dprintf3("[%d] fork() ret: %d\n", getpid(), forkret);
450
451         if (!forkret) {
452                 /* in the child */
453                 while (1) {
454                         dprintf1("child sleeping...\n");
455                         sleep(30);
456                 }
457         }
458         return forkret;
459 }
460
461 int sys_mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
462                 unsigned long pkey)
463 {
464         int sret;
465
466         dprintf2("%s(0x%p, %zx, prot=%lx, pkey=%lx)\n", __func__,
467                         ptr, size, orig_prot, pkey);
468
469         errno = 0;
470         sret = syscall(__NR_pkey_mprotect, ptr, size, orig_prot, pkey);
471         if (errno) {
472                 dprintf2("SYS_mprotect_key sret: %d\n", sret);
473                 dprintf2("SYS_mprotect_key prot: 0x%lx\n", orig_prot);
474                 dprintf2("SYS_mprotect_key failed, errno: %d\n", errno);
475                 if (DEBUG_LEVEL >= 2)
476                         perror("SYS_mprotect_pkey");
477         }
478         return sret;
479 }
480
481 int sys_pkey_alloc(unsigned long flags, unsigned long init_val)
482 {
483         int ret = syscall(SYS_pkey_alloc, flags, init_val);
484         dprintf1("%s(flags=%lx, init_val=%lx) syscall ret: %d errno: %d\n",
485                         __func__, flags, init_val, ret, errno);
486         return ret;
487 }
488
489 int alloc_pkey(void)
490 {
491         int ret;
492         unsigned long init_val = 0x0;
493
494         dprintf1("%s()::%d, pkey_reg: 0x%016llx shadow: %016llx\n",
495                         __func__, __LINE__, __read_pkey_reg(), shadow_pkey_reg);
496         ret = sys_pkey_alloc(0, init_val);
497         /*
498          * pkey_alloc() sets PKEY register, so we need to reflect it in
499          * shadow_pkey_reg:
500          */
501         dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
502                         " shadow: 0x%016llx\n",
503                         __func__, __LINE__, ret, __read_pkey_reg(),
504                         shadow_pkey_reg);
505         if (ret > 0) {
506                 /* clear both the bits: */
507                 shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, ret,
508                                                 ~PKEY_MASK);
509                 dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
510                                 " shadow: 0x%016llx\n",
511                                 __func__,
512                                 __LINE__, ret, __read_pkey_reg(),
513                                 shadow_pkey_reg);
514                 /*
515                  * move the new state in from init_val
516                  * (remember, we cheated and init_val == pkey_reg format)
517                  */
518                 shadow_pkey_reg = set_pkey_bits(shadow_pkey_reg, ret,
519                                                 init_val);
520         }
521         dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
522                         " shadow: 0x%016llx\n",
523                         __func__, __LINE__, ret, __read_pkey_reg(),
524                         shadow_pkey_reg);
525         dprintf1("%s()::%d errno: %d\n", __func__, __LINE__, errno);
526         /* for shadow checking: */
527         read_pkey_reg();
528         dprintf4("%s()::%d, ret: %d pkey_reg: 0x%016llx"
529                  " shadow: 0x%016llx\n",
530                 __func__, __LINE__, ret, __read_pkey_reg(),
531                 shadow_pkey_reg);
532         return ret;
533 }
534
535 int sys_pkey_free(unsigned long pkey)
536 {
537         int ret = syscall(SYS_pkey_free, pkey);
538         dprintf1("%s(pkey=%ld) syscall ret: %d\n", __func__, pkey, ret);
539         return ret;
540 }
541
542 /*
543  * I had a bug where pkey bits could be set by mprotect() but
544  * not cleared.  This ensures we get lots of random bit sets
545  * and clears on the vma and pte pkey bits.
546  */
547 int alloc_random_pkey(void)
548 {
549         int max_nr_pkey_allocs;
550         int ret;
551         int i;
552         int alloced_pkeys[NR_PKEYS];
553         int nr_alloced = 0;
554         int random_index;
555         memset(alloced_pkeys, 0, sizeof(alloced_pkeys));
556
557         /* allocate every possible key and make a note of which ones we got */
558         max_nr_pkey_allocs = NR_PKEYS;
559         for (i = 0; i < max_nr_pkey_allocs; i++) {
560                 int new_pkey = alloc_pkey();
561                 if (new_pkey < 0)
562                         break;
563                 alloced_pkeys[nr_alloced++] = new_pkey;
564         }
565
566         pkey_assert(nr_alloced > 0);
567         /* select a random one out of the allocated ones */
568         random_index = rand() % nr_alloced;
569         ret = alloced_pkeys[random_index];
570         /* now zero it out so we don't free it next */
571         alloced_pkeys[random_index] = 0;
572
573         /* go through the allocated ones that we did not want and free them */
574         for (i = 0; i < nr_alloced; i++) {
575                 int free_ret;
576                 if (!alloced_pkeys[i])
577                         continue;
578                 free_ret = sys_pkey_free(alloced_pkeys[i]);
579                 pkey_assert(!free_ret);
580         }
581         dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
582                          " shadow: 0x%016llx\n", __func__,
583                         __LINE__, ret, __read_pkey_reg(), shadow_pkey_reg);
584         return ret;
585 }
586
587 int mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
588                 unsigned long pkey)
589 {
590         int nr_iterations = random() % 100;
591         int ret;
592
593         while (0) {
594                 int rpkey = alloc_random_pkey();
595                 ret = sys_mprotect_pkey(ptr, size, orig_prot, pkey);
596                 dprintf1("sys_mprotect_pkey(%p, %zx, prot=0x%lx, pkey=%ld) ret: %d\n",
597                                 ptr, size, orig_prot, pkey, ret);
598                 if (nr_iterations-- < 0)
599                         break;
600
601                 dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
602                         " shadow: 0x%016llx\n",
603                         __func__, __LINE__, ret, __read_pkey_reg(),
604                         shadow_pkey_reg);
605                 sys_pkey_free(rpkey);
606                 dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
607                         " shadow: 0x%016llx\n",
608                         __func__, __LINE__, ret, __read_pkey_reg(),
609                         shadow_pkey_reg);
610         }
611         pkey_assert(pkey < NR_PKEYS);
612
613         ret = sys_mprotect_pkey(ptr, size, orig_prot, pkey);
614         dprintf1("mprotect_pkey(%p, %zx, prot=0x%lx, pkey=%ld) ret: %d\n",
615                         ptr, size, orig_prot, pkey, ret);
616         pkey_assert(!ret);
617         dprintf1("%s()::%d, ret: %d pkey_reg: 0x%016llx"
618                         " shadow: 0x%016llx\n", __func__,
619                         __LINE__, ret, __read_pkey_reg(), shadow_pkey_reg);
620         return ret;
621 }
622
623 struct pkey_malloc_record {
624         void *ptr;
625         long size;
626         int prot;
627 };
628 struct pkey_malloc_record *pkey_malloc_records;
629 struct pkey_malloc_record *pkey_last_malloc_record;
630 long nr_pkey_malloc_records;
631 void record_pkey_malloc(void *ptr, long size, int prot)
632 {
633         long i;
634         struct pkey_malloc_record *rec = NULL;
635
636         for (i = 0; i < nr_pkey_malloc_records; i++) {
637                 rec = &pkey_malloc_records[i];
638                 /* find a free record */
639                 if (rec)
640                         break;
641         }
642         if (!rec) {
643                 /* every record is full */
644                 size_t old_nr_records = nr_pkey_malloc_records;
645                 size_t new_nr_records = (nr_pkey_malloc_records * 2 + 1);
646                 size_t new_size = new_nr_records * sizeof(struct pkey_malloc_record);
647                 dprintf2("new_nr_records: %zd\n", new_nr_records);
648                 dprintf2("new_size: %zd\n", new_size);
649                 pkey_malloc_records = realloc(pkey_malloc_records, new_size);
650                 pkey_assert(pkey_malloc_records != NULL);
651                 rec = &pkey_malloc_records[nr_pkey_malloc_records];
652                 /*
653                  * realloc() does not initialize memory, so zero it from
654                  * the first new record all the way to the end.
655                  */
656                 for (i = 0; i < new_nr_records - old_nr_records; i++)
657                         memset(rec + i, 0, sizeof(*rec));
658         }
659         dprintf3("filling malloc record[%d/%p]: {%p, %ld}\n",
660                 (int)(rec - pkey_malloc_records), rec, ptr, size);
661         rec->ptr = ptr;
662         rec->size = size;
663         rec->prot = prot;
664         pkey_last_malloc_record = rec;
665         nr_pkey_malloc_records++;
666 }
667
668 void free_pkey_malloc(void *ptr)
669 {
670         long i;
671         int ret;
672         dprintf3("%s(%p)\n", __func__, ptr);
673         for (i = 0; i < nr_pkey_malloc_records; i++) {
674                 struct pkey_malloc_record *rec = &pkey_malloc_records[i];
675                 dprintf4("looking for ptr %p at record[%ld/%p]: {%p, %ld}\n",
676                                 ptr, i, rec, rec->ptr, rec->size);
677                 if ((ptr <  rec->ptr) ||
678                     (ptr >= rec->ptr + rec->size))
679                         continue;
680
681                 dprintf3("found ptr %p at record[%ld/%p]: {%p, %ld}\n",
682                                 ptr, i, rec, rec->ptr, rec->size);
683                 nr_pkey_malloc_records--;
684                 ret = munmap(rec->ptr, rec->size);
685                 dprintf3("munmap ret: %d\n", ret);
686                 pkey_assert(!ret);
687                 dprintf3("clearing rec->ptr, rec: %p\n", rec);
688                 rec->ptr = NULL;
689                 dprintf3("done clearing rec->ptr, rec: %p\n", rec);
690                 return;
691         }
692         pkey_assert(false);
693 }
694
695
696 void *malloc_pkey_with_mprotect(long size, int prot, u16 pkey)
697 {
698         void *ptr;
699         int ret;
700
701         read_pkey_reg();
702         dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
703                         size, prot, pkey);
704         pkey_assert(pkey < NR_PKEYS);
705         ptr = mmap(NULL, size, prot, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
706         pkey_assert(ptr != (void *)-1);
707         ret = mprotect_pkey((void *)ptr, PAGE_SIZE, prot, pkey);
708         pkey_assert(!ret);
709         record_pkey_malloc(ptr, size, prot);
710         read_pkey_reg();
711
712         dprintf1("%s() for pkey %d @ %p\n", __func__, pkey, ptr);
713         return ptr;
714 }
715
716 void *malloc_pkey_anon_huge(long size, int prot, u16 pkey)
717 {
718         int ret;
719         void *ptr;
720
721         dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
722                         size, prot, pkey);
723         /*
724          * Guarantee we can fit at least one huge page in the resulting
725          * allocation by allocating space for 2:
726          */
727         size = ALIGN_UP(size, HPAGE_SIZE * 2);
728         ptr = mmap(NULL, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
729         pkey_assert(ptr != (void *)-1);
730         record_pkey_malloc(ptr, size, prot);
731         mprotect_pkey(ptr, size, prot, pkey);
732
733         dprintf1("unaligned ptr: %p\n", ptr);
734         ptr = ALIGN_PTR_UP(ptr, HPAGE_SIZE);
735         dprintf1("  aligned ptr: %p\n", ptr);
736         ret = madvise(ptr, HPAGE_SIZE, MADV_HUGEPAGE);
737         dprintf1("MADV_HUGEPAGE ret: %d\n", ret);
738         ret = madvise(ptr, HPAGE_SIZE, MADV_WILLNEED);
739         dprintf1("MADV_WILLNEED ret: %d\n", ret);
740         memset(ptr, 0, HPAGE_SIZE);
741
742         dprintf1("mmap()'d thp for pkey %d @ %p\n", pkey, ptr);
743         return ptr;
744 }
745
746 int hugetlb_setup_ok;
747 #define SYSFS_FMT_NR_HUGE_PAGES "/sys/kernel/mm/hugepages/hugepages-%ldkB/nr_hugepages"
748 #define GET_NR_HUGE_PAGES 10
749 void setup_hugetlbfs(void)
750 {
751         int err;
752         int fd;
753         char buf[256];
754         long hpagesz_kb;
755         long hpagesz_mb;
756
757         if (geteuid() != 0) {
758                 fprintf(stderr, "WARNING: not run as root, can not do hugetlb test\n");
759                 return;
760         }
761
762         cat_into_file(__stringify(GET_NR_HUGE_PAGES), "/proc/sys/vm/nr_hugepages");
763
764         /*
765          * Now go make sure that we got the pages and that they
766          * are PMD-level pages. Someone might have made PUD-level
767          * pages the default.
768          */
769         hpagesz_kb = HPAGE_SIZE / 1024;
770         hpagesz_mb = hpagesz_kb / 1024;
771         sprintf(buf, SYSFS_FMT_NR_HUGE_PAGES, hpagesz_kb);
772         fd = open(buf, O_RDONLY);
773         if (fd < 0) {
774                 fprintf(stderr, "opening sysfs %ldM hugetlb config: %s\n",
775                         hpagesz_mb, strerror(errno));
776                 return;
777         }
778
779         /* -1 to guarantee leaving the trailing \0 */
780         err = read(fd, buf, sizeof(buf)-1);
781         close(fd);
782         if (err <= 0) {
783                 fprintf(stderr, "reading sysfs %ldM hugetlb config: %s\n",
784                         hpagesz_mb, strerror(errno));
785                 return;
786         }
787
788         if (atoi(buf) != GET_NR_HUGE_PAGES) {
789                 fprintf(stderr, "could not confirm %ldM pages, got: '%s' expected %d\n",
790                         hpagesz_mb, buf, GET_NR_HUGE_PAGES);
791                 return;
792         }
793
794         hugetlb_setup_ok = 1;
795 }
796
797 void *malloc_pkey_hugetlb(long size, int prot, u16 pkey)
798 {
799         void *ptr;
800         int flags = MAP_ANONYMOUS|MAP_PRIVATE|MAP_HUGETLB;
801
802         if (!hugetlb_setup_ok)
803                 return PTR_ERR_ENOTSUP;
804
805         dprintf1("doing %s(%ld, %x, %x)\n", __func__, size, prot, pkey);
806         size = ALIGN_UP(size, HPAGE_SIZE * 2);
807         pkey_assert(pkey < NR_PKEYS);
808         ptr = mmap(NULL, size, PROT_NONE, flags, -1, 0);
809         pkey_assert(ptr != (void *)-1);
810         mprotect_pkey(ptr, size, prot, pkey);
811
812         record_pkey_malloc(ptr, size, prot);
813
814         dprintf1("mmap()'d hugetlbfs for pkey %d @ %p\n", pkey, ptr);
815         return ptr;
816 }
817
818 void *malloc_pkey_mmap_dax(long size, int prot, u16 pkey)
819 {
820         void *ptr;
821         int fd;
822
823         dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
824                         size, prot, pkey);
825         pkey_assert(pkey < NR_PKEYS);
826         fd = open("/dax/foo", O_RDWR);
827         pkey_assert(fd >= 0);
828
829         ptr = mmap(0, size, prot, MAP_SHARED, fd, 0);
830         pkey_assert(ptr != (void *)-1);
831
832         mprotect_pkey(ptr, size, prot, pkey);
833
834         record_pkey_malloc(ptr, size, prot);
835
836         dprintf1("mmap()'d for pkey %d @ %p\n", pkey, ptr);
837         close(fd);
838         return ptr;
839 }
840
841 void *(*pkey_malloc[])(long size, int prot, u16 pkey) = {
842
843         malloc_pkey_with_mprotect,
844         malloc_pkey_with_mprotect_subpage,
845         malloc_pkey_anon_huge,
846         malloc_pkey_hugetlb
847 /* can not do direct with the pkey_mprotect() API:
848         malloc_pkey_mmap_direct,
849         malloc_pkey_mmap_dax,
850 */
851 };
852
853 void *malloc_pkey(long size, int prot, u16 pkey)
854 {
855         void *ret;
856         static int malloc_type;
857         int nr_malloc_types = ARRAY_SIZE(pkey_malloc);
858
859         pkey_assert(pkey < NR_PKEYS);
860
861         while (1) {
862                 pkey_assert(malloc_type < nr_malloc_types);
863
864                 ret = pkey_malloc[malloc_type](size, prot, pkey);
865                 pkey_assert(ret != (void *)-1);
866
867                 malloc_type++;
868                 if (malloc_type >= nr_malloc_types)
869                         malloc_type = (random()%nr_malloc_types);
870
871                 /* try again if the malloc_type we tried is unsupported */
872                 if (ret == PTR_ERR_ENOTSUP)
873                         continue;
874
875                 break;
876         }
877
878         dprintf3("%s(%ld, prot=%x, pkey=%x) returning: %p\n", __func__,
879                         size, prot, pkey, ret);
880         return ret;
881 }
882
883 int last_pkey_faults;
884 #define UNKNOWN_PKEY -2
885 void expected_pkey_fault(int pkey)
886 {
887         dprintf2("%s(): last_pkey_faults: %d pkey_faults: %d\n",
888                         __func__, last_pkey_faults, pkey_faults);
889         dprintf2("%s(%d): last_si_pkey: %d\n", __func__, pkey, last_si_pkey);
890         pkey_assert(last_pkey_faults + 1 == pkey_faults);
891
892        /*
893         * For exec-only memory, we do not know the pkey in
894         * advance, so skip this check.
895         */
896         if (pkey != UNKNOWN_PKEY)
897                 pkey_assert(last_si_pkey == pkey);
898
899 #if defined(__i386__) || defined(__x86_64__) /* arch */
900         /*
901          * The signal handler shold have cleared out PKEY register to let the
902          * test program continue.  We now have to restore it.
903          */
904         if (__read_pkey_reg() != 0)
905 #else /* arch */
906         if (__read_pkey_reg() != shadow_pkey_reg)
907 #endif /* arch */
908                 pkey_assert(0);
909
910         __write_pkey_reg(shadow_pkey_reg);
911         dprintf1("%s() set pkey_reg=%016llx to restore state after signal "
912                        "nuked it\n", __func__, shadow_pkey_reg);
913         last_pkey_faults = pkey_faults;
914         last_si_pkey = -1;
915 }
916
917 #define do_not_expect_pkey_fault(msg)   do {                    \
918         if (last_pkey_faults != pkey_faults)                    \
919                 dprintf0("unexpected PKey fault: %s\n", msg);   \
920         pkey_assert(last_pkey_faults == pkey_faults);           \
921 } while (0)
922
923 int test_fds[10] = { -1 };
924 int nr_test_fds;
925 void __save_test_fd(int fd)
926 {
927         pkey_assert(fd >= 0);
928         pkey_assert(nr_test_fds < ARRAY_SIZE(test_fds));
929         test_fds[nr_test_fds] = fd;
930         nr_test_fds++;
931 }
932
933 int get_test_read_fd(void)
934 {
935         int test_fd = open("/etc/passwd", O_RDONLY);
936         __save_test_fd(test_fd);
937         return test_fd;
938 }
939
940 void close_test_fds(void)
941 {
942         int i;
943
944         for (i = 0; i < nr_test_fds; i++) {
945                 if (test_fds[i] < 0)
946                         continue;
947                 close(test_fds[i]);
948                 test_fds[i] = -1;
949         }
950         nr_test_fds = 0;
951 }
952
953 #define barrier() __asm__ __volatile__("": : :"memory")
954 __attribute__((noinline)) int read_ptr(int *ptr)
955 {
956         /*
957          * Keep GCC from optimizing this away somehow
958          */
959         barrier();
960         return *ptr;
961 }
962
963 void test_pkey_alloc_free_attach_pkey0(int *ptr, u16 pkey)
964 {
965         int i, err;
966         int max_nr_pkey_allocs;
967         int alloced_pkeys[NR_PKEYS];
968         int nr_alloced = 0;
969         long size;
970
971         pkey_assert(pkey_last_malloc_record);
972         size = pkey_last_malloc_record->size;
973         /*
974          * This is a bit of a hack.  But mprotect() requires
975          * huge-page-aligned sizes when operating on hugetlbfs.
976          * So, make sure that we use something that's a multiple
977          * of a huge page when we can.
978          */
979         if (size >= HPAGE_SIZE)
980                 size = HPAGE_SIZE;
981
982         /* allocate every possible key and make sure key-0 never got allocated */
983         max_nr_pkey_allocs = NR_PKEYS;
984         for (i = 0; i < max_nr_pkey_allocs; i++) {
985                 int new_pkey = alloc_pkey();
986                 pkey_assert(new_pkey != 0);
987
988                 if (new_pkey < 0)
989                         break;
990                 alloced_pkeys[nr_alloced++] = new_pkey;
991         }
992         /* free all the allocated keys */
993         for (i = 0; i < nr_alloced; i++) {
994                 int free_ret;
995
996                 if (!alloced_pkeys[i])
997                         continue;
998                 free_ret = sys_pkey_free(alloced_pkeys[i]);
999                 pkey_assert(!free_ret);
1000         }
1001
1002         /* attach key-0 in various modes */
1003         err = sys_mprotect_pkey(ptr, size, PROT_READ, 0);
1004         pkey_assert(!err);
1005         err = sys_mprotect_pkey(ptr, size, PROT_WRITE, 0);
1006         pkey_assert(!err);
1007         err = sys_mprotect_pkey(ptr, size, PROT_EXEC, 0);
1008         pkey_assert(!err);
1009         err = sys_mprotect_pkey(ptr, size, PROT_READ|PROT_WRITE, 0);
1010         pkey_assert(!err);
1011         err = sys_mprotect_pkey(ptr, size, PROT_READ|PROT_WRITE|PROT_EXEC, 0);
1012         pkey_assert(!err);
1013 }
1014
1015 void test_read_of_write_disabled_region(int *ptr, u16 pkey)
1016 {
1017         int ptr_contents;
1018
1019         dprintf1("disabling write access to PKEY[1], doing read\n");
1020         pkey_write_deny(pkey);
1021         ptr_contents = read_ptr(ptr);
1022         dprintf1("*ptr: %d\n", ptr_contents);
1023         dprintf1("\n");
1024 }
1025 void test_read_of_access_disabled_region(int *ptr, u16 pkey)
1026 {
1027         int ptr_contents;
1028
1029         dprintf1("disabling access to PKEY[%02d], doing read @ %p\n", pkey, ptr);
1030         read_pkey_reg();
1031         pkey_access_deny(pkey);
1032         ptr_contents = read_ptr(ptr);
1033         dprintf1("*ptr: %d\n", ptr_contents);
1034         expected_pkey_fault(pkey);
1035 }
1036
1037 void test_read_of_access_disabled_region_with_page_already_mapped(int *ptr,
1038                 u16 pkey)
1039 {
1040         int ptr_contents;
1041
1042         dprintf1("disabling access to PKEY[%02d], doing read @ %p\n",
1043                                 pkey, ptr);
1044         ptr_contents = read_ptr(ptr);
1045         dprintf1("reading ptr before disabling the read : %d\n",
1046                         ptr_contents);
1047         read_pkey_reg();
1048         pkey_access_deny(pkey);
1049         ptr_contents = read_ptr(ptr);
1050         dprintf1("*ptr: %d\n", ptr_contents);
1051         expected_pkey_fault(pkey);
1052 }
1053
1054 void test_write_of_write_disabled_region_with_page_already_mapped(int *ptr,
1055                 u16 pkey)
1056 {
1057         *ptr = __LINE__;
1058         dprintf1("disabling write access; after accessing the page, "
1059                 "to PKEY[%02d], doing write\n", pkey);
1060         pkey_write_deny(pkey);
1061         *ptr = __LINE__;
1062         expected_pkey_fault(pkey);
1063 }
1064
1065 void test_write_of_write_disabled_region(int *ptr, u16 pkey)
1066 {
1067         dprintf1("disabling write access to PKEY[%02d], doing write\n", pkey);
1068         pkey_write_deny(pkey);
1069         *ptr = __LINE__;
1070         expected_pkey_fault(pkey);
1071 }
1072 void test_write_of_access_disabled_region(int *ptr, u16 pkey)
1073 {
1074         dprintf1("disabling access to PKEY[%02d], doing write\n", pkey);
1075         pkey_access_deny(pkey);
1076         *ptr = __LINE__;
1077         expected_pkey_fault(pkey);
1078 }
1079
1080 void test_write_of_access_disabled_region_with_page_already_mapped(int *ptr,
1081                         u16 pkey)
1082 {
1083         *ptr = __LINE__;
1084         dprintf1("disabling access; after accessing the page, "
1085                 " to PKEY[%02d], doing write\n", pkey);
1086         pkey_access_deny(pkey);
1087         *ptr = __LINE__;
1088         expected_pkey_fault(pkey);
1089 }
1090
1091 void test_kernel_write_of_access_disabled_region(int *ptr, u16 pkey)
1092 {
1093         int ret;
1094         int test_fd = get_test_read_fd();
1095
1096         dprintf1("disabling access to PKEY[%02d], "
1097                  "having kernel read() to buffer\n", pkey);
1098         pkey_access_deny(pkey);
1099         ret = read(test_fd, ptr, 1);
1100         dprintf1("read ret: %d\n", ret);
1101         pkey_assert(ret);
1102 }
1103 void test_kernel_write_of_write_disabled_region(int *ptr, u16 pkey)
1104 {
1105         int ret;
1106         int test_fd = get_test_read_fd();
1107
1108         pkey_write_deny(pkey);
1109         ret = read(test_fd, ptr, 100);
1110         dprintf1("read ret: %d\n", ret);
1111         if (ret < 0 && (DEBUG_LEVEL > 0))
1112                 perror("verbose read result (OK for this to be bad)");
1113         pkey_assert(ret);
1114 }
1115
1116 void test_kernel_gup_of_access_disabled_region(int *ptr, u16 pkey)
1117 {
1118         int pipe_ret, vmsplice_ret;
1119         struct iovec iov;
1120         int pipe_fds[2];
1121
1122         pipe_ret = pipe(pipe_fds);
1123
1124         pkey_assert(pipe_ret == 0);
1125         dprintf1("disabling access to PKEY[%02d], "
1126                  "having kernel vmsplice from buffer\n", pkey);
1127         pkey_access_deny(pkey);
1128         iov.iov_base = ptr;
1129         iov.iov_len = PAGE_SIZE;
1130         vmsplice_ret = vmsplice(pipe_fds[1], &iov, 1, SPLICE_F_GIFT);
1131         dprintf1("vmsplice() ret: %d\n", vmsplice_ret);
1132         pkey_assert(vmsplice_ret == -1);
1133
1134         close(pipe_fds[0]);
1135         close(pipe_fds[1]);
1136 }
1137
1138 void test_kernel_gup_write_to_write_disabled_region(int *ptr, u16 pkey)
1139 {
1140         int ignored = 0xdada;
1141         int futex_ret;
1142         int some_int = __LINE__;
1143
1144         dprintf1("disabling write to PKEY[%02d], "
1145                  "doing futex gunk in buffer\n", pkey);
1146         *ptr = some_int;
1147         pkey_write_deny(pkey);
1148         futex_ret = syscall(SYS_futex, ptr, FUTEX_WAIT, some_int-1, NULL,
1149                         &ignored, ignored);
1150         if (DEBUG_LEVEL > 0)
1151                 perror("futex");
1152         dprintf1("futex() ret: %d\n", futex_ret);
1153 }
1154
1155 /* Assumes that all pkeys other than 'pkey' are unallocated */
1156 void test_pkey_syscalls_on_non_allocated_pkey(int *ptr, u16 pkey)
1157 {
1158         int err;
1159         int i;
1160
1161         /* Note: 0 is the default pkey, so don't mess with it */
1162         for (i = 1; i < NR_PKEYS; i++) {
1163                 if (pkey == i)
1164                         continue;
1165
1166                 dprintf1("trying get/set/free to non-allocated pkey: %2d\n", i);
1167                 err = sys_pkey_free(i);
1168                 pkey_assert(err);
1169
1170                 err = sys_pkey_free(i);
1171                 pkey_assert(err);
1172
1173                 err = sys_mprotect_pkey(ptr, PAGE_SIZE, PROT_READ, i);
1174                 pkey_assert(err);
1175         }
1176 }
1177
1178 /* Assumes that all pkeys other than 'pkey' are unallocated */
1179 void test_pkey_syscalls_bad_args(int *ptr, u16 pkey)
1180 {
1181         int err;
1182         int bad_pkey = NR_PKEYS+99;
1183
1184         /* pass a known-invalid pkey in: */
1185         err = sys_mprotect_pkey(ptr, PAGE_SIZE, PROT_READ, bad_pkey);
1186         pkey_assert(err);
1187 }
1188
1189 void become_child(void)
1190 {
1191         pid_t forkret;
1192
1193         forkret = fork();
1194         pkey_assert(forkret >= 0);
1195         dprintf3("[%d] fork() ret: %d\n", getpid(), forkret);
1196
1197         if (!forkret) {
1198                 /* in the child */
1199                 return;
1200         }
1201         exit(0);
1202 }
1203
1204 /* Assumes that all pkeys other than 'pkey' are unallocated */
1205 void test_pkey_alloc_exhaust(int *ptr, u16 pkey)
1206 {
1207         int err;
1208         int allocated_pkeys[NR_PKEYS] = {0};
1209         int nr_allocated_pkeys = 0;
1210         int i;
1211
1212         for (i = 0; i < NR_PKEYS*3; i++) {
1213                 int new_pkey;
1214                 dprintf1("%s() alloc loop: %d\n", __func__, i);
1215                 new_pkey = alloc_pkey();
1216                 dprintf4("%s()::%d, err: %d pkey_reg: 0x%016llx"
1217                                 " shadow: 0x%016llx\n",
1218                                 __func__, __LINE__, err, __read_pkey_reg(),
1219                                 shadow_pkey_reg);
1220                 read_pkey_reg(); /* for shadow checking */
1221                 dprintf2("%s() errno: %d ENOSPC: %d\n", __func__, errno, ENOSPC);
1222                 if ((new_pkey == -1) && (errno == ENOSPC)) {
1223                         dprintf2("%s() failed to allocate pkey after %d tries\n",
1224                                 __func__, nr_allocated_pkeys);
1225                 } else {
1226                         /*
1227                          * Ensure the number of successes never
1228                          * exceeds the number of keys supported
1229                          * in the hardware.
1230                          */
1231                         pkey_assert(nr_allocated_pkeys < NR_PKEYS);
1232                         allocated_pkeys[nr_allocated_pkeys++] = new_pkey;
1233                 }
1234
1235                 /*
1236                  * Make sure that allocation state is properly
1237                  * preserved across fork().
1238                  */
1239                 if (i == NR_PKEYS*2)
1240                         become_child();
1241         }
1242
1243         dprintf3("%s()::%d\n", __func__, __LINE__);
1244
1245         /*
1246          * On x86:
1247          * There are 16 pkeys supported in hardware.  Three are
1248          * allocated by the time we get here:
1249          *   1. The default key (0)
1250          *   2. One possibly consumed by an execute-only mapping.
1251          *   3. One allocated by the test code and passed in via
1252          *      'pkey' to this function.
1253          * Ensure that we can allocate at least another 13 (16-3).
1254          *
1255          * On powerpc:
1256          * There are either 5, 28, 29 or 32 pkeys supported in
1257          * hardware depending on the page size (4K or 64K) and
1258          * platform (powernv or powervm). Four are allocated by
1259          * the time we get here. These include pkey-0, pkey-1,
1260          * exec-only pkey and the one allocated by the test code.
1261          * Ensure that we can allocate the remaining.
1262          */
1263         pkey_assert(i >= (NR_PKEYS - get_arch_reserved_keys() - 1));
1264
1265         for (i = 0; i < nr_allocated_pkeys; i++) {
1266                 err = sys_pkey_free(allocated_pkeys[i]);
1267                 pkey_assert(!err);
1268                 read_pkey_reg(); /* for shadow checking */
1269         }
1270 }
1271
1272 void arch_force_pkey_reg_init(void)
1273 {
1274 #if defined(__i386__) || defined(__x86_64__) /* arch */
1275         u64 *buf;
1276
1277         /*
1278          * All keys should be allocated and set to allow reads and
1279          * writes, so the register should be all 0.  If not, just
1280          * skip the test.
1281          */
1282         if (read_pkey_reg())
1283                 return;
1284
1285         /*
1286          * Just allocate an absurd about of memory rather than
1287          * doing the XSAVE size enumeration dance.
1288          */
1289         buf = mmap(NULL, 1*MB, PROT_READ|PROT_WRITE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
1290
1291         /* These __builtins require compiling with -mxsave */
1292
1293         /* XSAVE to build a valid buffer: */
1294         __builtin_ia32_xsave(buf, XSTATE_PKEY);
1295         /* Clear XSTATE_BV[PKRU]: */
1296         buf[XSTATE_BV_OFFSET/sizeof(u64)] &= ~XSTATE_PKEY;
1297         /* XRSTOR will likely get PKRU back to the init state: */
1298         __builtin_ia32_xrstor(buf, XSTATE_PKEY);
1299
1300         munmap(buf, 1*MB);
1301 #endif
1302 }
1303
1304
1305 /*
1306  * This is mostly useless on ppc for now.  But it will not
1307  * hurt anything and should give some better coverage as
1308  * a long-running test that continually checks the pkey
1309  * register.
1310  */
1311 void test_pkey_init_state(int *ptr, u16 pkey)
1312 {
1313         int err;
1314         int allocated_pkeys[NR_PKEYS] = {0};
1315         int nr_allocated_pkeys = 0;
1316         int i;
1317
1318         for (i = 0; i < NR_PKEYS; i++) {
1319                 int new_pkey = alloc_pkey();
1320
1321                 if (new_pkey < 0)
1322                         continue;
1323                 allocated_pkeys[nr_allocated_pkeys++] = new_pkey;
1324         }
1325
1326         dprintf3("%s()::%d\n", __func__, __LINE__);
1327
1328         arch_force_pkey_reg_init();
1329
1330         /*
1331          * Loop for a bit, hoping to get exercise the kernel
1332          * context switch code.
1333          */
1334         for (i = 0; i < 1000000; i++)
1335                 read_pkey_reg();
1336
1337         for (i = 0; i < nr_allocated_pkeys; i++) {
1338                 err = sys_pkey_free(allocated_pkeys[i]);
1339                 pkey_assert(!err);
1340                 read_pkey_reg(); /* for shadow checking */
1341         }
1342 }
1343
1344 /*
1345  * pkey 0 is special.  It is allocated by default, so you do not
1346  * have to call pkey_alloc() to use it first.  Make sure that it
1347  * is usable.
1348  */
1349 void test_mprotect_with_pkey_0(int *ptr, u16 pkey)
1350 {
1351         long size;
1352         int prot;
1353
1354         assert(pkey_last_malloc_record);
1355         size = pkey_last_malloc_record->size;
1356         /*
1357          * This is a bit of a hack.  But mprotect() requires
1358          * huge-page-aligned sizes when operating on hugetlbfs.
1359          * So, make sure that we use something that's a multiple
1360          * of a huge page when we can.
1361          */
1362         if (size >= HPAGE_SIZE)
1363                 size = HPAGE_SIZE;
1364         prot = pkey_last_malloc_record->prot;
1365
1366         /* Use pkey 0 */
1367         mprotect_pkey(ptr, size, prot, 0);
1368
1369         /* Make sure that we can set it back to the original pkey. */
1370         mprotect_pkey(ptr, size, prot, pkey);
1371 }
1372
1373 void test_ptrace_of_child(int *ptr, u16 pkey)
1374 {
1375         __attribute__((__unused__)) int peek_result;
1376         pid_t child_pid;
1377         void *ignored = 0;
1378         long ret;
1379         int status;
1380         /*
1381          * This is the "control" for our little expermient.  Make sure
1382          * we can always access it when ptracing.
1383          */
1384         int *plain_ptr_unaligned = malloc(HPAGE_SIZE);
1385         int *plain_ptr = ALIGN_PTR_UP(plain_ptr_unaligned, PAGE_SIZE);
1386
1387         /*
1388          * Fork a child which is an exact copy of this process, of course.
1389          * That means we can do all of our tests via ptrace() and then plain
1390          * memory access and ensure they work differently.
1391          */
1392         child_pid = fork_lazy_child();
1393         dprintf1("[%d] child pid: %d\n", getpid(), child_pid);
1394
1395         ret = ptrace(PTRACE_ATTACH, child_pid, ignored, ignored);
1396         if (ret)
1397                 perror("attach");
1398         dprintf1("[%d] attach ret: %ld %d\n", getpid(), ret, __LINE__);
1399         pkey_assert(ret != -1);
1400         ret = waitpid(child_pid, &status, WUNTRACED);
1401         if ((ret != child_pid) || !(WIFSTOPPED(status))) {
1402                 fprintf(stderr, "weird waitpid result %ld stat %x\n",
1403                                 ret, status);
1404                 pkey_assert(0);
1405         }
1406         dprintf2("waitpid ret: %ld\n", ret);
1407         dprintf2("waitpid status: %d\n", status);
1408
1409         pkey_access_deny(pkey);
1410         pkey_write_deny(pkey);
1411
1412         /* Write access, untested for now:
1413         ret = ptrace(PTRACE_POKEDATA, child_pid, peek_at, data);
1414         pkey_assert(ret != -1);
1415         dprintf1("poke at %p: %ld\n", peek_at, ret);
1416         */
1417
1418         /*
1419          * Try to access the pkey-protected "ptr" via ptrace:
1420          */
1421         ret = ptrace(PTRACE_PEEKDATA, child_pid, ptr, ignored);
1422         /* expect it to work, without an error: */
1423         pkey_assert(ret != -1);
1424         /* Now access from the current task, and expect an exception: */
1425         peek_result = read_ptr(ptr);
1426         expected_pkey_fault(pkey);
1427
1428         /*
1429          * Try to access the NON-pkey-protected "plain_ptr" via ptrace:
1430          */
1431         ret = ptrace(PTRACE_PEEKDATA, child_pid, plain_ptr, ignored);
1432         /* expect it to work, without an error: */
1433         pkey_assert(ret != -1);
1434         /* Now access from the current task, and expect NO exception: */
1435         peek_result = read_ptr(plain_ptr);
1436         do_not_expect_pkey_fault("read plain pointer after ptrace");
1437
1438         ret = ptrace(PTRACE_DETACH, child_pid, ignored, 0);
1439         pkey_assert(ret != -1);
1440
1441         ret = kill(child_pid, SIGKILL);
1442         pkey_assert(ret != -1);
1443
1444         wait(&status);
1445
1446         free(plain_ptr_unaligned);
1447 }
1448
1449 void *get_pointer_to_instructions(void)
1450 {
1451         void *p1;
1452
1453         p1 = ALIGN_PTR_UP(&lots_o_noops_around_write, PAGE_SIZE);
1454         dprintf3("&lots_o_noops: %p\n", &lots_o_noops_around_write);
1455         /* lots_o_noops_around_write should be page-aligned already */
1456         assert(p1 == &lots_o_noops_around_write);
1457
1458         /* Point 'p1' at the *second* page of the function: */
1459         p1 += PAGE_SIZE;
1460
1461         /*
1462          * Try to ensure we fault this in on next touch to ensure
1463          * we get an instruction fault as opposed to a data one
1464          */
1465         madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1466
1467         return p1;
1468 }
1469
1470 void test_executing_on_unreadable_memory(int *ptr, u16 pkey)
1471 {
1472         void *p1;
1473         int scratch;
1474         int ptr_contents;
1475         int ret;
1476
1477         p1 = get_pointer_to_instructions();
1478         lots_o_noops_around_write(&scratch);
1479         ptr_contents = read_ptr(p1);
1480         dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1481
1482         ret = mprotect_pkey(p1, PAGE_SIZE, PROT_EXEC, (u64)pkey);
1483         pkey_assert(!ret);
1484         pkey_access_deny(pkey);
1485
1486         dprintf2("pkey_reg: %016llx\n", read_pkey_reg());
1487
1488         /*
1489          * Make sure this is an *instruction* fault
1490          */
1491         madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1492         lots_o_noops_around_write(&scratch);
1493         do_not_expect_pkey_fault("executing on PROT_EXEC memory");
1494         expect_fault_on_read_execonly_key(p1, pkey);
1495 }
1496
1497 void test_implicit_mprotect_exec_only_memory(int *ptr, u16 pkey)
1498 {
1499         void *p1;
1500         int scratch;
1501         int ptr_contents;
1502         int ret;
1503
1504         dprintf1("%s() start\n", __func__);
1505
1506         p1 = get_pointer_to_instructions();
1507         lots_o_noops_around_write(&scratch);
1508         ptr_contents = read_ptr(p1);
1509         dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1510
1511         /* Use a *normal* mprotect(), not mprotect_pkey(): */
1512         ret = mprotect(p1, PAGE_SIZE, PROT_EXEC);
1513         pkey_assert(!ret);
1514
1515         /*
1516          * Reset the shadow, assuming that the above mprotect()
1517          * correctly changed PKRU, but to an unknown value since
1518          * the actual allocated pkey is unknown.
1519          */
1520         shadow_pkey_reg = __read_pkey_reg();
1521
1522         dprintf2("pkey_reg: %016llx\n", read_pkey_reg());
1523
1524         /* Make sure this is an *instruction* fault */
1525         madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1526         lots_o_noops_around_write(&scratch);
1527         do_not_expect_pkey_fault("executing on PROT_EXEC memory");
1528         expect_fault_on_read_execonly_key(p1, UNKNOWN_PKEY);
1529
1530         /*
1531          * Put the memory back to non-PROT_EXEC.  Should clear the
1532          * exec-only pkey off the VMA and allow it to be readable
1533          * again.  Go to PROT_NONE first to check for a kernel bug
1534          * that did not clear the pkey when doing PROT_NONE.
1535          */
1536         ret = mprotect(p1, PAGE_SIZE, PROT_NONE);
1537         pkey_assert(!ret);
1538
1539         ret = mprotect(p1, PAGE_SIZE, PROT_READ|PROT_EXEC);
1540         pkey_assert(!ret);
1541         ptr_contents = read_ptr(p1);
1542         do_not_expect_pkey_fault("plain read on recently PROT_EXEC area");
1543 }
1544
1545 #if defined(__i386__) || defined(__x86_64__)
1546 void test_ptrace_modifies_pkru(int *ptr, u16 pkey)
1547 {
1548         u32 new_pkru;
1549         pid_t child;
1550         int status, ret;
1551         int pkey_offset = pkey_reg_xstate_offset();
1552         size_t xsave_size = cpu_max_xsave_size();
1553         void *xsave;
1554         u32 *pkey_register;
1555         u64 *xstate_bv;
1556         struct iovec iov;
1557
1558         new_pkru = ~read_pkey_reg();
1559         /* Don't make PROT_EXEC mappings inaccessible */
1560         new_pkru &= ~3;
1561
1562         child = fork();
1563         pkey_assert(child >= 0);
1564         dprintf3("[%d] fork() ret: %d\n", getpid(), child);
1565         if (!child) {
1566                 ptrace(PTRACE_TRACEME, 0, 0, 0);
1567                 /* Stop and allow the tracer to modify PKRU directly */
1568                 raise(SIGSTOP);
1569
1570                 /*
1571                  * need __read_pkey_reg() version so we do not do shadow_pkey_reg
1572                  * checking
1573                  */
1574                 if (__read_pkey_reg() != new_pkru)
1575                         exit(1);
1576
1577                 /* Stop and allow the tracer to clear XSTATE_BV for PKRU */
1578                 raise(SIGSTOP);
1579
1580                 if (__read_pkey_reg() != 0)
1581                         exit(1);
1582
1583                 /* Stop and allow the tracer to examine PKRU */
1584                 raise(SIGSTOP);
1585
1586                 exit(0);
1587         }
1588
1589         pkey_assert(child == waitpid(child, &status, 0));
1590         dprintf3("[%d] waitpid(%d) status: %x\n", getpid(), child, status);
1591         pkey_assert(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP);
1592
1593         xsave = (void *)malloc(xsave_size);
1594         pkey_assert(xsave > 0);
1595
1596         /* Modify the PKRU register directly */
1597         iov.iov_base = xsave;
1598         iov.iov_len = xsave_size;
1599         ret = ptrace(PTRACE_GETREGSET, child, (void *)NT_X86_XSTATE, &iov);
1600         pkey_assert(ret == 0);
1601
1602         pkey_register = (u32 *)(xsave + pkey_offset);
1603         pkey_assert(*pkey_register == read_pkey_reg());
1604
1605         *pkey_register = new_pkru;
1606
1607         ret = ptrace(PTRACE_SETREGSET, child, (void *)NT_X86_XSTATE, &iov);
1608         pkey_assert(ret == 0);
1609
1610         /* Test that the modification is visible in ptrace before any execution */
1611         memset(xsave, 0xCC, xsave_size);
1612         ret = ptrace(PTRACE_GETREGSET, child, (void *)NT_X86_XSTATE, &iov);
1613         pkey_assert(ret == 0);
1614         pkey_assert(*pkey_register == new_pkru);
1615
1616         /* Execute the tracee */
1617         ret = ptrace(PTRACE_CONT, child, 0, 0);
1618         pkey_assert(ret == 0);
1619
1620         /* Test that the tracee saw the PKRU value change */
1621         pkey_assert(child == waitpid(child, &status, 0));
1622         dprintf3("[%d] waitpid(%d) status: %x\n", getpid(), child, status);
1623         pkey_assert(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP);
1624
1625         /* Test that the modification is visible in ptrace after execution */
1626         memset(xsave, 0xCC, xsave_size);
1627         ret = ptrace(PTRACE_GETREGSET, child, (void *)NT_X86_XSTATE, &iov);
1628         pkey_assert(ret == 0);
1629         pkey_assert(*pkey_register == new_pkru);
1630
1631         /* Clear the PKRU bit from XSTATE_BV */
1632         xstate_bv = (u64 *)(xsave + 512);
1633         *xstate_bv &= ~(1 << 9);
1634
1635         ret = ptrace(PTRACE_SETREGSET, child, (void *)NT_X86_XSTATE, &iov);
1636         pkey_assert(ret == 0);
1637
1638         /* Test that the modification is visible in ptrace before any execution */
1639         memset(xsave, 0xCC, xsave_size);
1640         ret = ptrace(PTRACE_GETREGSET, child, (void *)NT_X86_XSTATE, &iov);
1641         pkey_assert(ret == 0);
1642         pkey_assert(*pkey_register == 0);
1643
1644         ret = ptrace(PTRACE_CONT, child, 0, 0);
1645         pkey_assert(ret == 0);
1646
1647         /* Test that the tracee saw the PKRU value go to 0 */
1648         pkey_assert(child == waitpid(child, &status, 0));
1649         dprintf3("[%d] waitpid(%d) status: %x\n", getpid(), child, status);
1650         pkey_assert(WIFSTOPPED(status) && WSTOPSIG(status) == SIGSTOP);
1651
1652         /* Test that the modification is visible in ptrace after execution */
1653         memset(xsave, 0xCC, xsave_size);
1654         ret = ptrace(PTRACE_GETREGSET, child, (void *)NT_X86_XSTATE, &iov);
1655         pkey_assert(ret == 0);
1656         pkey_assert(*pkey_register == 0);
1657
1658         ret = ptrace(PTRACE_CONT, child, 0, 0);
1659         pkey_assert(ret == 0);
1660         pkey_assert(child == waitpid(child, &status, 0));
1661         dprintf3("[%d] waitpid(%d) status: %x\n", getpid(), child, status);
1662         pkey_assert(WIFEXITED(status));
1663         pkey_assert(WEXITSTATUS(status) == 0);
1664         free(xsave);
1665 }
1666 #endif
1667
1668 void test_mprotect_pkey_on_unsupported_cpu(int *ptr, u16 pkey)
1669 {
1670         int size = PAGE_SIZE;
1671         int sret;
1672
1673         if (cpu_has_pkeys()) {
1674                 dprintf1("SKIP: %s: no CPU support\n", __func__);
1675                 return;
1676         }
1677
1678         sret = syscall(__NR_pkey_mprotect, ptr, size, PROT_READ, pkey);
1679         pkey_assert(sret < 0);
1680 }
1681
1682 void (*pkey_tests[])(int *ptr, u16 pkey) = {
1683         test_read_of_write_disabled_region,
1684         test_read_of_access_disabled_region,
1685         test_read_of_access_disabled_region_with_page_already_mapped,
1686         test_write_of_write_disabled_region,
1687         test_write_of_write_disabled_region_with_page_already_mapped,
1688         test_write_of_access_disabled_region,
1689         test_write_of_access_disabled_region_with_page_already_mapped,
1690         test_kernel_write_of_access_disabled_region,
1691         test_kernel_write_of_write_disabled_region,
1692         test_kernel_gup_of_access_disabled_region,
1693         test_kernel_gup_write_to_write_disabled_region,
1694         test_executing_on_unreadable_memory,
1695         test_implicit_mprotect_exec_only_memory,
1696         test_mprotect_with_pkey_0,
1697         test_ptrace_of_child,
1698         test_pkey_init_state,
1699         test_pkey_syscalls_on_non_allocated_pkey,
1700         test_pkey_syscalls_bad_args,
1701         test_pkey_alloc_exhaust,
1702         test_pkey_alloc_free_attach_pkey0,
1703 #if defined(__i386__) || defined(__x86_64__)
1704         test_ptrace_modifies_pkru,
1705 #endif
1706 };
1707
1708 void run_tests_once(void)
1709 {
1710         int *ptr;
1711         int prot = PROT_READ|PROT_WRITE;
1712
1713         for (test_nr = 0; test_nr < ARRAY_SIZE(pkey_tests); test_nr++) {
1714                 int pkey;
1715                 int orig_pkey_faults = pkey_faults;
1716
1717                 dprintf1("======================\n");
1718                 dprintf1("test %d preparing...\n", test_nr);
1719
1720                 tracing_on();
1721                 pkey = alloc_random_pkey();
1722                 dprintf1("test %d starting with pkey: %d\n", test_nr, pkey);
1723                 ptr = malloc_pkey(PAGE_SIZE, prot, pkey);
1724                 dprintf1("test %d starting...\n", test_nr);
1725                 pkey_tests[test_nr](ptr, pkey);
1726                 dprintf1("freeing test memory: %p\n", ptr);
1727                 free_pkey_malloc(ptr);
1728                 sys_pkey_free(pkey);
1729
1730                 dprintf1("pkey_faults: %d\n", pkey_faults);
1731                 dprintf1("orig_pkey_faults: %d\n", orig_pkey_faults);
1732
1733                 tracing_off();
1734                 close_test_fds();
1735
1736                 printf("test %2d PASSED (iteration %d)\n", test_nr, iteration_nr);
1737                 dprintf1("======================\n\n");
1738         }
1739         iteration_nr++;
1740 }
1741
1742 void pkey_setup_shadow(void)
1743 {
1744         shadow_pkey_reg = __read_pkey_reg();
1745 }
1746
1747 int main(void)
1748 {
1749         int nr_iterations = 22;
1750         int pkeys_supported = is_pkeys_supported();
1751
1752         srand((unsigned int)time(NULL));
1753
1754         setup_handlers();
1755
1756         printf("has pkeys: %d\n", pkeys_supported);
1757
1758         if (!pkeys_supported) {
1759                 int size = PAGE_SIZE;
1760                 int *ptr;
1761
1762                 printf("running PKEY tests for unsupported CPU/OS\n");
1763
1764                 ptr  = mmap(NULL, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
1765                 assert(ptr != (void *)-1);
1766                 test_mprotect_pkey_on_unsupported_cpu(ptr, 1);
1767                 exit(0);
1768         }
1769
1770         pkey_setup_shadow();
1771         printf("startup pkey_reg: %016llx\n", read_pkey_reg());
1772         setup_hugetlbfs();
1773
1774         while (nr_iterations-- > 0)
1775                 run_tests_once();
1776
1777         printf("done (all tests OK)\n");
1778         return 0;
1779 }