linux/arch/arm64/net/bpf_jit_comp.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * BPF JIT compiler for ARM64
   4 *
   5 * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
   6 */
   7
   8#define pr_fmt(fmt) "bpf_jit: " fmt
   9
  10#include <linux/bpf.h>
  11#include <linux/filter.h>
  12#include <linux/printk.h>
  13#include <linux/slab.h>
  14
  15#include <asm/byteorder.h>
  16#include <asm/cacheflush.h>
  17#include <asm/debug-monitors.h>
  18#include <asm/set_memory.h>
  19
  20#include "bpf_jit.h"
  21
  22#define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
  23#define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
  24#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
  25#define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
  26
  27/* Map BPF registers to A64 registers */
  28static const int bpf2a64[] = {
  29        /* return value from in-kernel function, and exit value from eBPF */
  30        [BPF_REG_0] = A64_R(7),
  31        /* arguments from eBPF program to in-kernel function */
  32        [BPF_REG_1] = A64_R(0),
  33        [BPF_REG_2] = A64_R(1),
  34        [BPF_REG_3] = A64_R(2),
  35        [BPF_REG_4] = A64_R(3),
  36        [BPF_REG_5] = A64_R(4),
  37        /* callee saved registers that in-kernel function will preserve */
  38        [BPF_REG_6] = A64_R(19),
  39        [BPF_REG_7] = A64_R(20),
  40        [BPF_REG_8] = A64_R(21),
  41        [BPF_REG_9] = A64_R(22),
  42        /* read-only frame pointer to access stack */
  43        [BPF_REG_FP] = A64_R(25),
  44        /* temporary registers for internal BPF JIT */
  45        [TMP_REG_1] = A64_R(10),
  46        [TMP_REG_2] = A64_R(11),
  47        [TMP_REG_3] = A64_R(12),
  48        /* tail_call_cnt */
  49        [TCALL_CNT] = A64_R(26),
  50        /* temporary register for blinding constants */
  51        [BPF_REG_AX] = A64_R(9),
  52};
  53
  54struct jit_ctx {
  55        const struct bpf_prog *prog;
  56        int idx;
  57        int epilogue_offset;
  58        int *offset;
  59        __le32 *image;
  60        u32 stack_size;
  61};
  62
  63static inline void emit(const u32 insn, struct jit_ctx *ctx)
  64{
  65        if (ctx->image != NULL)
  66                ctx->image[ctx->idx] = cpu_to_le32(insn);
  67
  68        ctx->idx++;
  69}
  70
  71static inline void emit_a64_mov_i(const int is64, const int reg,
  72                                  const s32 val, struct jit_ctx *ctx)
  73{
  74        u16 hi = val >> 16;
  75        u16 lo = val & 0xffff;
  76
  77        if (hi & 0x8000) {
  78                if (hi == 0xffff) {
  79                        emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
  80                } else {
  81                        emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
  82                        if (lo != 0xffff)
  83                                emit(A64_MOVK(is64, reg, lo, 0), ctx);
  84                }
  85        } else {
  86                emit(A64_MOVZ(is64, reg, lo, 0), ctx);
  87                if (hi)
  88                        emit(A64_MOVK(is64, reg, hi, 16), ctx);
  89        }
  90}
  91
  92static int i64_i16_blocks(const u64 val, bool inverse)
  93{
  94        return (((val >>  0) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
  95               (((val >> 16) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
  96               (((val >> 32) & 0xffff) != (inverse ? 0xffff : 0x0000)) +
  97               (((val >> 48) & 0xffff) != (inverse ? 0xffff : 0x0000));
  98}
  99
 100static inline void emit_a64_mov_i64(const int reg, const u64 val,
 101                                    struct jit_ctx *ctx)
 102{
 103        u64 nrm_tmp = val, rev_tmp = ~val;
 104        bool inverse;
 105        int shift;
 106
 107        if (!(nrm_tmp >> 32))
 108                return emit_a64_mov_i(0, reg, (u32)val, ctx);
 109
 110        inverse = i64_i16_blocks(nrm_tmp, true) < i64_i16_blocks(nrm_tmp, false);
 111        shift = max(round_down((inverse ? (fls64(rev_tmp) - 1) :
 112                                          (fls64(nrm_tmp) - 1)), 16), 0);
 113        if (inverse)
 114                emit(A64_MOVN(1, reg, (rev_tmp >> shift) & 0xffff, shift), ctx);
 115        else
 116                emit(A64_MOVZ(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
 117        shift -= 16;
 118        while (shift >= 0) {
 119                if (((nrm_tmp >> shift) & 0xffff) != (inverse ? 0xffff : 0x0000))
 120                        emit(A64_MOVK(1, reg, (nrm_tmp >> shift) & 0xffff, shift), ctx);
 121                shift -= 16;
 122        }
 123}
 124
 125/*
 126 * Kernel addresses in the vmalloc space use at most 48 bits, and the
 127 * remaining bits are guaranteed to be 0x1. So we can compose the address
 128 * with a fixed length movn/movk/movk sequence.
 129 */
 130static inline void emit_addr_mov_i64(const int reg, const u64 val,
 131                                     struct jit_ctx *ctx)
 132{
 133        u64 tmp = val;
 134        int shift = 0;
 135
 136        emit(A64_MOVN(1, reg, ~tmp & 0xffff, shift), ctx);
 137        while (shift < 32) {
 138                tmp >>= 16;
 139                shift += 16;
 140                emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
 141        }
 142}
 143
 144static inline int bpf2a64_offset(int bpf_to, int bpf_from,
 145                                 const struct jit_ctx *ctx)
 146{
 147        int to = ctx->offset[bpf_to];
 148        /* -1 to account for the Branch instruction */
 149        int from = ctx->offset[bpf_from] - 1;
 150
 151        return to - from;
 152}
 153
 154static void jit_fill_hole(void *area, unsigned int size)
 155{
 156        __le32 *ptr;
 157        /* We are guaranteed to have aligned memory. */
 158        for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
 159                *ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
 160}
 161
 162static inline int epilogue_offset(const struct jit_ctx *ctx)
 163{
 164        int to = ctx->epilogue_offset;
 165        int from = ctx->idx;
 166
 167        return to - from;
 168}
 169
 170static bool is_addsub_imm(u32 imm)
 171{
 172        /* Either imm12 or shifted imm12. */
 173        return !(imm & ~0xfff) || !(imm & ~0xfff000);
 174}
 175
 176/* Stack must be multiples of 16B */
 177#define STACK_ALIGN(sz) (((sz) + 15) & ~15)
 178
 179/* Tail call offset to jump into */
 180#if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL)
 181#define PROLOGUE_OFFSET 8
 182#else
 183#define PROLOGUE_OFFSET 7
 184#endif
 185
 186static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
 187{
 188        const struct bpf_prog *prog = ctx->prog;
 189        const u8 r6 = bpf2a64[BPF_REG_6];
 190        const u8 r7 = bpf2a64[BPF_REG_7];
 191        const u8 r8 = bpf2a64[BPF_REG_8];
 192        const u8 r9 = bpf2a64[BPF_REG_9];
 193        const u8 fp = bpf2a64[BPF_REG_FP];
 194        const u8 tcc = bpf2a64[TCALL_CNT];
 195        const int idx0 = ctx->idx;
 196        int cur_offset;
 197
 198        /*
 199         * BPF prog stack layout
 200         *
 201         *                         high
 202         * original A64_SP =>   0:+-----+ BPF prologue
 203         *                        |FP/LR|
 204         * current A64_FP =>  -16:+-----+
 205         *                        | ... | callee saved registers
 206         * BPF fp register => -64:+-----+ <= (BPF_FP)
 207         *                        |     |
 208         *                        | ... | BPF prog stack
 209         *                        |     |
 210         *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
 211         *                        |RSVD | padding
 212         * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
 213         *                        |     |
 214         *                        | ... | Function call stack
 215         *                        |     |
 216         *                        +-----+
 217         *                          low
 218         *
 219         */
 220
 221        /* BTI landing pad */
 222        if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
 223                emit(A64_BTI_C, ctx);
 224
 225        /* Save FP and LR registers to stay align with ARM64 AAPCS */
 226        emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
 227        emit(A64_MOV(1, A64_FP, A64_SP), ctx);
 228
 229        /* Save callee-saved registers */
 230        emit(A64_PUSH(r6, r7, A64_SP), ctx);
 231        emit(A64_PUSH(r8, r9, A64_SP), ctx);
 232        emit(A64_PUSH(fp, tcc, A64_SP), ctx);
 233
 234        /* Set up BPF prog stack base register */
 235        emit(A64_MOV(1, fp, A64_SP), ctx);
 236
 237        if (!ebpf_from_cbpf) {
 238                /* Initialize tail_call_cnt */
 239                emit(A64_MOVZ(1, tcc, 0, 0), ctx);
 240
 241                cur_offset = ctx->idx - idx0;
 242                if (cur_offset != PROLOGUE_OFFSET) {
 243                        pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
 244                                    cur_offset, PROLOGUE_OFFSET);
 245                        return -1;
 246                }
 247
 248                /* BTI landing pad for the tail call, done with a BR */
 249                if (IS_ENABLED(CONFIG_ARM64_BTI_KERNEL))
 250                        emit(A64_BTI_J, ctx);
 251        }
 252
 253        ctx->stack_size = STACK_ALIGN(prog->aux->stack_depth);
 254
 255        /* Set up function call stack */
 256        emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
 257        return 0;
 258}
 259
 260static int out_offset = -1; /* initialized on the first pass of build_body() */
 261static int emit_bpf_tail_call(struct jit_ctx *ctx)
 262{
 263        /* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
 264        const u8 r2 = bpf2a64[BPF_REG_2];
 265        const u8 r3 = bpf2a64[BPF_REG_3];
 266
 267        const u8 tmp = bpf2a64[TMP_REG_1];
 268        const u8 prg = bpf2a64[TMP_REG_2];
 269        const u8 tcc = bpf2a64[TCALL_CNT];
 270        const int idx0 = ctx->idx;
 271#define cur_offset (ctx->idx - idx0)
 272#define jmp_offset (out_offset - (cur_offset))
 273        size_t off;
 274
 275        /* if (index >= array->map.max_entries)
 276         *     goto out;
 277         */
 278        off = offsetof(struct bpf_array, map.max_entries);
 279        emit_a64_mov_i64(tmp, off, ctx);
 280        emit(A64_LDR32(tmp, r2, tmp), ctx);
 281        emit(A64_MOV(0, r3, r3), ctx);
 282        emit(A64_CMP(0, r3, tmp), ctx);
 283        emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
 284
 285        /* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
 286         *     goto out;
 287         * tail_call_cnt++;
 288         */
 289        emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
 290        emit(A64_CMP(1, tcc, tmp), ctx);
 291        emit(A64_B_(A64_COND_HI, jmp_offset), ctx);
 292        emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
 293
 294        /* prog = array->ptrs[index];
 295         * if (prog == NULL)
 296         *     goto out;
 297         */
 298        off = offsetof(struct bpf_array, ptrs);
 299        emit_a64_mov_i64(tmp, off, ctx);
 300        emit(A64_ADD(1, tmp, r2, tmp), ctx);
 301        emit(A64_LSL(1, prg, r3, 3), ctx);
 302        emit(A64_LDR64(prg, tmp, prg), ctx);
 303        emit(A64_CBZ(1, prg, jmp_offset), ctx);
 304
 305        /* goto *(prog->bpf_func + prologue_offset); */
 306        off = offsetof(struct bpf_prog, bpf_func);
 307        emit_a64_mov_i64(tmp, off, ctx);
 308        emit(A64_LDR64(tmp, prg, tmp), ctx);
 309        emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
 310        emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
 311        emit(A64_BR(tmp), ctx);
 312
 313        /* out: */
 314        if (out_offset == -1)
 315                out_offset = cur_offset;
 316        if (cur_offset != out_offset) {
 317                pr_err_once("tail_call out_offset = %d, expected %d!\n",
 318                            cur_offset, out_offset);
 319                return -1;
 320        }
 321        return 0;
 322#undef cur_offset
 323#undef jmp_offset
 324}
 325
 326static void build_epilogue(struct jit_ctx *ctx)
 327{
 328        const u8 r0 = bpf2a64[BPF_REG_0];
 329        const u8 r6 = bpf2a64[BPF_REG_6];
 330        const u8 r7 = bpf2a64[BPF_REG_7];
 331        const u8 r8 = bpf2a64[BPF_REG_8];
 332        const u8 r9 = bpf2a64[BPF_REG_9];
 333        const u8 fp = bpf2a64[BPF_REG_FP];
 334
 335        /* We're done with BPF stack */
 336        emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
 337
 338        /* Restore fs (x25) and x26 */
 339        emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
 340
 341        /* Restore callee-saved register */
 342        emit(A64_POP(r8, r9, A64_SP), ctx);
 343        emit(A64_POP(r6, r7, A64_SP), ctx);
 344
 345        /* Restore FP/LR registers */
 346        emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
 347
 348        /* Set return value */
 349        emit(A64_MOV(1, A64_R(0), r0), ctx);
 350
 351        emit(A64_RET(A64_LR), ctx);
 352}
 353
 354/* JITs an eBPF instruction.
 355 * Returns:
 356 * 0  - successfully JITed an 8-byte eBPF instruction.
 357 * >0 - successfully JITed a 16-byte eBPF instruction.
 358 * <0 - failed to JIT.
 359 */
 360static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
 361                      bool extra_pass)
 362{
 363        const u8 code = insn->code;
 364        const u8 dst = bpf2a64[insn->dst_reg];
 365        const u8 src = bpf2a64[insn->src_reg];
 366        const u8 tmp = bpf2a64[TMP_REG_1];
 367        const u8 tmp2 = bpf2a64[TMP_REG_2];
 368        const u8 tmp3 = bpf2a64[TMP_REG_3];
 369        const s16 off = insn->off;
 370        const s32 imm = insn->imm;
 371        const int i = insn - ctx->prog->insnsi;
 372        const bool is64 = BPF_CLASS(code) == BPF_ALU64 ||
 373                          BPF_CLASS(code) == BPF_JMP;
 374        const bool isdw = BPF_SIZE(code) == BPF_DW;
 375        u8 jmp_cond, reg;
 376        s32 jmp_offset;
 377        u32 a64_insn;
 378
 379#define check_imm(bits, imm) do {                               \
 380        if ((((imm) > 0) && ((imm) >> (bits))) ||               \
 381            (((imm) < 0) && (~(imm) >> (bits)))) {              \
 382                pr_info("[%2d] imm=%d(0x%x) out of range\n",    \
 383                        i, imm, imm);                           \
 384                return -EINVAL;                                 \
 385        }                                                       \
 386} while (0)
 387#define check_imm19(imm) check_imm(19, imm)
 388#define check_imm26(imm) check_imm(26, imm)
 389
 390        switch (code) {
 391        /* dst = src */
 392        case BPF_ALU | BPF_MOV | BPF_X:
 393        case BPF_ALU64 | BPF_MOV | BPF_X:
 394                emit(A64_MOV(is64, dst, src), ctx);
 395                break;
 396        /* dst = dst OP src */
 397        case BPF_ALU | BPF_ADD | BPF_X:
 398        case BPF_ALU64 | BPF_ADD | BPF_X:
 399                emit(A64_ADD(is64, dst, dst, src), ctx);
 400                break;
 401        case BPF_ALU | BPF_SUB | BPF_X:
 402        case BPF_ALU64 | BPF_SUB | BPF_X:
 403                emit(A64_SUB(is64, dst, dst, src), ctx);
 404                break;
 405        case BPF_ALU | BPF_AND | BPF_X:
 406        case BPF_ALU64 | BPF_AND | BPF_X:
 407                emit(A64_AND(is64, dst, dst, src), ctx);
 408                break;
 409        case BPF_ALU | BPF_OR | BPF_X:
 410        case BPF_ALU64 | BPF_OR | BPF_X:
 411                emit(A64_ORR(is64, dst, dst, src), ctx);
 412                break;
 413        case BPF_ALU | BPF_XOR | BPF_X:
 414        case BPF_ALU64 | BPF_XOR | BPF_X:
 415                emit(A64_EOR(is64, dst, dst, src), ctx);
 416                break;
 417        case BPF_ALU | BPF_MUL | BPF_X:
 418        case BPF_ALU64 | BPF_MUL | BPF_X:
 419                emit(A64_MUL(is64, dst, dst, src), ctx);
 420                break;
 421        case BPF_ALU | BPF_DIV | BPF_X:
 422        case BPF_ALU64 | BPF_DIV | BPF_X:
 423        case BPF_ALU | BPF_MOD | BPF_X:
 424        case BPF_ALU64 | BPF_MOD | BPF_X:
 425                switch (BPF_OP(code)) {
 426                case BPF_DIV:
 427                        emit(A64_UDIV(is64, dst, dst, src), ctx);
 428                        break;
 429                case BPF_MOD:
 430                        emit(A64_UDIV(is64, tmp, dst, src), ctx);
 431                        emit(A64_MSUB(is64, dst, dst, tmp, src), ctx);
 432                        break;
 433                }
 434                break;
 435        case BPF_ALU | BPF_LSH | BPF_X:
 436        case BPF_ALU64 | BPF_LSH | BPF_X:
 437                emit(A64_LSLV(is64, dst, dst, src), ctx);
 438                break;
 439        case BPF_ALU | BPF_RSH | BPF_X:
 440        case BPF_ALU64 | BPF_RSH | BPF_X:
 441                emit(A64_LSRV(is64, dst, dst, src), ctx);
 442                break;
 443        case BPF_ALU | BPF_ARSH | BPF_X:
 444        case BPF_ALU64 | BPF_ARSH | BPF_X:
 445                emit(A64_ASRV(is64, dst, dst, src), ctx);
 446                break;
 447        /* dst = -dst */
 448        case BPF_ALU | BPF_NEG:
 449        case BPF_ALU64 | BPF_NEG:
 450                emit(A64_NEG(is64, dst, dst), ctx);
 451                break;
 452        /* dst = BSWAP##imm(dst) */
 453        case BPF_ALU | BPF_END | BPF_FROM_LE:
 454        case BPF_ALU | BPF_END | BPF_FROM_BE:
 455#ifdef CONFIG_CPU_BIG_ENDIAN
 456                if (BPF_SRC(code) == BPF_FROM_BE)
 457                        goto emit_bswap_uxt;
 458#else /* !CONFIG_CPU_BIG_ENDIAN */
 459                if (BPF_SRC(code) == BPF_FROM_LE)
 460                        goto emit_bswap_uxt;
 461#endif
 462                switch (imm) {
 463                case 16:
 464                        emit(A64_REV16(is64, dst, dst), ctx);
 465                        /* zero-extend 16 bits into 64 bits */
 466                        emit(A64_UXTH(is64, dst, dst), ctx);
 467                        break;
 468                case 32:
 469                        emit(A64_REV32(is64, dst, dst), ctx);
 470                        /* upper 32 bits already cleared */
 471                        break;
 472                case 64:
 473                        emit(A64_REV64(dst, dst), ctx);
 474                        break;
 475                }
 476                break;
 477emit_bswap_uxt:
 478                switch (imm) {
 479                case 16:
 480                        /* zero-extend 16 bits into 64 bits */
 481                        emit(A64_UXTH(is64, dst, dst), ctx);
 482                        break;
 483                case 32:
 484                        /* zero-extend 32 bits into 64 bits */
 485                        emit(A64_UXTW(is64, dst, dst), ctx);
 486                        break;
 487                case 64:
 488                        /* nop */
 489                        break;
 490                }
 491                break;
 492        /* dst = imm */
 493        case BPF_ALU | BPF_MOV | BPF_K:
 494        case BPF_ALU64 | BPF_MOV | BPF_K:
 495                emit_a64_mov_i(is64, dst, imm, ctx);
 496                break;
 497        /* dst = dst OP imm */
 498        case BPF_ALU | BPF_ADD | BPF_K:
 499        case BPF_ALU64 | BPF_ADD | BPF_K:
 500                if (is_addsub_imm(imm)) {
 501                        emit(A64_ADD_I(is64, dst, dst, imm), ctx);
 502                } else if (is_addsub_imm(-imm)) {
 503                        emit(A64_SUB_I(is64, dst, dst, -imm), ctx);
 504                } else {
 505                        emit_a64_mov_i(is64, tmp, imm, ctx);
 506                        emit(A64_ADD(is64, dst, dst, tmp), ctx);
 507                }
 508                break;
 509        case BPF_ALU | BPF_SUB | BPF_K:
 510        case BPF_ALU64 | BPF_SUB | BPF_K:
 511                if (is_addsub_imm(imm)) {
 512                        emit(A64_SUB_I(is64, dst, dst, imm), ctx);
 513                } else if (is_addsub_imm(-imm)) {
 514                        emit(A64_ADD_I(is64, dst, dst, -imm), ctx);
 515                } else {
 516                        emit_a64_mov_i(is64, tmp, imm, ctx);
 517                        emit(A64_SUB(is64, dst, dst, tmp), ctx);
 518                }
 519                break;
 520        case BPF_ALU | BPF_AND | BPF_K:
 521        case BPF_ALU64 | BPF_AND | BPF_K:
 522                a64_insn = A64_AND_I(is64, dst, dst, imm);
 523                if (a64_insn != AARCH64_BREAK_FAULT) {
 524                        emit(a64_insn, ctx);
 525                } else {
 526                        emit_a64_mov_i(is64, tmp, imm, ctx);
 527                        emit(A64_AND(is64, dst, dst, tmp), ctx);
 528                }
 529                break;
 530        case BPF_ALU | BPF_OR | BPF_K:
 531        case BPF_ALU64 | BPF_OR | BPF_K:
 532                a64_insn = A64_ORR_I(is64, dst, dst, imm);
 533                if (a64_insn != AARCH64_BREAK_FAULT) {
 534                        emit(a64_insn, ctx);
 535                } else {
 536                        emit_a64_mov_i(is64, tmp, imm, ctx);
 537                        emit(A64_ORR(is64, dst, dst, tmp), ctx);
 538                }
 539                break;
 540        case BPF_ALU | BPF_XOR | BPF_K:
 541        case BPF_ALU64 | BPF_XOR | BPF_K:
 542                a64_insn = A64_EOR_I(is64, dst, dst, imm);
 543                if (a64_insn != AARCH64_BREAK_FAULT) {
 544                        emit(a64_insn, ctx);
 545                } else {
 546                        emit_a64_mov_i(is64, tmp, imm, ctx);
 547                        emit(A64_EOR(is64, dst, dst, tmp), ctx);
 548                }
 549                break;
 550        case BPF_ALU | BPF_MUL | BPF_K:
 551        case BPF_ALU64 | BPF_MUL | BPF_K:
 552                emit_a64_mov_i(is64, tmp, imm, ctx);
 553                emit(A64_MUL(is64, dst, dst, tmp), ctx);
 554                break;
 555        case BPF_ALU | BPF_DIV | BPF_K:
 556        case BPF_ALU64 | BPF_DIV | BPF_K:
 557                emit_a64_mov_i(is64, tmp, imm, ctx);
 558                emit(A64_UDIV(is64, dst, dst, tmp), ctx);
 559                break;
 560        case BPF_ALU | BPF_MOD | BPF_K:
 561        case BPF_ALU64 | BPF_MOD | BPF_K:
 562                emit_a64_mov_i(is64, tmp2, imm, ctx);
 563                emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
 564                emit(A64_MSUB(is64, dst, dst, tmp, tmp2), ctx);
 565                break;
 566        case BPF_ALU | BPF_LSH | BPF_K:
 567        case BPF_ALU64 | BPF_LSH | BPF_K:
 568                emit(A64_LSL(is64, dst, dst, imm), ctx);
 569                break;
 570        case BPF_ALU | BPF_RSH | BPF_K:
 571        case BPF_ALU64 | BPF_RSH | BPF_K:
 572                emit(A64_LSR(is64, dst, dst, imm), ctx);
 573                break;
 574        case BPF_ALU | BPF_ARSH | BPF_K:
 575        case BPF_ALU64 | BPF_ARSH | BPF_K:
 576                emit(A64_ASR(is64, dst, dst, imm), ctx);
 577                break;
 578
 579        /* JUMP off */
 580        case BPF_JMP | BPF_JA:
 581                jmp_offset = bpf2a64_offset(i + off, i, ctx);
 582                check_imm26(jmp_offset);
 583                emit(A64_B(jmp_offset), ctx);
 584                break;
 585        /* IF (dst COND src) JUMP off */
 586        case BPF_JMP | BPF_JEQ | BPF_X:
 587        case BPF_JMP | BPF_JGT | BPF_X:
 588        case BPF_JMP | BPF_JLT | BPF_X:
 589        case BPF_JMP | BPF_JGE | BPF_X:
 590        case BPF_JMP | BPF_JLE | BPF_X:
 591        case BPF_JMP | BPF_JNE | BPF_X:
 592        case BPF_JMP | BPF_JSGT | BPF_X:
 593        case BPF_JMP | BPF_JSLT | BPF_X:
 594        case BPF_JMP | BPF_JSGE | BPF_X:
 595        case BPF_JMP | BPF_JSLE | BPF_X:
 596        case BPF_JMP32 | BPF_JEQ | BPF_X:
 597        case BPF_JMP32 | BPF_JGT | BPF_X:
 598        case BPF_JMP32 | BPF_JLT | BPF_X:
 599        case BPF_JMP32 | BPF_JGE | BPF_X:
 600        case BPF_JMP32 | BPF_JLE | BPF_X:
 601        case BPF_JMP32 | BPF_JNE | BPF_X:
 602        case BPF_JMP32 | BPF_JSGT | BPF_X:
 603        case BPF_JMP32 | BPF_JSLT | BPF_X:
 604        case BPF_JMP32 | BPF_JSGE | BPF_X:
 605        case BPF_JMP32 | BPF_JSLE | BPF_X:
 606                emit(A64_CMP(is64, dst, src), ctx);
 607emit_cond_jmp:
 608                jmp_offset = bpf2a64_offset(i + off, i, ctx);
 609                check_imm19(jmp_offset);
 610                switch (BPF_OP(code)) {
 611                case BPF_JEQ:
 612                        jmp_cond = A64_COND_EQ;
 613                        break;
 614                case BPF_JGT:
 615                        jmp_cond = A64_COND_HI;
 616                        break;
 617                case BPF_JLT:
 618                        jmp_cond = A64_COND_CC;
 619                        break;
 620                case BPF_JGE:
 621                        jmp_cond = A64_COND_CS;
 622                        break;
 623                case BPF_JLE:
 624                        jmp_cond = A64_COND_LS;
 625                        break;
 626                case BPF_JSET:
 627                case BPF_JNE:
 628                        jmp_cond = A64_COND_NE;
 629                        break;
 630                case BPF_JSGT:
 631                        jmp_cond = A64_COND_GT;
 632                        break;
 633                case BPF_JSLT:
 634                        jmp_cond = A64_COND_LT;
 635                        break;
 636                case BPF_JSGE:
 637                        jmp_cond = A64_COND_GE;
 638                        break;
 639                case BPF_JSLE:
 640                        jmp_cond = A64_COND_LE;
 641                        break;
 642                default:
 643                        return -EFAULT;
 644                }
 645                emit(A64_B_(jmp_cond, jmp_offset), ctx);
 646                break;
 647        case BPF_JMP | BPF_JSET | BPF_X:
 648        case BPF_JMP32 | BPF_JSET | BPF_X:
 649                emit(A64_TST(is64, dst, src), ctx);
 650                goto emit_cond_jmp;
 651        /* IF (dst COND imm) JUMP off */
 652        case BPF_JMP | BPF_JEQ | BPF_K:
 653        case BPF_JMP | BPF_JGT | BPF_K:
 654        case BPF_JMP | BPF_JLT | BPF_K:
 655        case BPF_JMP | BPF_JGE | BPF_K:
 656        case BPF_JMP | BPF_JLE | BPF_K:
 657        case BPF_JMP | BPF_JNE | BPF_K:
 658        case BPF_JMP | BPF_JSGT | BPF_K:
 659        case BPF_JMP | BPF_JSLT | BPF_K:
 660        case BPF_JMP | BPF_JSGE | BPF_K:
 661        case BPF_JMP | BPF_JSLE | BPF_K:
 662        case BPF_JMP32 | BPF_JEQ | BPF_K:
 663        case BPF_JMP32 | BPF_JGT | BPF_K:
 664        case BPF_JMP32 | BPF_JLT | BPF_K:
 665        case BPF_JMP32 | BPF_JGE | BPF_K:
 666        case BPF_JMP32 | BPF_JLE | BPF_K:
 667        case BPF_JMP32 | BPF_JNE | BPF_K:
 668        case BPF_JMP32 | BPF_JSGT | BPF_K:
 669        case BPF_JMP32 | BPF_JSLT | BPF_K:
 670        case BPF_JMP32 | BPF_JSGE | BPF_K:
 671        case BPF_JMP32 | BPF_JSLE | BPF_K:
 672                if (is_addsub_imm(imm)) {
 673                        emit(A64_CMP_I(is64, dst, imm), ctx);
 674                } else if (is_addsub_imm(-imm)) {
 675                        emit(A64_CMN_I(is64, dst, -imm), ctx);
 676                } else {
 677                        emit_a64_mov_i(is64, tmp, imm, ctx);
 678                        emit(A64_CMP(is64, dst, tmp), ctx);
 679                }
 680                goto emit_cond_jmp;
 681        case BPF_JMP | BPF_JSET | BPF_K:
 682        case BPF_JMP32 | BPF_JSET | BPF_K:
 683                a64_insn = A64_TST_I(is64, dst, imm);
 684                if (a64_insn != AARCH64_BREAK_FAULT) {
 685                        emit(a64_insn, ctx);
 686                } else {
 687                        emit_a64_mov_i(is64, tmp, imm, ctx);
 688                        emit(A64_TST(is64, dst, tmp), ctx);
 689                }
 690                goto emit_cond_jmp;
 691        /* function call */
 692        case BPF_JMP | BPF_CALL:
 693        {
 694                const u8 r0 = bpf2a64[BPF_REG_0];
 695                bool func_addr_fixed;
 696                u64 func_addr;
 697                int ret;
 698
 699                ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
 700                                            &func_addr, &func_addr_fixed);
 701                if (ret < 0)
 702                        return ret;
 703                emit_addr_mov_i64(tmp, func_addr, ctx);
 704                emit(A64_BLR(tmp), ctx);
 705                emit(A64_MOV(1, r0, A64_R(0)), ctx);
 706                break;
 707        }
 708        /* tail call */
 709        case BPF_JMP | BPF_TAIL_CALL:
 710                if (emit_bpf_tail_call(ctx))
 711                        return -EFAULT;
 712                break;
 713        /* function return */
 714        case BPF_JMP | BPF_EXIT:
 715                /* Optimization: when last instruction is EXIT,
 716                   simply fallthrough to epilogue. */
 717                if (i == ctx->prog->len - 1)
 718                        break;
 719                jmp_offset = epilogue_offset(ctx);
 720                check_imm26(jmp_offset);
 721                emit(A64_B(jmp_offset), ctx);
 722                break;
 723
 724        /* dst = imm64 */
 725        case BPF_LD | BPF_IMM | BPF_DW:
 726        {
 727                const struct bpf_insn insn1 = insn[1];
 728                u64 imm64;
 729
 730                imm64 = (u64)insn1.imm << 32 | (u32)imm;
 731                emit_a64_mov_i64(dst, imm64, ctx);
 732
 733                return 1;
 734        }
 735
 736        /* LDX: dst = *(size *)(src + off) */
 737        case BPF_LDX | BPF_MEM | BPF_W:
 738        case BPF_LDX | BPF_MEM | BPF_H:
 739        case BPF_LDX | BPF_MEM | BPF_B:
 740        case BPF_LDX | BPF_MEM | BPF_DW:
 741                emit_a64_mov_i(1, tmp, off, ctx);
 742                switch (BPF_SIZE(code)) {
 743                case BPF_W:
 744                        emit(A64_LDR32(dst, src, tmp), ctx);
 745                        break;
 746                case BPF_H:
 747                        emit(A64_LDRH(dst, src, tmp), ctx);
 748                        break;
 749                case BPF_B:
 750                        emit(A64_LDRB(dst, src, tmp), ctx);
 751                        break;
 752                case BPF_DW:
 753                        emit(A64_LDR64(dst, src, tmp), ctx);
 754                        break;
 755                }
 756                break;
 757
 758        /* ST: *(size *)(dst + off) = imm */
 759        case BPF_ST | BPF_MEM | BPF_W:
 760        case BPF_ST | BPF_MEM | BPF_H:
 761        case BPF_ST | BPF_MEM | BPF_B:
 762        case BPF_ST | BPF_MEM | BPF_DW:
 763                /* Load imm to a register then store it */
 764                emit_a64_mov_i(1, tmp2, off, ctx);
 765                emit_a64_mov_i(1, tmp, imm, ctx);
 766                switch (BPF_SIZE(code)) {
 767                case BPF_W:
 768                        emit(A64_STR32(tmp, dst, tmp2), ctx);
 769                        break;
 770                case BPF_H:
 771                        emit(A64_STRH(tmp, dst, tmp2), ctx);
 772                        break;
 773                case BPF_B:
 774                        emit(A64_STRB(tmp, dst, tmp2), ctx);
 775                        break;
 776                case BPF_DW:
 777                        emit(A64_STR64(tmp, dst, tmp2), ctx);
 778                        break;
 779                }
 780                break;
 781
 782        /* STX: *(size *)(dst + off) = src */
 783        case BPF_STX | BPF_MEM | BPF_W:
 784        case BPF_STX | BPF_MEM | BPF_H:
 785        case BPF_STX | BPF_MEM | BPF_B:
 786        case BPF_STX | BPF_MEM | BPF_DW:
 787                emit_a64_mov_i(1, tmp, off, ctx);
 788                switch (BPF_SIZE(code)) {
 789                case BPF_W:
 790                        emit(A64_STR32(src, dst, tmp), ctx);
 791                        break;
 792                case BPF_H:
 793                        emit(A64_STRH(src, dst, tmp), ctx);
 794                        break;
 795                case BPF_B:
 796                        emit(A64_STRB(src, dst, tmp), ctx);
 797                        break;
 798                case BPF_DW:
 799                        emit(A64_STR64(src, dst, tmp), ctx);
 800                        break;
 801                }
 802                break;
 803
 804        /* STX XADD: lock *(u32 *)(dst + off) += src */
 805        case BPF_STX | BPF_XADD | BPF_W:
 806        /* STX XADD: lock *(u64 *)(dst + off) += src */
 807        case BPF_STX | BPF_XADD | BPF_DW:
 808                if (!off) {
 809                        reg = dst;
 810                } else {
 811                        emit_a64_mov_i(1, tmp, off, ctx);
 812                        emit(A64_ADD(1, tmp, tmp, dst), ctx);
 813                        reg = tmp;
 814                }
 815                if (cpus_have_cap(ARM64_HAS_LSE_ATOMICS)) {
 816                        emit(A64_STADD(isdw, reg, src), ctx);
 817                } else {
 818                        emit(A64_LDXR(isdw, tmp2, reg), ctx);
 819                        emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
 820                        emit(A64_STXR(isdw, tmp2, reg, tmp3), ctx);
 821                        jmp_offset = -3;
 822                        check_imm19(jmp_offset);
 823                        emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
 824                }
 825                break;
 826
 827        default:
 828                pr_err_once("unknown opcode %02x\n", code);
 829                return -EINVAL;
 830        }
 831
 832        return 0;
 833}
 834
 835static int build_body(struct jit_ctx *ctx, bool extra_pass)
 836{
 837        const struct bpf_prog *prog = ctx->prog;
 838        int i;
 839
 840        for (i = 0; i < prog->len; i++) {
 841                const struct bpf_insn *insn = &prog->insnsi[i];
 842                int ret;
 843
 844                ret = build_insn(insn, ctx, extra_pass);
 845                if (ret > 0) {
 846                        i++;
 847                        if (ctx->image == NULL)
 848                                ctx->offset[i] = ctx->idx;
 849                        continue;
 850                }
 851                if (ctx->image == NULL)
 852                        ctx->offset[i] = ctx->idx;
 853                if (ret)
 854                        return ret;
 855        }
 856
 857        return 0;
 858}
 859
 860static int validate_code(struct jit_ctx *ctx)
 861{
 862        int i;
 863
 864        for (i = 0; i < ctx->idx; i++) {
 865                u32 a64_insn = le32_to_cpu(ctx->image[i]);
 866
 867                if (a64_insn == AARCH64_BREAK_FAULT)
 868                        return -1;
 869        }
 870
 871        return 0;
 872}
 873
 874static inline void bpf_flush_icache(void *start, void *end)
 875{
 876        flush_icache_range((unsigned long)start, (unsigned long)end);
 877}
 878
 879struct arm64_jit_data {
 880        struct bpf_binary_header *header;
 881        u8 *image;
 882        struct jit_ctx ctx;
 883};
 884
 885struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 886{
 887        struct bpf_prog *tmp, *orig_prog = prog;
 888        struct bpf_binary_header *header;
 889        struct arm64_jit_data *jit_data;
 890        bool was_classic = bpf_prog_was_classic(prog);
 891        bool tmp_blinded = false;
 892        bool extra_pass = false;
 893        struct jit_ctx ctx;
 894        int image_size;
 895        u8 *image_ptr;
 896
 897        if (!prog->jit_requested)
 898                return orig_prog;
 899
 900        tmp = bpf_jit_blind_constants(prog);
 901        /* If blinding was requested and we failed during blinding,
 902         * we must fall back to the interpreter.
 903         */
 904        if (IS_ERR(tmp))
 905                return orig_prog;
 906        if (tmp != prog) {
 907                tmp_blinded = true;
 908                prog = tmp;
 909        }
 910
 911        jit_data = prog->aux->jit_data;
 912        if (!jit_data) {
 913                jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
 914                if (!jit_data) {
 915                        prog = orig_prog;
 916                        goto out;
 917                }
 918                prog->aux->jit_data = jit_data;
 919        }
 920        if (jit_data->ctx.offset) {
 921                ctx = jit_data->ctx;
 922                image_ptr = jit_data->image;
 923                header = jit_data->header;
 924                extra_pass = true;
 925                image_size = sizeof(u32) * ctx.idx;
 926                goto skip_init_ctx;
 927        }
 928        memset(&ctx, 0, sizeof(ctx));
 929        ctx.prog = prog;
 930
 931        ctx.offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
 932        if (ctx.offset == NULL) {
 933                prog = orig_prog;
 934                goto out_off;
 935        }
 936
 937        /* 1. Initial fake pass to compute ctx->idx. */
 938
 939        /* Fake pass to fill in ctx->offset. */
 940        if (build_body(&ctx, extra_pass)) {
 941                prog = orig_prog;
 942                goto out_off;
 943        }
 944
 945        if (build_prologue(&ctx, was_classic)) {
 946                prog = orig_prog;
 947                goto out_off;
 948        }
 949
 950        ctx.epilogue_offset = ctx.idx;
 951        build_epilogue(&ctx);
 952
 953        /* Now we know the actual image size. */
 954        image_size = sizeof(u32) * ctx.idx;
 955        header = bpf_jit_binary_alloc(image_size, &image_ptr,
 956                                      sizeof(u32), jit_fill_hole);
 957        if (header == NULL) {
 958                prog = orig_prog;
 959                goto out_off;
 960        }
 961
 962        /* 2. Now, the actual pass. */
 963
 964        ctx.image = (__le32 *)image_ptr;
 965skip_init_ctx:
 966        ctx.idx = 0;
 967
 968        build_prologue(&ctx, was_classic);
 969
 970        if (build_body(&ctx, extra_pass)) {
 971                bpf_jit_binary_free(header);
 972                prog = orig_prog;
 973                goto out_off;
 974        }
 975
 976        build_epilogue(&ctx);
 977
 978        /* 3. Extra pass to validate JITed code. */
 979        if (validate_code(&ctx)) {
 980                bpf_jit_binary_free(header);
 981                prog = orig_prog;
 982                goto out_off;
 983        }
 984
 985        /* And we're done. */
 986        if (bpf_jit_enable > 1)
 987                bpf_jit_dump(prog->len, image_size, 2, ctx.image);
 988
 989        bpf_flush_icache(header, ctx.image + ctx.idx);
 990
 991        if (!prog->is_func || extra_pass) {
 992                if (extra_pass && ctx.idx != jit_data->ctx.idx) {
 993                        pr_err_once("multi-func JIT bug %d != %d\n",
 994                                    ctx.idx, jit_data->ctx.idx);
 995                        bpf_jit_binary_free(header);
 996                        prog->bpf_func = NULL;
 997                        prog->jited = 0;
 998                        goto out_off;
 999                }
1000                bpf_jit_binary_lock_ro(header);
1001        } else {
1002                jit_data->ctx = ctx;
1003                jit_data->image = image_ptr;
1004                jit_data->header = header;
1005        }
1006        prog->bpf_func = (void *)ctx.image;
1007        prog->jited = 1;
1008        prog->jited_len = image_size;
1009
1010        if (!prog->is_func || extra_pass) {
1011                bpf_prog_fill_jited_linfo(prog, ctx.offset);
1012out_off:
1013                kfree(ctx.offset);
1014                kfree(jit_data);
1015                prog->aux->jit_data = NULL;
1016        }
1017out:
1018        if (tmp_blinded)
1019                bpf_jit_prog_release_other(prog, prog == orig_prog ?
1020                                           tmp : orig_prog);
1021        return prog;
1022}
1023
1024void *bpf_jit_alloc_exec(unsigned long size)
1025{
1026        return __vmalloc_node_range(size, PAGE_SIZE, BPF_JIT_REGION_START,
1027                                    BPF_JIT_REGION_END, GFP_KERNEL,
1028                                    PAGE_KERNEL, 0, NUMA_NO_NODE,
1029                                    __builtin_return_address(0));
1030}
1031
1032void bpf_jit_free_exec(void *addr)
1033{
1034        return vfree(addr);
1035}
1036