linux/arch/arm64/net/bpf_jit_comp.c
<<
>>
Prefs
   1/*
   2 * BPF JIT compiler for ARM64
   3 *
   4 * Copyright (C) 2014-2016 Zi Shen Lim <zlim.lnx@gmail.com>
   5 *
   6 * This program is free software; you can redistribute it and/or modify
   7 * it under the terms of the GNU General Public License version 2 as
   8 * published by the Free Software Foundation.
   9 *
  10 * This program is distributed in the hope that it will be useful,
  11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
  12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  13 * GNU General Public License for more details.
  14 *
  15 * You should have received a copy of the GNU General Public License
  16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
  17 */
  18
  19#define pr_fmt(fmt) "bpf_jit: " fmt
  20
  21#include <linux/bpf.h>
  22#include <linux/filter.h>
  23#include <linux/printk.h>
  24#include <linux/skbuff.h>
  25#include <linux/slab.h>
  26
  27#include <asm/byteorder.h>
  28#include <asm/cacheflush.h>
  29#include <asm/debug-monitors.h>
  30#include <asm/set_memory.h>
  31
  32#include "bpf_jit.h"
  33
  34#define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
  35#define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
  36#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
  37#define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
  38
  39/* Map BPF registers to A64 registers */
  40static const int bpf2a64[] = {
  41        /* return value from in-kernel function, and exit value from eBPF */
  42        [BPF_REG_0] = A64_R(7),
  43        /* arguments from eBPF program to in-kernel function */
  44        [BPF_REG_1] = A64_R(0),
  45        [BPF_REG_2] = A64_R(1),
  46        [BPF_REG_3] = A64_R(2),
  47        [BPF_REG_4] = A64_R(3),
  48        [BPF_REG_5] = A64_R(4),
  49        /* callee saved registers that in-kernel function will preserve */
  50        [BPF_REG_6] = A64_R(19),
  51        [BPF_REG_7] = A64_R(20),
  52        [BPF_REG_8] = A64_R(21),
  53        [BPF_REG_9] = A64_R(22),
  54        /* read-only frame pointer to access stack */
  55        [BPF_REG_FP] = A64_R(25),
  56        /* temporary registers for internal BPF JIT */
  57        [TMP_REG_1] = A64_R(10),
  58        [TMP_REG_2] = A64_R(11),
  59        [TMP_REG_3] = A64_R(12),
  60        /* tail_call_cnt */
  61        [TCALL_CNT] = A64_R(26),
  62        /* temporary register for blinding constants */
  63        [BPF_REG_AX] = A64_R(9),
  64};
  65
  66struct jit_ctx {
  67        const struct bpf_prog *prog;
  68        int idx;
  69        int epilogue_offset;
  70        int *offset;
  71        __le32 *image;
  72        u32 stack_size;
  73};
  74
  75static inline void emit(const u32 insn, struct jit_ctx *ctx)
  76{
  77        if (ctx->image != NULL)
  78                ctx->image[ctx->idx] = cpu_to_le32(insn);
  79
  80        ctx->idx++;
  81}
  82
  83static inline void emit_a64_mov_i64(const int reg, const u64 val,
  84                                    struct jit_ctx *ctx)
  85{
  86        u64 tmp = val;
  87        int shift = 0;
  88
  89        emit(A64_MOVZ(1, reg, tmp & 0xffff, shift), ctx);
  90        tmp >>= 16;
  91        shift += 16;
  92        while (tmp) {
  93                if (tmp & 0xffff)
  94                        emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
  95                tmp >>= 16;
  96                shift += 16;
  97        }
  98}
  99
 100static inline void emit_addr_mov_i64(const int reg, const u64 val,
 101                                     struct jit_ctx *ctx)
 102{
 103        u64 tmp = val;
 104        int shift = 0;
 105
 106        emit(A64_MOVZ(1, reg, tmp & 0xffff, shift), ctx);
 107        for (;shift < 48;) {
 108                tmp >>= 16;
 109                shift += 16;
 110                emit(A64_MOVK(1, reg, tmp & 0xffff, shift), ctx);
 111        }
 112}
 113
 114static inline void emit_a64_mov_i(const int is64, const int reg,
 115                                  const s32 val, struct jit_ctx *ctx)
 116{
 117        u16 hi = val >> 16;
 118        u16 lo = val & 0xffff;
 119
 120        if (hi & 0x8000) {
 121                if (hi == 0xffff) {
 122                        emit(A64_MOVN(is64, reg, (u16)~lo, 0), ctx);
 123                } else {
 124                        emit(A64_MOVN(is64, reg, (u16)~hi, 16), ctx);
 125                        emit(A64_MOVK(is64, reg, lo, 0), ctx);
 126                }
 127        } else {
 128                emit(A64_MOVZ(is64, reg, lo, 0), ctx);
 129                if (hi)
 130                        emit(A64_MOVK(is64, reg, hi, 16), ctx);
 131        }
 132}
 133
 134static inline int bpf2a64_offset(int bpf_to, int bpf_from,
 135                                 const struct jit_ctx *ctx)
 136{
 137        int to = ctx->offset[bpf_to];
 138        /* -1 to account for the Branch instruction */
 139        int from = ctx->offset[bpf_from] - 1;
 140
 141        return to - from;
 142}
 143
 144static void jit_fill_hole(void *area, unsigned int size)
 145{
 146        __le32 *ptr;
 147        /* We are guaranteed to have aligned memory. */
 148        for (ptr = area; size >= sizeof(u32); size -= sizeof(u32))
 149                *ptr++ = cpu_to_le32(AARCH64_BREAK_FAULT);
 150}
 151
 152static inline int epilogue_offset(const struct jit_ctx *ctx)
 153{
 154        int to = ctx->epilogue_offset;
 155        int from = ctx->idx;
 156
 157        return to - from;
 158}
 159
 160/* Stack must be multiples of 16B */
 161#define STACK_ALIGN(sz) (((sz) + 15) & ~15)
 162
 163/* Tail call offset to jump into */
 164#define PROLOGUE_OFFSET 7
 165
 166static int build_prologue(struct jit_ctx *ctx)
 167{
 168        const struct bpf_prog *prog = ctx->prog;
 169        const u8 r6 = bpf2a64[BPF_REG_6];
 170        const u8 r7 = bpf2a64[BPF_REG_7];
 171        const u8 r8 = bpf2a64[BPF_REG_8];
 172        const u8 r9 = bpf2a64[BPF_REG_9];
 173        const u8 fp = bpf2a64[BPF_REG_FP];
 174        const u8 tcc = bpf2a64[TCALL_CNT];
 175        const int idx0 = ctx->idx;
 176        int cur_offset;
 177
 178        /*
 179         * BPF prog stack layout
 180         *
 181         *                         high
 182         * original A64_SP =>   0:+-----+ BPF prologue
 183         *                        |FP/LR|
 184         * current A64_FP =>  -16:+-----+
 185         *                        | ... | callee saved registers
 186         * BPF fp register => -64:+-----+ <= (BPF_FP)
 187         *                        |     |
 188         *                        | ... | BPF prog stack
 189         *                        |     |
 190         *                        +-----+ <= (BPF_FP - prog->aux->stack_depth)
 191         *                        |RSVD | JIT scratchpad
 192         * current A64_SP =>      +-----+ <= (BPF_FP - ctx->stack_size)
 193         *                        |     |
 194         *                        | ... | Function call stack
 195         *                        |     |
 196         *                        +-----+
 197         *                          low
 198         *
 199         */
 200
 201        /* Save FP and LR registers to stay align with ARM64 AAPCS */
 202        emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
 203        emit(A64_MOV(1, A64_FP, A64_SP), ctx);
 204
 205        /* Save callee-saved registers */
 206        emit(A64_PUSH(r6, r7, A64_SP), ctx);
 207        emit(A64_PUSH(r8, r9, A64_SP), ctx);
 208        emit(A64_PUSH(fp, tcc, A64_SP), ctx);
 209
 210        /* Set up BPF prog stack base register */
 211        emit(A64_MOV(1, fp, A64_SP), ctx);
 212
 213        /* Initialize tail_call_cnt */
 214        emit(A64_MOVZ(1, tcc, 0, 0), ctx);
 215
 216        cur_offset = ctx->idx - idx0;
 217        if (cur_offset != PROLOGUE_OFFSET) {
 218                pr_err_once("PROLOGUE_OFFSET = %d, expected %d!\n",
 219                            cur_offset, PROLOGUE_OFFSET);
 220                return -1;
 221        }
 222
 223        /* 4 byte extra for skb_copy_bits buffer */
 224        ctx->stack_size = prog->aux->stack_depth + 4;
 225        ctx->stack_size = STACK_ALIGN(ctx->stack_size);
 226
 227        /* Set up function call stack */
 228        emit(A64_SUB_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
 229        return 0;
 230}
 231
 232static int out_offset = -1; /* initialized on the first pass of build_body() */
 233static int emit_bpf_tail_call(struct jit_ctx *ctx)
 234{
 235        /* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
 236        const u8 r2 = bpf2a64[BPF_REG_2];
 237        const u8 r3 = bpf2a64[BPF_REG_3];
 238
 239        const u8 tmp = bpf2a64[TMP_REG_1];
 240        const u8 prg = bpf2a64[TMP_REG_2];
 241        const u8 tcc = bpf2a64[TCALL_CNT];
 242        const int idx0 = ctx->idx;
 243#define cur_offset (ctx->idx - idx0)
 244#define jmp_offset (out_offset - (cur_offset))
 245        size_t off;
 246
 247        /* if (index >= array->map.max_entries)
 248         *     goto out;
 249         */
 250        off = offsetof(struct bpf_array, map.max_entries);
 251        emit_a64_mov_i64(tmp, off, ctx);
 252        emit(A64_LDR32(tmp, r2, tmp), ctx);
 253        emit(A64_MOV(0, r3, r3), ctx);
 254        emit(A64_CMP(0, r3, tmp), ctx);
 255        emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
 256
 257        /* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
 258         *     goto out;
 259         * tail_call_cnt++;
 260         */
 261        emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
 262        emit(A64_CMP(1, tcc, tmp), ctx);
 263        emit(A64_B_(A64_COND_HI, jmp_offset), ctx);
 264        emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
 265
 266        /* prog = array->ptrs[index];
 267         * if (prog == NULL)
 268         *     goto out;
 269         */
 270        off = offsetof(struct bpf_array, ptrs);
 271        emit_a64_mov_i64(tmp, off, ctx);
 272        emit(A64_ADD(1, tmp, r2, tmp), ctx);
 273        emit(A64_LSL(1, prg, r3, 3), ctx);
 274        emit(A64_LDR64(prg, tmp, prg), ctx);
 275        emit(A64_CBZ(1, prg, jmp_offset), ctx);
 276
 277        /* goto *(prog->bpf_func + prologue_offset); */
 278        off = offsetof(struct bpf_prog, bpf_func);
 279        emit_a64_mov_i64(tmp, off, ctx);
 280        emit(A64_LDR64(tmp, prg, tmp), ctx);
 281        emit(A64_ADD_I(1, tmp, tmp, sizeof(u32) * PROLOGUE_OFFSET), ctx);
 282        emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
 283        emit(A64_BR(tmp), ctx);
 284
 285        /* out: */
 286        if (out_offset == -1)
 287                out_offset = cur_offset;
 288        if (cur_offset != out_offset) {
 289                pr_err_once("tail_call out_offset = %d, expected %d!\n",
 290                            cur_offset, out_offset);
 291                return -1;
 292        }
 293        return 0;
 294#undef cur_offset
 295#undef jmp_offset
 296}
 297
 298static void build_epilogue(struct jit_ctx *ctx)
 299{
 300        const u8 r0 = bpf2a64[BPF_REG_0];
 301        const u8 r6 = bpf2a64[BPF_REG_6];
 302        const u8 r7 = bpf2a64[BPF_REG_7];
 303        const u8 r8 = bpf2a64[BPF_REG_8];
 304        const u8 r9 = bpf2a64[BPF_REG_9];
 305        const u8 fp = bpf2a64[BPF_REG_FP];
 306
 307        /* We're done with BPF stack */
 308        emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
 309
 310        /* Restore fs (x25) and x26 */
 311        emit(A64_POP(fp, A64_R(26), A64_SP), ctx);
 312
 313        /* Restore callee-saved register */
 314        emit(A64_POP(r8, r9, A64_SP), ctx);
 315        emit(A64_POP(r6, r7, A64_SP), ctx);
 316
 317        /* Restore FP/LR registers */
 318        emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
 319
 320        /* Set return value */
 321        emit(A64_MOV(1, A64_R(0), r0), ctx);
 322
 323        emit(A64_RET(A64_LR), ctx);
 324}
 325
 326/* JITs an eBPF instruction.
 327 * Returns:
 328 * 0  - successfully JITed an 8-byte eBPF instruction.
 329 * >0 - successfully JITed a 16-byte eBPF instruction.
 330 * <0 - failed to JIT.
 331 */
 332static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx)
 333{
 334        const u8 code = insn->code;
 335        const u8 dst = bpf2a64[insn->dst_reg];
 336        const u8 src = bpf2a64[insn->src_reg];
 337        const u8 tmp = bpf2a64[TMP_REG_1];
 338        const u8 tmp2 = bpf2a64[TMP_REG_2];
 339        const u8 tmp3 = bpf2a64[TMP_REG_3];
 340        const s16 off = insn->off;
 341        const s32 imm = insn->imm;
 342        const int i = insn - ctx->prog->insnsi;
 343        const bool is64 = BPF_CLASS(code) == BPF_ALU64;
 344        const bool isdw = BPF_SIZE(code) == BPF_DW;
 345        u8 jmp_cond;
 346        s32 jmp_offset;
 347
 348#define check_imm(bits, imm) do {                               \
 349        if ((((imm) > 0) && ((imm) >> (bits))) ||               \
 350            (((imm) < 0) && (~(imm) >> (bits)))) {              \
 351                pr_info("[%2d] imm=%d(0x%x) out of range\n",    \
 352                        i, imm, imm);                           \
 353                return -EINVAL;                                 \
 354        }                                                       \
 355} while (0)
 356#define check_imm19(imm) check_imm(19, imm)
 357#define check_imm26(imm) check_imm(26, imm)
 358
 359        switch (code) {
 360        /* dst = src */
 361        case BPF_ALU | BPF_MOV | BPF_X:
 362        case BPF_ALU64 | BPF_MOV | BPF_X:
 363                emit(A64_MOV(is64, dst, src), ctx);
 364                break;
 365        /* dst = dst OP src */
 366        case BPF_ALU | BPF_ADD | BPF_X:
 367        case BPF_ALU64 | BPF_ADD | BPF_X:
 368                emit(A64_ADD(is64, dst, dst, src), ctx);
 369                break;
 370        case BPF_ALU | BPF_SUB | BPF_X:
 371        case BPF_ALU64 | BPF_SUB | BPF_X:
 372                emit(A64_SUB(is64, dst, dst, src), ctx);
 373                break;
 374        case BPF_ALU | BPF_AND | BPF_X:
 375        case BPF_ALU64 | BPF_AND | BPF_X:
 376                emit(A64_AND(is64, dst, dst, src), ctx);
 377                break;
 378        case BPF_ALU | BPF_OR | BPF_X:
 379        case BPF_ALU64 | BPF_OR | BPF_X:
 380                emit(A64_ORR(is64, dst, dst, src), ctx);
 381                break;
 382        case BPF_ALU | BPF_XOR | BPF_X:
 383        case BPF_ALU64 | BPF_XOR | BPF_X:
 384                emit(A64_EOR(is64, dst, dst, src), ctx);
 385                break;
 386        case BPF_ALU | BPF_MUL | BPF_X:
 387        case BPF_ALU64 | BPF_MUL | BPF_X:
 388                emit(A64_MUL(is64, dst, dst, src), ctx);
 389                break;
 390        case BPF_ALU | BPF_DIV | BPF_X:
 391        case BPF_ALU64 | BPF_DIV | BPF_X:
 392        case BPF_ALU | BPF_MOD | BPF_X:
 393        case BPF_ALU64 | BPF_MOD | BPF_X:
 394                switch (BPF_OP(code)) {
 395                case BPF_DIV:
 396                        emit(A64_UDIV(is64, dst, dst, src), ctx);
 397                        break;
 398                case BPF_MOD:
 399                        emit(A64_UDIV(is64, tmp, dst, src), ctx);
 400                        emit(A64_MUL(is64, tmp, tmp, src), ctx);
 401                        emit(A64_SUB(is64, dst, dst, tmp), ctx);
 402                        break;
 403                }
 404                break;
 405        case BPF_ALU | BPF_LSH | BPF_X:
 406        case BPF_ALU64 | BPF_LSH | BPF_X:
 407                emit(A64_LSLV(is64, dst, dst, src), ctx);
 408                break;
 409        case BPF_ALU | BPF_RSH | BPF_X:
 410        case BPF_ALU64 | BPF_RSH | BPF_X:
 411                emit(A64_LSRV(is64, dst, dst, src), ctx);
 412                break;
 413        case BPF_ALU | BPF_ARSH | BPF_X:
 414        case BPF_ALU64 | BPF_ARSH | BPF_X:
 415                emit(A64_ASRV(is64, dst, dst, src), ctx);
 416                break;
 417        /* dst = -dst */
 418        case BPF_ALU | BPF_NEG:
 419        case BPF_ALU64 | BPF_NEG:
 420                emit(A64_NEG(is64, dst, dst), ctx);
 421                break;
 422        /* dst = BSWAP##imm(dst) */
 423        case BPF_ALU | BPF_END | BPF_FROM_LE:
 424        case BPF_ALU | BPF_END | BPF_FROM_BE:
 425#ifdef CONFIG_CPU_BIG_ENDIAN
 426                if (BPF_SRC(code) == BPF_FROM_BE)
 427                        goto emit_bswap_uxt;
 428#else /* !CONFIG_CPU_BIG_ENDIAN */
 429                if (BPF_SRC(code) == BPF_FROM_LE)
 430                        goto emit_bswap_uxt;
 431#endif
 432                switch (imm) {
 433                case 16:
 434                        emit(A64_REV16(is64, dst, dst), ctx);
 435                        /* zero-extend 16 bits into 64 bits */
 436                        emit(A64_UXTH(is64, dst, dst), ctx);
 437                        break;
 438                case 32:
 439                        emit(A64_REV32(is64, dst, dst), ctx);
 440                        /* upper 32 bits already cleared */
 441                        break;
 442                case 64:
 443                        emit(A64_REV64(dst, dst), ctx);
 444                        break;
 445                }
 446                break;
 447emit_bswap_uxt:
 448                switch (imm) {
 449                case 16:
 450                        /* zero-extend 16 bits into 64 bits */
 451                        emit(A64_UXTH(is64, dst, dst), ctx);
 452                        break;
 453                case 32:
 454                        /* zero-extend 32 bits into 64 bits */
 455                        emit(A64_UXTW(is64, dst, dst), ctx);
 456                        break;
 457                case 64:
 458                        /* nop */
 459                        break;
 460                }
 461                break;
 462        /* dst = imm */
 463        case BPF_ALU | BPF_MOV | BPF_K:
 464        case BPF_ALU64 | BPF_MOV | BPF_K:
 465                emit_a64_mov_i(is64, dst, imm, ctx);
 466                break;
 467        /* dst = dst OP imm */
 468        case BPF_ALU | BPF_ADD | BPF_K:
 469        case BPF_ALU64 | BPF_ADD | BPF_K:
 470                emit_a64_mov_i(is64, tmp, imm, ctx);
 471                emit(A64_ADD(is64, dst, dst, tmp), ctx);
 472                break;
 473        case BPF_ALU | BPF_SUB | BPF_K:
 474        case BPF_ALU64 | BPF_SUB | BPF_K:
 475                emit_a64_mov_i(is64, tmp, imm, ctx);
 476                emit(A64_SUB(is64, dst, dst, tmp), ctx);
 477                break;
 478        case BPF_ALU | BPF_AND | BPF_K:
 479        case BPF_ALU64 | BPF_AND | BPF_K:
 480                emit_a64_mov_i(is64, tmp, imm, ctx);
 481                emit(A64_AND(is64, dst, dst, tmp), ctx);
 482                break;
 483        case BPF_ALU | BPF_OR | BPF_K:
 484        case BPF_ALU64 | BPF_OR | BPF_K:
 485                emit_a64_mov_i(is64, tmp, imm, ctx);
 486                emit(A64_ORR(is64, dst, dst, tmp), ctx);
 487                break;
 488        case BPF_ALU | BPF_XOR | BPF_K:
 489        case BPF_ALU64 | BPF_XOR | BPF_K:
 490                emit_a64_mov_i(is64, tmp, imm, ctx);
 491                emit(A64_EOR(is64, dst, dst, tmp), ctx);
 492                break;
 493        case BPF_ALU | BPF_MUL | BPF_K:
 494        case BPF_ALU64 | BPF_MUL | BPF_K:
 495                emit_a64_mov_i(is64, tmp, imm, ctx);
 496                emit(A64_MUL(is64, dst, dst, tmp), ctx);
 497                break;
 498        case BPF_ALU | BPF_DIV | BPF_K:
 499        case BPF_ALU64 | BPF_DIV | BPF_K:
 500                emit_a64_mov_i(is64, tmp, imm, ctx);
 501                emit(A64_UDIV(is64, dst, dst, tmp), ctx);
 502                break;
 503        case BPF_ALU | BPF_MOD | BPF_K:
 504        case BPF_ALU64 | BPF_MOD | BPF_K:
 505                emit_a64_mov_i(is64, tmp2, imm, ctx);
 506                emit(A64_UDIV(is64, tmp, dst, tmp2), ctx);
 507                emit(A64_MUL(is64, tmp, tmp, tmp2), ctx);
 508                emit(A64_SUB(is64, dst, dst, tmp), ctx);
 509                break;
 510        case BPF_ALU | BPF_LSH | BPF_K:
 511        case BPF_ALU64 | BPF_LSH | BPF_K:
 512                emit(A64_LSL(is64, dst, dst, imm), ctx);
 513                break;
 514        case BPF_ALU | BPF_RSH | BPF_K:
 515        case BPF_ALU64 | BPF_RSH | BPF_K:
 516                emit(A64_LSR(is64, dst, dst, imm), ctx);
 517                break;
 518        case BPF_ALU | BPF_ARSH | BPF_K:
 519        case BPF_ALU64 | BPF_ARSH | BPF_K:
 520                emit(A64_ASR(is64, dst, dst, imm), ctx);
 521                break;
 522
 523        /* JUMP off */
 524        case BPF_JMP | BPF_JA:
 525                jmp_offset = bpf2a64_offset(i + off, i, ctx);
 526                check_imm26(jmp_offset);
 527                emit(A64_B(jmp_offset), ctx);
 528                break;
 529        /* IF (dst COND src) JUMP off */
 530        case BPF_JMP | BPF_JEQ | BPF_X:
 531        case BPF_JMP | BPF_JGT | BPF_X:
 532        case BPF_JMP | BPF_JLT | BPF_X:
 533        case BPF_JMP | BPF_JGE | BPF_X:
 534        case BPF_JMP | BPF_JLE | BPF_X:
 535        case BPF_JMP | BPF_JNE | BPF_X:
 536        case BPF_JMP | BPF_JSGT | BPF_X:
 537        case BPF_JMP | BPF_JSLT | BPF_X:
 538        case BPF_JMP | BPF_JSGE | BPF_X:
 539        case BPF_JMP | BPF_JSLE | BPF_X:
 540                emit(A64_CMP(1, dst, src), ctx);
 541emit_cond_jmp:
 542                jmp_offset = bpf2a64_offset(i + off, i, ctx);
 543                check_imm19(jmp_offset);
 544                switch (BPF_OP(code)) {
 545                case BPF_JEQ:
 546                        jmp_cond = A64_COND_EQ;
 547                        break;
 548                case BPF_JGT:
 549                        jmp_cond = A64_COND_HI;
 550                        break;
 551                case BPF_JLT:
 552                        jmp_cond = A64_COND_CC;
 553                        break;
 554                case BPF_JGE:
 555                        jmp_cond = A64_COND_CS;
 556                        break;
 557                case BPF_JLE:
 558                        jmp_cond = A64_COND_LS;
 559                        break;
 560                case BPF_JSET:
 561                case BPF_JNE:
 562                        jmp_cond = A64_COND_NE;
 563                        break;
 564                case BPF_JSGT:
 565                        jmp_cond = A64_COND_GT;
 566                        break;
 567                case BPF_JSLT:
 568                        jmp_cond = A64_COND_LT;
 569                        break;
 570                case BPF_JSGE:
 571                        jmp_cond = A64_COND_GE;
 572                        break;
 573                case BPF_JSLE:
 574                        jmp_cond = A64_COND_LE;
 575                        break;
 576                default:
 577                        return -EFAULT;
 578                }
 579                emit(A64_B_(jmp_cond, jmp_offset), ctx);
 580                break;
 581        case BPF_JMP | BPF_JSET | BPF_X:
 582                emit(A64_TST(1, dst, src), ctx);
 583                goto emit_cond_jmp;
 584        /* IF (dst COND imm) JUMP off */
 585        case BPF_JMP | BPF_JEQ | BPF_K:
 586        case BPF_JMP | BPF_JGT | BPF_K:
 587        case BPF_JMP | BPF_JLT | BPF_K:
 588        case BPF_JMP | BPF_JGE | BPF_K:
 589        case BPF_JMP | BPF_JLE | BPF_K:
 590        case BPF_JMP | BPF_JNE | BPF_K:
 591        case BPF_JMP | BPF_JSGT | BPF_K:
 592        case BPF_JMP | BPF_JSLT | BPF_K:
 593        case BPF_JMP | BPF_JSGE | BPF_K:
 594        case BPF_JMP | BPF_JSLE | BPF_K:
 595                emit_a64_mov_i(1, tmp, imm, ctx);
 596                emit(A64_CMP(1, dst, tmp), ctx);
 597                goto emit_cond_jmp;
 598        case BPF_JMP | BPF_JSET | BPF_K:
 599                emit_a64_mov_i(1, tmp, imm, ctx);
 600                emit(A64_TST(1, dst, tmp), ctx);
 601                goto emit_cond_jmp;
 602        /* function call */
 603        case BPF_JMP | BPF_CALL:
 604        {
 605                const u8 r0 = bpf2a64[BPF_REG_0];
 606                const u64 func = (u64)__bpf_call_base + imm;
 607
 608                if (ctx->prog->is_func)
 609                        emit_addr_mov_i64(tmp, func, ctx);
 610                else
 611                        emit_a64_mov_i64(tmp, func, ctx);
 612                emit(A64_BLR(tmp), ctx);
 613                emit(A64_MOV(1, r0, A64_R(0)), ctx);
 614                break;
 615        }
 616        /* tail call */
 617        case BPF_JMP | BPF_TAIL_CALL:
 618                if (emit_bpf_tail_call(ctx))
 619                        return -EFAULT;
 620                break;
 621        /* function return */
 622        case BPF_JMP | BPF_EXIT:
 623                /* Optimization: when last instruction is EXIT,
 624                   simply fallthrough to epilogue. */
 625                if (i == ctx->prog->len - 1)
 626                        break;
 627                jmp_offset = epilogue_offset(ctx);
 628                check_imm26(jmp_offset);
 629                emit(A64_B(jmp_offset), ctx);
 630                break;
 631
 632        /* dst = imm64 */
 633        case BPF_LD | BPF_IMM | BPF_DW:
 634        {
 635                const struct bpf_insn insn1 = insn[1];
 636                u64 imm64;
 637
 638                imm64 = (u64)insn1.imm << 32 | (u32)imm;
 639                emit_a64_mov_i64(dst, imm64, ctx);
 640
 641                return 1;
 642        }
 643
 644        /* LDX: dst = *(size *)(src + off) */
 645        case BPF_LDX | BPF_MEM | BPF_W:
 646        case BPF_LDX | BPF_MEM | BPF_H:
 647        case BPF_LDX | BPF_MEM | BPF_B:
 648        case BPF_LDX | BPF_MEM | BPF_DW:
 649                emit_a64_mov_i(1, tmp, off, ctx);
 650                switch (BPF_SIZE(code)) {
 651                case BPF_W:
 652                        emit(A64_LDR32(dst, src, tmp), ctx);
 653                        break;
 654                case BPF_H:
 655                        emit(A64_LDRH(dst, src, tmp), ctx);
 656                        break;
 657                case BPF_B:
 658                        emit(A64_LDRB(dst, src, tmp), ctx);
 659                        break;
 660                case BPF_DW:
 661                        emit(A64_LDR64(dst, src, tmp), ctx);
 662                        break;
 663                }
 664                break;
 665
 666        /* ST: *(size *)(dst + off) = imm */
 667        case BPF_ST | BPF_MEM | BPF_W:
 668        case BPF_ST | BPF_MEM | BPF_H:
 669        case BPF_ST | BPF_MEM | BPF_B:
 670        case BPF_ST | BPF_MEM | BPF_DW:
 671                /* Load imm to a register then store it */
 672                emit_a64_mov_i(1, tmp2, off, ctx);
 673                emit_a64_mov_i(1, tmp, imm, ctx);
 674                switch (BPF_SIZE(code)) {
 675                case BPF_W:
 676                        emit(A64_STR32(tmp, dst, tmp2), ctx);
 677                        break;
 678                case BPF_H:
 679                        emit(A64_STRH(tmp, dst, tmp2), ctx);
 680                        break;
 681                case BPF_B:
 682                        emit(A64_STRB(tmp, dst, tmp2), ctx);
 683                        break;
 684                case BPF_DW:
 685                        emit(A64_STR64(tmp, dst, tmp2), ctx);
 686                        break;
 687                }
 688                break;
 689
 690        /* STX: *(size *)(dst + off) = src */
 691        case BPF_STX | BPF_MEM | BPF_W:
 692        case BPF_STX | BPF_MEM | BPF_H:
 693        case BPF_STX | BPF_MEM | BPF_B:
 694        case BPF_STX | BPF_MEM | BPF_DW:
 695                emit_a64_mov_i(1, tmp, off, ctx);
 696                switch (BPF_SIZE(code)) {
 697                case BPF_W:
 698                        emit(A64_STR32(src, dst, tmp), ctx);
 699                        break;
 700                case BPF_H:
 701                        emit(A64_STRH(src, dst, tmp), ctx);
 702                        break;
 703                case BPF_B:
 704                        emit(A64_STRB(src, dst, tmp), ctx);
 705                        break;
 706                case BPF_DW:
 707                        emit(A64_STR64(src, dst, tmp), ctx);
 708                        break;
 709                }
 710                break;
 711        /* STX XADD: lock *(u32 *)(dst + off) += src */
 712        case BPF_STX | BPF_XADD | BPF_W:
 713        /* STX XADD: lock *(u64 *)(dst + off) += src */
 714        case BPF_STX | BPF_XADD | BPF_DW:
 715                emit_a64_mov_i(1, tmp, off, ctx);
 716                emit(A64_ADD(1, tmp, tmp, dst), ctx);
 717                emit(A64_PRFM(tmp, PST, L1, STRM), ctx);
 718                emit(A64_LDXR(isdw, tmp2, tmp), ctx);
 719                emit(A64_ADD(isdw, tmp2, tmp2, src), ctx);
 720                emit(A64_STXR(isdw, tmp2, tmp, tmp3), ctx);
 721                jmp_offset = -3;
 722                check_imm19(jmp_offset);
 723                emit(A64_CBNZ(0, tmp3, jmp_offset), ctx);
 724                break;
 725
 726        /* R0 = ntohx(*(size *)(((struct sk_buff *)R6)->data + imm)) */
 727        case BPF_LD | BPF_ABS | BPF_W:
 728        case BPF_LD | BPF_ABS | BPF_H:
 729        case BPF_LD | BPF_ABS | BPF_B:
 730        /* R0 = ntohx(*(size *)(((struct sk_buff *)R6)->data + src + imm)) */
 731        case BPF_LD | BPF_IND | BPF_W:
 732        case BPF_LD | BPF_IND | BPF_H:
 733        case BPF_LD | BPF_IND | BPF_B:
 734        {
 735                const u8 r0 = bpf2a64[BPF_REG_0]; /* r0 = return value */
 736                const u8 r6 = bpf2a64[BPF_REG_6]; /* r6 = pointer to sk_buff */
 737                const u8 fp = bpf2a64[BPF_REG_FP];
 738                const u8 r1 = bpf2a64[BPF_REG_1]; /* r1: struct sk_buff *skb */
 739                const u8 r2 = bpf2a64[BPF_REG_2]; /* r2: int k */
 740                const u8 r3 = bpf2a64[BPF_REG_3]; /* r3: unsigned int size */
 741                const u8 r4 = bpf2a64[BPF_REG_4]; /* r4: void *buffer */
 742                const u8 r5 = bpf2a64[BPF_REG_5]; /* r5: void *(*func)(...) */
 743                int size;
 744
 745                emit(A64_MOV(1, r1, r6), ctx);
 746                emit_a64_mov_i(0, r2, imm, ctx);
 747                if (BPF_MODE(code) == BPF_IND)
 748                        emit(A64_ADD(0, r2, r2, src), ctx);
 749                switch (BPF_SIZE(code)) {
 750                case BPF_W:
 751                        size = 4;
 752                        break;
 753                case BPF_H:
 754                        size = 2;
 755                        break;
 756                case BPF_B:
 757                        size = 1;
 758                        break;
 759                default:
 760                        return -EINVAL;
 761                }
 762                emit_a64_mov_i64(r3, size, ctx);
 763                emit(A64_SUB_I(1, r4, fp, ctx->stack_size), ctx);
 764                emit_a64_mov_i64(r5, (unsigned long)bpf_load_pointer, ctx);
 765                emit(A64_BLR(r5), ctx);
 766                emit(A64_MOV(1, r0, A64_R(0)), ctx);
 767
 768                jmp_offset = epilogue_offset(ctx);
 769                check_imm19(jmp_offset);
 770                emit(A64_CBZ(1, r0, jmp_offset), ctx);
 771                emit(A64_MOV(1, r5, r0), ctx);
 772                switch (BPF_SIZE(code)) {
 773                case BPF_W:
 774                        emit(A64_LDR32(r0, r5, A64_ZR), ctx);
 775#ifndef CONFIG_CPU_BIG_ENDIAN
 776                        emit(A64_REV32(0, r0, r0), ctx);
 777#endif
 778                        break;
 779                case BPF_H:
 780                        emit(A64_LDRH(r0, r5, A64_ZR), ctx);
 781#ifndef CONFIG_CPU_BIG_ENDIAN
 782                        emit(A64_REV16(0, r0, r0), ctx);
 783#endif
 784                        break;
 785                case BPF_B:
 786                        emit(A64_LDRB(r0, r5, A64_ZR), ctx);
 787                        break;
 788                }
 789                break;
 790        }
 791        default:
 792                pr_err_once("unknown opcode %02x\n", code);
 793                return -EINVAL;
 794        }
 795
 796        return 0;
 797}
 798
 799static int build_body(struct jit_ctx *ctx)
 800{
 801        const struct bpf_prog *prog = ctx->prog;
 802        int i;
 803
 804        for (i = 0; i < prog->len; i++) {
 805                const struct bpf_insn *insn = &prog->insnsi[i];
 806                int ret;
 807
 808                ret = build_insn(insn, ctx);
 809                if (ret > 0) {
 810                        i++;
 811                        if (ctx->image == NULL)
 812                                ctx->offset[i] = ctx->idx;
 813                        continue;
 814                }
 815                if (ctx->image == NULL)
 816                        ctx->offset[i] = ctx->idx;
 817                if (ret)
 818                        return ret;
 819        }
 820
 821        return 0;
 822}
 823
 824static int validate_code(struct jit_ctx *ctx)
 825{
 826        int i;
 827
 828        for (i = 0; i < ctx->idx; i++) {
 829                u32 a64_insn = le32_to_cpu(ctx->image[i]);
 830
 831                if (a64_insn == AARCH64_BREAK_FAULT)
 832                        return -1;
 833        }
 834
 835        return 0;
 836}
 837
 838static inline void bpf_flush_icache(void *start, void *end)
 839{
 840        flush_icache_range((unsigned long)start, (unsigned long)end);
 841}
 842
 843struct arm64_jit_data {
 844        struct bpf_binary_header *header;
 845        u8 *image;
 846        struct jit_ctx ctx;
 847};
 848
 849struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 850{
 851        struct bpf_prog *tmp, *orig_prog = prog;
 852        struct bpf_binary_header *header;
 853        struct arm64_jit_data *jit_data;
 854        bool tmp_blinded = false;
 855        bool extra_pass = false;
 856        struct jit_ctx ctx;
 857        int image_size;
 858        u8 *image_ptr;
 859
 860        if (!prog->jit_requested)
 861                return orig_prog;
 862
 863        tmp = bpf_jit_blind_constants(prog);
 864        /* If blinding was requested and we failed during blinding,
 865         * we must fall back to the interpreter.
 866         */
 867        if (IS_ERR(tmp))
 868                return orig_prog;
 869        if (tmp != prog) {
 870                tmp_blinded = true;
 871                prog = tmp;
 872        }
 873
 874        jit_data = prog->aux->jit_data;
 875        if (!jit_data) {
 876                jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
 877                if (!jit_data) {
 878                        prog = orig_prog;
 879                        goto out;
 880                }
 881                prog->aux->jit_data = jit_data;
 882        }
 883        if (jit_data->ctx.offset) {
 884                ctx = jit_data->ctx;
 885                image_ptr = jit_data->image;
 886                header = jit_data->header;
 887                extra_pass = true;
 888                image_size = sizeof(u32) * ctx.idx;
 889                goto skip_init_ctx;
 890        }
 891        memset(&ctx, 0, sizeof(ctx));
 892        ctx.prog = prog;
 893
 894        ctx.offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
 895        if (ctx.offset == NULL) {
 896                prog = orig_prog;
 897                goto out_off;
 898        }
 899
 900        /* 1. Initial fake pass to compute ctx->idx. */
 901
 902        /* Fake pass to fill in ctx->offset. */
 903        if (build_body(&ctx)) {
 904                prog = orig_prog;
 905                goto out_off;
 906        }
 907
 908        if (build_prologue(&ctx)) {
 909                prog = orig_prog;
 910                goto out_off;
 911        }
 912
 913        ctx.epilogue_offset = ctx.idx;
 914        build_epilogue(&ctx);
 915
 916        /* Now we know the actual image size. */
 917        image_size = sizeof(u32) * ctx.idx;
 918        header = bpf_jit_binary_alloc(image_size, &image_ptr,
 919                                      sizeof(u32), jit_fill_hole);
 920        if (header == NULL) {
 921                prog = orig_prog;
 922                goto out_off;
 923        }
 924
 925        /* 2. Now, the actual pass. */
 926
 927        ctx.image = (__le32 *)image_ptr;
 928skip_init_ctx:
 929        ctx.idx = 0;
 930
 931        build_prologue(&ctx);
 932
 933        if (build_body(&ctx)) {
 934                bpf_jit_binary_free(header);
 935                prog = orig_prog;
 936                goto out_off;
 937        }
 938
 939        build_epilogue(&ctx);
 940
 941        /* 3. Extra pass to validate JITed code. */
 942        if (validate_code(&ctx)) {
 943                bpf_jit_binary_free(header);
 944                prog = orig_prog;
 945                goto out_off;
 946        }
 947
 948        /* And we're done. */
 949        if (bpf_jit_enable > 1)
 950                bpf_jit_dump(prog->len, image_size, 2, ctx.image);
 951
 952        bpf_flush_icache(header, ctx.image + ctx.idx);
 953
 954        if (!prog->is_func || extra_pass) {
 955                if (extra_pass && ctx.idx != jit_data->ctx.idx) {
 956                        pr_err_once("multi-func JIT bug %d != %d\n",
 957                                    ctx.idx, jit_data->ctx.idx);
 958                        bpf_jit_binary_free(header);
 959                        prog->bpf_func = NULL;
 960                        prog->jited = 0;
 961                        goto out_off;
 962                }
 963                bpf_jit_binary_lock_ro(header);
 964        } else {
 965                jit_data->ctx = ctx;
 966                jit_data->image = image_ptr;
 967                jit_data->header = header;
 968        }
 969        prog->bpf_func = (void *)ctx.image;
 970        prog->jited = 1;
 971        prog->jited_len = image_size;
 972
 973        if (!prog->is_func || extra_pass) {
 974out_off:
 975                kfree(ctx.offset);
 976                kfree(jit_data);
 977                prog->aux->jit_data = NULL;
 978        }
 979out:
 980        if (tmp_blinded)
 981                bpf_jit_prog_release_other(prog, prog == orig_prog ?
 982                                           tmp : orig_prog);
 983        return prog;
 984}
 985