linux/tools/testing/selftests/x86/protection_keys.c
<<
>>
Prefs
   1/*
   2 * Tests x86 Memory Protection Keys (see Documentation/x86/protection-keys.txt)
   3 *
   4 * There are examples in here of:
   5 *  * how to set protection keys on memory
   6 *  * how to set/clear bits in PKRU (the rights register)
   7 *  * how to handle SEGV_PKRU signals and extract pkey-relevant
   8 *    information from the siginfo
   9 *
  10 * Things to add:
  11 *      make sure KSM and KSM COW breaking works
  12 *      prefault pages in at malloc, or not
  13 *      protect MPX bounds tables with protection keys?
  14 *      make sure VMA splitting/merging is working correctly
  15 *      OOMs can destroy mm->mmap (see exit_mmap()), so make sure it is immune to pkeys
  16 *      look for pkey "leaks" where it is still set on a VMA but "freed" back to the kernel
  17 *      do a plain mprotect() to a mprotect_pkey() area and make sure the pkey sticks
  18 *
  19 * Compile like this:
  20 *      gcc      -o protection_keys    -O2 -g -std=gnu99 -pthread -Wall protection_keys.c -lrt -ldl -lm
  21 *      gcc -m32 -o protection_keys_32 -O2 -g -std=gnu99 -pthread -Wall protection_keys.c -lrt -ldl -lm
  22 */
  23#define _GNU_SOURCE
  24#include <errno.h>
  25#include <linux/futex.h>
  26#include <sys/time.h>
  27#include <sys/syscall.h>
  28#include <string.h>
  29#include <stdio.h>
  30#include <stdint.h>
  31#include <stdbool.h>
  32#include <signal.h>
  33#include <assert.h>
  34#include <stdlib.h>
  35#include <ucontext.h>
  36#include <sys/mman.h>
  37#include <sys/types.h>
  38#include <sys/wait.h>
  39#include <sys/stat.h>
  40#include <fcntl.h>
  41#include <unistd.h>
  42#include <sys/ptrace.h>
  43#include <setjmp.h>
  44
  45#include "pkey-helpers.h"
  46
  47int iteration_nr = 1;
  48int test_nr;
  49
  50unsigned int shadow_pkru;
  51
  52#define HPAGE_SIZE      (1UL<<21)
  53#define ARRAY_SIZE(x) (sizeof(x) / sizeof(*(x)))
  54#define ALIGN_UP(x, align_to)   (((x) + ((align_to)-1)) & ~((align_to)-1))
  55#define ALIGN_DOWN(x, align_to) ((x) & ~((align_to)-1))
  56#define ALIGN_PTR_UP(p, ptr_align_to)   ((typeof(p))ALIGN_UP((unsigned long)(p),        ptr_align_to))
  57#define ALIGN_PTR_DOWN(p, ptr_align_to) ((typeof(p))ALIGN_DOWN((unsigned long)(p),      ptr_align_to))
  58#define __stringify_1(x...)     #x
  59#define __stringify(x...)       __stringify_1(x)
  60
  61#define PTR_ERR_ENOTSUP ((void *)-ENOTSUP)
  62
  63int dprint_in_signal;
  64char dprint_in_signal_buffer[DPRINT_IN_SIGNAL_BUF_SIZE];
  65
  66extern void abort_hooks(void);
  67#define pkey_assert(condition) do {             \
  68        if (!(condition)) {                     \
  69                dprintf0("assert() at %s::%d test_nr: %d iteration: %d\n", \
  70                                __FILE__, __LINE__,     \
  71                                test_nr, iteration_nr); \
  72                dprintf0("errno at assert: %d", errno); \
  73                abort_hooks();                  \
  74                assert(condition);              \
  75        }                                       \
  76} while (0)
  77#define raw_assert(cond) assert(cond)
  78
  79void cat_into_file(char *str, char *file)
  80{
  81        int fd = open(file, O_RDWR);
  82        int ret;
  83
  84        dprintf2("%s(): writing '%s' to '%s'\n", __func__, str, file);
  85        /*
  86         * these need to be raw because they are called under
  87         * pkey_assert()
  88         */
  89        raw_assert(fd >= 0);
  90        ret = write(fd, str, strlen(str));
  91        if (ret != strlen(str)) {
  92                perror("write to file failed");
  93                fprintf(stderr, "filename: '%s' str: '%s'\n", file, str);
  94                raw_assert(0);
  95        }
  96        close(fd);
  97}
  98
  99#if CONTROL_TRACING > 0
 100static int warned_tracing;
 101int tracing_root_ok(void)
 102{
 103        if (geteuid() != 0) {
 104                if (!warned_tracing)
 105                        fprintf(stderr, "WARNING: not run as root, "
 106                                        "can not do tracing control\n");
 107                warned_tracing = 1;
 108                return 0;
 109        }
 110        return 1;
 111}
 112#endif
 113
 114void tracing_on(void)
 115{
 116#if CONTROL_TRACING > 0
 117#define TRACEDIR "/sys/kernel/debug/tracing"
 118        char pidstr[32];
 119
 120        if (!tracing_root_ok())
 121                return;
 122
 123        sprintf(pidstr, "%d", getpid());
 124        cat_into_file("0", TRACEDIR "/tracing_on");
 125        cat_into_file("\n", TRACEDIR "/trace");
 126        if (1) {
 127                cat_into_file("function_graph", TRACEDIR "/current_tracer");
 128                cat_into_file("1", TRACEDIR "/options/funcgraph-proc");
 129        } else {
 130                cat_into_file("nop", TRACEDIR "/current_tracer");
 131        }
 132        cat_into_file(pidstr, TRACEDIR "/set_ftrace_pid");
 133        cat_into_file("1", TRACEDIR "/tracing_on");
 134        dprintf1("enabled tracing\n");
 135#endif
 136}
 137
 138void tracing_off(void)
 139{
 140#if CONTROL_TRACING > 0
 141        if (!tracing_root_ok())
 142                return;
 143        cat_into_file("0", "/sys/kernel/debug/tracing/tracing_on");
 144#endif
 145}
 146
 147void abort_hooks(void)
 148{
 149        fprintf(stderr, "running %s()...\n", __func__);
 150        tracing_off();
 151#ifdef SLEEP_ON_ABORT
 152        sleep(SLEEP_ON_ABORT);
 153#endif
 154}
 155
 156static inline void __page_o_noops(void)
 157{
 158        /* 8-bytes of instruction * 512 bytes = 1 page */
 159        asm(".rept 512 ; nopl 0x7eeeeeee(%eax) ; .endr");
 160}
 161
 162/*
 163 * This attempts to have roughly a page of instructions followed by a few
 164 * instructions that do a write, and another page of instructions.  That
 165 * way, we are pretty sure that the write is in the second page of
 166 * instructions and has at least a page of padding behind it.
 167 *
 168 * *That* lets us be sure to madvise() away the write instruction, which
 169 * will then fault, which makes sure that the fault code handles
 170 * execute-only memory properly.
 171 */
 172__attribute__((__aligned__(PAGE_SIZE)))
 173void lots_o_noops_around_write(int *write_to_me)
 174{
 175        dprintf3("running %s()\n", __func__);
 176        __page_o_noops();
 177        /* Assume this happens in the second page of instructions: */
 178        *write_to_me = __LINE__;
 179        /* pad out by another page: */
 180        __page_o_noops();
 181        dprintf3("%s() done\n", __func__);
 182}
 183
 184/* Define some kernel-like types */
 185#define  u8 uint8_t
 186#define u16 uint16_t
 187#define u32 uint32_t
 188#define u64 uint64_t
 189
 190#ifdef __i386__
 191#define SYS_mprotect_key 380
 192#define SYS_pkey_alloc   381
 193#define SYS_pkey_free    382
 194#define REG_IP_IDX REG_EIP
 195#define si_pkey_offset 0x18
 196#else
 197#define SYS_mprotect_key 329
 198#define SYS_pkey_alloc   330
 199#define SYS_pkey_free    331
 200#define REG_IP_IDX REG_RIP
 201#define si_pkey_offset 0x20
 202#endif
 203
 204void dump_mem(void *dumpme, int len_bytes)
 205{
 206        char *c = (void *)dumpme;
 207        int i;
 208
 209        for (i = 0; i < len_bytes; i += sizeof(u64)) {
 210                u64 *ptr = (u64 *)(c + i);
 211                dprintf1("dump[%03d][@%p]: %016jx\n", i, ptr, *ptr);
 212        }
 213}
 214
 215#define __SI_FAULT      (3 << 16)
 216#define SEGV_BNDERR     (__SI_FAULT|3)  /* failed address bound checks */
 217#define SEGV_PKUERR     (__SI_FAULT|4)
 218
 219static char *si_code_str(int si_code)
 220{
 221        if (si_code & SEGV_MAPERR)
 222                return "SEGV_MAPERR";
 223        if (si_code & SEGV_ACCERR)
 224                return "SEGV_ACCERR";
 225        if (si_code & SEGV_BNDERR)
 226                return "SEGV_BNDERR";
 227        if (si_code & SEGV_PKUERR)
 228                return "SEGV_PKUERR";
 229        return "UNKNOWN";
 230}
 231
 232int pkru_faults;
 233int last_si_pkey = -1;
 234void signal_handler(int signum, siginfo_t *si, void *vucontext)
 235{
 236        ucontext_t *uctxt = vucontext;
 237        int trapno;
 238        unsigned long ip;
 239        char *fpregs;
 240        u32 *pkru_ptr;
 241        u64 si_pkey;
 242        u32 *si_pkey_ptr;
 243        int pkru_offset;
 244        fpregset_t fpregset;
 245
 246        dprint_in_signal = 1;
 247        dprintf1(">>>>===============SIGSEGV============================\n");
 248        dprintf1("%s()::%d, pkru: 0x%x shadow: %x\n", __func__, __LINE__,
 249                        __rdpkru(), shadow_pkru);
 250
 251        trapno = uctxt->uc_mcontext.gregs[REG_TRAPNO];
 252        ip = uctxt->uc_mcontext.gregs[REG_IP_IDX];
 253        fpregset = uctxt->uc_mcontext.fpregs;
 254        fpregs = (void *)fpregset;
 255
 256        dprintf2("%s() trapno: %d ip: 0x%lx info->si_code: %s/%d\n", __func__,
 257                        trapno, ip, si_code_str(si->si_code), si->si_code);
 258#ifdef __i386__
 259        /*
 260         * 32-bit has some extra padding so that userspace can tell whether
 261         * the XSTATE header is present in addition to the "legacy" FPU
 262         * state.  We just assume that it is here.
 263         */
 264        fpregs += 0x70;
 265#endif
 266        pkru_offset = pkru_xstate_offset();
 267        pkru_ptr = (void *)(&fpregs[pkru_offset]);
 268
 269        dprintf1("siginfo: %p\n", si);
 270        dprintf1(" fpregs: %p\n", fpregs);
 271        /*
 272         * If we got a PKRU fault, we *HAVE* to have at least one bit set in
 273         * here.
 274         */
 275        dprintf1("pkru_xstate_offset: %d\n", pkru_xstate_offset());
 276        if (DEBUG_LEVEL > 4)
 277                dump_mem(pkru_ptr - 128, 256);
 278        pkey_assert(*pkru_ptr);
 279
 280        si_pkey_ptr = (u32 *)(((u8 *)si) + si_pkey_offset);
 281        dprintf1("si_pkey_ptr: %p\n", si_pkey_ptr);
 282        dump_mem(si_pkey_ptr - 8, 24);
 283        si_pkey = *si_pkey_ptr;
 284        pkey_assert(si_pkey < NR_PKEYS);
 285        last_si_pkey = si_pkey;
 286
 287        if ((si->si_code == SEGV_MAPERR) ||
 288            (si->si_code == SEGV_ACCERR) ||
 289            (si->si_code == SEGV_BNDERR)) {
 290                printf("non-PK si_code, exiting...\n");
 291                exit(4);
 292        }
 293
 294        dprintf1("signal pkru from xsave: %08x\n", *pkru_ptr);
 295        /* need __rdpkru() version so we do not do shadow_pkru checking */
 296        dprintf1("signal pkru from  pkru: %08x\n", __rdpkru());
 297        dprintf1("si_pkey from siginfo: %jx\n", si_pkey);
 298        *(u64 *)pkru_ptr = 0x00000000;
 299        dprintf1("WARNING: set PRKU=0 to allow faulting instruction to continue\n");
 300        pkru_faults++;
 301        dprintf1("<<<<==================================================\n");
 302        return;
 303        if (trapno == 14) {
 304                fprintf(stderr,
 305                        "ERROR: In signal handler, page fault, trapno = %d, ip = %016lx\n",
 306                        trapno, ip);
 307                fprintf(stderr, "si_addr %p\n", si->si_addr);
 308                fprintf(stderr, "REG_ERR: %lx\n",
 309                                (unsigned long)uctxt->uc_mcontext.gregs[REG_ERR]);
 310                exit(1);
 311        } else {
 312                fprintf(stderr, "unexpected trap %d! at 0x%lx\n", trapno, ip);
 313                fprintf(stderr, "si_addr %p\n", si->si_addr);
 314                fprintf(stderr, "REG_ERR: %lx\n",
 315                                (unsigned long)uctxt->uc_mcontext.gregs[REG_ERR]);
 316                exit(2);
 317        }
 318        dprint_in_signal = 0;
 319}
 320
 321int wait_all_children(void)
 322{
 323        int status;
 324        return waitpid(-1, &status, 0);
 325}
 326
 327void sig_chld(int x)
 328{
 329        dprint_in_signal = 1;
 330        dprintf2("[%d] SIGCHLD: %d\n", getpid(), x);
 331        dprint_in_signal = 0;
 332}
 333
 334void setup_sigsegv_handler(void)
 335{
 336        int r, rs;
 337        struct sigaction newact;
 338        struct sigaction oldact;
 339
 340        /* #PF is mapped to sigsegv */
 341        int signum  = SIGSEGV;
 342
 343        newact.sa_handler = 0;
 344        newact.sa_sigaction = signal_handler;
 345
 346        /*sigset_t - signals to block while in the handler */
 347        /* get the old signal mask. */
 348        rs = sigprocmask(SIG_SETMASK, 0, &newact.sa_mask);
 349        pkey_assert(rs == 0);
 350
 351        /* call sa_sigaction, not sa_handler*/
 352        newact.sa_flags = SA_SIGINFO;
 353
 354        newact.sa_restorer = 0;  /* void(*)(), obsolete */
 355        r = sigaction(signum, &newact, &oldact);
 356        r = sigaction(SIGALRM, &newact, &oldact);
 357        pkey_assert(r == 0);
 358}
 359
 360void setup_handlers(void)
 361{
 362        signal(SIGCHLD, &sig_chld);
 363        setup_sigsegv_handler();
 364}
 365
 366pid_t fork_lazy_child(void)
 367{
 368        pid_t forkret;
 369
 370        forkret = fork();
 371        pkey_assert(forkret >= 0);
 372        dprintf3("[%d] fork() ret: %d\n", getpid(), forkret);
 373
 374        if (!forkret) {
 375                /* in the child */
 376                while (1) {
 377                        dprintf1("child sleeping...\n");
 378                        sleep(30);
 379                }
 380        }
 381        return forkret;
 382}
 383
 384void davecmp(void *_a, void *_b, int len)
 385{
 386        int i;
 387        unsigned long *a = _a;
 388        unsigned long *b = _b;
 389
 390        for (i = 0; i < len / sizeof(*a); i++) {
 391                if (a[i] == b[i])
 392                        continue;
 393
 394                dprintf3("[%3d]: a: %016lx b: %016lx\n", i, a[i], b[i]);
 395        }
 396}
 397
 398void dumpit(char *f)
 399{
 400        int fd = open(f, O_RDONLY);
 401        char buf[100];
 402        int nr_read;
 403
 404        dprintf2("maps fd: %d\n", fd);
 405        do {
 406                nr_read = read(fd, &buf[0], sizeof(buf));
 407                write(1, buf, nr_read);
 408        } while (nr_read > 0);
 409        close(fd);
 410}
 411
 412#define PKEY_DISABLE_ACCESS    0x1
 413#define PKEY_DISABLE_WRITE     0x2
 414
 415u32 pkey_get(int pkey, unsigned long flags)
 416{
 417        u32 mask = (PKEY_DISABLE_ACCESS|PKEY_DISABLE_WRITE);
 418        u32 pkru = __rdpkru();
 419        u32 shifted_pkru;
 420        u32 masked_pkru;
 421
 422        dprintf1("%s(pkey=%d, flags=%lx) = %x / %d\n",
 423                        __func__, pkey, flags, 0, 0);
 424        dprintf2("%s() raw pkru: %x\n", __func__, pkru);
 425
 426        shifted_pkru = (pkru >> (pkey * PKRU_BITS_PER_PKEY));
 427        dprintf2("%s() shifted_pkru: %x\n", __func__, shifted_pkru);
 428        masked_pkru = shifted_pkru & mask;
 429        dprintf2("%s() masked  pkru: %x\n", __func__, masked_pkru);
 430        /*
 431         * shift down the relevant bits to the lowest two, then
 432         * mask off all the other high bits.
 433         */
 434        return masked_pkru;
 435}
 436
 437int pkey_set(int pkey, unsigned long rights, unsigned long flags)
 438{
 439        u32 mask = (PKEY_DISABLE_ACCESS|PKEY_DISABLE_WRITE);
 440        u32 old_pkru = __rdpkru();
 441        u32 new_pkru;
 442
 443        /* make sure that 'rights' only contains the bits we expect: */
 444        assert(!(rights & ~mask));
 445
 446        /* copy old pkru */
 447        new_pkru = old_pkru;
 448        /* mask out bits from pkey in old value: */
 449        new_pkru &= ~(mask << (pkey * PKRU_BITS_PER_PKEY));
 450        /* OR in new bits for pkey: */
 451        new_pkru |= (rights << (pkey * PKRU_BITS_PER_PKEY));
 452
 453        __wrpkru(new_pkru);
 454
 455        dprintf3("%s(pkey=%d, rights=%lx, flags=%lx) = %x pkru now: %x old_pkru: %x\n",
 456                        __func__, pkey, rights, flags, 0, __rdpkru(), old_pkru);
 457        return 0;
 458}
 459
 460void pkey_disable_set(int pkey, int flags)
 461{
 462        unsigned long syscall_flags = 0;
 463        int ret;
 464        int pkey_rights;
 465        u32 orig_pkru;
 466
 467        dprintf1("START->%s(%d, 0x%x)\n", __func__,
 468                pkey, flags);
 469        pkey_assert(flags & (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE));
 470
 471        pkey_rights = pkey_get(pkey, syscall_flags);
 472
 473        dprintf1("%s(%d) pkey_get(%d): %x\n", __func__,
 474                        pkey, pkey, pkey_rights);
 475        pkey_assert(pkey_rights >= 0);
 476
 477        pkey_rights |= flags;
 478
 479        ret = pkey_set(pkey, pkey_rights, syscall_flags);
 480        assert(!ret);
 481        /*pkru and flags have the same format */
 482        shadow_pkru |= flags << (pkey * 2);
 483        dprintf1("%s(%d) shadow: 0x%x\n", __func__, pkey, shadow_pkru);
 484
 485        pkey_assert(ret >= 0);
 486
 487        pkey_rights = pkey_get(pkey, syscall_flags);
 488        dprintf1("%s(%d) pkey_get(%d): %x\n", __func__,
 489                        pkey, pkey, pkey_rights);
 490
 491        dprintf1("%s(%d) pkru: 0x%x\n", __func__, pkey, rdpkru());
 492        if (flags)
 493                pkey_assert(rdpkru() > orig_pkru);
 494        dprintf1("END<---%s(%d, 0x%x)\n", __func__,
 495                pkey, flags);
 496}
 497
 498void pkey_disable_clear(int pkey, int flags)
 499{
 500        unsigned long syscall_flags = 0;
 501        int ret;
 502        int pkey_rights = pkey_get(pkey, syscall_flags);
 503        u32 orig_pkru = rdpkru();
 504
 505        pkey_assert(flags & (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE));
 506
 507        dprintf1("%s(%d) pkey_get(%d): %x\n", __func__,
 508                        pkey, pkey, pkey_rights);
 509        pkey_assert(pkey_rights >= 0);
 510
 511        pkey_rights |= flags;
 512
 513        ret = pkey_set(pkey, pkey_rights, 0);
 514        /* pkru and flags have the same format */
 515        shadow_pkru &= ~(flags << (pkey * 2));
 516        pkey_assert(ret >= 0);
 517
 518        pkey_rights = pkey_get(pkey, syscall_flags);
 519        dprintf1("%s(%d) pkey_get(%d): %x\n", __func__,
 520                        pkey, pkey, pkey_rights);
 521
 522        dprintf1("%s(%d) pkru: 0x%x\n", __func__, pkey, rdpkru());
 523        if (flags)
 524                assert(rdpkru() > orig_pkru);
 525}
 526
 527void pkey_write_allow(int pkey)
 528{
 529        pkey_disable_clear(pkey, PKEY_DISABLE_WRITE);
 530}
 531void pkey_write_deny(int pkey)
 532{
 533        pkey_disable_set(pkey, PKEY_DISABLE_WRITE);
 534}
 535void pkey_access_allow(int pkey)
 536{
 537        pkey_disable_clear(pkey, PKEY_DISABLE_ACCESS);
 538}
 539void pkey_access_deny(int pkey)
 540{
 541        pkey_disable_set(pkey, PKEY_DISABLE_ACCESS);
 542}
 543
 544int sys_mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
 545                unsigned long pkey)
 546{
 547        int sret;
 548
 549        dprintf2("%s(0x%p, %zx, prot=%lx, pkey=%lx)\n", __func__,
 550                        ptr, size, orig_prot, pkey);
 551
 552        errno = 0;
 553        sret = syscall(SYS_mprotect_key, ptr, size, orig_prot, pkey);
 554        if (errno) {
 555                dprintf2("SYS_mprotect_key sret: %d\n", sret);
 556                dprintf2("SYS_mprotect_key prot: 0x%lx\n", orig_prot);
 557                dprintf2("SYS_mprotect_key failed, errno: %d\n", errno);
 558                if (DEBUG_LEVEL >= 2)
 559                        perror("SYS_mprotect_pkey");
 560        }
 561        return sret;
 562}
 563
 564int sys_pkey_alloc(unsigned long flags, unsigned long init_val)
 565{
 566        int ret = syscall(SYS_pkey_alloc, flags, init_val);
 567        dprintf1("%s(flags=%lx, init_val=%lx) syscall ret: %d errno: %d\n",
 568                        __func__, flags, init_val, ret, errno);
 569        return ret;
 570}
 571
 572int alloc_pkey(void)
 573{
 574        int ret;
 575        unsigned long init_val = 0x0;
 576
 577        dprintf1("alloc_pkey()::%d, pkru: 0x%x shadow: %x\n",
 578                        __LINE__, __rdpkru(), shadow_pkru);
 579        ret = sys_pkey_alloc(0, init_val);
 580        /*
 581         * pkey_alloc() sets PKRU, so we need to reflect it in
 582         * shadow_pkru:
 583         */
 584        dprintf4("alloc_pkey()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n",
 585                        __LINE__, ret, __rdpkru(), shadow_pkru);
 586        if (ret) {
 587                /* clear both the bits: */
 588                shadow_pkru &= ~(0x3      << (ret * 2));
 589                dprintf4("alloc_pkey()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n",
 590                                __LINE__, ret, __rdpkru(), shadow_pkru);
 591                /*
 592                 * move the new state in from init_val
 593                 * (remember, we cheated and init_val == pkru format)
 594                 */
 595                shadow_pkru |=  (init_val << (ret * 2));
 596        }
 597        dprintf4("alloc_pkey()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n",
 598                        __LINE__, ret, __rdpkru(), shadow_pkru);
 599        dprintf1("alloc_pkey()::%d errno: %d\n", __LINE__, errno);
 600        /* for shadow checking: */
 601        rdpkru();
 602        dprintf4("alloc_pkey()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n",
 603                        __LINE__, ret, __rdpkru(), shadow_pkru);
 604        return ret;
 605}
 606
 607int sys_pkey_free(unsigned long pkey)
 608{
 609        int ret = syscall(SYS_pkey_free, pkey);
 610        dprintf1("%s(pkey=%ld) syscall ret: %d\n", __func__, pkey, ret);
 611        return ret;
 612}
 613
 614/*
 615 * I had a bug where pkey bits could be set by mprotect() but
 616 * not cleared.  This ensures we get lots of random bit sets
 617 * and clears on the vma and pte pkey bits.
 618 */
 619int alloc_random_pkey(void)
 620{
 621        int max_nr_pkey_allocs;
 622        int ret;
 623        int i;
 624        int alloced_pkeys[NR_PKEYS];
 625        int nr_alloced = 0;
 626        int random_index;
 627        memset(alloced_pkeys, 0, sizeof(alloced_pkeys));
 628
 629        /* allocate every possible key and make a note of which ones we got */
 630        max_nr_pkey_allocs = NR_PKEYS;
 631        max_nr_pkey_allocs = 1;
 632        for (i = 0; i < max_nr_pkey_allocs; i++) {
 633                int new_pkey = alloc_pkey();
 634                if (new_pkey < 0)
 635                        break;
 636                alloced_pkeys[nr_alloced++] = new_pkey;
 637        }
 638
 639        pkey_assert(nr_alloced > 0);
 640        /* select a random one out of the allocated ones */
 641        random_index = rand() % nr_alloced;
 642        ret = alloced_pkeys[random_index];
 643        /* now zero it out so we don't free it next */
 644        alloced_pkeys[random_index] = 0;
 645
 646        /* go through the allocated ones that we did not want and free them */
 647        for (i = 0; i < nr_alloced; i++) {
 648                int free_ret;
 649                if (!alloced_pkeys[i])
 650                        continue;
 651                free_ret = sys_pkey_free(alloced_pkeys[i]);
 652                pkey_assert(!free_ret);
 653        }
 654        dprintf1("%s()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n", __func__,
 655                        __LINE__, ret, __rdpkru(), shadow_pkru);
 656        return ret;
 657}
 658
 659int mprotect_pkey(void *ptr, size_t size, unsigned long orig_prot,
 660                unsigned long pkey)
 661{
 662        int nr_iterations = random() % 100;
 663        int ret;
 664
 665        while (0) {
 666                int rpkey = alloc_random_pkey();
 667                ret = sys_mprotect_pkey(ptr, size, orig_prot, pkey);
 668                dprintf1("sys_mprotect_pkey(%p, %zx, prot=0x%lx, pkey=%ld) ret: %d\n",
 669                                ptr, size, orig_prot, pkey, ret);
 670                if (nr_iterations-- < 0)
 671                        break;
 672
 673                dprintf1("%s()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n", __func__,
 674                        __LINE__, ret, __rdpkru(), shadow_pkru);
 675                sys_pkey_free(rpkey);
 676                dprintf1("%s()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n", __func__,
 677                        __LINE__, ret, __rdpkru(), shadow_pkru);
 678        }
 679        pkey_assert(pkey < NR_PKEYS);
 680
 681        ret = sys_mprotect_pkey(ptr, size, orig_prot, pkey);
 682        dprintf1("mprotect_pkey(%p, %zx, prot=0x%lx, pkey=%ld) ret: %d\n",
 683                        ptr, size, orig_prot, pkey, ret);
 684        pkey_assert(!ret);
 685        dprintf1("%s()::%d, ret: %d pkru: 0x%x shadow: 0x%x\n", __func__,
 686                        __LINE__, ret, __rdpkru(), shadow_pkru);
 687        return ret;
 688}
 689
 690struct pkey_malloc_record {
 691        void *ptr;
 692        long size;
 693};
 694struct pkey_malloc_record *pkey_malloc_records;
 695long nr_pkey_malloc_records;
 696void record_pkey_malloc(void *ptr, long size)
 697{
 698        long i;
 699        struct pkey_malloc_record *rec = NULL;
 700
 701        for (i = 0; i < nr_pkey_malloc_records; i++) {
 702                rec = &pkey_malloc_records[i];
 703                /* find a free record */
 704                if (rec)
 705                        break;
 706        }
 707        if (!rec) {
 708                /* every record is full */
 709                size_t old_nr_records = nr_pkey_malloc_records;
 710                size_t new_nr_records = (nr_pkey_malloc_records * 2 + 1);
 711                size_t new_size = new_nr_records * sizeof(struct pkey_malloc_record);
 712                dprintf2("new_nr_records: %zd\n", new_nr_records);
 713                dprintf2("new_size: %zd\n", new_size);
 714                pkey_malloc_records = realloc(pkey_malloc_records, new_size);
 715                pkey_assert(pkey_malloc_records != NULL);
 716                rec = &pkey_malloc_records[nr_pkey_malloc_records];
 717                /*
 718                 * realloc() does not initialize memory, so zero it from
 719                 * the first new record all the way to the end.
 720                 */
 721                for (i = 0; i < new_nr_records - old_nr_records; i++)
 722                        memset(rec + i, 0, sizeof(*rec));
 723        }
 724        dprintf3("filling malloc record[%d/%p]: {%p, %ld}\n",
 725                (int)(rec - pkey_malloc_records), rec, ptr, size);
 726        rec->ptr = ptr;
 727        rec->size = size;
 728        nr_pkey_malloc_records++;
 729}
 730
 731void free_pkey_malloc(void *ptr)
 732{
 733        long i;
 734        int ret;
 735        dprintf3("%s(%p)\n", __func__, ptr);
 736        for (i = 0; i < nr_pkey_malloc_records; i++) {
 737                struct pkey_malloc_record *rec = &pkey_malloc_records[i];
 738                dprintf4("looking for ptr %p at record[%ld/%p]: {%p, %ld}\n",
 739                                ptr, i, rec, rec->ptr, rec->size);
 740                if ((ptr <  rec->ptr) ||
 741                    (ptr >= rec->ptr + rec->size))
 742                        continue;
 743
 744                dprintf3("found ptr %p at record[%ld/%p]: {%p, %ld}\n",
 745                                ptr, i, rec, rec->ptr, rec->size);
 746                nr_pkey_malloc_records--;
 747                ret = munmap(rec->ptr, rec->size);
 748                dprintf3("munmap ret: %d\n", ret);
 749                pkey_assert(!ret);
 750                dprintf3("clearing rec->ptr, rec: %p\n", rec);
 751                rec->ptr = NULL;
 752                dprintf3("done clearing rec->ptr, rec: %p\n", rec);
 753                return;
 754        }
 755        pkey_assert(false);
 756}
 757
 758
 759void *malloc_pkey_with_mprotect(long size, int prot, u16 pkey)
 760{
 761        void *ptr;
 762        int ret;
 763
 764        rdpkru();
 765        dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
 766                        size, prot, pkey);
 767        pkey_assert(pkey < NR_PKEYS);
 768        ptr = mmap(NULL, size, prot, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
 769        pkey_assert(ptr != (void *)-1);
 770        ret = mprotect_pkey((void *)ptr, PAGE_SIZE, prot, pkey);
 771        pkey_assert(!ret);
 772        record_pkey_malloc(ptr, size);
 773        rdpkru();
 774
 775        dprintf1("%s() for pkey %d @ %p\n", __func__, pkey, ptr);
 776        return ptr;
 777}
 778
 779void *malloc_pkey_anon_huge(long size, int prot, u16 pkey)
 780{
 781        int ret;
 782        void *ptr;
 783
 784        dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
 785                        size, prot, pkey);
 786        /*
 787         * Guarantee we can fit at least one huge page in the resulting
 788         * allocation by allocating space for 2:
 789         */
 790        size = ALIGN_UP(size, HPAGE_SIZE * 2);
 791        ptr = mmap(NULL, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
 792        pkey_assert(ptr != (void *)-1);
 793        record_pkey_malloc(ptr, size);
 794        mprotect_pkey(ptr, size, prot, pkey);
 795
 796        dprintf1("unaligned ptr: %p\n", ptr);
 797        ptr = ALIGN_PTR_UP(ptr, HPAGE_SIZE);
 798        dprintf1("  aligned ptr: %p\n", ptr);
 799        ret = madvise(ptr, HPAGE_SIZE, MADV_HUGEPAGE);
 800        dprintf1("MADV_HUGEPAGE ret: %d\n", ret);
 801        ret = madvise(ptr, HPAGE_SIZE, MADV_WILLNEED);
 802        dprintf1("MADV_WILLNEED ret: %d\n", ret);
 803        memset(ptr, 0, HPAGE_SIZE);
 804
 805        dprintf1("mmap()'d thp for pkey %d @ %p\n", pkey, ptr);
 806        return ptr;
 807}
 808
 809int hugetlb_setup_ok;
 810#define GET_NR_HUGE_PAGES 10
 811void setup_hugetlbfs(void)
 812{
 813        int err;
 814        int fd;
 815        int validated_nr_pages;
 816        int i;
 817        char buf[] = "123";
 818
 819        if (geteuid() != 0) {
 820                fprintf(stderr, "WARNING: not run as root, can not do hugetlb test\n");
 821                return;
 822        }
 823
 824        cat_into_file(__stringify(GET_NR_HUGE_PAGES), "/proc/sys/vm/nr_hugepages");
 825
 826        /*
 827         * Now go make sure that we got the pages and that they
 828         * are 2M pages.  Someone might have made 1G the default.
 829         */
 830        fd = open("/sys/kernel/mm/hugepages/hugepages-2048kB/nr_hugepages", O_RDONLY);
 831        if (fd < 0) {
 832                perror("opening sysfs 2M hugetlb config");
 833                return;
 834        }
 835
 836        /* -1 to guarantee leaving the trailing \0 */
 837        err = read(fd, buf, sizeof(buf)-1);
 838        close(fd);
 839        if (err <= 0) {
 840                perror("reading sysfs 2M hugetlb config");
 841                return;
 842        }
 843
 844        if (atoi(buf) != GET_NR_HUGE_PAGES) {
 845                fprintf(stderr, "could not confirm 2M pages, got: '%s' expected %d\n",
 846                        buf, GET_NR_HUGE_PAGES);
 847                return;
 848        }
 849
 850        hugetlb_setup_ok = 1;
 851}
 852
 853void *malloc_pkey_hugetlb(long size, int prot, u16 pkey)
 854{
 855        void *ptr;
 856        int flags = MAP_ANONYMOUS|MAP_PRIVATE|MAP_HUGETLB;
 857
 858        if (!hugetlb_setup_ok)
 859                return PTR_ERR_ENOTSUP;
 860
 861        dprintf1("doing %s(%ld, %x, %x)\n", __func__, size, prot, pkey);
 862        size = ALIGN_UP(size, HPAGE_SIZE * 2);
 863        pkey_assert(pkey < NR_PKEYS);
 864        ptr = mmap(NULL, size, PROT_NONE, flags, -1, 0);
 865        pkey_assert(ptr != (void *)-1);
 866        mprotect_pkey(ptr, size, prot, pkey);
 867
 868        record_pkey_malloc(ptr, size);
 869
 870        dprintf1("mmap()'d hugetlbfs for pkey %d @ %p\n", pkey, ptr);
 871        return ptr;
 872}
 873
 874void *malloc_pkey_mmap_dax(long size, int prot, u16 pkey)
 875{
 876        void *ptr;
 877        int fd;
 878
 879        dprintf1("doing %s(size=%ld, prot=0x%x, pkey=%d)\n", __func__,
 880                        size, prot, pkey);
 881        pkey_assert(pkey < NR_PKEYS);
 882        fd = open("/dax/foo", O_RDWR);
 883        pkey_assert(fd >= 0);
 884
 885        ptr = mmap(0, size, prot, MAP_SHARED, fd, 0);
 886        pkey_assert(ptr != (void *)-1);
 887
 888        mprotect_pkey(ptr, size, prot, pkey);
 889
 890        record_pkey_malloc(ptr, size);
 891
 892        dprintf1("mmap()'d for pkey %d @ %p\n", pkey, ptr);
 893        close(fd);
 894        return ptr;
 895}
 896
 897void *(*pkey_malloc[])(long size, int prot, u16 pkey) = {
 898
 899        malloc_pkey_with_mprotect,
 900        malloc_pkey_anon_huge,
 901        malloc_pkey_hugetlb
 902/* can not do direct with the pkey_mprotect() API:
 903        malloc_pkey_mmap_direct,
 904        malloc_pkey_mmap_dax,
 905*/
 906};
 907
 908void *malloc_pkey(long size, int prot, u16 pkey)
 909{
 910        void *ret;
 911        static int malloc_type;
 912        int nr_malloc_types = ARRAY_SIZE(pkey_malloc);
 913
 914        pkey_assert(pkey < NR_PKEYS);
 915
 916        while (1) {
 917                pkey_assert(malloc_type < nr_malloc_types);
 918
 919                ret = pkey_malloc[malloc_type](size, prot, pkey);
 920                pkey_assert(ret != (void *)-1);
 921
 922                malloc_type++;
 923                if (malloc_type >= nr_malloc_types)
 924                        malloc_type = (random()%nr_malloc_types);
 925
 926                /* try again if the malloc_type we tried is unsupported */
 927                if (ret == PTR_ERR_ENOTSUP)
 928                        continue;
 929
 930                break;
 931        }
 932
 933        dprintf3("%s(%ld, prot=%x, pkey=%x) returning: %p\n", __func__,
 934                        size, prot, pkey, ret);
 935        return ret;
 936}
 937
 938int last_pkru_faults;
 939void expected_pk_fault(int pkey)
 940{
 941        dprintf2("%s(): last_pkru_faults: %d pkru_faults: %d\n",
 942                        __func__, last_pkru_faults, pkru_faults);
 943        dprintf2("%s(%d): last_si_pkey: %d\n", __func__, pkey, last_si_pkey);
 944        pkey_assert(last_pkru_faults + 1 == pkru_faults);
 945        pkey_assert(last_si_pkey == pkey);
 946        /*
 947         * The signal handler shold have cleared out PKRU to let the
 948         * test program continue.  We now have to restore it.
 949         */
 950        if (__rdpkru() != 0)
 951                pkey_assert(0);
 952
 953        __wrpkru(shadow_pkru);
 954        dprintf1("%s() set PKRU=%x to restore state after signal nuked it\n",
 955                        __func__, shadow_pkru);
 956        last_pkru_faults = pkru_faults;
 957        last_si_pkey = -1;
 958}
 959
 960void do_not_expect_pk_fault(void)
 961{
 962        pkey_assert(last_pkru_faults == pkru_faults);
 963}
 964
 965int test_fds[10] = { -1 };
 966int nr_test_fds;
 967void __save_test_fd(int fd)
 968{
 969        pkey_assert(fd >= 0);
 970        pkey_assert(nr_test_fds < ARRAY_SIZE(test_fds));
 971        test_fds[nr_test_fds] = fd;
 972        nr_test_fds++;
 973}
 974
 975int get_test_read_fd(void)
 976{
 977        int test_fd = open("/etc/passwd", O_RDONLY);
 978        __save_test_fd(test_fd);
 979        return test_fd;
 980}
 981
 982void close_test_fds(void)
 983{
 984        int i;
 985
 986        for (i = 0; i < nr_test_fds; i++) {
 987                if (test_fds[i] < 0)
 988                        continue;
 989                close(test_fds[i]);
 990                test_fds[i] = -1;
 991        }
 992        nr_test_fds = 0;
 993}
 994
 995#define barrier() __asm__ __volatile__("": : :"memory")
 996__attribute__((noinline)) int read_ptr(int *ptr)
 997{
 998        /*
 999         * Keep GCC from optimizing this away somehow
1000         */
1001        barrier();
1002        return *ptr;
1003}
1004
1005void test_read_of_write_disabled_region(int *ptr, u16 pkey)
1006{
1007        int ptr_contents;
1008
1009        dprintf1("disabling write access to PKEY[1], doing read\n");
1010        pkey_write_deny(pkey);
1011        ptr_contents = read_ptr(ptr);
1012        dprintf1("*ptr: %d\n", ptr_contents);
1013        dprintf1("\n");
1014}
1015void test_read_of_access_disabled_region(int *ptr, u16 pkey)
1016{
1017        int ptr_contents;
1018
1019        dprintf1("disabling access to PKEY[%02d], doing read @ %p\n", pkey, ptr);
1020        rdpkru();
1021        pkey_access_deny(pkey);
1022        ptr_contents = read_ptr(ptr);
1023        dprintf1("*ptr: %d\n", ptr_contents);
1024        expected_pk_fault(pkey);
1025}
1026void test_write_of_write_disabled_region(int *ptr, u16 pkey)
1027{
1028        dprintf1("disabling write access to PKEY[%02d], doing write\n", pkey);
1029        pkey_write_deny(pkey);
1030        *ptr = __LINE__;
1031        expected_pk_fault(pkey);
1032}
1033void test_write_of_access_disabled_region(int *ptr, u16 pkey)
1034{
1035        dprintf1("disabling access to PKEY[%02d], doing write\n", pkey);
1036        pkey_access_deny(pkey);
1037        *ptr = __LINE__;
1038        expected_pk_fault(pkey);
1039}
1040void test_kernel_write_of_access_disabled_region(int *ptr, u16 pkey)
1041{
1042        int ret;
1043        int test_fd = get_test_read_fd();
1044
1045        dprintf1("disabling access to PKEY[%02d], "
1046                 "having kernel read() to buffer\n", pkey);
1047        pkey_access_deny(pkey);
1048        ret = read(test_fd, ptr, 1);
1049        dprintf1("read ret: %d\n", ret);
1050        pkey_assert(ret);
1051}
1052void test_kernel_write_of_write_disabled_region(int *ptr, u16 pkey)
1053{
1054        int ret;
1055        int test_fd = get_test_read_fd();
1056
1057        pkey_write_deny(pkey);
1058        ret = read(test_fd, ptr, 100);
1059        dprintf1("read ret: %d\n", ret);
1060        if (ret < 0 && (DEBUG_LEVEL > 0))
1061                perror("verbose read result (OK for this to be bad)");
1062        pkey_assert(ret);
1063}
1064
1065void test_kernel_gup_of_access_disabled_region(int *ptr, u16 pkey)
1066{
1067        int pipe_ret, vmsplice_ret;
1068        struct iovec iov;
1069        int pipe_fds[2];
1070
1071        pipe_ret = pipe(pipe_fds);
1072
1073        pkey_assert(pipe_ret == 0);
1074        dprintf1("disabling access to PKEY[%02d], "
1075                 "having kernel vmsplice from buffer\n", pkey);
1076        pkey_access_deny(pkey);
1077        iov.iov_base = ptr;
1078        iov.iov_len = PAGE_SIZE;
1079        vmsplice_ret = vmsplice(pipe_fds[1], &iov, 1, SPLICE_F_GIFT);
1080        dprintf1("vmsplice() ret: %d\n", vmsplice_ret);
1081        pkey_assert(vmsplice_ret == -1);
1082
1083        close(pipe_fds[0]);
1084        close(pipe_fds[1]);
1085}
1086
1087void test_kernel_gup_write_to_write_disabled_region(int *ptr, u16 pkey)
1088{
1089        int ignored = 0xdada;
1090        int futex_ret;
1091        int some_int = __LINE__;
1092
1093        dprintf1("disabling write to PKEY[%02d], "
1094                 "doing futex gunk in buffer\n", pkey);
1095        *ptr = some_int;
1096        pkey_write_deny(pkey);
1097        futex_ret = syscall(SYS_futex, ptr, FUTEX_WAIT, some_int-1, NULL,
1098                        &ignored, ignored);
1099        if (DEBUG_LEVEL > 0)
1100                perror("futex");
1101        dprintf1("futex() ret: %d\n", futex_ret);
1102}
1103
1104/* Assumes that all pkeys other than 'pkey' are unallocated */
1105void test_pkey_syscalls_on_non_allocated_pkey(int *ptr, u16 pkey)
1106{
1107        int err;
1108        int i;
1109
1110        /* Note: 0 is the default pkey, so don't mess with it */
1111        for (i = 1; i < NR_PKEYS; i++) {
1112                if (pkey == i)
1113                        continue;
1114
1115                dprintf1("trying get/set/free to non-allocated pkey: %2d\n", i);
1116                err = sys_pkey_free(i);
1117                pkey_assert(err);
1118
1119                /* not enforced when pkey_get() is not a syscall
1120                err = pkey_get(i, 0);
1121                pkey_assert(err < 0);
1122                */
1123
1124                err = sys_pkey_free(i);
1125                pkey_assert(err);
1126
1127                err = sys_mprotect_pkey(ptr, PAGE_SIZE, PROT_READ, i);
1128                pkey_assert(err);
1129        }
1130}
1131
1132/* Assumes that all pkeys other than 'pkey' are unallocated */
1133void test_pkey_syscalls_bad_args(int *ptr, u16 pkey)
1134{
1135        int err;
1136        int bad_flag = (PKEY_DISABLE_ACCESS | PKEY_DISABLE_WRITE) + 1;
1137        int bad_pkey = NR_PKEYS+99;
1138
1139        /* not enforced when pkey_get() is not a syscall
1140        err = pkey_get(bad_pkey, bad_flag);
1141        pkey_assert(err < 0);
1142        */
1143
1144        /* pass a known-invalid pkey in: */
1145        err = sys_mprotect_pkey(ptr, PAGE_SIZE, PROT_READ, bad_pkey);
1146        pkey_assert(err);
1147}
1148
1149/* Assumes that all pkeys other than 'pkey' are unallocated */
1150void test_pkey_alloc_exhaust(int *ptr, u16 pkey)
1151{
1152        unsigned long flags;
1153        unsigned long init_val;
1154        int err;
1155        int allocated_pkeys[NR_PKEYS] = {0};
1156        int nr_allocated_pkeys = 0;
1157        int i;
1158
1159        for (i = 0; i < NR_PKEYS*2; i++) {
1160                int new_pkey;
1161                dprintf1("%s() alloc loop: %d\n", __func__, i);
1162                new_pkey = alloc_pkey();
1163                dprintf4("%s()::%d, err: %d pkru: 0x%x shadow: 0x%x\n", __func__,
1164                                __LINE__, err, __rdpkru(), shadow_pkru);
1165                rdpkru(); /* for shadow checking */
1166                dprintf2("%s() errno: %d ENOSPC: %d\n", __func__, errno, ENOSPC);
1167                if ((new_pkey == -1) && (errno == ENOSPC)) {
1168                        dprintf2("%s() failed to allocate pkey after %d tries\n",
1169                                __func__, nr_allocated_pkeys);
1170                        break;
1171                }
1172                pkey_assert(nr_allocated_pkeys < NR_PKEYS);
1173                allocated_pkeys[nr_allocated_pkeys++] = new_pkey;
1174        }
1175
1176        dprintf3("%s()::%d\n", __func__, __LINE__);
1177
1178        /*
1179         * ensure it did not reach the end of the loop without
1180         * failure:
1181         */
1182        pkey_assert(i < NR_PKEYS*2);
1183
1184        /*
1185         * There are 16 pkeys supported in hardware.  One is taken
1186         * up for the default (0) and another can be taken up by
1187         * an execute-only mapping.  Ensure that we can allocate
1188         * at least 14 (16-2).
1189         */
1190        pkey_assert(i >= NR_PKEYS-2);
1191
1192        for (i = 0; i < nr_allocated_pkeys; i++) {
1193                err = sys_pkey_free(allocated_pkeys[i]);
1194                pkey_assert(!err);
1195                rdpkru(); /* for shadow checking */
1196        }
1197}
1198
1199void test_ptrace_of_child(int *ptr, u16 pkey)
1200{
1201        __attribute__((__unused__)) int peek_result;
1202        pid_t child_pid;
1203        void *ignored = 0;
1204        long ret;
1205        int status;
1206        /*
1207         * This is the "control" for our little expermient.  Make sure
1208         * we can always access it when ptracing.
1209         */
1210        int *plain_ptr_unaligned = malloc(HPAGE_SIZE);
1211        int *plain_ptr = ALIGN_PTR_UP(plain_ptr_unaligned, PAGE_SIZE);
1212
1213        /*
1214         * Fork a child which is an exact copy of this process, of course.
1215         * That means we can do all of our tests via ptrace() and then plain
1216         * memory access and ensure they work differently.
1217         */
1218        child_pid = fork_lazy_child();
1219        dprintf1("[%d] child pid: %d\n", getpid(), child_pid);
1220
1221        ret = ptrace(PTRACE_ATTACH, child_pid, ignored, ignored);
1222        if (ret)
1223                perror("attach");
1224        dprintf1("[%d] attach ret: %ld %d\n", getpid(), ret, __LINE__);
1225        pkey_assert(ret != -1);
1226        ret = waitpid(child_pid, &status, WUNTRACED);
1227        if ((ret != child_pid) || !(WIFSTOPPED(status))) {
1228                fprintf(stderr, "weird waitpid result %ld stat %x\n",
1229                                ret, status);
1230                pkey_assert(0);
1231        }
1232        dprintf2("waitpid ret: %ld\n", ret);
1233        dprintf2("waitpid status: %d\n", status);
1234
1235        pkey_access_deny(pkey);
1236        pkey_write_deny(pkey);
1237
1238        /* Write access, untested for now:
1239        ret = ptrace(PTRACE_POKEDATA, child_pid, peek_at, data);
1240        pkey_assert(ret != -1);
1241        dprintf1("poke at %p: %ld\n", peek_at, ret);
1242        */
1243
1244        /*
1245         * Try to access the pkey-protected "ptr" via ptrace:
1246         */
1247        ret = ptrace(PTRACE_PEEKDATA, child_pid, ptr, ignored);
1248        /* expect it to work, without an error: */
1249        pkey_assert(ret != -1);
1250        /* Now access from the current task, and expect an exception: */
1251        peek_result = read_ptr(ptr);
1252        expected_pk_fault(pkey);
1253
1254        /*
1255         * Try to access the NON-pkey-protected "plain_ptr" via ptrace:
1256         */
1257        ret = ptrace(PTRACE_PEEKDATA, child_pid, plain_ptr, ignored);
1258        /* expect it to work, without an error: */
1259        pkey_assert(ret != -1);
1260        /* Now access from the current task, and expect NO exception: */
1261        peek_result = read_ptr(plain_ptr);
1262        do_not_expect_pk_fault();
1263
1264        ret = ptrace(PTRACE_DETACH, child_pid, ignored, 0);
1265        pkey_assert(ret != -1);
1266
1267        ret = kill(child_pid, SIGKILL);
1268        pkey_assert(ret != -1);
1269
1270        wait(&status);
1271
1272        free(plain_ptr_unaligned);
1273}
1274
1275void test_executing_on_unreadable_memory(int *ptr, u16 pkey)
1276{
1277        void *p1;
1278        int scratch;
1279        int ptr_contents;
1280        int ret;
1281
1282        p1 = ALIGN_PTR_UP(&lots_o_noops_around_write, PAGE_SIZE);
1283        dprintf3("&lots_o_noops: %p\n", &lots_o_noops_around_write);
1284        /* lots_o_noops_around_write should be page-aligned already */
1285        assert(p1 == &lots_o_noops_around_write);
1286
1287        /* Point 'p1' at the *second* page of the function: */
1288        p1 += PAGE_SIZE;
1289
1290        madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1291        lots_o_noops_around_write(&scratch);
1292        ptr_contents = read_ptr(p1);
1293        dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1294
1295        ret = mprotect_pkey(p1, PAGE_SIZE, PROT_EXEC, (u64)pkey);
1296        pkey_assert(!ret);
1297        pkey_access_deny(pkey);
1298
1299        dprintf2("pkru: %x\n", rdpkru());
1300
1301        /*
1302         * Make sure this is an *instruction* fault
1303         */
1304        madvise(p1, PAGE_SIZE, MADV_DONTNEED);
1305        lots_o_noops_around_write(&scratch);
1306        do_not_expect_pk_fault();
1307        ptr_contents = read_ptr(p1);
1308        dprintf2("ptr (%p) contents@%d: %x\n", p1, __LINE__, ptr_contents);
1309        expected_pk_fault(pkey);
1310}
1311
1312void test_mprotect_pkey_on_unsupported_cpu(int *ptr, u16 pkey)
1313{
1314        int size = PAGE_SIZE;
1315        int sret;
1316
1317        if (cpu_has_pku()) {
1318                dprintf1("SKIP: %s: no CPU support\n", __func__);
1319                return;
1320        }
1321
1322        sret = syscall(SYS_mprotect_key, ptr, size, PROT_READ, pkey);
1323        pkey_assert(sret < 0);
1324}
1325
1326void (*pkey_tests[])(int *ptr, u16 pkey) = {
1327        test_read_of_write_disabled_region,
1328        test_read_of_access_disabled_region,
1329        test_write_of_write_disabled_region,
1330        test_write_of_access_disabled_region,
1331        test_kernel_write_of_access_disabled_region,
1332        test_kernel_write_of_write_disabled_region,
1333        test_kernel_gup_of_access_disabled_region,
1334        test_kernel_gup_write_to_write_disabled_region,
1335        test_executing_on_unreadable_memory,
1336        test_ptrace_of_child,
1337        test_pkey_syscalls_on_non_allocated_pkey,
1338        test_pkey_syscalls_bad_args,
1339        test_pkey_alloc_exhaust,
1340};
1341
1342void run_tests_once(void)
1343{
1344        int *ptr;
1345        int prot = PROT_READ|PROT_WRITE;
1346
1347        for (test_nr = 0; test_nr < ARRAY_SIZE(pkey_tests); test_nr++) {
1348                int pkey;
1349                int orig_pkru_faults = pkru_faults;
1350
1351                dprintf1("======================\n");
1352                dprintf1("test %d preparing...\n", test_nr);
1353
1354                tracing_on();
1355                pkey = alloc_random_pkey();
1356                dprintf1("test %d starting with pkey: %d\n", test_nr, pkey);
1357                ptr = malloc_pkey(PAGE_SIZE, prot, pkey);
1358                dprintf1("test %d starting...\n", test_nr);
1359                pkey_tests[test_nr](ptr, pkey);
1360                dprintf1("freeing test memory: %p\n", ptr);
1361                free_pkey_malloc(ptr);
1362                sys_pkey_free(pkey);
1363
1364                dprintf1("pkru_faults: %d\n", pkru_faults);
1365                dprintf1("orig_pkru_faults: %d\n", orig_pkru_faults);
1366
1367                tracing_off();
1368                close_test_fds();
1369
1370                printf("test %2d PASSED (itertation %d)\n", test_nr, iteration_nr);
1371                dprintf1("======================\n\n");
1372        }
1373        iteration_nr++;
1374}
1375
1376void pkey_setup_shadow(void)
1377{
1378        shadow_pkru = __rdpkru();
1379}
1380
1381int main(void)
1382{
1383        int nr_iterations = 22;
1384
1385        setup_handlers();
1386
1387        printf("has pku: %d\n", cpu_has_pku());
1388
1389        if (!cpu_has_pku()) {
1390                int size = PAGE_SIZE;
1391                int *ptr;
1392
1393                printf("running PKEY tests for unsupported CPU/OS\n");
1394
1395                ptr  = mmap(NULL, size, PROT_NONE, MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
1396                assert(ptr != (void *)-1);
1397                test_mprotect_pkey_on_unsupported_cpu(ptr, 1);
1398                exit(0);
1399        }
1400
1401        pkey_setup_shadow();
1402        printf("startup pkru: %x\n", rdpkru());
1403        setup_hugetlbfs();
1404
1405        while (nr_iterations-- > 0)
1406                run_tests_once();
1407
1408        printf("done (all tests OK)\n");
1409        return 0;
1410}
1411