linux/arch/x86/net/bpf_jit_comp.c
<<
>>
Prefs
   1/* bpf_jit_comp.c : BPF JIT compiler
   2 *
   3 * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
   4 * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
   5 *
   6 * This program is free software; you can redistribute it and/or
   7 * modify it under the terms of the GNU General Public License
   8 * as published by the Free Software Foundation; version 2
   9 * of the License.
  10 */
  11#include <linux/netdevice.h>
  12#include <linux/filter.h>
  13#include <linux/if_vlan.h>
  14#include <asm/cacheflush.h>
  15#include <asm/set_memory.h>
  16#include <asm/nospec-branch.h>
  17#include <linux/bpf.h>
  18
  19/*
  20 * assembly code in arch/x86/net/bpf_jit.S
  21 */
  22extern u8 sk_load_word[], sk_load_half[], sk_load_byte[];
  23extern u8 sk_load_word_positive_offset[], sk_load_half_positive_offset[];
  24extern u8 sk_load_byte_positive_offset[];
  25extern u8 sk_load_word_negative_offset[], sk_load_half_negative_offset[];
  26extern u8 sk_load_byte_negative_offset[];
  27
  28static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
  29{
  30        if (len == 1)
  31                *ptr = bytes;
  32        else if (len == 2)
  33                *(u16 *)ptr = bytes;
  34        else {
  35                *(u32 *)ptr = bytes;
  36                barrier();
  37        }
  38        return ptr + len;
  39}
  40
  41#define EMIT(bytes, len) \
  42        do { prog = emit_code(prog, bytes, len); cnt += len; } while (0)
  43
  44#define EMIT1(b1)               EMIT(b1, 1)
  45#define EMIT2(b1, b2)           EMIT((b1) + ((b2) << 8), 2)
  46#define EMIT3(b1, b2, b3)       EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
  47#define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
  48#define EMIT1_off32(b1, off) \
  49        do {EMIT1(b1); EMIT(off, 4); } while (0)
  50#define EMIT2_off32(b1, b2, off) \
  51        do {EMIT2(b1, b2); EMIT(off, 4); } while (0)
  52#define EMIT3_off32(b1, b2, b3, off) \
  53        do {EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
  54#define EMIT4_off32(b1, b2, b3, b4, off) \
  55        do {EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
  56
  57static bool is_imm8(int value)
  58{
  59        return value <= 127 && value >= -128;
  60}
  61
  62static bool is_simm32(s64 value)
  63{
  64        return value == (s64) (s32) value;
  65}
  66
  67/* mov dst, src */
  68#define EMIT_mov(DST, SRC) \
  69        do {if (DST != SRC) \
  70                EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
  71        } while (0)
  72
  73static int bpf_size_to_x86_bytes(int bpf_size)
  74{
  75        if (bpf_size == BPF_W)
  76                return 4;
  77        else if (bpf_size == BPF_H)
  78                return 2;
  79        else if (bpf_size == BPF_B)
  80                return 1;
  81        else if (bpf_size == BPF_DW)
  82                return 4; /* imm32 */
  83        else
  84                return 0;
  85}
  86
  87/* list of x86 cond jumps opcodes (. + s8)
  88 * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
  89 */
  90#define X86_JB  0x72
  91#define X86_JAE 0x73
  92#define X86_JE  0x74
  93#define X86_JNE 0x75
  94#define X86_JBE 0x76
  95#define X86_JA  0x77
  96#define X86_JL  0x7C
  97#define X86_JGE 0x7D
  98#define X86_JLE 0x7E
  99#define X86_JG  0x7F
 100
 101static void bpf_flush_icache(void *start, void *end)
 102{
 103        mm_segment_t old_fs = get_fs();
 104
 105        set_fs(KERNEL_DS);
 106        smp_wmb();
 107        flush_icache_range((unsigned long)start, (unsigned long)end);
 108        set_fs(old_fs);
 109}
 110
 111#define CHOOSE_LOAD_FUNC(K, func) \
 112        ((int)K < 0 ? ((int)K >= SKF_LL_OFF ? func##_negative_offset : func) : func##_positive_offset)
 113
 114/* pick a register outside of BPF range for JIT internal work */
 115#define AUX_REG (MAX_BPF_JIT_REG + 1)
 116
 117/* The following table maps BPF registers to x64 registers.
 118 *
 119 * x64 register r12 is unused, since if used as base address
 120 * register in load/store instructions, it always needs an
 121 * extra byte of encoding and is callee saved.
 122 *
 123 *  r9 caches skb->len - skb->data_len
 124 * r10 caches skb->data, and used for blinding (if enabled)
 125 */
 126static const int reg2hex[] = {
 127        [BPF_REG_0] = 0,  /* rax */
 128        [BPF_REG_1] = 7,  /* rdi */
 129        [BPF_REG_2] = 6,  /* rsi */
 130        [BPF_REG_3] = 2,  /* rdx */
 131        [BPF_REG_4] = 1,  /* rcx */
 132        [BPF_REG_5] = 0,  /* r8 */
 133        [BPF_REG_6] = 3,  /* rbx callee saved */
 134        [BPF_REG_7] = 5,  /* r13 callee saved */
 135        [BPF_REG_8] = 6,  /* r14 callee saved */
 136        [BPF_REG_9] = 7,  /* r15 callee saved */
 137        [BPF_REG_FP] = 5, /* rbp readonly */
 138        [BPF_REG_AX] = 2, /* r10 temp register */
 139        [AUX_REG] = 3,    /* r11 temp register */
 140};
 141
 142/* is_ereg() == true if BPF register 'reg' maps to x64 r8..r15
 143 * which need extra byte of encoding.
 144 * rax,rcx,...,rbp have simpler encoding
 145 */
 146static bool is_ereg(u32 reg)
 147{
 148        return (1 << reg) & (BIT(BPF_REG_5) |
 149                             BIT(AUX_REG) |
 150                             BIT(BPF_REG_7) |
 151                             BIT(BPF_REG_8) |
 152                             BIT(BPF_REG_9) |
 153                             BIT(BPF_REG_AX));
 154}
 155
 156static bool is_axreg(u32 reg)
 157{
 158        return reg == BPF_REG_0;
 159}
 160
 161/* add modifiers if 'reg' maps to x64 registers r8..r15 */
 162static u8 add_1mod(u8 byte, u32 reg)
 163{
 164        if (is_ereg(reg))
 165                byte |= 1;
 166        return byte;
 167}
 168
 169static u8 add_2mod(u8 byte, u32 r1, u32 r2)
 170{
 171        if (is_ereg(r1))
 172                byte |= 1;
 173        if (is_ereg(r2))
 174                byte |= 4;
 175        return byte;
 176}
 177
 178/* encode 'dst_reg' register into x64 opcode 'byte' */
 179static u8 add_1reg(u8 byte, u32 dst_reg)
 180{
 181        return byte + reg2hex[dst_reg];
 182}
 183
 184/* encode 'dst_reg' and 'src_reg' registers into x64 opcode 'byte' */
 185static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
 186{
 187        return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
 188}
 189
 190static void jit_fill_hole(void *area, unsigned int size)
 191{
 192        /* fill whole space with int3 instructions */
 193        memset(area, 0xcc, size);
 194}
 195
 196struct jit_context {
 197        int cleanup_addr; /* epilogue code offset */
 198        bool seen_ld_abs;
 199        bool seen_ax_reg;
 200};
 201
 202/* maximum number of bytes emitted while JITing one eBPF insn */
 203#define BPF_MAX_INSN_SIZE       128
 204#define BPF_INSN_SAFETY         64
 205
 206#define AUX_STACK_SPACE \
 207        (32 /* space for rbx, r13, r14, r15 */ + \
 208         8 /* space for skb_copy_bits() buffer */)
 209
 210#define PROLOGUE_SIZE 37
 211
 212/* emit x64 prologue code for BPF program and check it's size.
 213 * bpf_tail_call helper will skip it while jumping into another program
 214 */
 215static void emit_prologue(u8 **pprog, u32 stack_depth)
 216{
 217        u8 *prog = *pprog;
 218        int cnt = 0;
 219
 220        EMIT1(0x55); /* push rbp */
 221        EMIT3(0x48, 0x89, 0xE5); /* mov rbp,rsp */
 222
 223        /* sub rsp, rounded_stack_depth + AUX_STACK_SPACE */
 224        EMIT3_off32(0x48, 0x81, 0xEC,
 225                    round_up(stack_depth, 8) + AUX_STACK_SPACE);
 226
 227        /* sub rbp, AUX_STACK_SPACE */
 228        EMIT4(0x48, 0x83, 0xED, AUX_STACK_SPACE);
 229
 230        /* all classic BPF filters use R6(rbx) save it */
 231
 232        /* mov qword ptr [rbp+0],rbx */
 233        EMIT4(0x48, 0x89, 0x5D, 0);
 234
 235        /* bpf_convert_filter() maps classic BPF register X to R7 and uses R8
 236         * as temporary, so all tcpdump filters need to spill/fill R7(r13) and
 237         * R8(r14). R9(r15) spill could be made conditional, but there is only
 238         * one 'bpf_error' return path out of helper functions inside bpf_jit.S
 239         * The overhead of extra spill is negligible for any filter other
 240         * than synthetic ones. Therefore not worth adding complexity.
 241         */
 242
 243        /* mov qword ptr [rbp+8],r13 */
 244        EMIT4(0x4C, 0x89, 0x6D, 8);
 245        /* mov qword ptr [rbp+16],r14 */
 246        EMIT4(0x4C, 0x89, 0x75, 16);
 247        /* mov qword ptr [rbp+24],r15 */
 248        EMIT4(0x4C, 0x89, 0x7D, 24);
 249
 250        /* Clear the tail call counter (tail_call_cnt): for eBPF tail calls
 251         * we need to reset the counter to 0. It's done in two instructions,
 252         * resetting rax register to 0 (xor on eax gets 0 extended), and
 253         * moving it to the counter location.
 254         */
 255
 256        /* xor eax, eax */
 257        EMIT2(0x31, 0xc0);
 258        /* mov qword ptr [rbp+32], rax */
 259        EMIT4(0x48, 0x89, 0x45, 32);
 260
 261        BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
 262        *pprog = prog;
 263}
 264
 265/* generate the following code:
 266 * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
 267 *   if (index >= array->map.max_entries)
 268 *     goto out;
 269 *   if (++tail_call_cnt > MAX_TAIL_CALL_CNT)
 270 *     goto out;
 271 *   prog = array->ptrs[index];
 272 *   if (prog == NULL)
 273 *     goto out;
 274 *   goto *(prog->bpf_func + prologue_size);
 275 * out:
 276 */
 277static void emit_bpf_tail_call(u8 **pprog)
 278{
 279        u8 *prog = *pprog;
 280        int label1, label2, label3;
 281        int cnt = 0;
 282
 283        /* rdi - pointer to ctx
 284         * rsi - pointer to bpf_array
 285         * rdx - index in bpf_array
 286         */
 287
 288        /* if (index >= array->map.max_entries)
 289         *   goto out;
 290         */
 291        EMIT2(0x89, 0xD2);                        /* mov edx, edx */
 292        EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
 293              offsetof(struct bpf_array, map.max_entries));
 294#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* number of bytes to jump */
 295        EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
 296        label1 = cnt;
 297
 298        /* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
 299         *   goto out;
 300         */
 301        EMIT2_off32(0x8B, 0x85, 36);              /* mov eax, dword ptr [rbp + 36] */
 302        EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
 303#define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE)
 304        EMIT2(X86_JA, OFFSET2);                   /* ja out */
 305        label2 = cnt;
 306        EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
 307        EMIT2_off32(0x89, 0x85, 36);              /* mov dword ptr [rbp + 36], eax */
 308
 309        /* prog = array->ptrs[index]; */
 310        EMIT4_off32(0x48, 0x8B, 0x84, 0xD6,       /* mov rax, [rsi + rdx * 8 + offsetof(...)] */
 311                    offsetof(struct bpf_array, ptrs));
 312
 313        /* if (prog == NULL)
 314         *   goto out;
 315         */
 316        EMIT3(0x48, 0x85, 0xC0);                  /* test rax,rax */
 317#define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
 318        EMIT2(X86_JE, OFFSET3);                   /* je out */
 319        label3 = cnt;
 320
 321        /* goto *(prog->bpf_func + prologue_size); */
 322        EMIT4(0x48, 0x8B, 0x40,                   /* mov rax, qword ptr [rax + 32] */
 323              offsetof(struct bpf_prog, bpf_func));
 324        EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE);   /* add rax, prologue_size */
 325
 326        /* now we're ready to jump into next BPF program
 327         * rdi == ctx (1st arg)
 328         * rax == prog->bpf_func + prologue_size
 329         */
 330        RETPOLINE_RAX_BPF_JIT();
 331
 332        /* out: */
 333        BUILD_BUG_ON(cnt - label1 != OFFSET1);
 334        BUILD_BUG_ON(cnt - label2 != OFFSET2);
 335        BUILD_BUG_ON(cnt - label3 != OFFSET3);
 336        *pprog = prog;
 337}
 338
 339
 340static void emit_load_skb_data_hlen(u8 **pprog)
 341{
 342        u8 *prog = *pprog;
 343        int cnt = 0;
 344
 345        /* r9d = skb->len - skb->data_len (headlen)
 346         * r10 = skb->data
 347         */
 348        /* mov %r9d, off32(%rdi) */
 349        EMIT3_off32(0x44, 0x8b, 0x8f, offsetof(struct sk_buff, len));
 350
 351        /* sub %r9d, off32(%rdi) */
 352        EMIT3_off32(0x44, 0x2b, 0x8f, offsetof(struct sk_buff, data_len));
 353
 354        /* mov %r10, off32(%rdi) */
 355        EMIT3_off32(0x4c, 0x8b, 0x97, offsetof(struct sk_buff, data));
 356        *pprog = prog;
 357}
 358
 359static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 360                  int oldproglen, struct jit_context *ctx)
 361{
 362        struct bpf_insn *insn = bpf_prog->insnsi;
 363        int insn_cnt = bpf_prog->len;
 364        bool seen_ld_abs = ctx->seen_ld_abs | (oldproglen == 0);
 365        bool seen_ax_reg = ctx->seen_ax_reg | (oldproglen == 0);
 366        bool seen_exit = false;
 367        u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
 368        int i, cnt = 0;
 369        int proglen = 0;
 370        u8 *prog = temp;
 371
 372        emit_prologue(&prog, bpf_prog->aux->stack_depth);
 373
 374        if (seen_ld_abs)
 375                emit_load_skb_data_hlen(&prog);
 376
 377        for (i = 0; i < insn_cnt; i++, insn++) {
 378                const s32 imm32 = insn->imm;
 379                u32 dst_reg = insn->dst_reg;
 380                u32 src_reg = insn->src_reg;
 381                u8 b1 = 0, b2 = 0, b3 = 0;
 382                s64 jmp_offset;
 383                u8 jmp_cond;
 384                bool reload_skb_data;
 385                int ilen;
 386                u8 *func;
 387
 388                if (dst_reg == BPF_REG_AX || src_reg == BPF_REG_AX)
 389                        ctx->seen_ax_reg = seen_ax_reg = true;
 390
 391                switch (insn->code) {
 392                        /* ALU */
 393                case BPF_ALU | BPF_ADD | BPF_X:
 394                case BPF_ALU | BPF_SUB | BPF_X:
 395                case BPF_ALU | BPF_AND | BPF_X:
 396                case BPF_ALU | BPF_OR | BPF_X:
 397                case BPF_ALU | BPF_XOR | BPF_X:
 398                case BPF_ALU64 | BPF_ADD | BPF_X:
 399                case BPF_ALU64 | BPF_SUB | BPF_X:
 400                case BPF_ALU64 | BPF_AND | BPF_X:
 401                case BPF_ALU64 | BPF_OR | BPF_X:
 402                case BPF_ALU64 | BPF_XOR | BPF_X:
 403                        switch (BPF_OP(insn->code)) {
 404                        case BPF_ADD: b2 = 0x01; break;
 405                        case BPF_SUB: b2 = 0x29; break;
 406                        case BPF_AND: b2 = 0x21; break;
 407                        case BPF_OR: b2 = 0x09; break;
 408                        case BPF_XOR: b2 = 0x31; break;
 409                        }
 410                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 411                                EMIT1(add_2mod(0x48, dst_reg, src_reg));
 412                        else if (is_ereg(dst_reg) || is_ereg(src_reg))
 413                                EMIT1(add_2mod(0x40, dst_reg, src_reg));
 414                        EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
 415                        break;
 416
 417                        /* mov dst, src */
 418                case BPF_ALU64 | BPF_MOV | BPF_X:
 419                        EMIT_mov(dst_reg, src_reg);
 420                        break;
 421
 422                        /* mov32 dst, src */
 423                case BPF_ALU | BPF_MOV | BPF_X:
 424                        if (is_ereg(dst_reg) || is_ereg(src_reg))
 425                                EMIT1(add_2mod(0x40, dst_reg, src_reg));
 426                        EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
 427                        break;
 428
 429                        /* neg dst */
 430                case BPF_ALU | BPF_NEG:
 431                case BPF_ALU64 | BPF_NEG:
 432                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 433                                EMIT1(add_1mod(0x48, dst_reg));
 434                        else if (is_ereg(dst_reg))
 435                                EMIT1(add_1mod(0x40, dst_reg));
 436                        EMIT2(0xF7, add_1reg(0xD8, dst_reg));
 437                        break;
 438
 439                case BPF_ALU | BPF_ADD | BPF_K:
 440                case BPF_ALU | BPF_SUB | BPF_K:
 441                case BPF_ALU | BPF_AND | BPF_K:
 442                case BPF_ALU | BPF_OR | BPF_K:
 443                case BPF_ALU | BPF_XOR | BPF_K:
 444                case BPF_ALU64 | BPF_ADD | BPF_K:
 445                case BPF_ALU64 | BPF_SUB | BPF_K:
 446                case BPF_ALU64 | BPF_AND | BPF_K:
 447                case BPF_ALU64 | BPF_OR | BPF_K:
 448                case BPF_ALU64 | BPF_XOR | BPF_K:
 449                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 450                                EMIT1(add_1mod(0x48, dst_reg));
 451                        else if (is_ereg(dst_reg))
 452                                EMIT1(add_1mod(0x40, dst_reg));
 453
 454                        /* b3 holds 'normal' opcode, b2 short form only valid
 455                         * in case dst is eax/rax.
 456                         */
 457                        switch (BPF_OP(insn->code)) {
 458                        case BPF_ADD:
 459                                b3 = 0xC0;
 460                                b2 = 0x05;
 461                                break;
 462                        case BPF_SUB:
 463                                b3 = 0xE8;
 464                                b2 = 0x2D;
 465                                break;
 466                        case BPF_AND:
 467                                b3 = 0xE0;
 468                                b2 = 0x25;
 469                                break;
 470                        case BPF_OR:
 471                                b3 = 0xC8;
 472                                b2 = 0x0D;
 473                                break;
 474                        case BPF_XOR:
 475                                b3 = 0xF0;
 476                                b2 = 0x35;
 477                                break;
 478                        }
 479
 480                        if (is_imm8(imm32))
 481                                EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
 482                        else if (is_axreg(dst_reg))
 483                                EMIT1_off32(b2, imm32);
 484                        else
 485                                EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
 486                        break;
 487
 488                case BPF_ALU64 | BPF_MOV | BPF_K:
 489                        /* optimization: if imm32 is positive,
 490                         * use 'mov eax, imm32' (which zero-extends imm32)
 491                         * to save 2 bytes
 492                         */
 493                        if (imm32 < 0) {
 494                                /* 'mov rax, imm32' sign extends imm32 */
 495                                b1 = add_1mod(0x48, dst_reg);
 496                                b2 = 0xC7;
 497                                b3 = 0xC0;
 498                                EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
 499                                break;
 500                        }
 501
 502                case BPF_ALU | BPF_MOV | BPF_K:
 503                        /* optimization: if imm32 is zero, use 'xor <dst>,<dst>'
 504                         * to save 3 bytes.
 505                         */
 506                        if (imm32 == 0) {
 507                                if (is_ereg(dst_reg))
 508                                        EMIT1(add_2mod(0x40, dst_reg, dst_reg));
 509                                b2 = 0x31; /* xor */
 510                                b3 = 0xC0;
 511                                EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
 512                                break;
 513                        }
 514
 515                        /* mov %eax, imm32 */
 516                        if (is_ereg(dst_reg))
 517                                EMIT1(add_1mod(0x40, dst_reg));
 518                        EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
 519                        break;
 520
 521                case BPF_LD | BPF_IMM | BPF_DW:
 522                        /* optimization: if imm64 is zero, use 'xor <dst>,<dst>'
 523                         * to save 7 bytes.
 524                         */
 525                        if (insn[0].imm == 0 && insn[1].imm == 0) {
 526                                b1 = add_2mod(0x48, dst_reg, dst_reg);
 527                                b2 = 0x31; /* xor */
 528                                b3 = 0xC0;
 529                                EMIT3(b1, b2, add_2reg(b3, dst_reg, dst_reg));
 530
 531                                insn++;
 532                                i++;
 533                                break;
 534                        }
 535
 536                        /* movabsq %rax, imm64 */
 537                        EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
 538                        EMIT(insn[0].imm, 4);
 539                        EMIT(insn[1].imm, 4);
 540
 541                        insn++;
 542                        i++;
 543                        break;
 544
 545                        /* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
 546                case BPF_ALU | BPF_MOD | BPF_X:
 547                case BPF_ALU | BPF_DIV | BPF_X:
 548                case BPF_ALU | BPF_MOD | BPF_K:
 549                case BPF_ALU | BPF_DIV | BPF_K:
 550                case BPF_ALU64 | BPF_MOD | BPF_X:
 551                case BPF_ALU64 | BPF_DIV | BPF_X:
 552                case BPF_ALU64 | BPF_MOD | BPF_K:
 553                case BPF_ALU64 | BPF_DIV | BPF_K:
 554                        EMIT1(0x50); /* push rax */
 555                        EMIT1(0x52); /* push rdx */
 556
 557                        if (BPF_SRC(insn->code) == BPF_X)
 558                                /* mov r11, src_reg */
 559                                EMIT_mov(AUX_REG, src_reg);
 560                        else
 561                                /* mov r11, imm32 */
 562                                EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
 563
 564                        /* mov rax, dst_reg */
 565                        EMIT_mov(BPF_REG_0, dst_reg);
 566
 567                        /* xor edx, edx
 568                         * equivalent to 'xor rdx, rdx', but one byte less
 569                         */
 570                        EMIT2(0x31, 0xd2);
 571
 572                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 573                                /* div r11 */
 574                                EMIT3(0x49, 0xF7, 0xF3);
 575                        else
 576                                /* div r11d */
 577                                EMIT3(0x41, 0xF7, 0xF3);
 578
 579                        if (BPF_OP(insn->code) == BPF_MOD)
 580                                /* mov r11, rdx */
 581                                EMIT3(0x49, 0x89, 0xD3);
 582                        else
 583                                /* mov r11, rax */
 584                                EMIT3(0x49, 0x89, 0xC3);
 585
 586                        EMIT1(0x5A); /* pop rdx */
 587                        EMIT1(0x58); /* pop rax */
 588
 589                        /* mov dst_reg, r11 */
 590                        EMIT_mov(dst_reg, AUX_REG);
 591                        break;
 592
 593                case BPF_ALU | BPF_MUL | BPF_K:
 594                case BPF_ALU | BPF_MUL | BPF_X:
 595                case BPF_ALU64 | BPF_MUL | BPF_K:
 596                case BPF_ALU64 | BPF_MUL | BPF_X:
 597                        EMIT1(0x50); /* push rax */
 598                        EMIT1(0x52); /* push rdx */
 599
 600                        /* mov r11, dst_reg */
 601                        EMIT_mov(AUX_REG, dst_reg);
 602
 603                        if (BPF_SRC(insn->code) == BPF_X)
 604                                /* mov rax, src_reg */
 605                                EMIT_mov(BPF_REG_0, src_reg);
 606                        else
 607                                /* mov rax, imm32 */
 608                                EMIT3_off32(0x48, 0xC7, 0xC0, imm32);
 609
 610                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 611                                EMIT1(add_1mod(0x48, AUX_REG));
 612                        else if (is_ereg(AUX_REG))
 613                                EMIT1(add_1mod(0x40, AUX_REG));
 614                        /* mul(q) r11 */
 615                        EMIT2(0xF7, add_1reg(0xE0, AUX_REG));
 616
 617                        /* mov r11, rax */
 618                        EMIT_mov(AUX_REG, BPF_REG_0);
 619
 620                        EMIT1(0x5A); /* pop rdx */
 621                        EMIT1(0x58); /* pop rax */
 622
 623                        /* mov dst_reg, r11 */
 624                        EMIT_mov(dst_reg, AUX_REG);
 625                        break;
 626
 627                        /* shifts */
 628                case BPF_ALU | BPF_LSH | BPF_K:
 629                case BPF_ALU | BPF_RSH | BPF_K:
 630                case BPF_ALU | BPF_ARSH | BPF_K:
 631                case BPF_ALU64 | BPF_LSH | BPF_K:
 632                case BPF_ALU64 | BPF_RSH | BPF_K:
 633                case BPF_ALU64 | BPF_ARSH | BPF_K:
 634                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 635                                EMIT1(add_1mod(0x48, dst_reg));
 636                        else if (is_ereg(dst_reg))
 637                                EMIT1(add_1mod(0x40, dst_reg));
 638
 639                        switch (BPF_OP(insn->code)) {
 640                        case BPF_LSH: b3 = 0xE0; break;
 641                        case BPF_RSH: b3 = 0xE8; break;
 642                        case BPF_ARSH: b3 = 0xF8; break;
 643                        }
 644                        EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
 645                        break;
 646
 647                case BPF_ALU | BPF_LSH | BPF_X:
 648                case BPF_ALU | BPF_RSH | BPF_X:
 649                case BPF_ALU | BPF_ARSH | BPF_X:
 650                case BPF_ALU64 | BPF_LSH | BPF_X:
 651                case BPF_ALU64 | BPF_RSH | BPF_X:
 652                case BPF_ALU64 | BPF_ARSH | BPF_X:
 653
 654                        /* check for bad case when dst_reg == rcx */
 655                        if (dst_reg == BPF_REG_4) {
 656                                /* mov r11, dst_reg */
 657                                EMIT_mov(AUX_REG, dst_reg);
 658                                dst_reg = AUX_REG;
 659                        }
 660
 661                        if (src_reg != BPF_REG_4) { /* common case */
 662                                EMIT1(0x51); /* push rcx */
 663
 664                                /* mov rcx, src_reg */
 665                                EMIT_mov(BPF_REG_4, src_reg);
 666                        }
 667
 668                        /* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
 669                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 670                                EMIT1(add_1mod(0x48, dst_reg));
 671                        else if (is_ereg(dst_reg))
 672                                EMIT1(add_1mod(0x40, dst_reg));
 673
 674                        switch (BPF_OP(insn->code)) {
 675                        case BPF_LSH: b3 = 0xE0; break;
 676                        case BPF_RSH: b3 = 0xE8; break;
 677                        case BPF_ARSH: b3 = 0xF8; break;
 678                        }
 679                        EMIT2(0xD3, add_1reg(b3, dst_reg));
 680
 681                        if (src_reg != BPF_REG_4)
 682                                EMIT1(0x59); /* pop rcx */
 683
 684                        if (insn->dst_reg == BPF_REG_4)
 685                                /* mov dst_reg, r11 */
 686                                EMIT_mov(insn->dst_reg, AUX_REG);
 687                        break;
 688
 689                case BPF_ALU | BPF_END | BPF_FROM_BE:
 690                        switch (imm32) {
 691                        case 16:
 692                                /* emit 'ror %ax, 8' to swap lower 2 bytes */
 693                                EMIT1(0x66);
 694                                if (is_ereg(dst_reg))
 695                                        EMIT1(0x41);
 696                                EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
 697
 698                                /* emit 'movzwl eax, ax' */
 699                                if (is_ereg(dst_reg))
 700                                        EMIT3(0x45, 0x0F, 0xB7);
 701                                else
 702                                        EMIT2(0x0F, 0xB7);
 703                                EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
 704                                break;
 705                        case 32:
 706                                /* emit 'bswap eax' to swap lower 4 bytes */
 707                                if (is_ereg(dst_reg))
 708                                        EMIT2(0x41, 0x0F);
 709                                else
 710                                        EMIT1(0x0F);
 711                                EMIT1(add_1reg(0xC8, dst_reg));
 712                                break;
 713                        case 64:
 714                                /* emit 'bswap rax' to swap 8 bytes */
 715                                EMIT3(add_1mod(0x48, dst_reg), 0x0F,
 716                                      add_1reg(0xC8, dst_reg));
 717                                break;
 718                        }
 719                        break;
 720
 721                case BPF_ALU | BPF_END | BPF_FROM_LE:
 722                        switch (imm32) {
 723                        case 16:
 724                                /* emit 'movzwl eax, ax' to zero extend 16-bit
 725                                 * into 64 bit
 726                                 */
 727                                if (is_ereg(dst_reg))
 728                                        EMIT3(0x45, 0x0F, 0xB7);
 729                                else
 730                                        EMIT2(0x0F, 0xB7);
 731                                EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
 732                                break;
 733                        case 32:
 734                                /* emit 'mov eax, eax' to clear upper 32-bits */
 735                                if (is_ereg(dst_reg))
 736                                        EMIT1(0x45);
 737                                EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
 738                                break;
 739                        case 64:
 740                                /* nop */
 741                                break;
 742                        }
 743                        break;
 744
 745                        /* ST: *(u8*)(dst_reg + off) = imm */
 746                case BPF_ST | BPF_MEM | BPF_B:
 747                        if (is_ereg(dst_reg))
 748                                EMIT2(0x41, 0xC6);
 749                        else
 750                                EMIT1(0xC6);
 751                        goto st;
 752                case BPF_ST | BPF_MEM | BPF_H:
 753                        if (is_ereg(dst_reg))
 754                                EMIT3(0x66, 0x41, 0xC7);
 755                        else
 756                                EMIT2(0x66, 0xC7);
 757                        goto st;
 758                case BPF_ST | BPF_MEM | BPF_W:
 759                        if (is_ereg(dst_reg))
 760                                EMIT2(0x41, 0xC7);
 761                        else
 762                                EMIT1(0xC7);
 763                        goto st;
 764                case BPF_ST | BPF_MEM | BPF_DW:
 765                        EMIT2(add_1mod(0x48, dst_reg), 0xC7);
 766
 767st:                     if (is_imm8(insn->off))
 768                                EMIT2(add_1reg(0x40, dst_reg), insn->off);
 769                        else
 770                                EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
 771
 772                        EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
 773                        break;
 774
 775                        /* STX: *(u8*)(dst_reg + off) = src_reg */
 776                case BPF_STX | BPF_MEM | BPF_B:
 777                        /* emit 'mov byte ptr [rax + off], al' */
 778                        if (is_ereg(dst_reg) || is_ereg(src_reg) ||
 779                            /* have to add extra byte for x86 SIL, DIL regs */
 780                            src_reg == BPF_REG_1 || src_reg == BPF_REG_2)
 781                                EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
 782                        else
 783                                EMIT1(0x88);
 784                        goto stx;
 785                case BPF_STX | BPF_MEM | BPF_H:
 786                        if (is_ereg(dst_reg) || is_ereg(src_reg))
 787                                EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
 788                        else
 789                                EMIT2(0x66, 0x89);
 790                        goto stx;
 791                case BPF_STX | BPF_MEM | BPF_W:
 792                        if (is_ereg(dst_reg) || is_ereg(src_reg))
 793                                EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
 794                        else
 795                                EMIT1(0x89);
 796                        goto stx;
 797                case BPF_STX | BPF_MEM | BPF_DW:
 798                        EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
 799stx:                    if (is_imm8(insn->off))
 800                                EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
 801                        else
 802                                EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
 803                                            insn->off);
 804                        break;
 805
 806                        /* LDX: dst_reg = *(u8*)(src_reg + off) */
 807                case BPF_LDX | BPF_MEM | BPF_B:
 808                        /* emit 'movzx rax, byte ptr [rax + off]' */
 809                        EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
 810                        goto ldx;
 811                case BPF_LDX | BPF_MEM | BPF_H:
 812                        /* emit 'movzx rax, word ptr [rax + off]' */
 813                        EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
 814                        goto ldx;
 815                case BPF_LDX | BPF_MEM | BPF_W:
 816                        /* emit 'mov eax, dword ptr [rax+0x14]' */
 817                        if (is_ereg(dst_reg) || is_ereg(src_reg))
 818                                EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
 819                        else
 820                                EMIT1(0x8B);
 821                        goto ldx;
 822                case BPF_LDX | BPF_MEM | BPF_DW:
 823                        /* emit 'mov rax, qword ptr [rax+0x14]' */
 824                        EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
 825ldx:                    /* if insn->off == 0 we can save one extra byte, but
 826                         * special case of x86 r13 which always needs an offset
 827                         * is not worth the hassle
 828                         */
 829                        if (is_imm8(insn->off))
 830                                EMIT2(add_2reg(0x40, src_reg, dst_reg), insn->off);
 831                        else
 832                                EMIT1_off32(add_2reg(0x80, src_reg, dst_reg),
 833                                            insn->off);
 834                        break;
 835
 836                        /* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
 837                case BPF_STX | BPF_XADD | BPF_W:
 838                        /* emit 'lock add dword ptr [rax + off], eax' */
 839                        if (is_ereg(dst_reg) || is_ereg(src_reg))
 840                                EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01);
 841                        else
 842                                EMIT2(0xF0, 0x01);
 843                        goto xadd;
 844                case BPF_STX | BPF_XADD | BPF_DW:
 845                        EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01);
 846xadd:                   if (is_imm8(insn->off))
 847                                EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
 848                        else
 849                                EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
 850                                            insn->off);
 851                        break;
 852
 853                        /* call */
 854                case BPF_JMP | BPF_CALL:
 855                        func = (u8 *) __bpf_call_base + imm32;
 856                        jmp_offset = func - (image + addrs[i]);
 857                        if (seen_ld_abs) {
 858                                reload_skb_data = bpf_helper_changes_pkt_data(func);
 859                                if (reload_skb_data) {
 860                                        EMIT1(0x57); /* push %rdi */
 861                                        jmp_offset += 22; /* pop, mov, sub, mov */
 862                                } else {
 863                                        EMIT2(0x41, 0x52); /* push %r10 */
 864                                        EMIT2(0x41, 0x51); /* push %r9 */
 865                                        /* need to adjust jmp offset, since
 866                                         * pop %r9, pop %r10 take 4 bytes after call insn
 867                                         */
 868                                        jmp_offset += 4;
 869                                }
 870                        }
 871                        if (!imm32 || !is_simm32(jmp_offset)) {
 872                                pr_err("unsupported bpf func %d addr %p image %p\n",
 873                                       imm32, func, image);
 874                                return -EINVAL;
 875                        }
 876                        EMIT1_off32(0xE8, jmp_offset);
 877                        if (seen_ld_abs) {
 878                                if (reload_skb_data) {
 879                                        EMIT1(0x5F); /* pop %rdi */
 880                                        emit_load_skb_data_hlen(&prog);
 881                                } else {
 882                                        EMIT2(0x41, 0x59); /* pop %r9 */
 883                                        EMIT2(0x41, 0x5A); /* pop %r10 */
 884                                }
 885                        }
 886                        break;
 887
 888                case BPF_JMP | BPF_TAIL_CALL:
 889                        emit_bpf_tail_call(&prog);
 890                        break;
 891
 892                        /* cond jump */
 893                case BPF_JMP | BPF_JEQ | BPF_X:
 894                case BPF_JMP | BPF_JNE | BPF_X:
 895                case BPF_JMP | BPF_JGT | BPF_X:
 896                case BPF_JMP | BPF_JLT | BPF_X:
 897                case BPF_JMP | BPF_JGE | BPF_X:
 898                case BPF_JMP | BPF_JLE | BPF_X:
 899                case BPF_JMP | BPF_JSGT | BPF_X:
 900                case BPF_JMP | BPF_JSLT | BPF_X:
 901                case BPF_JMP | BPF_JSGE | BPF_X:
 902                case BPF_JMP | BPF_JSLE | BPF_X:
 903                        /* cmp dst_reg, src_reg */
 904                        EMIT3(add_2mod(0x48, dst_reg, src_reg), 0x39,
 905                              add_2reg(0xC0, dst_reg, src_reg));
 906                        goto emit_cond_jmp;
 907
 908                case BPF_JMP | BPF_JSET | BPF_X:
 909                        /* test dst_reg, src_reg */
 910                        EMIT3(add_2mod(0x48, dst_reg, src_reg), 0x85,
 911                              add_2reg(0xC0, dst_reg, src_reg));
 912                        goto emit_cond_jmp;
 913
 914                case BPF_JMP | BPF_JSET | BPF_K:
 915                        /* test dst_reg, imm32 */
 916                        EMIT1(add_1mod(0x48, dst_reg));
 917                        EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
 918                        goto emit_cond_jmp;
 919
 920                case BPF_JMP | BPF_JEQ | BPF_K:
 921                case BPF_JMP | BPF_JNE | BPF_K:
 922                case BPF_JMP | BPF_JGT | BPF_K:
 923                case BPF_JMP | BPF_JLT | BPF_K:
 924                case BPF_JMP | BPF_JGE | BPF_K:
 925                case BPF_JMP | BPF_JLE | BPF_K:
 926                case BPF_JMP | BPF_JSGT | BPF_K:
 927                case BPF_JMP | BPF_JSLT | BPF_K:
 928                case BPF_JMP | BPF_JSGE | BPF_K:
 929                case BPF_JMP | BPF_JSLE | BPF_K:
 930                        /* cmp dst_reg, imm8/32 */
 931                        EMIT1(add_1mod(0x48, dst_reg));
 932
 933                        if (is_imm8(imm32))
 934                                EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
 935                        else
 936                                EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
 937
 938emit_cond_jmp:          /* convert BPF opcode to x86 */
 939                        switch (BPF_OP(insn->code)) {
 940                        case BPF_JEQ:
 941                                jmp_cond = X86_JE;
 942                                break;
 943                        case BPF_JSET:
 944                        case BPF_JNE:
 945                                jmp_cond = X86_JNE;
 946                                break;
 947                        case BPF_JGT:
 948                                /* GT is unsigned '>', JA in x86 */
 949                                jmp_cond = X86_JA;
 950                                break;
 951                        case BPF_JLT:
 952                                /* LT is unsigned '<', JB in x86 */
 953                                jmp_cond = X86_JB;
 954                                break;
 955                        case BPF_JGE:
 956                                /* GE is unsigned '>=', JAE in x86 */
 957                                jmp_cond = X86_JAE;
 958                                break;
 959                        case BPF_JLE:
 960                                /* LE is unsigned '<=', JBE in x86 */
 961                                jmp_cond = X86_JBE;
 962                                break;
 963                        case BPF_JSGT:
 964                                /* signed '>', GT in x86 */
 965                                jmp_cond = X86_JG;
 966                                break;
 967                        case BPF_JSLT:
 968                                /* signed '<', LT in x86 */
 969                                jmp_cond = X86_JL;
 970                                break;
 971                        case BPF_JSGE:
 972                                /* signed '>=', GE in x86 */
 973                                jmp_cond = X86_JGE;
 974                                break;
 975                        case BPF_JSLE:
 976                                /* signed '<=', LE in x86 */
 977                                jmp_cond = X86_JLE;
 978                                break;
 979                        default: /* to silence gcc warning */
 980                                return -EFAULT;
 981                        }
 982                        jmp_offset = addrs[i + insn->off] - addrs[i];
 983                        if (is_imm8(jmp_offset)) {
 984                                EMIT2(jmp_cond, jmp_offset);
 985                        } else if (is_simm32(jmp_offset)) {
 986                                EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
 987                        } else {
 988                                pr_err("cond_jmp gen bug %llx\n", jmp_offset);
 989                                return -EFAULT;
 990                        }
 991
 992                        break;
 993
 994                case BPF_JMP | BPF_JA:
 995                        jmp_offset = addrs[i + insn->off] - addrs[i];
 996                        if (!jmp_offset)
 997                                /* optimize out nop jumps */
 998                                break;
 999emit_jmp:
