linux/arch/x86/net/bpf_jit_comp.c
<<
>>
Prefs
   1/*
   2 * bpf_jit_comp.c: BPF JIT compiler
   3 *
   4 * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
   5 * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
   6 *
   7 * This program is free software; you can redistribute it and/or
   8 * modify it under the terms of the GNU General Public License
   9 * as published by the Free Software Foundation; version 2
  10 * of the License.
  11 */
  12#include <linux/netdevice.h>
  13#include <linux/filter.h>
  14#include <linux/if_vlan.h>
  15#include <linux/bpf.h>
  16#include <linux/memory.h>
  17#include <linux/sort.h>
  18#include <asm/extable.h>
  19#include <asm/set_memory.h>
  20#include <asm/nospec-branch.h>
  21#include <asm/text-patching.h>
  22#include <asm/asm-prototypes.h>
  23
  24static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
  25{
  26        if (len == 1)
  27                *ptr = bytes;
  28        else if (len == 2)
  29                *(u16 *)ptr = bytes;
  30        else {
  31                *(u32 *)ptr = bytes;
  32                barrier();
  33        }
  34        return ptr + len;
  35}
  36
  37#define EMIT(bytes, len) \
  38        do { prog = emit_code(prog, bytes, len); cnt += len; } while (0)
  39
  40#define EMIT1(b1)               EMIT(b1, 1)
  41#define EMIT2(b1, b2)           EMIT((b1) + ((b2) << 8), 2)
  42#define EMIT3(b1, b2, b3)       EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
  43#define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
  44
  45#define EMIT1_off32(b1, off) \
  46        do { EMIT1(b1); EMIT(off, 4); } while (0)
  47#define EMIT2_off32(b1, b2, off) \
  48        do { EMIT2(b1, b2); EMIT(off, 4); } while (0)
  49#define EMIT3_off32(b1, b2, b3, off) \
  50        do { EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
  51#define EMIT4_off32(b1, b2, b3, b4, off) \
  52        do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
  53
  54static bool is_imm8(int value)
  55{
  56        return value <= 127 && value >= -128;
  57}
  58
  59static bool is_simm32(s64 value)
  60{
  61        return value == (s64)(s32)value;
  62}
  63
  64static bool is_uimm32(u64 value)
  65{
  66        return value == (u64)(u32)value;
  67}
  68
  69/* mov dst, src */
  70#define EMIT_mov(DST, SRC)                                                               \
  71        do {                                                                             \
  72                if (DST != SRC)                                                          \
  73                        EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
  74        } while (0)
  75
  76static int bpf_size_to_x86_bytes(int bpf_size)
  77{
  78        if (bpf_size == BPF_W)
  79                return 4;
  80        else if (bpf_size == BPF_H)
  81                return 2;
  82        else if (bpf_size == BPF_B)
  83                return 1;
  84        else if (bpf_size == BPF_DW)
  85                return 4; /* imm32 */
  86        else
  87                return 0;
  88}
  89
  90/*
  91 * List of x86 cond jumps opcodes (. + s8)
  92 * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
  93 */
  94#define X86_JB  0x72
  95#define X86_JAE 0x73
  96#define X86_JE  0x74
  97#define X86_JNE 0x75
  98#define X86_JBE 0x76
  99#define X86_JA  0x77
 100#define X86_JL  0x7C
 101#define X86_JGE 0x7D
 102#define X86_JLE 0x7E
 103#define X86_JG  0x7F
 104
 105/* Pick a register outside of BPF range for JIT internal work */
 106#define AUX_REG (MAX_BPF_JIT_REG + 1)
 107#define X86_REG_R9 (MAX_BPF_JIT_REG + 2)
 108
 109/*
 110 * The following table maps BPF registers to x86-64 registers.
 111 *
 112 * x86-64 register R12 is unused, since if used as base address
 113 * register in load/store instructions, it always needs an
 114 * extra byte of encoding and is callee saved.
 115 *
 116 * x86-64 register R9 is not used by BPF programs, but can be used by BPF
 117 * trampoline. x86-64 register R10 is used for blinding (if enabled).
 118 */
 119static const int reg2hex[] = {
 120        [BPF_REG_0] = 0,  /* RAX */
 121        [BPF_REG_1] = 7,  /* RDI */
 122        [BPF_REG_2] = 6,  /* RSI */
 123        [BPF_REG_3] = 2,  /* RDX */
 124        [BPF_REG_4] = 1,  /* RCX */
 125        [BPF_REG_5] = 0,  /* R8  */
 126        [BPF_REG_6] = 3,  /* RBX callee saved */
 127        [BPF_REG_7] = 5,  /* R13 callee saved */
 128        [BPF_REG_8] = 6,  /* R14 callee saved */
 129        [BPF_REG_9] = 7,  /* R15 callee saved */
 130        [BPF_REG_FP] = 5, /* RBP readonly */
 131        [BPF_REG_AX] = 2, /* R10 temp register */
 132        [AUX_REG] = 3,    /* R11 temp register */
 133        [X86_REG_R9] = 1, /* R9 register, 6th function argument */
 134};
 135
 136static const int reg2pt_regs[] = {
 137        [BPF_REG_0] = offsetof(struct pt_regs, ax),
 138        [BPF_REG_1] = offsetof(struct pt_regs, di),
 139        [BPF_REG_2] = offsetof(struct pt_regs, si),
 140        [BPF_REG_3] = offsetof(struct pt_regs, dx),
 141        [BPF_REG_4] = offsetof(struct pt_regs, cx),
 142        [BPF_REG_5] = offsetof(struct pt_regs, r8),
 143        [BPF_REG_6] = offsetof(struct pt_regs, bx),
 144        [BPF_REG_7] = offsetof(struct pt_regs, r13),
 145        [BPF_REG_8] = offsetof(struct pt_regs, r14),
 146        [BPF_REG_9] = offsetof(struct pt_regs, r15),
 147};
 148
 149/*
 150 * is_ereg() == true if BPF register 'reg' maps to x86-64 r8..r15
 151 * which need extra byte of encoding.
 152 * rax,rcx,...,rbp have simpler encoding
 153 */
 154static bool is_ereg(u32 reg)
 155{
 156        return (1 << reg) & (BIT(BPF_REG_5) |
 157                             BIT(AUX_REG) |
 158                             BIT(BPF_REG_7) |
 159                             BIT(BPF_REG_8) |
 160                             BIT(BPF_REG_9) |
 161                             BIT(X86_REG_R9) |
 162                             BIT(BPF_REG_AX));
 163}
 164
 165/*
 166 * is_ereg_8l() == true if BPF register 'reg' is mapped to access x86-64
 167 * lower 8-bit registers dil,sil,bpl,spl,r8b..r15b, which need extra byte
 168 * of encoding. al,cl,dl,bl have simpler encoding.
 169 */
 170static bool is_ereg_8l(u32 reg)
 171{
 172        return is_ereg(reg) ||
 173            (1 << reg) & (BIT(BPF_REG_1) |
 174                          BIT(BPF_REG_2) |
 175                          BIT(BPF_REG_FP));
 176}
 177
 178static bool is_axreg(u32 reg)
 179{
 180        return reg == BPF_REG_0;
 181}
 182
 183/* Add modifiers if 'reg' maps to x86-64 registers R8..R15 */
 184static u8 add_1mod(u8 byte, u32 reg)
 185{
 186        if (is_ereg(reg))
 187                byte |= 1;
 188        return byte;
 189}
 190
 191static u8 add_2mod(u8 byte, u32 r1, u32 r2)
 192{
 193        if (is_ereg(r1))
 194                byte |= 1;
 195        if (is_ereg(r2))
 196                byte |= 4;
 197        return byte;
 198}
 199
 200/* Encode 'dst_reg' register into x86-64 opcode 'byte' */
 201static u8 add_1reg(u8 byte, u32 dst_reg)
 202{
 203        return byte + reg2hex[dst_reg];
 204}
 205
 206/* Encode 'dst_reg' and 'src_reg' registers into x86-64 opcode 'byte' */
 207static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
 208{
 209        return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
 210}
 211
 212static void jit_fill_hole(void *area, unsigned int size)
 213{
 214        /* Fill whole space with INT3 instructions */
 215        memset(area, 0xcc, size);
 216}
 217
 218struct jit_context {
 219        int cleanup_addr; /* Epilogue code offset */
 220};
 221
 222/* Maximum number of bytes emitted while JITing one eBPF insn */
 223#define BPF_MAX_INSN_SIZE       128
 224#define BPF_INSN_SAFETY         64
 225
 226/* Number of bytes emit_patch() needs to generate instructions */
 227#define X86_PATCH_SIZE          5
 228
 229#define PROLOGUE_SIZE           25
 230
 231/*
 232 * Emit x86-64 prologue code for BPF program and check its size.
 233 * bpf_tail_call helper will skip it while jumping into another program
 234 */
 235static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
 236{
 237        u8 *prog = *pprog;
 238        int cnt = X86_PATCH_SIZE;
 239
 240        /* BPF trampoline can be made to work without these nops,
 241         * but let's waste 5 bytes for now and optimize later
 242         */
 243        memcpy(prog, ideal_nops[NOP_ATOMIC5], cnt);
 244        prog += cnt;
 245        EMIT1(0x55);             /* push rbp */
 246        EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
 247        /* sub rsp, rounded_stack_depth */
 248        EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
 249        EMIT1(0x53);             /* push rbx */
 250        EMIT2(0x41, 0x55);       /* push r13 */
 251        EMIT2(0x41, 0x56);       /* push r14 */
 252        EMIT2(0x41, 0x57);       /* push r15 */
 253        if (!ebpf_from_cbpf) {
 254                /* zero init tail_call_cnt */
 255                EMIT2(0x6a, 0x00);
 256                BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
 257        }
 258        *pprog = prog;
 259}
 260
 261static int emit_patch(u8 **pprog, void *func, void *ip, u8 opcode)
 262{
 263        u8 *prog = *pprog;
 264        int cnt = 0;
 265        s64 offset;
 266
 267        offset = func - (ip + X86_PATCH_SIZE);
 268        if (!is_simm32(offset)) {
 269                pr_err("Target call %p is out of range\n", func);
 270                return -ERANGE;
 271        }
 272        EMIT1_off32(opcode, offset);
 273        *pprog = prog;
 274        return 0;
 275}
 276
 277static int emit_call(u8 **pprog, void *func, void *ip)
 278{
 279        return emit_patch(pprog, func, ip, 0xE8);
 280}
 281
 282static int emit_jump(u8 **pprog, void *func, void *ip)
 283{
 284        return emit_patch(pprog, func, ip, 0xE9);
 285}
 286
 287static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
 288                                void *old_addr, void *new_addr,
 289                                const bool text_live)
 290{
 291        const u8 *nop_insn = ideal_nops[NOP_ATOMIC5];
 292        u8 old_insn[X86_PATCH_SIZE];
 293        u8 new_insn[X86_PATCH_SIZE];
 294        u8 *prog;
 295        int ret;
 296
 297        memcpy(old_insn, nop_insn, X86_PATCH_SIZE);
 298        if (old_addr) {
 299                prog = old_insn;
 300                ret = t == BPF_MOD_CALL ?
 301                      emit_call(&prog, old_addr, ip) :
 302                      emit_jump(&prog, old_addr, ip);
 303                if (ret)
 304                        return ret;
 305        }
 306
 307        memcpy(new_insn, nop_insn, X86_PATCH_SIZE);
 308        if (new_addr) {
 309                prog = new_insn;
 310                ret = t == BPF_MOD_CALL ?
 311                      emit_call(&prog, new_addr, ip) :
 312                      emit_jump(&prog, new_addr, ip);
 313                if (ret)
 314                        return ret;
 315        }
 316
 317        ret = -EBUSY;
 318        mutex_lock(&text_mutex);
 319        if (memcmp(ip, old_insn, X86_PATCH_SIZE))
 320                goto out;
 321        if (memcmp(ip, new_insn, X86_PATCH_SIZE)) {
 322                if (text_live)
 323                        text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL);
 324                else
 325                        memcpy(ip, new_insn, X86_PATCH_SIZE);
 326        }
 327        ret = 0;
 328out:
 329        mutex_unlock(&text_mutex);
 330        return ret;
 331}
 332
 333int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
 334                       void *old_addr, void *new_addr)
 335{
 336        if (!is_kernel_text((long)ip) &&
 337            !is_bpf_text_address((long)ip))
 338                /* BPF poking in modules is not supported */
 339                return -EINVAL;
 340
 341        return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
 342}
 343
 344/*
 345 * Generate the following code:
 346 *
 347 * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
 348 *   if (index >= array->map.max_entries)
 349 *     goto out;
 350 *   if (++tail_call_cnt > MAX_TAIL_CALL_CNT)
 351 *     goto out;
 352 *   prog = array->ptrs[index];
 353 *   if (prog == NULL)
 354 *     goto out;
 355 *   goto *(prog->bpf_func + prologue_size);
 356 * out:
 357 */
 358static void emit_bpf_tail_call_indirect(u8 **pprog)
 359{
 360        u8 *prog = *pprog;
 361        int label1, label2, label3;
 362        int cnt = 0;
 363
 364        /*
 365         * rdi - pointer to ctx
 366         * rsi - pointer to bpf_array
 367         * rdx - index in bpf_array
 368         */
 369
 370        /*
 371         * if (index >= array->map.max_entries)
 372         *      goto out;
 373         */
 374        EMIT2(0x89, 0xD2);                        /* mov edx, edx */
 375        EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
 376              offsetof(struct bpf_array, map.max_entries));
 377#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* Number of bytes to jump */
 378        EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
 379        label1 = cnt;
 380
 381        /*
 382         * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
 383         *      goto out;
 384         */
 385        EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
 386        EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
 387#define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE)
 388        EMIT2(X86_JA, OFFSET2);                   /* ja out */
 389        label2 = cnt;
 390        EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
 391        EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
 392
 393        /* prog = array->ptrs[index]; */
 394        EMIT4_off32(0x48, 0x8B, 0x84, 0xD6,       /* mov rax, [rsi + rdx * 8 + offsetof(...)] */
 395                    offsetof(struct bpf_array, ptrs));
 396
 397        /*
 398         * if (prog == NULL)
 399         *      goto out;
 400         */
 401        EMIT3(0x48, 0x85, 0xC0);                  /* test rax,rax */
 402#define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
 403        EMIT2(X86_JE, OFFSET3);                   /* je out */
 404        label3 = cnt;
 405
 406        /* goto *(prog->bpf_func + prologue_size); */
 407        EMIT4(0x48, 0x8B, 0x40,                   /* mov rax, qword ptr [rax + 32] */
 408              offsetof(struct bpf_prog, bpf_func));
 409        EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE);   /* add rax, prologue_size */
 410
 411        /*
 412         * Wow we're ready to jump into next BPF program
 413         * rdi == ctx (1st arg)
 414         * rax == prog->bpf_func + prologue_size
 415         */
 416        RETPOLINE_RAX_BPF_JIT();
 417
 418        /* out: */
 419        BUILD_BUG_ON(cnt - label1 != OFFSET1);
 420        BUILD_BUG_ON(cnt - label2 != OFFSET2);
 421        BUILD_BUG_ON(cnt - label3 != OFFSET3);
 422        *pprog = prog;
 423}
 424
 425static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
 426                                      u8 **pprog, int addr, u8 *image)
 427{
 428        u8 *prog = *pprog;
 429        int cnt = 0;
 430
 431        /*
 432         * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
 433         *      goto out;
 434         */
 435        EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
 436        EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);         /* cmp eax, MAX_TAIL_CALL_CNT */
 437        EMIT2(X86_JA, 14);                            /* ja out */
 438        EMIT3(0x83, 0xC0, 0x01);                      /* add eax, 1 */
 439        EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
 440
 441        poke->ip = image + (addr - X86_PATCH_SIZE);
 442        poke->adj_off = PROLOGUE_SIZE;
 443
 444        memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE);
 445        prog += X86_PATCH_SIZE;
 446        /* out: */
 447
 448        *pprog = prog;
 449}
 450
 451static void bpf_tail_call_direct_fixup(struct bpf_prog *prog)
 452{
 453        struct bpf_jit_poke_descriptor *poke;
 454        struct bpf_array *array;
 455        struct bpf_prog *target;
 456        int i, ret;
 457
 458        for (i = 0; i < prog->aux->size_poke_tab; i++) {
 459                poke = &prog->aux->poke_tab[i];
 460                WARN_ON_ONCE(READ_ONCE(poke->ip_stable));
 461
 462                if (poke->reason != BPF_POKE_REASON_TAIL_CALL)
 463                        continue;
 464
 465                array = container_of(poke->tail_call.map, struct bpf_array, map);
 466                mutex_lock(&array->aux->poke_mutex);
 467                target = array->ptrs[poke->tail_call.key];
 468                if (target) {
 469                        /* Plain memcpy is used when image is not live yet
 470                         * and still not locked as read-only. Once poke
 471                         * location is active (poke->ip_stable), any parallel
 472                         * bpf_arch_text_poke() might occur still on the
 473                         * read-write image until we finally locked it as
 474                         * read-only. Both modifications on the given image
 475                         * are under text_mutex to avoid interference.
 476                         */
 477                        ret = __bpf_arch_text_poke(poke->ip, BPF_MOD_JUMP, NULL,
 478                                                   (u8 *)target->bpf_func +
 479                                                   poke->adj_off, false);
 480                        BUG_ON(ret < 0);
 481                }
 482                WRITE_ONCE(poke->ip_stable, true);
 483                mutex_unlock(&array->aux->poke_mutex);
 484        }
 485}
 486
 487static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
 488                           u32 dst_reg, const u32 imm32)
 489{
 490        u8 *prog = *pprog;
 491        u8 b1, b2, b3;
 492        int cnt = 0;
 493
 494        /*
 495         * Optimization: if imm32 is positive, use 'mov %eax, imm32'
 496         * (which zero-extends imm32) to save 2 bytes.
 497         */
 498        if (sign_propagate && (s32)imm32 < 0) {
 499                /* 'mov %rax, imm32' sign extends imm32 */
 500                b1 = add_1mod(0x48, dst_reg);
 501                b2 = 0xC7;
 502                b3 = 0xC0;
 503                EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
 504                goto done;
 505        }
 506
 507        /*
 508         * Optimization: if imm32 is zero, use 'xor %eax, %eax'
 509         * to save 3 bytes.
 510         */
 511        if (imm32 == 0) {
 512                if (is_ereg(dst_reg))
 513                        EMIT1(add_2mod(0x40, dst_reg, dst_reg));
 514                b2 = 0x31; /* xor */
 515                b3 = 0xC0;
 516                EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
 517                goto done;
 518        }
 519
 520        /* mov %eax, imm32 */
 521        if (is_ereg(dst_reg))
 522                EMIT1(add_1mod(0x40, dst_reg));
 523        EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
 524done:
 525        *pprog = prog;
 526}
 527
 528static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
 529                           const u32 imm32_hi, const u32 imm32_lo)
 530{
 531        u8 *prog = *pprog;
 532        int cnt = 0;
 533
 534        if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
 535                /*
 536                 * For emitting plain u32, where sign bit must not be
 537                 * propagated LLVM tends to load imm64 over mov32
 538                 * directly, so save couple of bytes by just doing
 539                 * 'mov %eax, imm32' instead.
 540                 */
 541                emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
 542        } else {
 543                /* movabsq %rax, imm64 */
 544                EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
 545                EMIT(imm32_lo, 4);
 546                EMIT(imm32_hi, 4);
 547        }
 548
 549        *pprog = prog;
 550}
 551
 552static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
 553{
 554        u8 *prog = *pprog;
 555        int cnt = 0;
 556
 557        if (is64) {
 558                /* mov dst, src */
 559                EMIT_mov(dst_reg, src_reg);
 560        } else {
 561                /* mov32 dst, src */
 562                if (is_ereg(dst_reg) || is_ereg(src_reg))
 563                        EMIT1(add_2mod(0x40, dst_reg, src_reg));
 564                EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
 565        }
 566
 567        *pprog = prog;
 568}
 569
 570/* LDX: dst_reg = *(u8*)(src_reg + off) */
 571static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
 572{
 573        u8 *prog = *pprog;
 574        int cnt = 0;
 575
 576        switch (size) {
 577        case BPF_B:
 578                /* Emit 'movzx rax, byte ptr [rax + off]' */
 579                EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
 580                break;
 581        case BPF_H:
 582                /* Emit 'movzx rax, word ptr [rax + off]' */
 583                EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
 584                break;
 585        case BPF_W:
 586                /* Emit 'mov eax, dword ptr [rax+0x14]' */
 587                if (is_ereg(dst_reg) || is_ereg(src_reg))
 588                        EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
 589                else
 590                        EMIT1(0x8B);
 591                break;
 592        case BPF_DW:
 593                /* Emit 'mov rax, qword ptr [rax+0x14]' */
 594                EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
 595                break;
 596        }
 597        /*
 598         * If insn->off == 0 we can save one extra byte, but
 599         * special case of x86 R13 which always needs an offset
 600         * is not worth the hassle
 601         */
 602        if (is_imm8(off))
 603                EMIT2(add_2reg(0x40, src_reg, dst_reg), off);
 604        else
 605                EMIT1_off32(add_2reg(0x80, src_reg, dst_reg), off);
 606        *pprog = prog;
 607}
 608
 609/* STX: *(u8*)(dst_reg + off) = src_reg */
 610static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
 611{
 612        u8 *prog = *pprog;
 613        int cnt = 0;
 614
 615        switch (size) {
 616        case BPF_B:
 617                /* Emit 'mov byte ptr [rax + off], al' */
 618                if (is_ereg(dst_reg) || is_ereg_8l(src_reg))
 619                        /* Add extra byte for eregs or SIL,DIL,BPL in src_reg */
 620                        EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
 621                else
 622                        EMIT1(0x88);
 623                break;
 624        case BPF_H:
 625                if (is_ereg(dst_reg) || is_ereg(src_reg))
 626                        EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
 627                else
 628                        EMIT2(0x66, 0x89);
 629                break;
 630        case BPF_W:
 631                if (is_ereg(dst_reg) || is_ereg(src_reg))
 632                        EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
 633                else
 634                        EMIT1(0x89);
 635                break;
 636        case BPF_DW:
 637                EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
 638                break;
 639        }
 640        if (is_imm8(off))
 641                EMIT2(add_2reg(0x40, dst_reg, src_reg), off);
 642        else
 643                EMIT1_off32(add_2reg(0x80, dst_reg, src_reg), off);
 644        *pprog = prog;
 645}
 646
 647static bool ex_handler_bpf(const struct exception_table_entry *x,
 648                           struct pt_regs *regs, int trapnr,
 649                           unsigned long error_code, unsigned long fault_addr)
 650{
 651        u32 reg = x->fixup >> 8;
 652
 653        /* jump over faulting load and clear dest register */
 654        *(unsigned long *)((void *)regs + reg) = 0;
 655        regs->ip += x->fixup & 0xff;
 656        return true;
 657}
 658
 659static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
 660                  int oldproglen, struct jit_context *ctx)
 661{
 662        struct bpf_insn *insn = bpf_prog->insnsi;
 663        int insn_cnt = bpf_prog->len;
 664        bool seen_exit = false;
 665        u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
 666        int i, cnt = 0, excnt = 0;
 667        int proglen = 0;
 668        u8 *prog = temp;
 669
 670        emit_prologue(&prog, bpf_prog->aux->stack_depth,
 671                      bpf_prog_was_classic(bpf_prog));
 672        addrs[0] = prog - temp;
 673
 674        for (i = 1; i <= insn_cnt; i++, insn++) {
 675                const s32 imm32 = insn->imm;
 676                u32 dst_reg = insn->dst_reg;
 677                u32 src_reg = insn->src_reg;
 678                u8 b2 = 0, b3 = 0;
 679                s64 jmp_offset;
 680                u8 jmp_cond;
 681                int ilen;
 682                u8 *func;
 683
 684                switch (insn->code) {
 685                        /* ALU */
 686                case BPF_ALU | BPF_ADD | BPF_X:
 687                case BPF_ALU | BPF_SUB | BPF_X:
 688                case BPF_ALU | BPF_AND | BPF_X:
 689                case BPF_ALU | BPF_OR | BPF_X:
 690                case BPF_ALU | BPF_XOR | BPF_X:
 691                case BPF_ALU64 | BPF_ADD | BPF_X:
 692                case BPF_ALU64 | BPF_SUB | BPF_X:
 693                case BPF_ALU64 | BPF_AND | BPF_X:
 694                case BPF_ALU64 | BPF_OR | BPF_X:
 695                case BPF_ALU64 | BPF_XOR | BPF_X:
 696                        switch (BPF_OP(insn->code)) {
 697                        case BPF_ADD: b2 = 0x01; break;
 698                        case BPF_SUB: b2 = 0x29; break;
 699                        case BPF_AND: b2 = 0x21; break;
 700                        case BPF_OR: b2 = 0x09; break;
 701                        case BPF_XOR: b2 = 0x31; break;
 702                        }
 703                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 704                                EMIT1(add_2mod(0x48, dst_reg, src_reg));
 705                        else if (is_ereg(dst_reg) || is_ereg(src_reg))
 706                                EMIT1(add_2mod(0x40, dst_reg, src_reg));
 707                        EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
 708                        break;
 709
 710                case BPF_ALU64 | BPF_MOV | BPF_X:
 711                case BPF_ALU | BPF_MOV | BPF_X:
 712                        emit_mov_reg(&prog,
 713                                     BPF_CLASS(insn->code) == BPF_ALU64,
 714                                     dst_reg, src_reg);
 715                        break;
 716
 717                        /* neg dst */
 718                case BPF_ALU | BPF_NEG:
 719                case BPF_ALU64 | BPF_NEG:
 720                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 721                                EMIT1(add_1mod(0x48, dst_reg));
 722                        else if (is_ereg(dst_reg))
 723                                EMIT1(add_1mod(0x40, dst_reg));
 724                        EMIT2(0xF7, add_1reg(0xD8, dst_reg));
 725                        break;
 726
 727                case BPF_ALU | BPF_ADD | BPF_K:
 728                case BPF_ALU | BPF_SUB | BPF_K:
 729                case BPF_ALU | BPF_AND | BPF_K:
 730                case BPF_ALU | BPF_OR | BPF_K:
 731                case BPF_ALU | BPF_XOR | BPF_K:
 732                case BPF_ALU64 | BPF_ADD | BPF_K:
 733                case BPF_ALU64 | BPF_SUB | BPF_K:
 734                case BPF_ALU64 | BPF_AND | BPF_K:
 735                case BPF_ALU64 | BPF_OR | BPF_K:
 736                case BPF_ALU64 | BPF_XOR | BPF_K:
 737                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 738                                EMIT1(add_1mod(0x48, dst_reg));
 739                        else if (is_ereg(dst_reg))
 740                                EMIT1(add_1mod(0x40, dst_reg));
 741
 742                        /*
 743                         * b3 holds 'normal' opcode, b2 short form only valid
 744                         * in case dst is eax/rax.
 745                         */
 746                        switch (BPF_OP(insn->code)) {
 747                        case BPF_ADD:
 748                                b3 = 0xC0;
 749                                b2 = 0x05;
 750                                break;
 751                        case BPF_SUB:
 752                                b3 = 0xE8;
 753                                b2 = 0x2D;
 754                                break;
 755                        case BPF_AND:
 756                                b3 = 0xE0;
 757                                b2 = 0x25;
 758                                break;
 759                        case BPF_OR:
 760                                b3 = 0xC8;
 761                                b2 = 0x0D;
 762                                break;
 763                        case BPF_XOR:
 764                                b3 = 0xF0;
 765                                b2 = 0x35;
 766                                break;
 767                        }
 768
 769                        if (is_imm8(imm32))
 770                                EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
 771                        else if (is_axreg(dst_reg))
 772                                EMIT1_off32(b2, imm32);
 773                        else
 774                                EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
 775                        break;
 776
 777                case BPF_ALU64 | BPF_MOV | BPF_K:
 778                case BPF_ALU | BPF_MOV | BPF_K:
 779                        emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
 780                                       dst_reg, imm32);
 781                        break;
 782
 783                case BPF_LD | BPF_IMM | BPF_DW:
 784                        emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
 785                        insn++;
 786                        i++;
 787                        break;
 788
 789                        /* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
 790                case BPF_ALU | BPF_MOD | BPF_X:
 791                case BPF_ALU | BPF_DIV | BPF_X:
 792                case BPF_ALU | BPF_MOD | BPF_K:
 793                case BPF_ALU | BPF_DIV | BPF_K:
 794                case BPF_ALU64 | BPF_MOD | BPF_X:
 795                case BPF_ALU64 | BPF_DIV | BPF_X:
 796                case BPF_ALU64 | BPF_MOD | BPF_K:
 797                case BPF_ALU64 | BPF_DIV | BPF_K:
 798                        EMIT1(0x50); /* push rax */
 799                        EMIT1(0x52); /* push rdx */
 800
 801                        if (BPF_SRC(insn->code) == BPF_X)
 802                                /* mov r11, src_reg */
 803                                EMIT_mov(AUX_REG, src_reg);
 804                        else
 805                                /* mov r11, imm32 */
 806                                EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
 807
 808                        /* mov rax, dst_reg */
 809                        EMIT_mov(BPF_REG_0, dst_reg);
 810
 811                        /*
 812                         * xor edx, edx
 813                         * equivalent to 'xor rdx, rdx', but one byte less
 814                         */
 815                        EMIT2(0x31, 0xd2);
 816
 817                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 818                                /* div r11 */
 819                                EMIT3(0x49, 0xF7, 0xF3);
 820                        else
 821                                /* div r11d */
 822                                EMIT3(0x41, 0xF7, 0xF3);
 823
 824                        if (BPF_OP(insn->code) == BPF_MOD)
 825                                /* mov r11, rdx */
 826                                EMIT3(0x49, 0x89, 0xD3);
 827                        else
 828                                /* mov r11, rax */
 829                                EMIT3(0x49, 0x89, 0xC3);
 830
 831                        EMIT1(0x5A); /* pop rdx */
 832                        EMIT1(0x58); /* pop rax */
 833
 834                        /* mov dst_reg, r11 */
 835                        EMIT_mov(dst_reg, AUX_REG);
 836                        break;
 837
 838                case BPF_ALU | BPF_MUL | BPF_K:
 839                case BPF_ALU | BPF_MUL | BPF_X:
 840                case BPF_ALU64 | BPF_MUL | BPF_K:
 841                case BPF_ALU64 | BPF_MUL | BPF_X:
 842                {
 843                        bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;
 844
 845                        if (dst_reg != BPF_REG_0)
 846                                EMIT1(0x50); /* push rax */
 847                        if (dst_reg != BPF_REG_3)
 848                                EMIT1(0x52); /* push rdx */
 849
 850                        /* mov r11, dst_reg */
 851                        EMIT_mov(AUX_REG, dst_reg);
 852
 853                        if (BPF_SRC(insn->code) == BPF_X)
 854                                emit_mov_reg(&prog, is64, BPF_REG_0, src_reg);
 855                        else
 856                                emit_mov_imm32(&prog, is64, BPF_REG_0, imm32);
 857
 858                        if (is64)
 859                                EMIT1(add_1mod(0x48, AUX_REG));
 860                        else if (is_ereg(AUX_REG))
 861                                EMIT1(add_1mod(0x40, AUX_REG));
 862                        /* mul(q) r11 */
 863                        EMIT2(0xF7, add_1reg(0xE0, AUX_REG));
 864
 865                        if (dst_reg != BPF_REG_3)
 866                                EMIT1(0x5A); /* pop rdx */
 867                        if (dst_reg != BPF_REG_0) {
 868                                /* mov dst_reg, rax */
 869                                EMIT_mov(dst_reg, BPF_REG_0);
 870                                EMIT1(0x58); /* pop rax */
 871                        }
 872                        break;
 873                }
 874                        /* Shifts */
 875                case BPF_ALU | BPF_LSH | BPF_K:
 876                case BPF_ALU | BPF_RSH | BPF_K:
 877                case BPF_ALU | BPF_ARSH | BPF_K:
 878                case BPF_ALU64 | BPF_LSH | BPF_K:
 879                case BPF_ALU64 | BPF_RSH | BPF_K:
 880                case BPF_ALU64 | BPF_ARSH | BPF_K:
 881                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 882                                EMIT1(add_1mod(0x48, dst_reg));
 883                        else if (is_ereg(dst_reg))
 884                                EMIT1(add_1mod(0x40, dst_reg));
 885
 886                        switch (BPF_OP(insn->code)) {
 887                        case BPF_LSH: b3 = 0xE0; break;
 888                        case BPF_RSH: b3 = 0xE8; break;
 889                        case BPF_ARSH: b3 = 0xF8; break;
 890                        }
 891
 892                        if (imm32 == 1)
 893                                EMIT2(0xD1, add_1reg(b3, dst_reg));
 894                        else
 895                                EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
 896                        break;
 897
 898                case BPF_ALU | BPF_LSH | BPF_X:
 899                case BPF_ALU | BPF_RSH | BPF_X:
 900                case BPF_ALU | BPF_ARSH | BPF_X:
 901                case BPF_ALU64 | BPF_LSH | BPF_X:
 902                case BPF_ALU64 | BPF_RSH | BPF_X:
 903                case BPF_ALU64 | BPF_ARSH | BPF_X:
 904
 905                        /* Check for bad case when dst_reg == rcx */
 906                        if (dst_reg == BPF_REG_4) {
 907                                /* mov r11, dst_reg */
 908                                EMIT_mov(AUX_REG, dst_reg);
 909                                dst_reg = AUX_REG;
 910                        }
 911
 912                        if (src_reg != BPF_REG_4) { /* common case */
 913                                EMIT1(0x51); /* push rcx */
 914
 915                                /* mov rcx, src_reg */
 916                                EMIT_mov(BPF_REG_4, src_reg);
 917                        }
 918
 919                        /* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
 920                        if (BPF_CLASS(insn->code) == BPF_ALU64)
 921                                EMIT1(add_1mod(0x48, dst_reg));
 922                        else if (is_ereg(dst_reg))
 923                                EMIT1(add_1mod(0x40, dst_reg));
 924
 925                        switch (BPF_OP(insn->code)) {
 926                        case BPF_LSH: b3 = 0xE0; break;
 927                        case BPF_RSH: b3 = 0xE8; break;
 928                        case BPF_ARSH: b3 = 0xF8; break;
 929                        }
 930                        EMIT2(0xD3, add_1reg(b3, dst_reg));
 931
 932                        if (src_reg != BPF_REG_4)
 933                                EMIT1(0x59); /* pop rcx */
 934
 935                        if (insn->dst_reg == BPF_REG_4)
 936                                /* mov dst_reg, r11 */
 937                                EMIT_mov(insn->dst_reg, AUX_REG);
 938                        break;
 939
 940                case BPF_ALU | BPF_END | BPF_FROM_BE:
 941                        switch (imm32) {
 942                        case 16:
 943                                /* Emit 'ror %ax, 8' to swap lower 2 bytes */
 944                                EMIT1(0x66);
 945                                if (is_ereg(dst_reg))
 946                                        EMIT1(0x41);
 947                                EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
 948
 949                                /* Emit 'movzwl eax, ax' */
 950                                if (is_ereg(dst_reg))
 951                                        EMIT3(0x45, 0x0F, 0xB7);
 952                                else
 953                                        EMIT2(0x0F, 0xB7);
 954                                EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
 955                                break;
 956                        case 32:
 957                                /* Emit 'bswap eax' to swap lower 4 bytes */
 958                                if (is_ereg(dst_reg))
 959                                        EMIT2(0x41, 0x0F);
 960                                else
 961                                        EMIT1(0x0F);
 962                                EMIT1(add_1reg(0xC8, dst_reg));
 963                                break;
 964                        case 64:
 965                                /* Emit 'bswap rax' to swap 8 bytes */
 966                                EMIT3(add_1mod(0x48, dst_reg), 0x0F,
 967                                      add_1reg(0xC8, dst_reg));
 968                                break;
 969                        }
 970                        break;
 971
 972                case BPF_ALU | BPF_END | BPF_FROM_LE:
 973                        switch (imm32) {
 974                        case 16:
 975                                /*
 976                                 * Emit 'movzwl eax, ax' to zero extend 16-bit
 977                                 * into 64 bit
 978                                 */
 979                                if (is_ereg(dst_reg))
 980                                        EMIT3(0x45, 0x0F, 0xB7);
 981                                else
 982                                        EMIT2(0x0F, 0xB7);
 983                                EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
 984                                break;
 985                        case 32:
 986                                /* Emit 'mov eax, eax' to clear upper 32-bits */
 987                                if (is_ereg(dst_reg))
 988                                        EMIT1(0x45);
 989                                EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
 990                                break;
 991                        case 64:
 992                                /* nop */
 993                                break;
 994                        }
 995                        break;
 996
 997                        /* ST: *(u8*)(dst_reg + off) = imm */
 998                case BPF_ST | BPF_MEM | BPF_B:
 999                        if (is_ereg(dst_reg))
1000                                EMIT2(0x41, 0xC6);
1001                        else
1002                                EMIT1(0xC6);
1003                        goto st;
1004                case BPF_ST | BPF_MEM | BPF_H:
1005                        if (is_ereg(dst_reg))
1006                                EMIT3(0x66, 0x41, 0xC7);
1007                        else
1008                                EMIT2(0x66, 0xC7);
1009                        goto st;
1010                case BPF_ST | BPF_MEM | BPF_W:
1011                        if (is_ereg(dst_reg))
1012                                EMIT2(0x41, 0xC7);
1013                        else
1014                                EMIT1(0xC7);
1015                        goto st;
1016                case BPF_ST | BPF_MEM | BPF_DW:
1017                        EMIT2(add_1mod(0x48, dst_reg), 0xC7);
1018
1019st:                     if (is_imm8(insn->off))
1020                                EMIT2(add_1reg(0x40, dst_reg), insn->off);
1021                        else
1022                                EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
1023
1024                        EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
1025                        break;
1026
1027                        /* STX: *(u8*)(dst_reg + off) = src_reg */
1028                case BPF_STX | BPF_MEM | BPF_B:
1029                case BPF_STX | BPF_MEM | BPF_H:
1030                case BPF_STX | BPF_MEM | BPF_W:
1031                case BPF_STX | BPF_MEM | BPF_DW:
1032                        emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
1033                        break;
1034
1035                        /* LDX: dst_reg = *(u8*)(src_reg + off) */
1036                case BPF_LDX | BPF_MEM | BPF_B:
1037                case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1038                case BPF_LDX | BPF_MEM | BPF_H:
1039                case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1040                case BPF_LDX | BPF_MEM | BPF_W:
1041                case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1042                case BPF_LDX | BPF_MEM | BPF_DW:
1043                case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1044                        emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
1045                        if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
1046                                struct exception_table_entry *ex;
1047                                u8 *_insn = image + proglen;
1048                                s64 delta;
1049
1050                                if (!bpf_prog->aux->extable)
1051                                        break;
1052
1053                                if (excnt >= bpf_prog->aux->num_exentries) {
1054                                        pr_err("ex gen bug\n");
1055                                        return -EFAULT;
1056                                }
1057                                ex = &bpf_prog->aux->extable[excnt++];
1058
1059                                delta = _insn - (u8 *)&ex->insn;
1060                                if (!is_simm32(delta)) {
1061                                        pr_err("extable->insn doesn't fit into 32-bit\n");
1062                                        return -EFAULT;
1063                                }
1064                                ex->insn = delta;
1065
1066                                delta = (u8 *)ex_handler_bpf - (u8 *)&ex->handler;
1067                                if (!is_simm32(delta)) {
1068                                        pr_err("extable->handler doesn't fit into 32-bit\n");
1069                                        return -EFAULT;
1070                                }
1071                                ex->handler = delta;
1072
1073                                if (dst_reg > BPF_REG_9) {
1074                                        pr_err("verifier error\n");
1075                                        return -EFAULT;
1076                                }
1077                                /*
1078                                 * Compute size of x86 insn and its target dest x86 register.
1079                                 * ex_handler_bpf() will use lower 8 bits to adjust
1080                                 * pt_regs->ip to jump over this x86 instruction
1081                                 * and upper bits to figure out which pt_regs to zero out.
1082                                 * End result: x86 insn "mov rbx, qword ptr [rax+0x14]"
1083                                 * of 4 bytes will be ignored and rbx will be zero inited.
1084                                 */
1085                                ex->fixup = (prog - temp) | (reg2pt_regs[dst_reg] << 8);
1086                        }
1087                        break;
1088
1089                        /* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
1090                case BPF_STX | BPF_XADD | BPF_W:
1091                        /* Emit 'lock add dword ptr [rax + off], eax' */
1092                        if (is_ereg(dst_reg) || is_ereg(src_reg))
1093                                EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01);
1094                        else
1095                                EMIT2(0xF0, 0x01);
1096                        goto xadd;
1097                case BPF_STX | BPF_XADD | BPF_DW:
1098                        EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01);
1099xadd:                   if (is_imm8(insn->off))
1100                                EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
1101                        else
1102                                EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
1103                                            insn->off);
1104                        break;
1105
1106                        /* call */
1107                case BPF_JMP | BPF_CALL:
1108                        func = (u8 *) __bpf_call_base + imm32;
1109                        if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
1110                                return -EINVAL;
1111                        break;
1112
1113                case BPF_JMP | BPF_TAIL_CALL:
1114                        if (imm32)
1115                                emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
1116                                                          &prog, addrs[i], image);
1117                        else
1118                                emit_bpf_tail_call_indirect(&prog);
1119                        break;
1120
1121                        /* cond jump */
1122                case BPF_JMP | BPF_JEQ | BPF_X:
1123                case BPF_JMP | BPF_JNE | BPF_X:
1124                case BPF_JMP | BPF_JGT | BPF_X:
1125                case BPF_JMP | BPF_JLT | BPF_X:
1126                case BPF_JMP | BPF_JGE | BPF_X:
1127                case BPF_JMP | BPF_JLE | BPF_X:
1128                case BPF_JMP | BPF_JSGT | BPF_X:
1129                case BPF_JMP | BPF_JSLT | BPF_X:
1130                case BPF_JMP | BPF_JSGE | BPF_X:
1131                case BPF_JMP | BPF_JSLE | BPF_X:
1132                case BPF_JMP32 | BPF_JEQ | BPF_X:
1133                case BPF_JMP32 | BPF_JNE | BPF_X:
1134                case BPF_JMP32 | BPF_JGT | BPF_X:
1135                case BPF_JMP32 | BPF_JLT | BPF_X:
1136                case BPF_JMP32 | BPF_JGE | BPF_X:
1137                case BPF_JMP32 | BPF_JLE | BPF_X:
1138                case BPF_JMP32 | BPF_JSGT | BPF_X:
1139                case BPF_JMP32 | BPF_JSLT | BPF_X:
1140                case BPF_JMP32 | BPF_JSGE | BPF_X:
1141                case BPF_JMP32 | BPF_JSLE | BPF_X:
1142                        /* cmp dst_reg, src_reg */
1143                        if (BPF_CLASS(insn->code) == BPF_JMP)
1144                                EMIT1(add_2mod(0x48, dst_reg, src_reg));
1145                        else if (is_ereg(dst_reg) || is_ereg(src_reg))
1146                                EMIT1(add_2mod(0x40, dst_reg, src_reg));
1147                        EMIT2(0x39, add_2reg(0xC0, dst_reg, src_reg));
1148                        goto emit_cond_jmp;
1149
1150                case BPF_JMP | BPF_JSET | BPF_X:
1151                case BPF_JMP32 | BPF_JSET | BPF_X:
1152                        /* test dst_reg, src_reg */
1153                        if (BPF_CLASS(insn->code) == BPF_JMP)
1154                                EMIT1(add_2mod(0x48, dst_reg, src_reg));
1155                        else if (is_ereg(dst_reg) || is_ereg(src_reg))
1156                                EMIT1(add_2mod(0x40, dst_reg, src_reg));
1157                        EMIT2(0x85, add_2reg(0xC0, dst_reg, src_reg));
1158                        goto emit_cond_jmp;
1159
1160                case BPF_JMP | BPF_JSET | BPF_K:
1161                case BPF_JMP32 | BPF_JSET | BPF_K:
1162                        /* test dst_reg, imm32 */
1163                        if (BPF_CLASS(insn->code) == BPF_JMP)
1164                                EMIT1(add_1mod(0x48, dst_reg));
1165                        else if (is_ereg(dst_reg))
1166                                EMIT1(add_1mod(0x40, dst_reg));
1167                        EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
1168                        goto emit_cond_jmp;
1169
1170                case BPF_JMP | BPF_JEQ | BPF_K:
1171                case BPF_JMP | BPF_JNE | BPF_K:
1172                case BPF_JMP | BPF_JGT | BPF_K:
1173                case BPF_JMP | BPF_JLT | BPF_K:
1174                case BPF_JMP | BPF_JGE | BPF_K:
1175                case BPF_JMP | BPF_JLE | BPF_K:
1176                case BPF_JMP | BPF_JSGT | BPF_K:
1177                case BPF_JMP | BPF_JSLT | BPF_K:
1178                case BPF_JMP | BPF_JSGE | BPF_K:
1179                case BPF_JMP | BPF_JSLE | BPF_K:
1180                case BPF_JMP32 | BPF_JEQ | BPF_K:
1181                case BPF_JMP32 | BPF_JNE | BPF_K:
1182                case BPF_JMP32 | BPF_JGT | BPF_K:
1183                case BPF_JMP32 | BPF_JLT | BPF_K:
1184                case BPF_JMP32 | BPF_JGE | BPF_K:
1185                case BPF_JMP32 | BPF_JLE | BPF_K:
1186                case BPF_JMP32 | BPF_JSGT | BPF_K:
1187                case BPF_JMP32 | BPF_JSLT | BPF_K:
1188                case BPF_JMP32 | BPF_JSGE | BPF_K:
1189                case BPF_JMP32 | BPF_JSLE | BPF_K:
1190                        /* test dst_reg, dst_reg to save one extra byte */
1191                        if (imm32 == 0) {
1192                                if (BPF_CLASS(insn->code) == BPF_JMP)
1193                                        EMIT1(add_2mod(0x48, dst_reg, dst_reg));
1194                                else if (is_ereg(dst_reg))
1195                                        EMIT1(add_2mod(0x40, dst_reg, dst_reg));
1196                                EMIT2(0x85, add_2reg(0xC0, dst_reg, dst_reg));
1197                                goto emit_cond_jmp;
1198                        }
1199
1200                        /* cmp dst_reg, imm8/32 */
1201                        if (BPF_CLASS(insn->code) == BPF_JMP)
1202                                EMIT1(add_1mod(0x48, dst_reg));
1203                        else if (is_ereg(dst_reg))
1204                                EMIT1(add_1mod(0x40, dst_reg));
1205
1206                        if (is_imm8(imm32))
1207                                EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
1208                        else
1209                                EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
1210
1211emit_cond_jmp:          /* Convert BPF opcode to x86 */
1212                        switch (BPF_OP(insn->code)) {
1213                        case BPF_JEQ:
1214                                jmp_cond = X86_JE;
1215                                break;
1216                        case BPF_JSET:
1217                        case BPF_JNE:
1218                                jmp_cond = X86_JNE;
1219                                break;
1220                        case BPF_JGT:
1221                                /* GT is unsigned '>', JA in x86 */
1222                                jmp_cond = X86_JA;
1223                                break;
1224                        case BPF_JLT:
1225                                /* LT is unsigned '<', JB in x86 */
1226                                jmp_cond = X86_JB;
1227                                break;
1228                        case BPF_JGE:
1229                                /* GE is unsigned '>=', JAE in x86 */
1230                                jmp_cond = X86_JAE;
1231                                break;
1232                        case BPF_JLE:
1233                                /* LE is unsigned '<=', JBE in x86 */
1234                                jmp_cond = X86_JBE;
1235                                break;
1236                        case BPF_JSGT:
1237                                /* Signed '>', GT in x86 */
1238                                jmp_cond = X86_JG;
1239                                break;
1240                        case BPF_JSLT:
1241                                /* Signed '<', LT in x86 */
1242                                jmp_cond = X86_JL;
1243                                break;
1244                        case BPF_JSGE:
1245                                /* Signed '>=', GE in x86 */
1246                                jmp_cond = X86_JGE;
1247                                break;
1248                        case BPF_JSLE:
1249                                /* Signed '<=', LE in x86 */
1250                                jmp_cond = X86_JLE;
1251                                break;
1252                        default: /* to silence GCC warning */
1253                                return -EFAULT;
1254                        }
1255                        jmp_offset = addrs[i + insn->off] - addrs[i];
1256                        if (is_imm8(jmp_offset)) {
1257                                EMIT2(jmp_cond, jmp_offset);
1258                        } else if (is_simm32(jmp_offset)) {
1259                                EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
1260                        } else {
1261                                pr_err("cond_jmp gen bug %llx\n", jmp_offset);
1262                                return -EFAULT;
1263                        }
1264
1265                        break;
1266
1267                case BPF_JMP | BPF_JA:
1268                        if (insn->off == -1)
1269                                /* -1 jmp instructions will always jump
1270                                 * backwards two bytes. Explicitly handling
1271                                 * this case avoids wasting too many passes
1272                                 * when there are long sequences of replaced
1273                                 * dead code.
1274                                 */
1275                                jmp_offset = -2;
1276                        else
1277                                jmp_offset = addrs[i + insn->off] - addrs[i];
1278
1279                        if (!jmp_offset)
1280                                /* Optimize out nop jumps */
1281                                break;
1282emit_jmp:
1283                        if (is_imm8(jmp_offset)) {
1284                                EMIT2(0xEB, jmp_offset);
1285                        } else if (is_simm32(jmp_offset)) {
1286                                EMIT1_off32(0xE9, jmp_offset);
1287                        } else {
1288                                pr_err("jmp gen bug %llx\n", jmp_offset);
1289                                return -EFAULT;
1290                        }
1291                        break;
1292
1293                case BPF_JMP | BPF_EXIT:
1294                        if (seen_exit) {
1295                                jmp_offset = ctx->cleanup_addr - addrs[i];
1296                                goto emit_jmp;
1297                        }
1298                        seen_exit = true;
1299                        /* Update cleanup_addr */
1300                        ctx->cleanup_addr = proglen;
1301                        if (!bpf_prog_was_classic(bpf_prog))
1302                                EMIT1(0x5B); /* get rid of tail_call_cnt */
1303                        EMIT2(0x41, 0x5F);   /* pop r15 */
1304                        EMIT2(0x41, 0x5E);   /* pop r14 */
1305                        EMIT2(0x41, 0x5D);   /* pop r13 */
1306                        EMIT1(0x5B);         /* pop rbx */
1307                        EMIT1(0xC9);         /* leave */
1308                        EMIT1(0xC3);         /* ret */
1309                        break;
1310
1311                default:
1312                        /*
1313                         * By design x86-64 JIT should support all BPF instructions.
1314                         * This error will be seen if new instruction was added
1315                         * to the interpreter, but not to the JIT, or if there is
1316                         * junk in bpf_prog.
1317                         */
1318                        pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
1319                        return -EINVAL;
1320                }
1321
1322                ilen = prog - temp;
1323                if (ilen > BPF_MAX_INSN_SIZE) {
1324                        pr_err("bpf_jit: fatal insn size error\n");
1325                        return -EFAULT;
1326                }
1327
1328                if (image) {
1329                        if (unlikely(proglen + ilen > oldproglen)) {
1330                                pr_err("bpf_jit: fatal error\n");
1331                                return -EFAULT;
1332                        }
1333                        memcpy(image + proglen, temp, ilen);
1334                }
1335                proglen += ilen;
1336                addrs[i] = proglen;
1337                prog = temp;
1338        }
1339
1340        if (image && excnt != bpf_prog->aux->num_exentries) {
1341                pr_err("extable is not populated\n");
1342                return -EFAULT;
1343        }
1344        return proglen;
1345}
1346
1347static void save_regs(const struct btf_func_model *m, u8 **prog, int nr_args,
1348                      int stack_size)
1349{
1350        int i;
1351        /* Store function arguments to stack.
1352         * For a function that accepts two pointers the sequence will be:
1353         * mov QWORD PTR [rbp-0x10],rdi
1354         * mov QWORD PTR [rbp-0x8],rsi
1355         */
1356        for (i = 0; i < min(nr_args, 6); i++)
1357                emit_stx(prog, bytes_to_bpf_size(m->arg_size[i]),
1358                         BPF_REG_FP,
1359                         i == 5 ? X86_REG_R9 : BPF_REG_1 + i,
1360                         -(stack_size - i * 8));
1361}
1362
1363static void restore_regs(const struct btf_func_model *m, u8 **prog, int nr_args,
1364                         int stack_size)
1365{
1366        int i;
1367
1368        /* Restore function arguments from stack.
1369         * For a function that accepts two pointers the sequence will be:
1370         * EMIT4(0x48, 0x8B, 0x7D, 0xF0); mov rdi,QWORD PTR [rbp-0x10]
1371         * EMIT4(0x48, 0x8B, 0x75, 0xF8); mov rsi,QWORD PTR [rbp-0x8]
1372         */
1373        for (i = 0; i < min(nr_args, 6); i++)
1374                emit_ldx(prog, bytes_to_bpf_size(m->arg_size[i]),
1375                         i == 5 ? X86_REG_R9 : BPF_REG_1 + i,
1376                         BPF_REG_FP,
1377                         -(stack_size - i * 8));
1378}
1379
1380static int invoke_bpf_prog(const struct btf_func_model *m, u8 **pprog,
1381                           struct bpf_prog *p, int stack_size, bool mod_ret)
1382{
1383        u8 *prog = *pprog;
1384        int cnt = 0;
1385
1386        if (emit_call(&prog, __bpf_prog_enter, prog))
1387                return -EINVAL;
1388        /* remember prog start time returned by __bpf_prog_enter */
1389        emit_mov_reg(&prog, true, BPF_REG_6, BPF_REG_0);
1390
1391        /* arg1: lea rdi, [rbp - stack_size] */
1392        EMIT4(0x48, 0x8D, 0x7D, -stack_size);
1393        /* arg2: progs[i]->insnsi for interpreter */
1394        if (!p->jited)
1395                emit_mov_imm64(&prog, BPF_REG_2,
1396                               (long) p->insnsi >> 32,
1397                               (u32) (long) p->insnsi);
1398        /* call JITed bpf program or interpreter */
1399        if (emit_call(&prog, p->bpf_func, prog))
1400                return -EINVAL;
1401
1402        /* BPF_TRAMP_MODIFY_RETURN trampolines can modify the return
1403         * of the previous call which is then passed on the stack to
1404         * the next BPF program.
1405         */
1406        if (mod_ret)
1407                emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
1408
1409        /* arg1: mov rdi, progs[i] */
1410        emit_mov_imm64(&prog, BPF_REG_1, (long) p >> 32,
1411                       (u32) (long) p);
1412        /* arg2: mov rsi, rbx <- start time in nsec */
1413        emit_mov_reg(&prog, true, BPF_REG_2, BPF_REG_6);
1414        if (emit_call(&prog, __bpf_prog_exit, prog))
1415                return -EINVAL;
1416
1417        *pprog = prog;
1418        return 0;
1419}
1420
1421static void emit_nops(u8 **pprog, unsigned int len)
1422{
1423        unsigned int i, noplen;
1424        u8 *prog = *pprog;
1425        int cnt = 0;
1426
1427        while (len > 0) {
1428                noplen = len;
1429
1430                if (noplen > ASM_NOP_MAX)
1431                        noplen = ASM_NOP_MAX;
1432
1433                for (i = 0; i < noplen; i++)
1434                        EMIT1(ideal_nops[noplen][i]);
1435                len -= noplen;
1436        }
1437
1438        *pprog = prog;
1439}
1440
1441static void emit_align(u8 **pprog, u32 align)
1442{
1443        u8 *target, *prog = *pprog;
1444
1445        target = PTR_ALIGN(prog, align);
1446        if (target != prog)
1447                emit_nops(&prog, target - prog);
1448
1449        *pprog = prog;
1450}
1451
1452static int emit_cond_near_jump(u8 **pprog, void *func, void *ip, u8 jmp_cond)
1453{
1454        u8 *prog = *pprog;
1455        int cnt = 0;
1456        s64 offset;
1457
1458        offset = func - (ip + 2 + 4);
1459        if (!is_simm32(offset)) {
1460                pr_err("Target %p is out of range\n", func);
1461                return -EINVAL;
1462        }
1463        EMIT2_off32(0x0F, jmp_cond + 0x10, offset);
1464        *pprog = prog;
1465        return 0;
1466}
1467
1468static int invoke_bpf(const struct btf_func_model *m, u8 **pprog,
1469                      struct bpf_tramp_progs *tp, int stack_size)
1470{
1471        int i;
1472        u8 *prog = *pprog;
1473
1474        for (i = 0; i < tp->nr_progs; i++) {
1475                if (invoke_bpf_prog(m, &prog, tp->progs[i], stack_size, false))
1476                        return -EINVAL;
1477        }
1478        *pprog = prog;
1479        return 0;
1480}
1481
1482static int invoke_bpf_mod_ret(const struct btf_func_model *m, u8 **pprog,
1483                              struct bpf_tramp_progs *tp, int stack_size,
1484                              u8 **branches)
1485{
1486        u8 *prog = *pprog;
1487        int i, cnt = 0;
1488
1489        /* The first fmod_ret program will receive a garbage return value.
1490         * Set this to 0 to avoid confusing the program.
1491         */
1492        emit_mov_imm32(&prog, false, BPF_REG_0, 0);
1493        emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
1494        for (i = 0; i < tp->nr_progs; i++) {
1495                if (invoke_bpf_prog(m, &prog, tp->progs[i], stack_size, true))
1496                        return -EINVAL;
1497
1498                /* mod_ret prog stored return value into [rbp - 8]. Emit:
1499                 * if (*(u64 *)(rbp - 8) !=  0)
1500                 *      goto do_fexit;
1501                 */
1502                /* cmp QWORD PTR [rbp - 0x8], 0x0 */
1503                EMIT4(0x48, 0x83, 0x7d, 0xf8); EMIT1(0x00);
1504
1505                /* Save the location of the branch and Generate 6 nops
1506                 * (4 bytes for an offset and 2 bytes for the jump) These nops
1507                 * are replaced with a conditional jump once do_fexit (i.e. the
1508                 * start of the fexit invocation) is finalized.
1509                 */
1510                branches[i] = prog;
1511                emit_nops(&prog, 4 + 2);
1512        }
1513
1514        *pprog = prog;
1515        return 0;
1516}
1517
1518/* Example:
1519 * __be16 eth_type_trans(struct sk_buff *skb, struct net_device *dev);
1520 * its 'struct btf_func_model' will be nr_args=2
1521 * The assembly code when eth_type_trans is executing after trampoline:
1522 *
1523 * push rbp
1524 * mov rbp, rsp
1525 * sub rsp, 16                     // space for skb and dev
1526 * push rbx                        // temp regs to pass start time
1527 * mov qword ptr [rbp - 16], rdi   // save skb pointer to stack
1528 * mov qword ptr [rbp - 8], rsi    // save dev pointer to stack
1529 * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
1530 * mov rbx, rax                    // remember start time in bpf stats are enabled
1531 * lea rdi, [rbp - 16]             // R1==ctx of bpf prog
1532 * call addr_of_jited_FENTRY_prog
1533 * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
1534 * mov rsi, rbx                    // prog start time
1535 * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
1536 * mov rdi, qword ptr [rbp - 16]   // restore skb pointer from stack
1537 * mov rsi, qword ptr [rbp - 8]    // restore dev pointer from stack
1538 * pop rbx
1539 * leave
1540 * ret
1541 *
1542 * eth_type_trans has 5 byte nop at the beginning. These 5 bytes will be
1543 * replaced with 'call generated_bpf_trampoline'. When it returns
1544 * eth_type_trans will continue executing with original skb and dev pointers.
1545 *
1546 * The assembly code when eth_type_trans is called from trampoline:
1547 *
1548 * push rbp
1549 * mov rbp, rsp
1550 * sub rsp, 24                     // space for skb, dev, return value
1551 * push rbx                        // temp regs to pass start time
1552 * mov qword ptr [rbp - 24], rdi   // save skb pointer to stack
1553 * mov qword ptr [rbp - 16], rsi   // save dev pointer to stack
1554 * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
1555 * mov rbx, rax                    // remember start time if bpf stats are enabled
1556 * lea rdi, [rbp - 24]             // R1==ctx of bpf prog
1557 * call addr_of_jited_FENTRY_prog  // bpf prog can access skb and dev
1558 * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
1559 * mov rsi, rbx                    // prog start time
1560 * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
1561 * mov rdi, qword ptr [rbp - 24]   // restore skb pointer from stack
1562 * mov rsi, qword ptr [rbp - 16]   // restore dev pointer from stack
1563 * call eth_type_trans+5           // execute body of eth_type_trans
1564 * mov qword ptr [rbp - 8], rax    // save return value
1565 * call __bpf_prog_enter           // rcu_read_lock and preempt_disable
1566 * mov rbx, rax                    // remember start time in bpf stats are enabled
1567 * lea rdi, [rbp - 24]             // R1==ctx of bpf prog
1568 * call addr_of_jited_FEXIT_prog   // bpf prog can access skb, dev, return value
1569 * movabsq rdi, 64bit_addr_of_struct_bpf_prog  // unused if bpf stats are off
1570 * mov rsi, rbx                    // prog start time
1571 * call __bpf_prog_exit            // rcu_read_unlock, preempt_enable and stats math
1572 * mov rax, qword ptr [rbp - 8]    // restore eth_type_trans's return value
1573 * pop rbx
1574 * leave
1575 * add rsp, 8                      // skip eth_type_trans's frame
1576 * ret                             // return to its caller
1577 */
1578int arch_prepare_bpf_trampoline(void *image, void *image_end,
1579                                const struct btf_func_model *m, u32 flags,
1580                                struct bpf_tramp_progs *tprogs,
1581                                void *orig_call)
1582{
1583        int ret, i, cnt = 0, nr_args = m->nr_args;
1584        int stack_size = nr_args * 8;
1585        struct bpf_tramp_progs *fentry = &tprogs[BPF_TRAMP_FENTRY];
1586        struct bpf_tramp_progs *fexit = &tprogs[BPF_TRAMP_FEXIT];
1587        struct bpf_tramp_progs *fmod_ret = &tprogs[BPF_TRAMP_MODIFY_RETURN];
1588        u8 **branches = NULL;
1589        u8 *prog;
1590
1591        /* x86-64 supports up to 6 arguments. 7+ can be added in the future */
1592        if (nr_args > 6)
1593                return -ENOTSUPP;
1594
1595        if ((flags & BPF_TRAMP_F_RESTORE_REGS) &&
1596            (flags & BPF_TRAMP_F_SKIP_FRAME))
1597                return -EINVAL;
1598
1599        if (flags & BPF_TRAMP_F_CALL_ORIG)
1600                stack_size += 8; /* room for return value of orig_call */
1601
1602        if (flags & BPF_TRAMP_F_SKIP_FRAME)
1603                /* skip patched call instruction and point orig_call to actual
1604                 * body of the kernel function.
1605                 */
1606                orig_call += X86_PATCH_SIZE;
1607
1608        prog = image;
1609
1610        EMIT1(0x55);             /* push rbp */
1611        EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
1612        EMIT4(0x48, 0x83, 0xEC, stack_size); /* sub rsp, stack_size */
1613        EMIT1(0x53);             /* push rbx */
1614
1615        save_regs(m, &prog, nr_args, stack_size);
1616
1617        if (fentry->nr_progs)
1618                if (invoke_bpf(m, &prog, fentry, stack_size))
1619                        return -EINVAL;
1620
1621        if (fmod_ret->nr_progs) {
1622                branches = kcalloc(fmod_ret->nr_progs, sizeof(u8 *),
1623                                   GFP_KERNEL);
1624                if (!branches)
1625                        return -ENOMEM;
1626
1627                if (invoke_bpf_mod_ret(m, &prog, fmod_ret, stack_size,
1628                                       branches)) {
1629                        ret = -EINVAL;
1630                        goto cleanup;
1631                }
1632        }
1633
1634        if (flags & BPF_TRAMP_F_CALL_ORIG) {
1635                if (fentry->nr_progs || fmod_ret->nr_progs)
1636                        restore_regs(m, &prog, nr_args, stack_size);
1637
1638                /* call original function */
1639                if (emit_call(&prog, orig_call, prog)) {
1640                        ret = -EINVAL;
1641                        goto cleanup;
1642                }
1643                /* remember return value in a stack for bpf prog to access */
1644                emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_0, -8);
1645        }
1646
1647        if (fmod_ret->nr_progs) {
1648                /* From Intel 64 and IA-32 Architectures Optimization
1649                 * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
1650                 * Coding Rule 11: All branch targets should be 16-byte
1651                 * aligned.
1652                 */
1653                emit_align(&prog, 16);
1654                /* Update the branches saved in invoke_bpf_mod_ret with the
1655                 * aligned address of do_fexit.
1656                 */
1657                for (i = 0; i < fmod_ret->nr_progs; i++)
1658                        emit_cond_near_jump(&branches[i], prog, branches[i],
1659                                            X86_JNE);
1660        }
1661
1662        if (fexit->nr_progs)
1663                if (invoke_bpf(m, &prog, fexit, stack_size)) {
1664                        ret = -EINVAL;
1665                        goto cleanup;
1666                }
1667
1668        if (flags & BPF_TRAMP_F_RESTORE_REGS)
1669                restore_regs(m, &prog, nr_args, stack_size);
1670
1671        /* This needs to be done regardless. If there were fmod_ret programs,
1672         * the return value is only updated on the stack and still needs to be
1673         * restored to R0.
1674         */
1675        if (flags & BPF_TRAMP_F_CALL_ORIG)
1676                /* restore original return value back into RAX */
1677                emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, -8);
1678
1679        EMIT1(0x5B); /* pop rbx */
1680        EMIT1(0xC9); /* leave */
1681        if (flags & BPF_TRAMP_F_SKIP_FRAME)
1682                /* skip our return address and return to parent */
1683                EMIT4(0x48, 0x83, 0xC4, 8); /* add rsp, 8 */
1684        EMIT1(0xC3); /* ret */
1685        /* Make sure the trampoline generation logic doesn't overflow */
1686        if (WARN_ON_ONCE(prog > (u8 *)image_end - BPF_INSN_SAFETY)) {
1687                ret = -EFAULT;
1688                goto cleanup;
1689        }
1690        ret = prog - (u8 *)image;
1691
1692cleanup:
1693        kfree(branches);
1694        return ret;
1695}
1696
1697static int emit_fallback_jump(u8 **pprog)
1698{
1699        u8 *prog = *pprog;
1700        int err = 0;
1701
1702#ifdef CONFIG_RETPOLINE
1703        /* Note that this assumes the the compiler uses external
1704         * thunks for indirect calls. Both clang and GCC use the same
1705         * naming convention for external thunks.
1706         */
1707        err = emit_jump(&prog, __x86_indirect_thunk_rdx, prog);
1708#else
1709        int cnt = 0;
1710
1711        EMIT2(0xFF, 0xE2);      /* jmp rdx */
1712#endif
1713        *pprog = prog;
1714        return err;
1715}
1716
1717static int emit_bpf_dispatcher(u8 **pprog, int a, int b, s64 *progs)
1718{
1719        u8 *jg_reloc, *prog = *pprog;
1720        int pivot, err, jg_bytes = 1, cnt = 0;
1721        s64 jg_offset;
1722
1723        if (a == b) {
1724                /* Leaf node of recursion, i.e. not a range of indices
1725                 * anymore.
1726                 */
1727                EMIT1(add_1mod(0x48, BPF_REG_3));       /* cmp rdx,func */
1728                if (!is_simm32(progs[a]))
1729                        return -1;
1730                EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3),
1731                            progs[a]);
1732                err = emit_cond_near_jump(&prog,        /* je func */
1733                                          (void *)progs[a], prog,
1734                                          X86_JE);
1735                if (err)
1736                        return err;
1737
1738                err = emit_fallback_jump(&prog);        /* jmp thunk/indirect */
1739                if (err)
1740                        return err;
1741
1742                *pprog = prog;
1743                return 0;
1744        }
1745
1746        /* Not a leaf node, so we pivot, and recursively descend into
1747         * the lower and upper ranges.
1748         */
1749        pivot = (b - a) / 2;
1750        EMIT1(add_1mod(0x48, BPF_REG_3));               /* cmp rdx,func */
1751        if (!is_simm32(progs[a + pivot]))
1752                return -1;
1753        EMIT2_off32(0x81, add_1reg(0xF8, BPF_REG_3), progs[a + pivot]);
1754
1755        if (pivot > 2) {                                /* jg upper_part */
1756                /* Require near jump. */
1757                jg_bytes = 4;
1758                EMIT2_off32(0x0F, X86_JG + 0x10, 0);
1759        } else {
1760                EMIT2(X86_JG, 0);
1761        }
1762        jg_reloc = prog;
1763
1764        err = emit_bpf_dispatcher(&prog, a, a + pivot,  /* emit lower_part */
1765                                  progs);
1766        if (err)
1767                return err;
1768
1769        /* From Intel 64 and IA-32 Architectures Optimization
1770         * Reference Manual, 3.4.1.4 Code Alignment, Assembly/Compiler
1771         * Coding Rule 11: All branch targets should be 16-byte
1772         * aligned.
1773         */
1774        emit_align(&prog, 16);
1775        jg_offset = prog - jg_reloc;
1776        emit_code(jg_reloc - jg_bytes, jg_offset, jg_bytes);
1777
1778        err = emit_bpf_dispatcher(&prog, a + pivot + 1, /* emit upper_part */
1779                                  b, progs);
1780        if (err)
1781                return err;
1782
1783        *pprog = prog;
1784        return 0;
1785}
1786
1787static int cmp_ips(const void *a, const void *b)
1788{
1789        const s64 *ipa = a;
1790        const s64 *ipb = b;
1791
1792        if (*ipa > *ipb)
1793                return 1;
1794        if (*ipa < *ipb)
1795                return -1;
1796        return 0;
1797}
1798
1799int arch_prepare_bpf_dispatcher(void *image, s64 *funcs, int num_funcs)
1800{
1801        u8 *prog = image;
1802
1803        sort(funcs, num_funcs, sizeof(funcs[0]), cmp_ips, NULL);
1804        return emit_bpf_dispatcher(&prog, 0, num_funcs - 1, funcs);
1805}
1806
1807struct x64_jit_data {
1808        struct bpf_binary_header *header;
1809        int *addrs;
1810        u8 *image;
1811        int proglen;
1812        struct jit_context ctx;
1813};
1814
1815struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1816{
1817        struct bpf_binary_header *header = NULL;
1818        struct bpf_prog *tmp, *orig_prog = prog;
1819        struct x64_jit_data *jit_data;
1820        int proglen, oldproglen = 0;
1821        struct jit_context ctx = {};
1822        bool tmp_blinded = false;
1823        bool extra_pass = false;
1824        u8 *image = NULL;
1825        int *addrs;
1826        int pass;
1827        int i;
1828
1829        if (!prog->jit_requested)
1830                return orig_prog;
1831
1832        tmp = bpf_jit_blind_constants(prog);
1833        /*
1834         * If blinding was requested and we failed during blinding,
1835         * we must fall back to the interpreter.
1836         */
1837        if (IS_ERR(tmp))
1838                return orig_prog;
1839        if (tmp != prog) {
1840                tmp_blinded = true;
1841                prog = tmp;
1842        }
1843
1844        jit_data = prog->aux->jit_data;
1845        if (!jit_data) {
1846                jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1847                if (!jit_data) {
1848                        prog = orig_prog;
1849                        goto out;
1850                }
1851                prog->aux->jit_data = jit_data;
1852        }
1853        addrs = jit_data->addrs;
1854        if (addrs) {
1855                ctx = jit_data->ctx;
1856                oldproglen = jit_data->proglen;
1857                image = jit_data->image;
1858                header = jit_data->header;
1859                extra_pass = true;
1860                goto skip_init_addrs;
1861        }
1862        addrs = kmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL);
1863        if (!addrs) {
1864                prog = orig_prog;
1865                goto out_addrs;
1866        }
1867
1868        /*
1869         * Before first pass, make a rough estimation of addrs[]
1870         * each BPF instruction is translated to less than 64 bytes
1871         */
1872        for (proglen = 0, i = 0; i <= prog->len; i++) {
1873                proglen += 64;
1874                addrs[i] = proglen;
1875        }
1876        ctx.cleanup_addr = proglen;
1877skip_init_addrs:
1878
1879        /*
1880         * JITed image shrinks with every pass and the loop iterates
1881         * until the image stops shrinking. Very large BPF programs
1882         * may converge on the last pass. In such case do one more
1883         * pass to emit the final image.
1884         */
1885        for (pass = 0; pass < 20 || image; pass++) {
1886                proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
1887                if (proglen <= 0) {
1888out_image:
1889                        image = NULL;
1890                        if (header)
1891                                bpf_jit_binary_free(header);
1892                        prog = orig_prog;
1893                        goto out_addrs;
1894                }
1895                if (image) {
1896                        if (proglen != oldproglen) {
1897                                pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
1898                                       proglen, oldproglen);
1899                                goto out_image;
1900                        }
1901                        break;
1902                }
1903                if (proglen == oldproglen) {
1904                        /*
1905                         * The number of entries in extable is the number of BPF_LDX
1906                         * insns that access kernel memory via "pointer to BTF type".
1907                         * The verifier changed their opcode from LDX|MEM|size
1908                         * to LDX|PROBE_MEM|size to make JITing easier.
1909                         */
1910                        u32 align = __alignof__(struct exception_table_entry);
1911                        u32 extable_size = prog->aux->num_exentries *
1912                                sizeof(struct exception_table_entry);
1913
1914                        /* allocate module memory for x86 insns and extable */
1915                        header = bpf_jit_binary_alloc(roundup(proglen, align) + extable_size,
1916                                                      &image, align, jit_fill_hole);
1917                        if (!header) {
1918                                prog = orig_prog;
1919                                goto out_addrs;
1920                        }
1921                        prog->aux->extable = (void *) image + roundup(proglen, align);
1922                }
1923                oldproglen = proglen;
1924                cond_resched();
1925        }
1926
1927        if (bpf_jit_enable > 1)
1928                bpf_jit_dump(prog->len, proglen, pass + 1, image);
1929
1930        if (image) {
1931                if (!prog->is_func || extra_pass) {
1932                        bpf_tail_call_direct_fixup(prog);
1933                        bpf_jit_binary_lock_ro(header);
1934                } else {
1935                        jit_data->addrs = addrs;
1936                        jit_data->ctx = ctx;
1937                        jit_data->proglen = proglen;
1938                        jit_data->image = image;
1939                        jit_data->header = header;
1940                }
1941                prog->bpf_func = (void *)image;
1942                prog->jited = 1;
1943                prog->jited_len = proglen;
1944        } else {
1945                prog = orig_prog;
1946        }
1947
1948        if (!image || !prog->is_func || extra_pass) {
1949                if (image)
1950                        bpf_prog_fill_jited_linfo(prog, addrs + 1);
1951out_addrs:
1952                kfree(addrs);
1953                kfree(jit_data);
1954                prog->aux->jit_data = NULL;
1955        }
1956out:
1957        if (tmp_blinded)
1958                bpf_jit_prog_release_other(prog, prog == orig_prog ?
1959                                           tmp : orig_prog);
1960        return prog;
1961}
1962