linux/arch/x86/net/bpf_jit_comp.c
<<
>>
Prefs
   1/* bpf_jit_comp.c : BPF JIT compiler
   2 *
   3 * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
   4 * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
   5 *
   6 * This program is free software; you can redistribute it and/or
   7 * modify it under the terms of the GNU General Public License
   8 * as published by the Free Software Foundation; version 2
   9 * of the License.
  10 */
  11#include <linux/netdevice.h>
  12#include <linux/filter.h>
  13#include <linux/if_vlan.h>
  14#include <linux/bpf.h>
  15
  16#include <asm/set_memory.h>
  17#include <asm/nospec-branch.h>
  18
  19/*
  20 * assembly code in arch/x86/net/bpf_jit.S
  21 */
  22extern u8 sk_load_word[], sk_load_half[], sk_load_byte[];
  23extern u8 sk_load_word_positive_offset[], sk_load_half_positive_offset[];
  24extern u8 sk_load_byte_positive_offset[];
  25extern u8 sk_load_word_negative_offset[], sk_load_half_negative_offset[];
  26extern u8 sk_load_byte_negative_offset[];
  27
  28static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
  29{
  30        if (len == 1)
  31                *ptr = bytes;
  32        else if (len == 2)
  33                *(u16 *)ptr = bytes;
  34        else {
  35                *(u32 *)ptr = bytes;
  36                barrier();
  37        }
  38        return ptr + len;
  39}
  40
  41#define EMIT(bytes, len) \
  42        do { prog = emit_code(prog, bytes, len); cnt += len; } while (0)
  43
  44#define EMIT1(b1)               EMIT(b1, 1)
  45#define EMIT2(b1, b2)           EMIT((b1) + ((b2) << 8), 2)
  46#define EMIT3(b1, b2, b3)       EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
  47#define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
  48#define EMIT1_off32(b1, off) \
  49        do {EMIT1(b1); EMIT(off, 4); } while (0)
  50#define EMIT2_off32(b1, b2, off) \
  51        do {EMIT2(b1, b2); EMIT(off, 4); } while (0)
  52#define EMIT3_off32(b1, b2, b3, off) \
  53        do {EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
  54#define EMIT4_off32(b1, b2, b3, b4, off) \
  55        do {EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
  56
  57static bool is_imm8(int value)
  58{
  59        return value <= 127 && value >= -128;
  60}
  61
  62static bool is_simm32(s64 value)
  63{
  64        return value == (s64)(s32)value;
  65}
  66
  67static bool is_uimm32(u64 value)
  68{
  69        return value == (u64)(u32)value;
  70}
  71
  72/* mov dst, src */
  73#define EMIT_mov(DST, SRC) \
  74        do {if (DST != SRC) \
  75                EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
  76        } while (0)
  77
  78static int bpf_size_to_x86_bytes(int bpf_size)
  79{
  80        if (bpf_size == BPF_W)
  81                return 4;
  82        else if (bpf_size == BPF_H)
  83                return 2;
  84        else if (bpf_size == BPF_B)
  85                return 1;
  86        else if (bpf_size == BPF_DW)
  87                return 4; /* imm32 */
  88        else
  89                return 0;
  90}
  91
  92/* list of x86 cond jumps opcodes (. + s8)
  93 * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
  94 */
  95#define X86_JB  0x72
  96#define X86_JAE 0x73
  97#define X86_JE  0x74
  98#define X86_JNE 0x75
  99#define X86_JBE 0x76
 100#define X86_JA  0x77
 101#define X86_JL  0x7C
 102#define X86_JGE 0x7D
 103#define X86_JLE 0x7E
 104#define X86_JG  0x7F
 105
 106#define CHOOSE_LOAD_FUNC(K, func) \
 107        ((int)K < 0 ? ((int)K >= SKF_LL_OFF ? func##_negative_offset : func) : func##_positive_offset)
 108
 109/* pick a register outside of BPF range for JIT internal work */
 110#define AUX_REG (MAX_BPF_JIT_REG + 1)
 111
 112/* The following table maps BPF registers to x64 registers.
 113 *
 114 * x64 register r12 is unused, since if used as base address
 115 * register in load/store instructions, it always needs an
 116 * extra byte of encoding and is callee saved.
 117 *
 118 *  r9 caches skb->len - skb->data_len
 119 * r10 caches skb->data, and used for blinding (if enabled)
 120 */
 121static const int reg2hex[] = {
 122        [BPF_REG_0] = 0,  /* rax */
 123        [BPF_REG_1] = 7,  /* rdi */
 124        [BPF_REG_2] = 6,  /* rsi */
 125        [BPF_REG_3] = 2,  /* rdx */
 126        [BPF_REG_4] = 1,  /* rcx */
 127        [BPF_REG_5] = 0,  /* r8 */
 128        [BPF_REG_6] = 3,  /* rbx callee saved */
 129        [BPF_REG_7] = 5,  /* r13 callee saved */
 130        [BPF_REG_8] = 6,  /* r14 callee saved */
 131        [BPF_REG_9] = 7,  /* r15 callee saved */
 132        [BPF_REG_FP] = 5, /* rbp readonly */
 133        [BPF_REG_AX] = 2, /* r10 temp register */
 134        [AUX_REG] = 3,    /* r11 temp register */
 135};
 136
 137/* is_ereg() == true if BPF register 'reg' maps to x64 r8..r15
 138 * which need extra byte of encoding.
 139 * rax,rcx,...,rbp have simpler encoding
 140 */
 141static bool is_ereg(u32 reg)
 142{
 143        return (1 << reg) & (BIT(BPF_REG_5) |
 144                             BIT(AUX_REG) |
 145                             BIT(BPF_REG_7) |
 146                             BIT(BPF_REG_8) |
 147                             BIT(BPF_REG_9) |
 148                             BIT(BPF_REG_AX));
 149}
 150
 151static bool is_axreg(u32 reg)
 152{
 153        return reg == BPF_REG_0;
 154}
 155
 156/* add modifiers if 'reg' maps to x64 registers r8..r15 */
 157static u8 add_1mod(u8 byte, u32 reg)
 158{
 159        if (is_ereg(reg))
 160                byte |= 1;
 161        return byte;
 162}
 163
 164static u8 add_2mod(u8 byte, u32 r1, u32 r2)
 165{
 166        if (is_ereg(r1))
 167                byte |= 1;
 168        if (is_ereg(r2))
 169                byte |= 4;
 170        return byte;
 171}
 172
 173/* encode 'dst_reg' register into x64 opcode 'byte' */
 174static u8 add_1reg(u8 byte, u32 dst_reg)
 175{
 176        return byte + reg2hex[dst_reg];
 177}
 178
 179/* encode 'dst_reg' and 'src_reg' registers into x64 opcode 'byte' */
 180static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
 181{
 182        return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
 183}
 184
 185static void jit_fill_hole(void *area, unsigned int size)
 186{
 187        /* fill whole space with int3 instructions */
 188        memset(area, 0xcc, size);
 189}
 190
 191struct jit_context {
 192        int cleanup_addr; /* epilogue code offset */
 193        bool seen_ld_abs;
 194        bool seen_ax_reg;
 195};
 196
 197/* maximum number of bytes emitted while JITing one eBPF insn */
 198#define BPF_MAX_INSN_SIZE       128
 199#define BPF_INSN_SAFETY         64
 200
 201#define AUX_STACK_SPACE \
 202        (32 /* space for rbx, r13, r14, r15 */ + \
 203         8 /* space for skb_copy_bits() buffer */)
 204
 205#define PROLOGUE_SIZE 37
 206
 207/* emit x64 prologue code for BPF program and check it's size.
 208 * bpf_tail_call helper will skip it while jumping into another program
 209 */
 210static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
 211{
 212        u8 *prog = *pprog;
 213        int cnt = 0;
 214
 215        EMIT1(0x55); /* push rbp */
 216        EMIT3(0x48, 0x89, 0xE5); /* mov rbp,rsp */
 217
 218        /* sub rsp, rounded_stack_depth + AUX_STACK_SPACE */
 219        EMIT3_off32(0x48, 0x81, 0xEC,
 220                    round_up(stack_depth, 8) + AUX_STACK_SPACE);
 221
 222        /* sub rbp, AUX_STACK_SPACE */
 223        EMIT4(0x48, 0x83, 0xED, AUX_STACK_SPACE);
 224
 225        /* all classic BPF filters use R6(rbx) save it */
 226
 227        /* mov qword ptr [rbp+0],rbx */
 228        EMIT4(0x48, 0x89, 0x5D, 0);
 229
 230        /* bpf_convert_filter() maps classic BPF register X to R7 and uses R8
 231         * as temporary, so all tcpdump filters need to spill/fill R7(r13) and
 232         * R8(r14). R9(r15) spill could be made conditional, but there is only
 233         * one 'bpf_error' return path out of helper functions inside bpf_jit.S
 234         * The overhead of extra spill is negligible for any filter other
 235         * than synthetic ones. Therefore not worth adding complexity.
 236         */
 237
 238        /* mov qword ptr [rbp+8],r13 */
 239        EMIT4(0x4C, 0x89, 0x6D, 8);
 240        /* mov qword ptr [rbp+16],r14 */
 241        EMIT4(0x4C, 0x89, 0x75, 16);
 242        /* mov qword ptr [rbp+24],r15 */
 243        EMIT4(0x4C, 0x89, 0x7D, 24);
 244
 245        if (!ebpf_from_cbpf) {
 246                /* Clear the tail call counter (tail_call_cnt): for eBPF tail
 247                 * calls we need to reset the counter to 0. It's done in two
 248                 * instructions, resetting rax register to 0, and moving it
 249                 * to the counter location.
 250                 */
 251
 252                /* xor eax, eax */
 253                EMIT2(0x31, 0xc0);
 254                /* mov qword ptr [rbp+32], rax */
 255                EMIT4(0x48, 0x89, 0x45, 32);
 256
 257                BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
 258        }
 259
 260        *pprog = prog;
 261}
 262
 263/* generate the following code:
 264 * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
 265 *   if (index >= array->map.max_entries)
 266 *     goto out;
 267 *   if (++tail_call_cnt > MAX_TAIL_CALL_CNT)
 268 *     goto out;
 269 *   prog = array->ptrs[index];
 270 *   if (prog == NULL)
 271 *     goto out;
 272 *   goto *(prog->bpf_func + prologue_size);
 273 * out:
 274 */
 275static void emit_bpf_tail_call(u8 **pprog)
 276{
 277        u8 *prog = *pprog;
 278        int label1, label2, label3;
 279        int cnt = 0;
 280
 281        /* rdi - pointer to ctx
 282         * rsi - pointer to bpf_array
 283         * rdx - index in bpf_array
 284         */
 285
 286        /* if (index >= array->map.max_entries)
 287         *   goto out;
 288         */
 289        EMIT2(0x89, 0xD2);                        /* mov edx, edx */
 290        EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
 291              offsetof(struct bpf_array, map.max_entries));
 292#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* number of bytes to jump */
 293        EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
 294        label1 = cnt;
 295
 296        /* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
 297         *   goto out;
 298         */
 299        EMIT2_off32(0x8B, 0x85, 36);              /* mov eax, dword ptr [rbp + 36] */
 300        EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
 301#define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE)
 302        EMIT2(X86_JA, OFFSET2);                   /* ja out */
 303        label2 = cnt;
 304        EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
 305        EMIT2_off32(0x89, 0x85, 36);              /* mov dword ptr [rbp + 36], eax */
 306
 307        /* prog = array->ptrs[index]; */
 308        EMIT4_off32(0x48, 0x8B, 0x84, 0xD6,       /* mov rax, [rsi + rdx * 8 + offsetof(...)] */
 309                    offsetof(struct bpf_array, ptrs));
 310
 311        /* if (prog == NULL)
 312         *   goto out;
 313         */
 314        EMIT3(0x48, 0x85, 0xC0);                  /* test rax,rax */
 315#define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
 316        EMIT2(X86_JE, OFFSET3);                   /* je out */
 317        label3 = cnt;
 318
 319        /* goto *(prog->bpf_func + prologue_size); */
 320        EMIT4(0x48, 0x8B, 0x40,                   /* mov rax, qword ptr [rax + 32] */
 321              offsetof(struct bpf_prog, bpf_func));
 322        EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE);   /* add rax, prologue_size */
 323
 324        /* now we're ready to jump into next BPF program
 325         * rdi == ctx (1st arg)
 326         * rax == prog->bpf_func + prologue_size
 327         */
 328        RETPOLINE_RAX_BPF_JIT();
 329
 330        /* out: */
 331        BUILD_BUG_ON(cnt - label1 != OFFSET1);
 332        BUILD_BUG_ON(cnt - label2 != OFFSET2);
 333        BUILD_BUG_ON(cnt - label3 != OFFSET3);
 334        *pprog = prog;
 335}
 336
 337
 338static void emit_load_skb_data_hlen(u8 **pprog)
 339{
 340        u8 *prog = *pprog;
 341        int cnt = 0;
 342
 343        /* r9d = skb->len - skb->data_len (headlen)
 344         * r10 = skb->data
 345         */
 346        /* mov %r9d, off32(%rdi) */
 347        EMIT3_off32(0x44, 0x8b, 0x8f, offsetof(struct sk_buff, len));
 348
 349        /* sub %r9d, off32(%rdi) */
 350        EMIT3_off32(0x44, 0x2b, 0x8f, offsetof(struct sk_buff, data_len));
 351
 352        /* mov %r10, off32(%rdi) */
 353        EMIT3_off32(0x4c, 0x8b, 0x97, offsetof(struct sk_buff, data));
 354        *pprog = prog;
 355}
 356
 357static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
 358                           u32 dst_reg, const u32 imm32)
 359{
 360        u8 *prog = *pprog;
 361        u8 b1, b2, b3;
 362        int cnt = 0;
 363
 364        /* optimization: if imm32 is positive, use 'mov %eax, imm32'
 365         * (which zero-extends imm32) to save 2 bytes.
 366         */
 367        if (sign_propagate && (s32)imm32 < 0) {
 368                /* 'mov %rax, imm32' sign extends imm32 */
 369                b1 = add_1mod(0x48, dst_reg);
 370                b2 = 0xC7;
 371                b3 = 0xC0;
 372                EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
 373                goto done;
 374        }
 375
 376        /* optimization: if imm32 is zero, use 'xor %eax, %eax'
 377         * to save 3 bytes.
 378         */
 379        if (imm32 == 0) {
 380                if (is_ereg(dst_reg))
 381                        EMIT1(add_2mod(0x40, dst_reg, dst_reg));
 382                b2 = 0x31; /* xor */
 383                b3 = 0xC0;
 384                EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
 385                goto done;
 386        }
 387
 388        /* mov %eax, imm32 */
 389        if (is_ereg(dst_reg))
 390                EMIT1(add_1mod(0x40, dst_reg));
 391        EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
 392done:
 393        *pprog = prog;
 394}
 395
 396static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
 397                           const u32 imm32_hi, const u32 imm32_lo)
 398{
 399        u8 *prog = *pprog;
 400        int cnt = 0;
 401
 402        if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
 403                /* For emitting plain u32, where sign bit must not be
 404                 * propagated LLVM tends to load imm64 over mov32
 405                 * directly, so save couple of bytes by just doing
 406                 * 'mov %eax, imm32' instead.
 407                 */
 408                emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
 409        } else {
 410                /* movabsq %rax, imm64 */
 411                EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
 412                EMIT(imm32_lo, 4);
 413                EMIT(imm32_hi, 4);
 414        }
 415
 416        *pprog = prog;
 417}
 418
 419static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
 420{
 421        u8 *prog = *pprog;
 422        int cnt = 0;
 423
 424        if (is64) {
 425                /* mov dst, src */
 426                EMIT_mov(dst_reg, src_reg);
 427        } else {
 428                /* mov32 dst, src */
 429                if (is_ereg(dst_reg) || is_ereg(src_reg))
 430                        EMIT1(add_2mod(0x40, dst_reg, src_reg));
 431                EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
 432        }
 433
 434        *pprog = prog;
 435}
 436
 437static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 438                  int oldproglen, struct jit_context *ctx)
 439{
 440        struct bpf_insn *insn = bpf_prog->insnsi;
 441        int insn_cnt = bpf_prog->len;
 442        bool seen_ld_abs = ctx->seen_ld_abs | (oldproglen == 0);
 443        bool seen_ax_reg = ctx->seen_ax_reg | (oldproglen == 0);
 444        bool seen_exit = false;
 445        u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
 446        int i, cnt = 0;
 447        int proglen = 0;
 448        u8 *prog = temp;
 449
 450        emit_prologue(&prog, bpf_prog->aux->stack_depth,
 451                      bpf_prog_was_classic(bpf_prog));
 452
 453        if (seen_ld_abs)
 454                emit_load_skb_data_hlen(&prog);
 455
 456        for (i = 0; i < insn_cnt; i++, insn++) {
 457                const s32 imm32 = insn->imm;
 458                u32 dst_reg = insn->dst_reg;
 459                u32 src_reg = insn->src_reg;
 460                u8 b2 = 0, b3 = 0;
 461                s64 jmp_offset;
 462                u8 jmp_cond;
 463                bool reload_skb_data;
 464                int ilen;
 465                u8 *func;
 466
 467                if (dst_reg == BPF_REG_AX || src_reg == BPF_REG_AX)
 468                        ctx->seen_ax_reg = seen_ax_reg = true;
 469
 470                switch (insn->code) {
 471                        /* ALU */
 472                case BPF_ALU | BPF_ADD | BPF_X:
 473                case BPF_ALU | BPF_SUB | BPF_X:
 474                case BPF_ALU | BPF_AND | BPF_X:
 475                case BPF_ALU | BPF_OR | BPF_X:
 476                case BPF_ALU | BPF_XOR | BPF_X:
 477                case BPF_ALU64 | BPF_ADD | BPF_X:
 478                case BPF_ALU64 | BPF_SUB | BPF_X:
 479                case BPF_ALU64 | BPF_AND | BPF_X:
 480                case BPF_ALU64 | BPF_OR | BPF_X:
 481                case BPF_ALU64 | BPF_XOR | BPF_X:
 482                        switch (BPF_OP(insn->code)) {
 483                        case BPF_ADD: b2 = 0x01; break;
 484                        case BPF_SUB: b2 = 0x29; break;
 485                        case BPF_AND: b2 = 0x21; break;
 486                        case BPF_OR: b2 = 0x09; break;
 487                        case BPF_XOR: b2 = 0x31; break;
 488                        }
 489                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 490                                EMIT1(add_2mod(0x48, dst_reg, src_reg));
 491                        else if (is_ereg(dst_reg) || is_ereg(src_reg))
 492                                EMIT1(add_2mod(0x40, dst_reg, src_reg));
 493                        EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
 494                        break;
 495
 496                case BPF_ALU64 | BPF_MOV | BPF_X:
 497                case BPF_ALU | BPF_MOV | BPF_X:
 498                        emit_mov_reg(&prog,
 499                                     BPF_CLASS(insn->code) == BPF_ALU64,
 500                                     dst_reg, src_reg);
 501                        break;
 502
 503                        /* neg dst */
 504                case BPF_ALU | BPF_NEG:
 505                case BPF_ALU64 | BPF_NEG:
 506                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 507                                EMIT1(add_1mod(0x48, dst_reg));
 508                        else if (is_ereg(dst_reg))
 509                                EMIT1(add_1mod(0x40, dst_reg));
 510                        EMIT2(0xF7, add_1reg(0xD8, dst_reg));
 511                        break;
 512
 513                case BPF_ALU | BPF_ADD | BPF_K:
 514                case BPF_ALU | BPF_SUB | BPF_K:
 515                case BPF_ALU | BPF_AND | BPF_K:
 516                case BPF_ALU | BPF_OR | BPF_K:
 517                case BPF_ALU | BPF_XOR | BPF_K:
 518                case BPF_ALU64 | BPF_ADD | BPF_K:
 519                case BPF_ALU64 | BPF_SUB | BPF_K:
 520                case BPF_ALU64 | BPF_AND | BPF_K:
 521                case BPF_ALU64 | BPF_OR | BPF_K:
 522                case BPF_ALU64 | BPF_XOR | BPF_K:
 523                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 524                                EMIT1(add_1mod(0x48, dst_reg));
 525                        else if (is_ereg(dst_reg))
 526                                EMIT1(add_1mod(0x40, dst_reg));
 527
 528                        /* b3 holds 'normal' opcode, b2 short form only valid
 529                         * in case dst is eax/rax.
 530                         */
 531                        switch (BPF_OP(insn->code)) {
 532                        case BPF_ADD:
 533                                b3 = 0xC0;
 534                                b2 = 0x05;
 535                                break;
 536                        case BPF_SUB:
 537                                b3 = 0xE8;
 538                                b2 = 0x2D;
 539                                break;
 540                        case BPF_AND:
 541                                b3 = 0xE0;
 542                                b2 = 0x25;
 543                                break;
 544                        case BPF_OR:
 545                                b3 = 0xC8;
 546                                b2 = 0x0D;
 547                                break;
 548                        case BPF_XOR:
 549                                b3 = 0xF0;
 550                                b2 = 0x35;
 551                                break;
 552                        }
 553
 554                        if (is_imm8(imm32))
 555                                EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
 556                        else if (is_axreg(dst_reg))
 557                                EMIT1_off32(b2, imm32);
 558                        else
 559                                EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
 560                        break;
 561
 562                case BPF_ALU64 | BPF_MOV | BPF_K:
 563                case BPF_ALU | BPF_MOV | BPF_K:
 564                        emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
 565                                       dst_reg, imm32);
 566                        break;
 567
 568                case BPF_LD | BPF_IMM | BPF_DW:
 569                        emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
 570                        insn++;
 571                        i++;
 572                        break;
 573
 574                        /* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
 575                case BPF_ALU | BPF_MOD | BPF_X:
 576                case BPF_ALU | BPF_DIV | BPF_X:
 577                case BPF_ALU | BPF_MOD | BPF_K:
 578                case BPF_ALU | BPF_DIV | BPF_K:
 579                case BPF_ALU64 | BPF_MOD | BPF_X:
 580                case BPF_ALU64 | BPF_DIV | BPF_X:
 581                case BPF_ALU64 | BPF_MOD | BPF_K:
 582                case BPF_ALU64 | BPF_DIV | BPF_K:
 583                        EMIT1(0x50); /* push rax */
 584                        EMIT1(0x52); /* push rdx */
 585
 586                        if (BPF_SRC(insn->code) == BPF_X)
 587                                /* mov r11, src_reg */
 588                                EMIT_mov(AUX_REG, src_reg);
 589                        else
 590                                /* mov r11, imm32 */
 591                                EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
 592
 593                        /* mov rax, dst_reg */
 594                        EMIT_mov(BPF_REG_0, dst_reg);
 595
 596                        /* xor edx, edx
 597                         * equivalent to 'xor rdx, rdx', but one byte less
 598                         */
 599                        EMIT2(0x31, 0xd2);
 600
 601                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 602                                /* div r11 */
 603                                EMIT3(0x49, 0xF7, 0xF3);
 604                        else
 605                                /* div r11d */
 606                                EMIT3(0x41, 0xF7, 0xF3);
 607
 608                        if (BPF_OP(insn->code) == BPF_MOD)
 609                                /* mov r11, rdx */
 610                                EMIT3(0x49, 0x89, 0xD3);
 611                        else
 612                                /* mov r11, rax */
 613                                EMIT3(0x49, 0x89, 0xC3);
 614
 615                        EMIT1(0x5A); /* pop rdx */
 616                        EMIT1(0x58); /* pop rax */
 617
 618                        /* mov dst_reg, r11 */
 619                        EMIT_mov(dst_reg, AUX_REG);
 620                        break;
 621
 622                case BPF_ALU | BPF_MUL | BPF_K:
 623                case BPF_ALU | BPF_MUL | BPF_X:
 624                case BPF_ALU64 | BPF_MUL | BPF_K:
 625                case BPF_ALU64 | BPF_MUL | BPF_X:
 626                {
 627                        bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;
 628
 629                        if (dst_reg != BPF_REG_0)
 630                                EMIT1(0x50); /* push rax */
 631                        if (dst_reg != BPF_REG_3)
 632                                EMIT1(0x52); /* push rdx */
 633
 634                        /* mov r11, dst_reg */
 635                        EMIT_mov(AUX_REG, dst_reg);
 636
 637                        if (BPF_SRC(insn->code) == BPF_X)
 638                                emit_mov_reg(&prog, is64, BPF_REG_0, src_reg);
 639                        else
 640                                emit_mov_imm32(&prog, is64, BPF_REG_0, imm32);
 641
 642                        if (is64)
 643                                EMIT1(add_1mod(0x48, AUX_REG));
 644                        else if (is_ereg(AUX_REG))
 645                                EMIT1(add_1mod(0x40, AUX_REG));
 646                        /* mul(q) r11 */
 647                        EMIT2(0xF7, add_1reg(0xE0, AUX_REG));
 648
 649                        if (dst_reg != BPF_REG_3)
 650                                EMIT1(0x5A); /* pop rdx */
 651                        if (dst_reg != BPF_REG_0) {
 652                                /* mov dst_reg, rax */
 653                                EMIT_mov(dst_reg, BPF_REG_0);
 654                                EMIT1(0x58); /* pop rax */
 655                        }
 656                        break;
 657                }
 658                        /* shifts */
 659                case BPF_ALU | BPF_LSH | BPF_K:
 660                case BPF_ALU | BPF_RSH | BPF_K:
 661                case BPF_ALU | BPF_ARSH | BPF_K:
 662                case BPF_ALU64 | BPF_LSH | BPF_K:
 663                case BPF_ALU64 | BPF_RSH | BPF_K:
 664                case BPF_ALU64 | BPF_ARSH | BPF_K:
 665                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 666                                EMIT1(add_1mod(0x48, dst_reg));
 667                        else if (is_ereg(dst_reg))
 668                                EMIT1(add_1mod(0x40, dst_reg));
 669
 670                        switch (BPF_OP(insn->code)) {
 671                        case BPF_LSH: b3 = 0xE0; break;
 672                        case BPF_RSH: b3 = 0xE8; break;
 673                        case BPF_ARSH: b3 = 0xF8; break;
 674                        }
 675
 676                        if (imm32 == 1)
 677                                EMIT2(0xD1, add_1reg(b3, dst_reg));
 678                        else
 679                                EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
 680                        break;
 681
 682                case BPF_ALU | BPF_LSH | BPF_X:
 683                case BPF_ALU | BPF_RSH | BPF_X:
 684                case BPF_ALU | BPF_ARSH | BPF_X:
 685                case BPF_ALU64 | BPF_LSH | BPF_X:
 686                case BPF_ALU64 | BPF_RSH | BPF_X:
 687                case BPF_ALU64 | BPF_ARSH | BPF_X:
 688
 689                        /* check for bad case when dst_reg == rcx */
 690                        if (dst_reg == BPF_REG_4) {
 691                                /* mov r11, dst_reg */
 692                                EMIT_mov(AUX_REG, dst_reg);
 693                                dst_reg = AUX_REG;
 694                        }
 695
 696                        if (src_reg != BPF_REG_4) { /* common case */
 697                                EMIT1(0x51); /* push rcx */
 698
 699                                /* mov rcx, src_reg */
 700                                EMIT_mov(BPF_REG_4, src_reg);
 701                        }
 702
 703                        /* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
 704                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 705                                EMIT1(add_1mod(0x48, dst_reg));
 706                        else if (is_ereg(dst_reg))
 707                                EMIT1(add_1mod(0x40, dst_reg));
 708
 709                        switch (BPF_OP(insn->code)) {
 710                        case BPF_LSH: b3 = 0xE0; break;
 711                        case BPF_RSH: b3 = 0xE8; break;
 712                        case BPF_ARSH: b3 = 0xF8; break;
 713                        }
 714                        EMIT2(0xD3, add_1reg(b3, dst_reg));
 715
 716                        if (src_reg != BPF_REG_4)
 717                                EMIT1(0x59); /* pop rcx */
 718
 719                        if (insn->dst_reg == BPF_REG_4)
 720                                /* mov dst_reg, r11 */
 721                                EMIT_mov(insn->dst_reg, AUX_REG);
 722                        break;
 723
 724                case BPF_ALU | BPF_END | BPF_FROM_BE:
 725                        switch (imm32) {
 726                        case 16:
 727                                /* emit 'ror %ax, 8' to swap lower 2 bytes */
 728                                EMIT1(0x66);
 729                                if (is_ereg(dst_reg))
 730                                        EMIT1(0x41);
 731                                EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
 732
 733                                /* emit 'movzwl eax, ax' */
 734                                if (is_ereg(dst_reg))
 735                                        EMIT3(0x45, 0x0F, 0xB7);
 736                                else
 737                                        EMIT2(0x0F, 0xB7);
 738                                EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
 739                                break;
 740                        case 32:
 741                                /* emit 'bswap eax' to swap lower 4 bytes */
 742                                if (is_ereg(dst_reg))
 743                                        EMIT2(0x41, 0x0F);
 744                                else
 745                                        EMIT1(0x0F);
 746                                EMIT1(add_1reg(0xC8, dst_reg));
 747                                break;
 748                        case 64:
 749                                /* emit 'bswap rax' to swap 8 bytes */
 750                                EMIT3(add_1mod(0x48, dst_reg), 0x0F,
 751                                      add_1reg(0xC8, dst_reg));
 752                                break;
 753                        }
 754                        break;
 755
 756                case BPF_ALU | BPF_END | BPF_FROM_LE:
 757                        switch (imm32) {
 758                        case 16:
 759                                /* emit 'movzwl eax, ax' to zero extend 16-bit
 760                                 * into 64 bit
 761                                 */
 762                                if (is_ereg(dst_reg))
 763                                        EMIT3(0x45, 0x0F, 0xB7);
 764                                else
 765                                        EMIT2(0x0F, 0xB7);
 766                                EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
 767                                break;
 768                        case 32:
 769                                /* emit 'mov eax, eax' to clear upper 32-bits */
 770                                if (is_ereg(dst_reg))
 771                                        EMIT1(0x45);
 772                                EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
 773                                break;
 774                        case 64:
 775                                /* nop */
 776                                break;
 777                        }
 778                        break;
 779
 780                        /* ST: *(u8*)(dst_reg + off) = imm */
 781                case BPF_ST | BPF_MEM | BPF_B:
 782                        if (is_ereg(dst_reg))
 783                                EMIT2(0x41, 0xC6);
 784                        else
 785                                EMIT1(0xC6);
 786                        goto st;
 787                case BPF_ST | BPF_MEM | BPF_H:
 788                        if (is_ereg(dst_reg))
 789                                EMIT3(0x66, 0x41, 0xC7);
 790                        else
 791                                EMIT2(0x66, 0xC7);
 792                        goto st;
 793                case BPF_ST | BPF_MEM | BPF_W:
 794                        if (is_ereg(dst_reg))
 795                                EMIT2(0x41, 0xC7);
 796                        else
 797                                EMIT1(0xC7);
 798                        goto st;
 799                case BPF_ST | BPF_MEM | BPF_DW:
 800                        EMIT2(add_1mod(0x48, dst_reg), 0xC7);
 801
 802st:                     if (is_imm8(insn->off))
 803                                EMIT2(add_1reg(0x40, dst_reg), insn->off);
 804                        else
 805                                EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
 806
 807                        EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
 808                        break;
 809
 810                        /* STX: *(u8*)(dst_reg + off) = src_reg */
 811                case BPF_STX | BPF_MEM | BPF_B:
 812                        /* emit 'mov byte ptr [rax + off], al' */
 813                        if (is_ereg(dst_reg) || is_ereg(src_reg) ||
 814                            /* have to add extra byte for x86 SIL, DIL regs */
 815                            src_reg == BPF_REG_1 || src_reg == BPF_REG_2)
 816                                EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
 817                        else
 818                                EMIT1(0x88);
 819                        goto stx;
 820                case BPF_STX | BPF_MEM | BPF_H:
 821                        if (is_ereg(dst_reg) || is_ereg(src_reg))
 822                                EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
 823                        else
 824                                EMIT2(0x66, 0x89);
 825                        goto stx;
 826                case BPF_STX | BPF_MEM | BPF_W:
 827                        if (is_ereg(dst_reg) || is_ereg(src_reg))
 828                                EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
 829                        else
 830                                EMIT1(0x89);
 831                        goto stx;
 832                case BPF_STX | BPF_MEM | BPF_DW:
 833                        EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
 834stx:                    if (is_imm8(insn->off))
 835                                EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
 836                        else
 837                                EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
 838                                            insn->off);
 839                        break;
 840
 841                        /* LDX: dst_reg = *(u8*)(src_reg + off) */
 842                case BPF_LDX | BPF_MEM | BPF_B:
 843                        /* emit 'movzx rax, byte ptr [rax + off]' */
 844                        EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
 845                        goto ldx;
 846                case BPF_LDX | BPF_MEM | BPF_H:
 847                        /* emit 'movzx rax, word ptr [rax + off]' */
 848                        EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
 849                        goto ldx;
 850                case BPF_LDX | BPF_MEM | BPF_W:
 851                        /* emit 'mov eax, dword ptr [rax+0x14]' */
 852                        if (is_ereg(dst_reg) || is_ereg(src_reg))
 853                                EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
 854                        else
 855                                EMIT1(0x8B);
 856                        goto ldx;
 857                case BPF_LDX | BPF_MEM | BPF_DW:
 858                        /* emit 'mov rax, qword ptr [rax+0x14]' */
 859                        EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
 860ldx:                    /* if insn->off == 0 we can save one extra byte, but
 861                         * special case of x86 r13 which always needs an offset
 862                         * is not worth the hassle
 863                         */
 864                        if (is_imm8(insn->off))
 865                                EMIT2(add_2reg(0x40, src_reg, dst_reg), insn->off);
 866                        else
 867                                EMIT1_off32(add_2reg(0x80, src_reg, dst_reg),
 868                                            insn->off);
 869                        break;
 870
 871                        /* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
 872                case BPF_STX | BPF_XADD | BPF_W:
 873                        /* emit 'lock add dword ptr [rax + off], eax' */
 874                        if (is_ereg(dst_reg) || is_ereg(src_reg))
 875                                EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01);
 876                        else
 877                                EMIT2(0xF0, 0x01);
 878                        goto xadd;
 879                case BPF_STX | BPF_XADD | BPF_DW:
 880                        EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01);
 881xadd:                   if (is_imm8(insn->off))
 882                                EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
 883                        else
 884                                EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
 885                                            insn->off);
 886                        break;
 887
 888                        /* call */
 889                case BPF_JMP | BPF_CALL:
 890                        func = (u8 *) __bpf_call_base + imm32;
 891                        jmp_offset = func - (image + addrs[i]);
 892                        if (seen_ld_abs) {
 893                                reload_skb_data = bpf_helper_changes_pkt_data(func);
 894                                if (reload_skb_data) {
 895                                        EMIT1(0x57); /* push %rdi */
 896                                        jmp_offset += 22; /* pop, mov, sub, mov */
 897                                } else {
 898                                        EMIT2(0x41, 0x52); /* push %r10 */
 899                                        EMIT2(0x41, 0x51); /* push %r9 */
 900                                        /* need to adjust jmp offset, since
 901                                         * pop %r9, pop %r10 take 4 bytes after call insn
 902                                         */
 903                                        jmp_offset += 4;
 904                                }
 905                        }
 906                        if (!imm32 || !is_simm32(jmp_offset)) {
 907                                pr_err("unsupported bpf func %d addr %p image %p\n",
 908                                       imm32, func, image);
 909                                return -EINVAL;
 910                        }
 911                        EMIT1_off32(0xE8, jmp_offset);
 912                        if (seen_ld_abs) {
 913                                if (reload_skb_data) {
 914                                        EMIT1(0x5F); /* pop %rdi */
 915                                        emit_load_skb_data_hlen(&prog);
 916                                } else {
 917                                        EMIT2(0x41, 0x59); /* pop %r9 */
 918                                        EMIT2(0x41, 0x5A); /* pop %r10 */
 919                                }
 920                        }
 921                        break;
 922
 923                case BPF_JMP | BPF_TAIL_CALL:
 924                        emit_bpf_tail_call(&prog);
 925                        break;
 926
 927                        /* cond jump */
 928                case BPF_JMP | BPF_JEQ | BPF_X:
 929                case BPF_JMP | BPF_JNE | BPF_X:
 930                case BPF_JMP | BPF_JGT | BPF_X:
 931                case BPF_JMP | BPF_JLT | BPF_X:
 932                case BPF_JMP | BPF_JGE | BPF_X:
 933                case BPF_JMP | BPF_JLE | BPF_X:
 934                case BPF_JMP | BPF_JSGT | BPF_X:
 935                case BPF_JMP | BPF_JSLT | BPF_X:
 936                case BPF_JMP | BPF_JSGE | BPF_X:
 937                case BPF_JMP | BPF_JSLE | BPF_X:
 938                        /* cmp dst_reg, src_reg */
 939                        EMIT3(add_2mod(0x48, dst_reg, src_reg), 0x39,
 940                              add_2reg(0xC0, dst_reg, src_reg));
 941                        goto emit_cond_jmp;
 942
 943                case BPF_JMP | BPF_JSET | BPF_X:
 944                        /* test dst_reg, src_reg */
 945                        EMIT3(add_2mod(0x48, dst_reg, src_reg), 0x85,
 946                              add_2reg(0xC0, dst_reg, src_reg));
 947                        goto emit_cond_jmp;
 948
 949                case BPF_JMP | BPF_JSET | BPF_K:
 950                        /* test dst_reg, imm32 */
 951                        EMIT1(add_1mod(0x48, dst_reg));
 952                        EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
 953                        goto emit_cond_jmp;
 954
 955                case BPF_JMP | BPF_JEQ | BPF_K:
 956                case BPF_JMP | BPF_JNE | BPF_K:
 957                case BPF_JMP | BPF_JGT | BPF_K:
 958                case BPF_JMP | BPF_JLT | BPF_K:
 959                case BPF_JMP | BPF_JGE | BPF_K:
 960                case BPF_JMP | BPF_JLE | BPF_K:
 961                case BPF_JMP | BPF_JSGT | BPF_K:
 962                case BPF_JMP | BPF_JSLT | BPF_K:
 963                case BPF_JMP | BPF_JSGE | BPF_K:
 964                case BPF_JMP | BPF_JSLE | BPF_K:
 965                        /* cmp dst_reg, imm8/32 */
 966                        EMIT1(add_1mod(0x48, dst_reg));
 967
 968                        if (is_imm8(imm32))
 969                                EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
 970                        else
 971                                EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
 972
 973emit_cond_jmp:          /* convert BPF opcode to x86 */
 974                        switch (BPF_OP(insn->code)) {
 975                        case BPF_JEQ:
 976                                jmp_cond = X86_JE;
 977                                break;
 978                        case BPF_JSET:
 979                        case BPF_JNE:
 980                                jmp_cond = X86_JNE;
 981                                break;
 982                        case BPF_JGT:
 983                                /* GT is unsigned '>', JA in x86 */
 984                                jmp_cond = X86_JA;
 985                                break;
 986                        case BPF_JLT:
 987                                /* LT is unsigned '<', JB in x86 */
 988                                jmp_cond = X86_JB;
 989                                break;
 990                        case BPF_JGE:
 991                                /* GE is unsigned '>=', JAE in x86 */
 992                                jmp_cond = X86_JAE;
 993                                break;
 994                        case BPF_JLE:
 995                                /* LE is unsigned '<=', JBE in x86 */
 996                                jmp_cond = X86_JBE;
 997                                break;
 998                        case BPF_JSGT:
 999                                /* signed '>', GT in x86 */
1000                                jmp_cond = X86_JG;
1001                                break;
1002                        case BPF_JSLT:
1003                                /* signed '<', LT in x86 */
1004                                jmp_cond = X86_JL;
1005                                break;
1006                        case BPF_JSGE:
1007                                /* signed '>=', GE in x86 */
1008                                jmp_cond = X86_JGE;
1009                                break;
1010                        case BPF_JSLE:
1011                                /* signed '<=', LE in x86 */
1012                                jmp_cond = X86_JLE;
1013                                break;
1014                        default: /* to silence gcc warning */
1015                                return -EFAULT;
1016                        }
1017                        jmp_offset = addrs[i + insn->off] - addrs[i];
1018                        if (is_imm8(jmp_offset)) {
1019                                EMIT2(jmp_cond, jmp_offset);
1020                        } else if (is_simm32(jmp_offset)) {
1021                                EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
1022                        } else {
1023                                pr_err("cond_jmp gen bug %llx\n", jmp_offset);
1024                                return -EFAULT;
1025                        }
1026
1027                        break;
1028
1029                case BPF_JMP | BPF_JA:
1030                        if (insn->off == -1)
1031                                /* -1 jmp instructions will always jump
1032                                 * backwards two bytes. Explicitly handling
1033                                 * this case avoids wasting too many passes
1034                                 * when there are long sequences of replaced
1035                                 * dead code.
1036                                 */
1037                                jmp_offset = -2;
1038                        else
1039                                jmp_offset = addrs[i + insn->off] - addrs[i];
1040
1041                        if (!jmp_offset)
1042                                /* optimize out nop jumps */
1043                                break;
1044emit_jmp:
1045                        if (is_imm8(jmp_offset)) {
1046                                EMIT2(0xEB, jmp_offset);
1047                        } else if (is_simm32(jmp_offset)) {
1048                                EMIT1_off32(0xE9, jmp_offset);
1049                        } else {
1050                                pr_err("jmp gen bug %llx\n", jmp_offset);
1051                                return -EFAULT;
1052                        }
1053                        break;
1054
1055                case BPF_LD | BPF_IND | BPF_W:
1056                        func = sk_load_word;
1057                        goto common_load;
1058                case BPF_LD | BPF_ABS | BPF_W:
1059                        func = CHOOSE_LOAD_FUNC(imm32, sk_load_word);
1060common_load:
1061                        ctx->seen_ld_abs = seen_ld_abs = true;
1062                        jmp_offset = func - (image + addrs[i]);
1063                        if (!func || !is_simm32(jmp_offset)) {
1064                                pr_err("unsupported bpf func %d addr %p image %p\n",
1065                                       imm32, func, image);
1066                                return -EINVAL;
1067                        }
1068                        if (BPF_MODE(insn->code) == BPF_ABS) {
1069                                /* mov %esi, imm32 */
1070                                EMIT1_off32(0xBE, imm32);
1071                        } else {
1072                                /* mov %rsi, src_reg */
1073                                EMIT_mov(BPF_REG_2, src_reg);
1074                                if (imm32) {
1075                                        if (is_imm8(imm32))
1076                                                /* add %esi, imm8 */
1077                                                EMIT3(0x83, 0xC6, imm32);
1078                                        else
1079                                                /* add %esi, imm32 */
1080                                                EMIT2_off32(0x81, 0xC6, imm32);
1081                                }
1082                        }
1083                        /* skb pointer is in R6 (%rbx), it will be copied into
1084                         * %rdi if skb_copy_bits() call is necessary.
1085                         * sk_load_* helpers also use %r10 and %r9d.
1086                         * See bpf_jit.S
1087                         */
1088                        if (seen_ax_reg)
1089                                /* r10 = skb->data, mov %r10, off32(%rbx) */
1090                                EMIT3_off32(0x4c, 0x8b, 0x93,
1091                                            offsetof(struct sk_buff, data));
1092                        EMIT1_off32(0xE8, jmp_offset); /* call */
1093                        break;
1094
1095                case BPF_LD | BPF_IND | BPF_H:
1096                        func = sk_load_half;
1097                        goto common_load;
1098                case BPF_LD | BPF_ABS | BPF_H:
1099                        func = CHOOSE_LOAD_FUNC(imm32, sk_load_half);
1100                        goto common_load;
1101                case BPF_LD | BPF_IND | BPF_B:
1102                        func = sk_load_byte;
1103                        goto common_load;
1104                case BPF_LD | BPF_ABS | BPF_B:
1105                        func = CHOOSE_LOAD_FUNC(imm32, sk_load_byte);
1106                        goto common_load;
1107
1108                case BPF_JMP | BPF_EXIT:
1109                        if (seen_exit) {
1110                                jmp_offset = ctx->cleanup_addr - addrs[i];
1111                                goto emit_jmp;
1112                        }
1113                        seen_exit = true;
1114                        /* update cleanup_addr */
1115                        ctx->cleanup_addr = proglen;
1116                        /* mov rbx, qword ptr [rbp+0] */
1117                        EMIT4(0x48, 0x8B, 0x5D, 0);
1118                        /* mov r13, qword ptr [rbp+8] */
1119                        EMIT4(0x4C, 0x8B, 0x6D, 8);
1120                        /* mov r14, qword ptr [rbp+16] */
1121                        EMIT4(0x4C, 0x8B, 0x75, 16);
1122                        /* mov r15, qword ptr [rbp+24] */
1123                        EMIT4(0x4C, 0x8B, 0x7D, 24);
1124
1125                        /* add rbp, AUX_STACK_SPACE */
1126                        EMIT4(0x48, 0x83, 0xC5, AUX_STACK_SPACE);
1127                        EMIT1(0xC9); /* leave */
1128                        EMIT1(0xC3); /* ret */
1129                        break;
1130
1131                default:
1132                        /* By design x64 JIT should support all BPF instructions
1133                         * This error will be seen if new instruction was added
1134                         * to interpreter, but not to JIT
1135                         * or if there is junk in bpf_prog
1136                         */
1137                        pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
1138                        return -EINVAL;
1139                }
1140
1141                ilen = prog - temp;
1142                if (ilen > BPF_MAX_INSN_SIZE) {
1143                        pr_err("bpf_jit: fatal insn size error\n");
1144                        return -EFAULT;
1145                }
1146
1147                if (image) {
1148                        if (unlikely(proglen + ilen > oldproglen)) {
1149                                pr_err("bpf_jit: fatal error\n");
1150                                return -EFAULT;
1151                        }
1152                        memcpy(image + proglen, temp, ilen);
1153                }
1154                proglen += ilen;
1155                addrs[i] = proglen;
1156                prog = temp;
1157        }
1158        return proglen;
1159}
1160
1161struct x64_jit_data {
1162        struct bpf_binary_header *header;
1163        int *addrs;
1164        u8 *image;
1165        int proglen;
1166        struct jit_context ctx;
1167};
1168
1169struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1170{
1171        struct bpf_binary_header *header = NULL;
1172        struct bpf_prog *tmp, *orig_prog = prog;
1173        struct x64_jit_data *jit_data;
1174        int proglen, oldproglen = 0;
1175        struct jit_context ctx = {};
1176        bool tmp_blinded = false;
1177        bool extra_pass = false;
1178        u8 *image = NULL;
1179        int *addrs;
1180        int pass;
1181        int i;
1182
1183        if (!prog->jit_requested)
1184                return orig_prog;
1185
1186        tmp = bpf_jit_blind_constants(prog);
1187        /* If blinding was requested and we failed during blinding,
1188         * we must fall back to the interpreter.
1189         */
1190        if (IS_ERR(tmp))
1191                return orig_prog;
1192        if (tmp != prog) {
1193                tmp_blinded = true;
1194                prog = tmp;
1195        }
1196
1197        jit_data = prog->aux->jit_data;
1198        if (!jit_data) {
1199                jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1200                if (!jit_data) {
1201                        prog = orig_prog;
1202                        goto out;
1203                }
1204                prog->aux->jit_data = jit_data;
1205        }
1206        addrs = jit_data->addrs;
1207        if (addrs) {
1208                ctx = jit_data->ctx;
1209                oldproglen = jit_data->proglen;
1210                image = jit_data->image;
1211                header = jit_data->header;
1212                extra_pass = true;
1213                goto skip_init_addrs;
1214        }
1215        addrs = kmalloc(prog->len * sizeof(*addrs), GFP_KERNEL);
1216        if (!addrs) {
1217                prog = orig_prog;
1218                goto out_addrs;
1219        }
1220
1221        /* Before first pass, make a rough estimation of addrs[]
1222         * each bpf instruction is translated to less than 64 bytes
1223         */
1224        for (proglen = 0, i = 0; i < prog->len; i++) {
1225                proglen += 64;
1226                addrs[i] = proglen;
1227        }
1228        ctx.cleanup_addr = proglen;
1229skip_init_addrs:
1230
1231        /* JITed image shrinks with every pass and the loop iterates
1232         * until the image stops shrinking. Very large bpf programs
1233         * may converge on the last pass. In such case do one more
1234         * pass to emit the final image
1235         */
1236        for (pass = 0; pass < 20 || image; pass++) {
1237                proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
1238                if (proglen <= 0) {
1239out_image:
1240                        image = NULL;
1241                        if (header)
1242                                bpf_jit_binary_free(header);
1243                        prog = orig_prog;
1244                        goto out_addrs;
1245                }
1246                if (image) {
1247                        if (proglen != oldproglen) {
1248                                pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
1249                                       proglen, oldproglen);
1250                                goto out_image;
1251                        }
1252                        break;
1253                }
1254                if (proglen == oldproglen) {
1255                        header = bpf_jit_binary_alloc(proglen, &image,
1256                                                      1, jit_fill_hole);
1257                        if (!header) {
1258                                prog = orig_prog;
1259                                goto out_addrs;
1260                        }
1261                }
1262                oldproglen = proglen;
1263                cond_resched();
1264        }
1265
1266        if (bpf_jit_enable > 1)
1267                bpf_jit_dump(prog->len, proglen, pass + 1, image);
1268
1269        if (image) {
1270                if (!prog->is_func || extra_pass) {
1271                        bpf_jit_binary_lock_ro(header);
1272                } else {
1273                        jit_data->addrs = addrs;
1274                        jit_data->ctx = ctx;
1275                        jit_data->proglen = proglen;
1276                        jit_data->image = image;
1277                        jit_data->header = header;
1278                }
1279                prog->bpf_func = (void *)image;
1280                prog->jited = 1;
1281                prog->jited_len = proglen;
1282        } else {
1283                prog = orig_prog;
1284        }
1285
1286        if (!image || !prog->is_func || extra_pass) {
1287out_addrs:
1288                kfree(addrs);
1289                kfree(jit_data);
1290                prog->aux->jit_data = NULL;
1291        }
1292out:
1293        if (tmp_blinded)
1294                bpf_jit_prog_release_other(prog, prog == orig_prog ?
1295                                           tmp : orig_prog);
1296        return prog;
1297}
1298