1000                        if (is_imm8(jmp_offset)) {
1001                                EMIT2(0xEB, jmp_offset);
1002                        } else if (is_simm32(jmp_offset)) {
1003                                EMIT1_off32(0xE9, jmp_offset);
1004                        } else {
1005                                pr_err("jmp gen bug %llx\n", jmp_offset);
1006                                return -EFAULT;
1007                        }
1008                        break;
1009
1010                case BPF_LD | BPF_IND | BPF_W:
1011                        func = sk_load_word;
1012                        goto common_load;
1013                case BPF_LD | BPF_ABS | BPF_W:
1014                        func = CHOOSE_LOAD_FUNC(imm32, sk_load_word);
1015common_load:
1016                        ctx->seen_ld_abs = seen_ld_abs = true;
1017                        jmp_offset = func - (image + addrs[i]);
1018                        if (!func || !is_simm32(jmp_offset)) {
1019                                pr_err("unsupported bpf func %d addr %p image %p\n",
1020                                       imm32, func, image);
1021                                return -EINVAL;
1022                        }
1023                        if (BPF_MODE(insn->code) == BPF_ABS) {
1024                                /* mov %esi, imm32 */
1025                                EMIT1_off32(0xBE, imm32);
1026                        } else {
1027                                /* mov %rsi, src_reg */
1028                                EMIT_mov(BPF_REG_2, src_reg);
1029                                if (imm32) {
1030                                        if (is_imm8(imm32))
1031                                                /* add %esi, imm8 */
1032                                                EMIT3(0x83, 0xC6, imm32);
1033                                        else
1034                                                /* add %esi, imm32 */
1035                                                EMIT2_off32(0x81, 0xC6, imm32);
1036                                }
1037                        }
1038                        /* skb pointer is in R6 (%rbx), it will be copied into
1039                         * %rdi if skb_copy_bits() call is necessary.
1040                         * sk_load_* helpers also use %r10 and %r9d.
1041                         * See bpf_jit.S
1042                         */
1043                        if (seen_ax_reg)
1044                                /* r10 = skb->data, mov %r10, off32(%rbx) */
1045                                EMIT3_off32(0x4c, 0x8b, 0x93,
1046                                            offsetof(struct sk_buff, data));
1047                        EMIT1_off32(0xE8, jmp_offset); /* call */
1048                        break;
1049
1050                case BPF_LD | BPF_IND | BPF_H:
1051                        func = sk_load_half;
1052                        goto common_load;
1053                case BPF_LD | BPF_ABS | BPF_H:
1054                        func = CHOOSE_LOAD_FUNC(imm32, sk_load_half);
1055                        goto common_load;
1056                case BPF_LD | BPF_IND | BPF_B:
1057                        func = sk_load_byte;
1058                        goto common_load;
1059                case BPF_LD | BPF_ABS | BPF_B:
1060                        func = CHOOSE_LOAD_FUNC(imm32, sk_load_byte);
1061                        goto common_load;
1062
1063                case BPF_JMP | BPF_EXIT:
1064                        if (seen_exit) {
1065                                jmp_offset = ctx->cleanup_addr - addrs[i];
1066                                goto emit_jmp;
1067                        }
1068                        seen_exit = true;
1069                        /* update cleanup_addr */
1070                        ctx->cleanup_addr = proglen;
1071                        /* mov rbx, qword ptr [rbp+0] */
1072                        EMIT4(0x48, 0x8B, 0x5D, 0);
1073                        /* mov r13, qword ptr [rbp+8] */
1074                        EMIT4(0x4C, 0x8B, 0x6D, 8);
1075                        /* mov r14, qword ptr [rbp+16] */
1076                        EMIT4(0x4C, 0x8B, 0x75, 16);
1077                        /* mov r15, qword ptr [rbp+24] */
1078                        EMIT4(0x4C, 0x8B, 0x7D, 24);
1079
1080                        /* add rbp, AUX_STACK_SPACE */
1081                        EMIT4(0x48, 0x83, 0xC5, AUX_STACK_SPACE);
1082                        EMIT1(0xC9); /* leave */
1083                        EMIT1(0xC3); /* ret */
1084                        break;
1085
1086                default:
1087                        /* By design x64 JIT should support all BPF instructions
1088                         * This error will be seen if new instruction was added
1089                         * to interpreter, but not to JIT
1090                         * or if there is junk in bpf_prog
1091                         */
1092                        pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
1093                        return -EINVAL;
1094                }
1095
1096                ilen = prog - temp;
1097                if (ilen > BPF_MAX_INSN_SIZE) {
1098                        pr_err("bpf_jit: fatal insn size error\n");
1099                        return -EFAULT;
1100                }
1101
1102                if (image) {
1103                        if (unlikely(proglen + ilen > oldproglen)) {
1104                                pr_err("bpf_jit: fatal error\n");
1105                                return -EFAULT;
1106                        }
1107                        memcpy(image + proglen, temp, ilen);
1108                }
1109                proglen += ilen;
1110                addrs[i] = proglen;
1111                prog = temp;
1112        }
1113        return proglen;
1114}
1115
1116struct x64_jit_data {
1117        struct bpf_binary_header *header;
1118        int *addrs;
1119        u8 *image;
1120        int proglen;
1121        struct jit_context ctx;
1122};
1123
1124struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1125{
1126        struct bpf_binary_header *header = NULL;
1127        struct bpf_prog *tmp, *orig_prog = prog;
1128        struct x64_jit_data *jit_data;
1129        int proglen, oldproglen = 0;
1130        struct jit_context ctx = {};
1131        bool tmp_blinded = false;
1132        bool extra_pass = false;
1133        u8 *image = NULL;
1134        int *addrs;
1135        int pass;
1136        int i;
1137
1138        if (!prog->jit_requested)
1139                return orig_prog;
1140
1141        tmp = bpf_jit_blind_constants(prog);
1142        /* If blinding was requested and we failed during blinding,
1143         * we must fall back to the interpreter.
1144         */
1145        if (IS_ERR(tmp))
1146                return orig_prog;
1147        if (tmp != prog) {
1148                tmp_blinded = true;
1149                prog = tmp;
1150        }
1151
1152        jit_data = prog->aux->jit_data;
1153        if (!jit_data) {
1154                jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1155                if (!jit_data) {
1156                        prog = orig_prog;
1157                        goto out;
1158                }
1159                prog->aux->jit_data = jit_data;
1160        }
1161        addrs = jit_data->addrs;
1162        if (addrs) {
1163                ctx = jit_data->ctx;
1164                oldproglen = jit_data->proglen;
1165                image = jit_data->image;
1166                header = jit_data->header;
1167                extra_pass = true;
1168                goto skip_init_addrs;
1169        }
1170        addrs = kmalloc(prog->len * sizeof(*addrs), GFP_KERNEL);
1171        if (!addrs) {
1172                prog = orig_prog;
1173                goto out_addrs;
1174        }
1175
1176        /* Before first pass, make a rough estimation of addrs[]
1177         * each bpf instruction is translated to less than 64 bytes
1178         */
1179        for (proglen = 0, i = 0; i < prog->len; i++) {
1180                proglen += 64;
1181                addrs[i] = proglen;
1182        }
1183        ctx.cleanup_addr = proglen;
1184skip_init_addrs:
1185
1186        /* JITed image shrinks with every pass and the loop iterates
1187         * until the image stops shrinking. Very large bpf programs
1188         * may converge on the last pass. In such case do one more
1189         * pass to emit the final image
1190         */
1191        for (pass = 0; pass < 20 || image; pass++) {
1192                proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
1193                if (proglen <= 0) {
1194                        image = NULL;
1195                        if (header)
1196                                bpf_jit_binary_free(header);
1197                        prog = orig_prog;
1198                        goto out_addrs;
1199                }
1200                if (image) {
1201                        if (proglen != oldproglen) {
1202                                pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
1203                                       proglen, oldproglen);
1204                                prog = orig_prog;
1205                                goto out_addrs;
1206                        }
1207                        break;
1208                }
1209                if (proglen == oldproglen) {
1210                        header = bpf_jit_binary_alloc(proglen, &image,
1211                                                      1, jit_fill_hole);
1212                        if (!header) {
1213                                prog = orig_prog;
1214                                goto out_addrs;
1215                        }
1216                }
1217                oldproglen = proglen;
1218                cond_resched();
1219        }
1220
1221        if (bpf_jit_enable > 1)
1222                bpf_jit_dump(prog->len, proglen, pass + 1, image);
1223
1224        if (image) {
1225                bpf_flush_icache(header, image + proglen);
1226                if (!prog->is_func || extra_pass) {
1227                        bpf_jit_binary_lock_ro(header);
1228                } else {
1229                        jit_data->addrs = addrs;
1230                        jit_data->ctx = ctx;
1231                        jit_data->proglen = proglen;
1232                        jit_data->image = image;
1233                        jit_data->header = header;
1234                }
1235                prog->bpf_func = (void *)image;
1236                prog->jited = 1;
1237                prog->jited_len = proglen;
1238        } else {
1239                prog = orig_prog;
1240        }
1241
1242        if (!prog->is_func || extra_pass) {
1243out_addrs:
1244                kfree(addrs);
1245                kfree(jit_data);
1246                prog->aux->jit_data = NULL;
1247        }
1248out:
1249        if (tmp_blinded)
1250                bpf_jit_prog_release_other(prog, prog == orig_prog ?
1251                                           tmp : orig_prog);
1252        return prog;
1253}
1254