linux/arch/riscv/net/bpf_jit_comp32.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/*
   3 * BPF JIT compiler for RV32G
   4 *
   5 * Copyright (c) 2020 Luke Nelson <luke.r.nels@gmail.com>
   6 * Copyright (c) 2020 Xi Wang <xi.wang@gmail.com>
   7 *
   8 * The code is based on the BPF JIT compiler for RV64G by Björn Töpel and
   9 * the BPF JIT compiler for 32-bit ARM by Shubham Bansal and Mircea Gherzan.
  10 */
  11
  12#include <linux/bpf.h>
  13#include <linux/filter.h>
  14#include "bpf_jit.h"
  15
  16/*
  17 * Stack layout during BPF program execution:
  18 *
  19 *                     high
  20 *     RV32 fp =>  +----------+
  21 *                 | saved ra |
  22 *                 | saved fp | RV32 callee-saved registers
  23 *                 |   ...    |
  24 *                 +----------+ <= (fp - 4 * NR_SAVED_REGISTERS)
  25 *                 |  hi(R6)  |
  26 *                 |  lo(R6)  |
  27 *                 |  hi(R7)  | JIT scratch space for BPF registers
  28 *                 |  lo(R7)  |
  29 *                 |   ...    |
  30 *  BPF_REG_FP =>  +----------+ <= (fp - 4 * NR_SAVED_REGISTERS
  31 *                 |          |        - 4 * BPF_JIT_SCRATCH_REGS)
  32 *                 |          |
  33 *                 |   ...    | BPF program stack
  34 *                 |          |
  35 *     RV32 sp =>  +----------+
  36 *                 |          |
  37 *                 |   ...    | Function call stack
  38 *                 |          |
  39 *                 +----------+
  40 *                     low
  41 */
  42
  43enum {
  44        /* Stack layout - these are offsets from top of JIT scratch space. */
  45        BPF_R6_HI,
  46        BPF_R6_LO,
  47        BPF_R7_HI,
  48        BPF_R7_LO,
  49        BPF_R8_HI,
  50        BPF_R8_LO,
  51        BPF_R9_HI,
  52        BPF_R9_LO,
  53        BPF_AX_HI,
  54        BPF_AX_LO,
  55        /* Stack space for BPF_REG_6 through BPF_REG_9 and BPF_REG_AX. */
  56        BPF_JIT_SCRATCH_REGS,
  57};
  58
  59/* Number of callee-saved registers stored to stack: ra, fp, s1--s7. */
  60#define NR_SAVED_REGISTERS      9
  61
  62/* Offset from fp for BPF registers stored on stack. */
  63#define STACK_OFFSET(k) (-4 - (4 * NR_SAVED_REGISTERS) - (4 * (k)))
  64
  65#define TMP_REG_1       (MAX_BPF_JIT_REG + 0)
  66#define TMP_REG_2       (MAX_BPF_JIT_REG + 1)
  67
  68#define RV_REG_TCC              RV_REG_T6
  69#define RV_REG_TCC_SAVED        RV_REG_S7
  70
  71static const s8 bpf2rv32[][2] = {
  72        /* Return value from in-kernel function, and exit value from eBPF. */
  73        [BPF_REG_0] = {RV_REG_S2, RV_REG_S1},
  74        /* Arguments from eBPF program to in-kernel function. */
  75        [BPF_REG_1] = {RV_REG_A1, RV_REG_A0},
  76        [BPF_REG_2] = {RV_REG_A3, RV_REG_A2},
  77        [BPF_REG_3] = {RV_REG_A5, RV_REG_A4},
  78        [BPF_REG_4] = {RV_REG_A7, RV_REG_A6},
  79        [BPF_REG_5] = {RV_REG_S4, RV_REG_S3},
  80        /*
  81         * Callee-saved registers that in-kernel function will preserve.
  82         * Stored on the stack.
  83         */
  84        [BPF_REG_6] = {STACK_OFFSET(BPF_R6_HI), STACK_OFFSET(BPF_R6_LO)},
  85        [BPF_REG_7] = {STACK_OFFSET(BPF_R7_HI), STACK_OFFSET(BPF_R7_LO)},
  86        [BPF_REG_8] = {STACK_OFFSET(BPF_R8_HI), STACK_OFFSET(BPF_R8_LO)},
  87        [BPF_REG_9] = {STACK_OFFSET(BPF_R9_HI), STACK_OFFSET(BPF_R9_LO)},
  88        /* Read-only frame pointer to access BPF stack. */
  89        [BPF_REG_FP] = {RV_REG_S6, RV_REG_S5},
  90        /* Temporary register for blinding constants. Stored on the stack. */
  91        [BPF_REG_AX] = {STACK_OFFSET(BPF_AX_HI), STACK_OFFSET(BPF_AX_LO)},
  92        /*
  93         * Temporary registers used by the JIT to operate on registers stored
  94         * on the stack. Save t0 and t1 to be used as temporaries in generated
  95         * code.
  96         */
  97        [TMP_REG_1] = {RV_REG_T3, RV_REG_T2},
  98        [TMP_REG_2] = {RV_REG_T5, RV_REG_T4},
  99};
 100
 101static s8 hi(const s8 *r)
 102{
 103        return r[0];
 104}
 105
 106static s8 lo(const s8 *r)
 107{
 108        return r[1];
 109}
 110
 111static void emit_imm(const s8 rd, s32 imm, struct rv_jit_context *ctx)
 112{
 113        u32 upper = (imm + (1 << 11)) >> 12;
 114        u32 lower = imm & 0xfff;
 115
 116        if (upper) {
 117                emit(rv_lui(rd, upper), ctx);
 118                emit(rv_addi(rd, rd, lower), ctx);
 119        } else {
 120                emit(rv_addi(rd, RV_REG_ZERO, lower), ctx);
 121        }
 122}
 123
 124static void emit_imm32(const s8 *rd, s32 imm, struct rv_jit_context *ctx)
 125{
 126        /* Emit immediate into lower bits. */
 127        emit_imm(lo(rd), imm, ctx);
 128
 129        /* Sign-extend into upper bits. */
 130        if (imm >= 0)
 131                emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
 132        else
 133                emit(rv_addi(hi(rd), RV_REG_ZERO, -1), ctx);
 134}
 135
 136static void emit_imm64(const s8 *rd, s32 imm_hi, s32 imm_lo,
 137                       struct rv_jit_context *ctx)
 138{
 139        emit_imm(lo(rd), imm_lo, ctx);
 140        emit_imm(hi(rd), imm_hi, ctx);
 141}
 142
 143static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
 144{
 145        int stack_adjust = ctx->stack_size;
 146        const s8 *r0 = bpf2rv32[BPF_REG_0];
 147
 148        /* Set return value if not tail call. */
 149        if (!is_tail_call) {
 150                emit(rv_addi(RV_REG_A0, lo(r0), 0), ctx);
 151                emit(rv_addi(RV_REG_A1, hi(r0), 0), ctx);
 152        }
 153
 154        /* Restore callee-saved registers. */
 155        emit(rv_lw(RV_REG_RA, stack_adjust - 4, RV_REG_SP), ctx);
 156        emit(rv_lw(RV_REG_FP, stack_adjust - 8, RV_REG_SP), ctx);
 157        emit(rv_lw(RV_REG_S1, stack_adjust - 12, RV_REG_SP), ctx);
 158        emit(rv_lw(RV_REG_S2, stack_adjust - 16, RV_REG_SP), ctx);
 159        emit(rv_lw(RV_REG_S3, stack_adjust - 20, RV_REG_SP), ctx);
 160        emit(rv_lw(RV_REG_S4, stack_adjust - 24, RV_REG_SP), ctx);
 161        emit(rv_lw(RV_REG_S5, stack_adjust - 28, RV_REG_SP), ctx);
 162        emit(rv_lw(RV_REG_S6, stack_adjust - 32, RV_REG_SP), ctx);
 163        emit(rv_lw(RV_REG_S7, stack_adjust - 36, RV_REG_SP), ctx);
 164
 165        emit(rv_addi(RV_REG_SP, RV_REG_SP, stack_adjust), ctx);
 166
 167        if (is_tail_call) {
 168                /*
 169                 * goto *(t0 + 4);
 170                 * Skips first instruction of prologue which initializes tail
 171                 * call counter. Assumes t0 contains address of target program,
 172                 * see emit_bpf_tail_call.
 173                 */
 174                emit(rv_jalr(RV_REG_ZERO, RV_REG_T0, 4), ctx);
 175        } else {
 176                emit(rv_jalr(RV_REG_ZERO, RV_REG_RA, 0), ctx);
 177        }
 178}
 179
 180static bool is_stacked(s8 reg)
 181{
 182        return reg < 0;
 183}
 184
 185static const s8 *bpf_get_reg64(const s8 *reg, const s8 *tmp,
 186                               struct rv_jit_context *ctx)
 187{
 188        if (is_stacked(hi(reg))) {
 189                emit(rv_lw(hi(tmp), hi(reg), RV_REG_FP), ctx);
 190                emit(rv_lw(lo(tmp), lo(reg), RV_REG_FP), ctx);
 191                reg = tmp;
 192        }
 193        return reg;
 194}
 195
 196static void bpf_put_reg64(const s8 *reg, const s8 *src,
 197                          struct rv_jit_context *ctx)
 198{
 199        if (is_stacked(hi(reg))) {
 200                emit(rv_sw(RV_REG_FP, hi(reg), hi(src)), ctx);
 201                emit(rv_sw(RV_REG_FP, lo(reg), lo(src)), ctx);
 202        }
 203}
 204
 205static const s8 *bpf_get_reg32(const s8 *reg, const s8 *tmp,
 206                               struct rv_jit_context *ctx)
 207{
 208        if (is_stacked(lo(reg))) {
 209                emit(rv_lw(lo(tmp), lo(reg), RV_REG_FP), ctx);
 210                reg = tmp;
 211        }
 212        return reg;
 213}
 214
 215static void bpf_put_reg32(const s8 *reg, const s8 *src,
 216                          struct rv_jit_context *ctx)
 217{
 218        if (is_stacked(lo(reg))) {
 219                emit(rv_sw(RV_REG_FP, lo(reg), lo(src)), ctx);
 220                if (!ctx->prog->aux->verifier_zext)
 221                        emit(rv_sw(RV_REG_FP, hi(reg), RV_REG_ZERO), ctx);
 222        } else if (!ctx->prog->aux->verifier_zext) {
 223                emit(rv_addi(hi(reg), RV_REG_ZERO, 0), ctx);
 224        }
 225}
 226
 227static void emit_jump_and_link(u8 rd, s32 rvoff, bool force_jalr,
 228                               struct rv_jit_context *ctx)
 229{
 230        s32 upper, lower;
 231
 232        if (rvoff && is_21b_int(rvoff) && !force_jalr) {
 233                emit(rv_jal(rd, rvoff >> 1), ctx);
 234                return;
 235        }
 236
 237        upper = (rvoff + (1 << 11)) >> 12;
 238        lower = rvoff & 0xfff;
 239        emit(rv_auipc(RV_REG_T1, upper), ctx);
 240        emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
 241}
 242
 243static void emit_alu_i64(const s8 *dst, s32 imm,
 244                         struct rv_jit_context *ctx, const u8 op)
 245{
 246        const s8 *tmp1 = bpf2rv32[TMP_REG_1];
 247        const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
 248
 249        switch (op) {
 250        case BPF_MOV:
 251                emit_imm32(rd, imm, ctx);
 252                break;
 253        case BPF_AND:
 254                if (is_12b_int(imm)) {
 255                        emit(rv_andi(lo(rd), lo(rd), imm), ctx);
 256                } else {
 257                        emit_imm(RV_REG_T0, imm, ctx);
 258                        emit(rv_and(lo(rd), lo(rd), RV_REG_T0), ctx);
 259                }
 260                if (imm >= 0)
 261                        emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
 262                break;
 263        case BPF_OR:
 264                if (is_12b_int(imm)) {
 265                        emit(rv_ori(lo(rd), lo(rd), imm), ctx);
 266                } else {
 267                        emit_imm(RV_REG_T0, imm, ctx);
 268                        emit(rv_or(lo(rd), lo(rd), RV_REG_T0), ctx);
 269                }
 270                if (imm < 0)
 271                        emit(rv_ori(hi(rd), RV_REG_ZERO, -1), ctx);
 272                break;
 273        case BPF_XOR:
 274                if (is_12b_int(imm)) {
 275                        emit(rv_xori(lo(rd), lo(rd), imm), ctx);
 276                } else {
 277                        emit_imm(RV_REG_T0, imm, ctx);
 278                        emit(rv_xor(lo(rd), lo(rd), RV_REG_T0), ctx);
 279                }
 280                if (imm < 0)
 281                        emit(rv_xori(hi(rd), hi(rd), -1), ctx);
 282                break;
 283        case BPF_LSH:
 284                if (imm >= 32) {
 285                        emit(rv_slli(hi(rd), lo(rd), imm - 32), ctx);
 286                        emit(rv_addi(lo(rd), RV_REG_ZERO, 0), ctx);
 287                } else if (imm == 0) {
 288                        /* Do nothing. */
 289                } else {
 290                        emit(rv_srli(RV_REG_T0, lo(rd), 32 - imm), ctx);
 291                        emit(rv_slli(hi(rd), hi(rd), imm), ctx);
 292                        emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
 293                        emit(rv_slli(lo(rd), lo(rd), imm), ctx);
 294                }
 295                break;
 296        case BPF_RSH:
 297                if (imm >= 32) {
 298                        emit(rv_srli(lo(rd), hi(rd), imm - 32), ctx);
 299                        emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
 300                } else if (imm == 0) {
 301                        /* Do nothing. */
 302                } else {
 303                        emit(rv_slli(RV_REG_T0, hi(rd), 32 - imm), ctx);
 304                        emit(rv_srli(lo(rd), lo(rd), imm), ctx);
 305                        emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
 306                        emit(rv_srli(hi(rd), hi(rd), imm), ctx);
 307                }
 308                break;
 309        case BPF_ARSH:
 310                if (imm >= 32) {
 311                        emit(rv_srai(lo(rd), hi(rd), imm - 32), ctx);
 312                        emit(rv_srai(hi(rd), hi(rd), 31), ctx);
 313                } else if (imm == 0) {
 314                        /* Do nothing. */
 315                } else {
 316                        emit(rv_slli(RV_REG_T0, hi(rd), 32 - imm), ctx);
 317                        emit(rv_srli(lo(rd), lo(rd), imm), ctx);
 318                        emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
 319                        emit(rv_srai(hi(rd), hi(rd), imm), ctx);
 320                }
 321                break;
 322        }
 323
 324        bpf_put_reg64(dst, rd, ctx);
 325}
 326
 327static void emit_alu_i32(const s8 *dst, s32 imm,
 328                         struct rv_jit_context *ctx, const u8 op)
 329{
 330        const s8 *tmp1 = bpf2rv32[TMP_REG_1];
 331        const s8 *rd = bpf_get_reg32(dst, tmp1, ctx);
 332
 333        switch (op) {
 334        case BPF_MOV:
 335                emit_imm(lo(rd), imm, ctx);
 336                break;
 337        case BPF_ADD:
 338                if (is_12b_int(imm)) {
 339                        emit(rv_addi(lo(rd), lo(rd), imm), ctx);
 340                } else {
 341                        emit_imm(RV_REG_T0, imm, ctx);
 342                        emit(rv_add(lo(rd), lo(rd), RV_REG_T0), ctx);
 343                }
 344                break;
 345        case BPF_SUB:
 346                if (is_12b_int(-imm)) {
 347                        emit(rv_addi(lo(rd), lo(rd), -imm), ctx);
 348                } else {
 349                        emit_imm(RV_REG_T0, imm, ctx);
 350                        emit(rv_sub(lo(rd), lo(rd), RV_REG_T0), ctx);
 351                }
 352                break;
 353        case BPF_AND:
 354                if (is_12b_int(imm)) {
 355                        emit(rv_andi(lo(rd), lo(rd), imm), ctx);
 356                } else {
 357                        emit_imm(RV_REG_T0, imm, ctx);
 358                        emit(rv_and(lo(rd), lo(rd), RV_REG_T0), ctx);
 359                }
 360                break;
 361        case BPF_OR:
 362                if (is_12b_int(imm)) {
 363                        emit(rv_ori(lo(rd), lo(rd), imm), ctx);
 364                } else {
 365                        emit_imm(RV_REG_T0, imm, ctx);
 366                        emit(rv_or(lo(rd), lo(rd), RV_REG_T0), ctx);
 367                }
 368                break;
 369        case BPF_XOR:
 370                if (is_12b_int(imm)) {
 371                        emit(rv_xori(lo(rd), lo(rd), imm), ctx);
 372                } else {
 373                        emit_imm(RV_REG_T0, imm, ctx);
 374                        emit(rv_xor(lo(rd), lo(rd), RV_REG_T0), ctx);
 375                }
 376                break;
 377        case BPF_LSH:
 378                if (is_12b_int(imm)) {
 379                        emit(rv_slli(lo(rd), lo(rd), imm), ctx);
 380                } else {
 381                        emit_imm(RV_REG_T0, imm, ctx);
 382                        emit(rv_sll(lo(rd), lo(rd), RV_REG_T0), ctx);
 383                }
 384                break;
 385        case BPF_RSH:
 386                if (is_12b_int(imm)) {
 387                        emit(rv_srli(lo(rd), lo(rd), imm), ctx);
 388                } else {
 389                        emit_imm(RV_REG_T0, imm, ctx);
 390                        emit(rv_srl(lo(rd), lo(rd), RV_REG_T0), ctx);
 391                }
 392                break;
 393        case BPF_ARSH:
 394                if (is_12b_int(imm)) {
 395                        emit(rv_srai(lo(rd), lo(rd), imm), ctx);
 396                } else {
 397                        emit_imm(RV_REG_T0, imm, ctx);
 398                        emit(rv_sra(lo(rd), lo(rd), RV_REG_T0), ctx);
 399                }
 400                break;
 401        }
 402
 403        bpf_put_reg32(dst, rd, ctx);
 404}
 405
 406static void emit_alu_r64(const s8 *dst, const s8 *src,
 407                         struct rv_jit_context *ctx, const u8 op)
 408{
 409        const s8 *tmp1 = bpf2rv32[TMP_REG_1];
 410        const s8 *tmp2 = bpf2rv32[TMP_REG_2];
 411        const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
 412        const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
 413
 414        switch (op) {
 415        case BPF_MOV:
 416                emit(rv_addi(lo(rd), lo(rs), 0), ctx);
 417                emit(rv_addi(hi(rd), hi(rs), 0), ctx);
 418                break;
 419        case BPF_ADD:
 420                if (rd == rs) {
 421                        emit(rv_srli(RV_REG_T0, lo(rd), 31), ctx);
 422                        emit(rv_slli(hi(rd), hi(rd), 1), ctx);
 423                        emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
 424                        emit(rv_slli(lo(rd), lo(rd), 1), ctx);
 425                } else {
 426                        emit(rv_add(lo(rd), lo(rd), lo(rs)), ctx);
 427                        emit(rv_sltu(RV_REG_T0, lo(rd), lo(rs)), ctx);
 428                        emit(rv_add(hi(rd), hi(rd), hi(rs)), ctx);
 429                        emit(rv_add(hi(rd), hi(rd), RV_REG_T0), ctx);
 430                }
 431                break;
 432        case BPF_SUB:
 433                emit(rv_sub(RV_REG_T1, hi(rd), hi(rs)), ctx);
 434                emit(rv_sltu(RV_REG_T0, lo(rd), lo(rs)), ctx);
 435                emit(rv_sub(hi(rd), RV_REG_T1, RV_REG_T0), ctx);
 436                emit(rv_sub(lo(rd), lo(rd), lo(rs)), ctx);
 437                break;
 438        case BPF_AND:
 439                emit(rv_and(lo(rd), lo(rd), lo(rs)), ctx);
 440                emit(rv_and(hi(rd), hi(rd), hi(rs)), ctx);
 441                break;
 442        case BPF_OR:
 443                emit(rv_or(lo(rd), lo(rd), lo(rs)), ctx);
 444                emit(rv_or(hi(rd), hi(rd), hi(rs)), ctx);
 445                break;
 446        case BPF_XOR:
 447                emit(rv_xor(lo(rd), lo(rd), lo(rs)), ctx);
 448                emit(rv_xor(hi(rd), hi(rd), hi(rs)), ctx);
 449                break;
 450        case BPF_MUL:
 451                emit(rv_mul(RV_REG_T0, hi(rs), lo(rd)), ctx);
 452                emit(rv_mul(hi(rd), hi(rd), lo(rs)), ctx);
 453                emit(rv_mulhu(RV_REG_T1, lo(rd), lo(rs)), ctx);
 454                emit(rv_add(hi(rd), hi(rd), RV_REG_T0), ctx);
 455                emit(rv_mul(lo(rd), lo(rd), lo(rs)), ctx);
 456                emit(rv_add(hi(rd), hi(rd), RV_REG_T1), ctx);
 457                break;
 458        case BPF_LSH:
 459                emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
 460                emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
 461                emit(rv_sll(hi(rd), lo(rd), RV_REG_T0), ctx);
 462                emit(rv_addi(lo(rd), RV_REG_ZERO, 0), ctx);
 463                emit(rv_jal(RV_REG_ZERO, 16), ctx);
 464                emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
 465                emit(rv_srli(RV_REG_T0, lo(rd), 1), ctx);
 466                emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
 467                emit(rv_srl(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
 468                emit(rv_sll(hi(rd), hi(rd), lo(rs)), ctx);
 469                emit(rv_or(hi(rd), RV_REG_T0, hi(rd)), ctx);
 470                emit(rv_sll(lo(rd), lo(rd), lo(rs)), ctx);
 471                break;
 472        case BPF_RSH:
 473                emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
 474                emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
 475                emit(rv_srl(lo(rd), hi(rd), RV_REG_T0), ctx);
 476                emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
 477                emit(rv_jal(RV_REG_ZERO, 16), ctx);
 478                emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
 479                emit(rv_slli(RV_REG_T0, hi(rd), 1), ctx);
 480                emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
 481                emit(rv_sll(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
 482                emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
 483                emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
 484                emit(rv_srl(hi(rd), hi(rd), lo(rs)), ctx);
 485                break;
 486        case BPF_ARSH:
 487                emit(rv_addi(RV_REG_T0, lo(rs), -32), ctx);
 488                emit(rv_blt(RV_REG_T0, RV_REG_ZERO, 8), ctx);
 489                emit(rv_sra(lo(rd), hi(rd), RV_REG_T0), ctx);
 490                emit(rv_srai(hi(rd), hi(rd), 31), ctx);
 491                emit(rv_jal(RV_REG_ZERO, 16), ctx);
 492                emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 31), ctx);
 493                emit(rv_slli(RV_REG_T0, hi(rd), 1), ctx);
 494                emit(rv_sub(RV_REG_T1, RV_REG_T1, lo(rs)), ctx);
 495                emit(rv_sll(RV_REG_T0, RV_REG_T0, RV_REG_T1), ctx);
 496                emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
 497                emit(rv_or(lo(rd), RV_REG_T0, lo(rd)), ctx);
 498                emit(rv_sra(hi(rd), hi(rd), lo(rs)), ctx);
 499                break;
 500        case BPF_NEG:
 501                emit(rv_sub(lo(rd), RV_REG_ZERO, lo(rd)), ctx);
 502                emit(rv_sltu(RV_REG_T0, RV_REG_ZERO, lo(rd)), ctx);
 503                emit(rv_sub(hi(rd), RV_REG_ZERO, hi(rd)), ctx);
 504                emit(rv_sub(hi(rd), hi(rd), RV_REG_T0), ctx);
 505                break;
 506        }
 507
 508        bpf_put_reg64(dst, rd, ctx);
 509}
 510
 511static void emit_alu_r32(const s8 *dst, const s8 *src,
 512                         struct rv_jit_context *ctx, const u8 op)
 513{
 514        const s8 *tmp1 = bpf2rv32[TMP_REG_1];
 515        const s8 *tmp2 = bpf2rv32[TMP_REG_2];
 516        const s8 *rd = bpf_get_reg32(dst, tmp1, ctx);
 517        const s8 *rs = bpf_get_reg32(src, tmp2, ctx);
 518
 519        switch (op) {
 520        case BPF_MOV:
 521                emit(rv_addi(lo(rd), lo(rs), 0), ctx);
 522                break;
 523        case BPF_ADD:
 524                emit(rv_add(lo(rd), lo(rd), lo(rs)), ctx);
 525                break;
 526        case BPF_SUB:
 527                emit(rv_sub(lo(rd), lo(rd), lo(rs)), ctx);
 528                break;
 529        case BPF_AND:
 530                emit(rv_and(lo(rd), lo(rd), lo(rs)), ctx);
 531                break;
 532        case BPF_OR:
 533                emit(rv_or(lo(rd), lo(rd), lo(rs)), ctx);
 534                break;
 535        case BPF_XOR:
 536                emit(rv_xor(lo(rd), lo(rd), lo(rs)), ctx);
 537                break;
 538        case BPF_MUL:
 539                emit(rv_mul(lo(rd), lo(rd), lo(rs)), ctx);
 540                break;
 541        case BPF_DIV:
 542                emit(rv_divu(lo(rd), lo(rd), lo(rs)), ctx);
 543                break;
 544        case BPF_MOD:
 545                emit(rv_remu(lo(rd), lo(rd), lo(rs)), ctx);
 546                break;
 547        case BPF_LSH:
 548                emit(rv_sll(lo(rd), lo(rd), lo(rs)), ctx);
 549                break;
 550        case BPF_RSH:
 551                emit(rv_srl(lo(rd), lo(rd), lo(rs)), ctx);
 552                break;
 553        case BPF_ARSH:
 554                emit(rv_sra(lo(rd), lo(rd), lo(rs)), ctx);
 555                break;
 556        case BPF_NEG:
 557                emit(rv_sub(lo(rd), RV_REG_ZERO, lo(rd)), ctx);
 558                break;
 559        }
 560
 561        bpf_put_reg32(dst, rd, ctx);
 562}
 563
 564static int emit_branch_r64(const s8 *src1, const s8 *src2, s32 rvoff,
 565                           struct rv_jit_context *ctx, const u8 op)
 566{
 567        int e, s = ctx->ninsns;
 568        const s8 *tmp1 = bpf2rv32[TMP_REG_1];
 569        const s8 *tmp2 = bpf2rv32[TMP_REG_2];
 570
 571        const s8 *rs1 = bpf_get_reg64(src1, tmp1, ctx);
 572        const s8 *rs2 = bpf_get_reg64(src2, tmp2, ctx);
 573
 574        /*
 575         * NO_JUMP skips over the rest of the instructions and the
 576         * emit_jump_and_link, meaning the BPF branch is not taken.
 577         * JUMP skips directly to the emit_jump_and_link, meaning
 578         * the BPF branch is taken.
 579         *
 580         * The fallthrough case results in the BPF branch being taken.
 581         */
 582#define NO_JUMP(idx) (6 + (2 * (idx)))
 583#define JUMP(idx) (2 + (2 * (idx)))
 584
 585        switch (op) {
 586        case BPF_JEQ:
 587                emit(rv_bne(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
 588                emit(rv_bne(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
 589                break;
 590        case BPF_JGT:
 591                emit(rv_bgtu(hi(rs1), hi(rs2), JUMP(2)), ctx);
 592                emit(rv_bltu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
 593                emit(rv_bleu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
 594                break;
 595        case BPF_JLT:
 596                emit(rv_bltu(hi(rs1), hi(rs2), JUMP(2)), ctx);
 597                emit(rv_bgtu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
 598                emit(rv_bgeu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
 599                break;
 600        case BPF_JGE:
 601                emit(rv_bgtu(hi(rs1), hi(rs2), JUMP(2)), ctx);
 602                emit(rv_bltu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
 603                emit(rv_bltu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
 604                break;
 605        case BPF_JLE:
 606                emit(rv_bltu(hi(rs1), hi(rs2), JUMP(2)), ctx);
 607                emit(rv_bgtu(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
 608                emit(rv_bgtu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
 609                break;
 610        case BPF_JNE:
 611                emit(rv_bne(hi(rs1), hi(rs2), JUMP(1)), ctx);
 612                emit(rv_beq(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
 613                break;
 614        case BPF_JSGT:
 615                emit(rv_bgt(hi(rs1), hi(rs2), JUMP(2)), ctx);
 616                emit(rv_blt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
 617                emit(rv_bleu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
 618                break;
 619        case BPF_JSLT:
 620                emit(rv_blt(hi(rs1), hi(rs2), JUMP(2)), ctx);
 621                emit(rv_bgt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
 622                emit(rv_bgeu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
 623                break;
 624        case BPF_JSGE:
 625                emit(rv_bgt(hi(rs1), hi(rs2), JUMP(2)), ctx);
 626                emit(rv_blt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
 627                emit(rv_bltu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
 628                break;
 629        case BPF_JSLE:
 630                emit(rv_blt(hi(rs1), hi(rs2), JUMP(2)), ctx);
 631                emit(rv_bgt(hi(rs1), hi(rs2), NO_JUMP(1)), ctx);
 632                emit(rv_bgtu(lo(rs1), lo(rs2), NO_JUMP(0)), ctx);
 633                break;
 634        case BPF_JSET:
 635                emit(rv_and(RV_REG_T0, hi(rs1), hi(rs2)), ctx);
 636                emit(rv_bne(RV_REG_T0, RV_REG_ZERO, JUMP(2)), ctx);
 637                emit(rv_and(RV_REG_T0, lo(rs1), lo(rs2)), ctx);
 638                emit(rv_beq(RV_REG_T0, RV_REG_ZERO, NO_JUMP(0)), ctx);
 639                break;
 640        }
 641
 642#undef NO_JUMP
 643#undef JUMP
 644
 645        e = ctx->ninsns;
 646        /* Adjust for extra insns. */
 647        rvoff -= ninsns_rvoff(e - s);
 648        emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
 649        return 0;
 650}
 651
 652static int emit_bcc(u8 op, u8 rd, u8 rs, int rvoff, struct rv_jit_context *ctx)
 653{
 654        int e, s = ctx->ninsns;
 655        bool far = false;
 656        int off;
 657
 658        if (op == BPF_JSET) {
 659                /*
 660                 * BPF_JSET is a special case: it has no inverse so we always
 661                 * treat it as a far branch.
 662                 */
 663                far = true;
 664        } else if (!is_13b_int(rvoff)) {
 665                op = invert_bpf_cond(op);
 666                far = true;
 667        }
 668
 669        /*
 670         * For a far branch, the condition is negated and we jump over the
 671         * branch itself, and the two instructions from emit_jump_and_link.
 672         * For a near branch, just use rvoff.
 673         */
 674        off = far ? 6 : (rvoff >> 1);
 675
 676        switch (op) {
 677        case BPF_JEQ:
 678                emit(rv_beq(rd, rs, off), ctx);
 679                break;
 680        case BPF_JGT:
 681                emit(rv_bgtu(rd, rs, off), ctx);
 682                break;
 683        case BPF_JLT:
 684                emit(rv_bltu(rd, rs, off), ctx);
 685                break;
 686        case BPF_JGE:
 687                emit(rv_bgeu(rd, rs, off), ctx);
 688                break;
 689        case BPF_JLE:
 690                emit(rv_bleu(rd, rs, off), ctx);
 691                break;
 692        case BPF_JNE:
 693                emit(rv_bne(rd, rs, off), ctx);
 694                break;
 695        case BPF_JSGT:
 696                emit(rv_bgt(rd, rs, off), ctx);
 697                break;
 698        case BPF_JSLT:
 699                emit(rv_blt(rd, rs, off), ctx);
 700                break;
 701        case BPF_JSGE:
 702                emit(rv_bge(rd, rs, off), ctx);
 703                break;
 704        case BPF_JSLE:
 705                emit(rv_ble(rd, rs, off), ctx);
 706                break;
 707        case BPF_JSET:
 708                emit(rv_and(RV_REG_T0, rd, rs), ctx);
 709                emit(rv_beq(RV_REG_T0, RV_REG_ZERO, off), ctx);
 710                break;
 711        }
 712
 713        if (far) {
 714                e = ctx->ninsns;
 715                /* Adjust for extra insns. */
 716                rvoff -= ninsns_rvoff(e - s);
 717                emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
 718        }
 719        return 0;
 720}
 721
 722static int emit_branch_r32(const s8 *src1, const s8 *src2, s32 rvoff,
 723                           struct rv_jit_context *ctx, const u8 op)
 724{
 725        int e, s = ctx->ninsns;
 726        const s8 *tmp1 = bpf2rv32[TMP_REG_1];
 727        const s8 *tmp2 = bpf2rv32[TMP_REG_2];
 728
 729        const s8 *rs1 = bpf_get_reg32(src1, tmp1, ctx);
 730        const s8 *rs2 = bpf_get_reg32(src2, tmp2, ctx);
 731
 732        e = ctx->ninsns;
 733        /* Adjust for extra insns. */
 734        rvoff -= ninsns_rvoff(e - s);
 735
 736        if (emit_bcc(op, lo(rs1), lo(rs2), rvoff, ctx))
 737                return -1;
 738
 739        return 0;
 740}
 741
 742static void emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
 743{
 744        const s8 *r0 = bpf2rv32[BPF_REG_0];
 745        const s8 *r5 = bpf2rv32[BPF_REG_5];
 746        u32 upper = ((u32)addr + (1 << 11)) >> 12;
 747        u32 lower = addr & 0xfff;
 748
 749        /* R1-R4 already in correct registers---need to push R5 to stack. */
 750        emit(rv_addi(RV_REG_SP, RV_REG_SP, -16), ctx);
 751        emit(rv_sw(RV_REG_SP, 0, lo(r5)), ctx);
 752        emit(rv_sw(RV_REG_SP, 4, hi(r5)), ctx);
 753
 754        /* Backup TCC. */
 755        emit(rv_addi(RV_REG_TCC_SAVED, RV_REG_TCC, 0), ctx);
 756
 757        /*
 758         * Use lui/jalr pair to jump to absolute address. Don't use emit_imm as
 759         * the number of emitted instructions should not depend on the value of
 760         * addr.
 761         */
 762        emit(rv_lui(RV_REG_T1, upper), ctx);
 763        emit(rv_jalr(RV_REG_RA, RV_REG_T1, lower), ctx);
 764
 765        /* Restore TCC. */
 766        emit(rv_addi(RV_REG_TCC, RV_REG_TCC_SAVED, 0), ctx);
 767
 768        /* Set return value and restore stack. */
 769        emit(rv_addi(lo(r0), RV_REG_A0, 0), ctx);
 770        emit(rv_addi(hi(r0), RV_REG_A1, 0), ctx);
 771        emit(rv_addi(RV_REG_SP, RV_REG_SP, 16), ctx);
 772}
 773
 774static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
 775{
 776        /*
 777         * R1 -> &ctx
 778         * R2 -> &array
 779         * R3 -> index
 780         */
 781        int tc_ninsn, off, start_insn = ctx->ninsns;
 782        const s8 *arr_reg = bpf2rv32[BPF_REG_2];
 783        const s8 *idx_reg = bpf2rv32[BPF_REG_3];
 784
 785        tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
 786                ctx->offset[0];
 787
 788        /* max_entries = array->map.max_entries; */
 789        off = offsetof(struct bpf_array, map.max_entries);
 790        if (is_12b_check(off, insn))
 791                return -1;
 792        emit(rv_lw(RV_REG_T1, off, lo(arr_reg)), ctx);
 793
 794        /*
 795         * if (index >= max_entries)
 796         *   goto out;
 797         */
 798        off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
 799        emit_bcc(BPF_JGE, lo(idx_reg), RV_REG_T1, off, ctx);
 800
 801        /*
 802         * temp_tcc = tcc - 1;
 803         * if (tcc < 0)
 804         *   goto out;
 805         */
 806        emit(rv_addi(RV_REG_T1, RV_REG_TCC, -1), ctx);
 807        off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
 808        emit_bcc(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
 809
 810        /*
 811         * prog = array->ptrs[index];
 812         * if (!prog)
 813         *   goto out;
 814         */
 815        emit(rv_slli(RV_REG_T0, lo(idx_reg), 2), ctx);
 816        emit(rv_add(RV_REG_T0, RV_REG_T0, lo(arr_reg)), ctx);
 817        off = offsetof(struct bpf_array, ptrs);
 818        if (is_12b_check(off, insn))
 819                return -1;
 820        emit(rv_lw(RV_REG_T0, off, RV_REG_T0), ctx);
 821        off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
 822        emit_bcc(BPF_JEQ, RV_REG_T0, RV_REG_ZERO, off, ctx);
 823
 824        /*
 825         * tcc = temp_tcc;
 826         * goto *(prog->bpf_func + 4);
 827         */
 828        off = offsetof(struct bpf_prog, bpf_func);
 829        if (is_12b_check(off, insn))
 830                return -1;
 831        emit(rv_lw(RV_REG_T0, off, RV_REG_T0), ctx);
 832        emit(rv_addi(RV_REG_TCC, RV_REG_T1, 0), ctx);
 833        /* Epilogue jumps to *(t0 + 4). */
 834        __build_epilogue(true, ctx);
 835        return 0;
 836}
 837
 838static int emit_load_r64(const s8 *dst, const s8 *src, s16 off,
 839                         struct rv_jit_context *ctx, const u8 size)
 840{
 841        const s8 *tmp1 = bpf2rv32[TMP_REG_1];
 842        const s8 *tmp2 = bpf2rv32[TMP_REG_2];
 843        const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
 844        const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
 845
 846        emit_imm(RV_REG_T0, off, ctx);
 847        emit(rv_add(RV_REG_T0, RV_REG_T0, lo(rs)), ctx);
 848
 849        switch (size) {
 850        case BPF_B:
 851                emit(rv_lbu(lo(rd), 0, RV_REG_T0), ctx);
 852                if (!ctx->prog->aux->verifier_zext)
 853                        emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
 854                break;
 855        case BPF_H:
 856                emit(rv_lhu(lo(rd), 0, RV_REG_T0), ctx);
 857                if (!ctx->prog->aux->verifier_zext)
 858                        emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
 859                break;
 860        case BPF_W:
 861                emit(rv_lw(lo(rd), 0, RV_REG_T0), ctx);
 862                if (!ctx->prog->aux->verifier_zext)
 863                        emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
 864                break;
 865        case BPF_DW:
 866                emit(rv_lw(lo(rd), 0, RV_REG_T0), ctx);
 867                emit(rv_lw(hi(rd), 4, RV_REG_T0), ctx);
 868                break;
 869        }
 870
 871        bpf_put_reg64(dst, rd, ctx);
 872        return 0;
 873}
 874
 875static int emit_store_r64(const s8 *dst, const s8 *src, s16 off,
 876                          struct rv_jit_context *ctx, const u8 size,
 877                          const u8 mode)
 878{
 879        const s8 *tmp1 = bpf2rv32[TMP_REG_1];
 880        const s8 *tmp2 = bpf2rv32[TMP_REG_2];
 881        const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
 882        const s8 *rs = bpf_get_reg64(src, tmp2, ctx);
 883
 884        if (mode == BPF_ATOMIC && size != BPF_W)
 885                return -1;
 886
 887        emit_imm(RV_REG_T0, off, ctx);
 888        emit(rv_add(RV_REG_T0, RV_REG_T0, lo(rd)), ctx);
 889
 890        switch (size) {
 891        case BPF_B:
 892                emit(rv_sb(RV_REG_T0, 0, lo(rs)), ctx);
 893                break;
 894        case BPF_H:
 895                emit(rv_sh(RV_REG_T0, 0, lo(rs)), ctx);
 896                break;
 897        case BPF_W:
 898                switch (mode) {
 899                case BPF_MEM:
 900                        emit(rv_sw(RV_REG_T0, 0, lo(rs)), ctx);
 901                        break;
 902                case BPF_ATOMIC: /* Only BPF_ADD supported */
 903                        emit(rv_amoadd_w(RV_REG_ZERO, lo(rs), RV_REG_T0, 0, 0),
 904                             ctx);
 905                        break;
 906                }
 907                break;
 908        case BPF_DW:
 909                emit(rv_sw(RV_REG_T0, 0, lo(rs)), ctx);
 910                emit(rv_sw(RV_REG_T0, 4, hi(rs)), ctx);
 911                break;
 912        }
 913
 914        return 0;
 915}
 916
 917static void emit_rev16(const s8 rd, struct rv_jit_context *ctx)
 918{
 919        emit(rv_slli(rd, rd, 16), ctx);
 920        emit(rv_slli(RV_REG_T1, rd, 8), ctx);
 921        emit(rv_srli(rd, rd, 8), ctx);
 922        emit(rv_add(RV_REG_T1, rd, RV_REG_T1), ctx);
 923        emit(rv_srli(rd, RV_REG_T1, 16), ctx);
 924}
 925
 926static void emit_rev32(const s8 rd, struct rv_jit_context *ctx)
 927{
 928        emit(rv_addi(RV_REG_T1, RV_REG_ZERO, 0), ctx);
 929        emit(rv_andi(RV_REG_T0, rd, 255), ctx);
 930        emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
 931        emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
 932        emit(rv_srli(rd, rd, 8), ctx);
 933        emit(rv_andi(RV_REG_T0, rd, 255), ctx);
 934        emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
 935        emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
 936        emit(rv_srli(rd, rd, 8), ctx);
 937        emit(rv_andi(RV_REG_T0, rd, 255), ctx);
 938        emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
 939        emit(rv_slli(RV_REG_T1, RV_REG_T1, 8), ctx);
 940        emit(rv_srli(rd, rd, 8), ctx);
 941        emit(rv_andi(RV_REG_T0, rd, 255), ctx);
 942        emit(rv_add(RV_REG_T1, RV_REG_T1, RV_REG_T0), ctx);
 943        emit(rv_addi(rd, RV_REG_T1, 0), ctx);
 944}
 945
 946static void emit_zext64(const s8 *dst, struct rv_jit_context *ctx)
 947{
 948        const s8 *rd;
 949        const s8 *tmp1 = bpf2rv32[TMP_REG_1];
 950
 951        rd = bpf_get_reg64(dst, tmp1, ctx);
 952        emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
 953        bpf_put_reg64(dst, rd, ctx);
 954}
 955
 956int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
 957                      bool extra_pass)
 958{
 959        bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
 960                BPF_CLASS(insn->code) == BPF_JMP;
 961        int s, e, rvoff, i = insn - ctx->prog->insnsi;
 962        u8 code = insn->code;
 963        s16 off = insn->off;
 964        s32 imm = insn->imm;
 965
 966        const s8 *dst = bpf2rv32[insn->dst_reg];
 967        const s8 *src = bpf2rv32[insn->src_reg];
 968        const s8 *tmp1 = bpf2rv32[TMP_REG_1];
 969        const s8 *tmp2 = bpf2rv32[TMP_REG_2];
 970
 971        switch (code) {
 972        case BPF_ALU64 | BPF_MOV | BPF_X:
 973
 974        case BPF_ALU64 | BPF_ADD | BPF_X:
 975        case BPF_ALU64 | BPF_ADD | BPF_K:
 976
 977        case BPF_ALU64 | BPF_SUB | BPF_X:
 978        case BPF_ALU64 | BPF_SUB | BPF_K:
 979
 980        case BPF_ALU64 | BPF_AND | BPF_X:
 981        case BPF_ALU64 | BPF_OR | BPF_X:
 982        case BPF_ALU64 | BPF_XOR | BPF_X:
 983
 984        case BPF_ALU64 | BPF_MUL | BPF_X:
 985        case BPF_ALU64 | BPF_MUL | BPF_K:
 986
 987        case BPF_ALU64 | BPF_LSH | BPF_X:
 988        case BPF_ALU64 | BPF_RSH | BPF_X:
 989        case BPF_ALU64 | BPF_ARSH | BPF_X:
 990                if (BPF_SRC(code) == BPF_K) {
 991                        emit_imm32(tmp2, imm, ctx);
 992                        src = tmp2;
 993                }
 994                emit_alu_r64(dst, src, ctx, BPF_OP(code));
 995                break;
 996
 997        case BPF_ALU64 | BPF_NEG:
 998                emit_alu_r64(dst, tmp2, ctx, BPF_OP(code));
 999                break;
1000
1001        case BPF_ALU64 | BPF_DIV | BPF_X:
1002        case BPF_ALU64 | BPF_DIV | BPF_K:
1003        case BPF_ALU64 | BPF_MOD | BPF_X:
1004        case BPF_ALU64 | BPF_MOD | BPF_K:
1005                goto notsupported;
1006
1007        case BPF_ALU64 | BPF_MOV | BPF_K:
1008        case BPF_ALU64 | BPF_AND | BPF_K:
1009        case BPF_ALU64 | BPF_OR | BPF_K:
1010        case BPF_ALU64 | BPF_XOR | BPF_K:
1011        case BPF_ALU64 | BPF_LSH | BPF_K:
1012        case BPF_ALU64 | BPF_RSH | BPF_K:
1013        case BPF_ALU64 | BPF_ARSH | BPF_K:
1014                emit_alu_i64(dst, imm, ctx, BPF_OP(code));
1015                break;
1016
1017        case BPF_ALU | BPF_MOV | BPF_X:
1018                if (imm == 1) {
1019                        /* Special mov32 for zext. */
1020                        emit_zext64(dst, ctx);
1021                        break;
1022                }
1023                fallthrough;
1024
1025        case BPF_ALU | BPF_ADD | BPF_X:
1026        case BPF_ALU | BPF_SUB | BPF_X:
1027        case BPF_ALU | BPF_AND | BPF_X:
1028        case BPF_ALU | BPF_OR | BPF_X:
1029        case BPF_ALU | BPF_XOR | BPF_X:
1030
1031        case BPF_ALU | BPF_MUL | BPF_X:
1032        case BPF_ALU | BPF_MUL | BPF_K:
1033
1034        case BPF_ALU | BPF_DIV | BPF_X:
1035        case BPF_ALU | BPF_DIV | BPF_K:
1036
1037        case BPF_ALU | BPF_MOD | BPF_X:
1038        case BPF_ALU | BPF_MOD | BPF_K:
1039
1040        case BPF_ALU | BPF_LSH | BPF_X:
1041        case BPF_ALU | BPF_RSH | BPF_X:
1042        case BPF_ALU | BPF_ARSH | BPF_X:
1043                if (BPF_SRC(code) == BPF_K) {
1044                        emit_imm32(tmp2, imm, ctx);
1045                        src = tmp2;
1046                }
1047                emit_alu_r32(dst, src, ctx, BPF_OP(code));
1048                break;
1049
1050        case BPF_ALU | BPF_MOV | BPF_K:
1051        case BPF_ALU | BPF_ADD | BPF_K:
1052        case BPF_ALU | BPF_SUB | BPF_K:
1053        case BPF_ALU | BPF_AND | BPF_K:
1054        case BPF_ALU | BPF_OR | BPF_K:
1055        case BPF_ALU | BPF_XOR | BPF_K:
1056        case BPF_ALU | BPF_LSH | BPF_K:
1057        case BPF_ALU | BPF_RSH | BPF_K:
1058        case BPF_ALU | BPF_ARSH | BPF_K:
1059                /*
1060                 * mul,div,mod are handled in the BPF_X case since there are
1061                 * no RISC-V I-type equivalents.
1062                 */
1063                emit_alu_i32(dst, imm, ctx, BPF_OP(code));
1064                break;
1065
1066        case BPF_ALU | BPF_NEG:
1067                /*
1068                 * src is ignored---choose tmp2 as a dummy register since it
1069                 * is not on the stack.
1070                 */
1071                emit_alu_r32(dst, tmp2, ctx, BPF_OP(code));
1072                break;
1073
1074        case BPF_ALU | BPF_END | BPF_FROM_LE:
1075        {
1076                const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1077
1078                switch (imm) {
1079                case 16:
1080                        emit(rv_slli(lo(rd), lo(rd), 16), ctx);
1081                        emit(rv_srli(lo(rd), lo(rd), 16), ctx);
1082                        fallthrough;
1083                case 32:
1084                        if (!ctx->prog->aux->verifier_zext)
1085                                emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1086                        break;
1087                case 64:
1088                        /* Do nothing. */
1089                        break;
1090                default:
1091                        pr_err("bpf-jit: BPF_END imm %d invalid\n", imm);
1092                        return -1;
1093                }
1094
1095                bpf_put_reg64(dst, rd, ctx);
1096                break;
1097        }
1098
1099        case BPF_ALU | BPF_END | BPF_FROM_BE:
1100        {
1101                const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1102
1103                switch (imm) {
1104                case 16:
1105                        emit_rev16(lo(rd), ctx);
1106                        if (!ctx->prog->aux->verifier_zext)
1107                                emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1108                        break;
1109                case 32:
1110                        emit_rev32(lo(rd), ctx);
1111                        if (!ctx->prog->aux->verifier_zext)
1112                                emit(rv_addi(hi(rd), RV_REG_ZERO, 0), ctx);
1113                        break;
1114                case 64:
1115                        /* Swap upper and lower halves. */
1116                        emit(rv_addi(RV_REG_T0, lo(rd), 0), ctx);
1117                        emit(rv_addi(lo(rd), hi(rd), 0), ctx);
1118                        emit(rv_addi(hi(rd), RV_REG_T0, 0), ctx);
1119
1120                        /* Swap each half. */
1121                        emit_rev32(lo(rd), ctx);
1122                        emit_rev32(hi(rd), ctx);
1123                        break;
1124                default:
1125                        pr_err("bpf-jit: BPF_END imm %d invalid\n", imm);
1126                        return -1;
1127                }
1128
1129                bpf_put_reg64(dst, rd, ctx);
1130                break;
1131        }
1132
1133        case BPF_JMP | BPF_JA:
1134                rvoff = rv_offset(i, off, ctx);
1135                emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
1136                break;
1137
1138        case BPF_JMP | BPF_CALL:
1139        {
1140                bool fixed;
1141                int ret;
1142                u64 addr;
1143
1144                ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass, &addr,
1145                                            &fixed);
1146                if (ret < 0)
1147                        return ret;
1148                emit_call(fixed, addr, ctx);
1149                break;
1150        }
1151
1152        case BPF_JMP | BPF_TAIL_CALL:
1153                if (emit_bpf_tail_call(i, ctx))
1154                        return -1;
1155                break;
1156
1157        case BPF_JMP | BPF_JEQ | BPF_X:
1158        case BPF_JMP | BPF_JEQ | BPF_K:
1159        case BPF_JMP32 | BPF_JEQ | BPF_X:
1160        case BPF_JMP32 | BPF_JEQ | BPF_K:
1161
1162        case BPF_JMP | BPF_JNE | BPF_X:
1163        case BPF_JMP | BPF_JNE | BPF_K:
1164        case BPF_JMP32 | BPF_JNE | BPF_X:
1165        case BPF_JMP32 | BPF_JNE | BPF_K:
1166
1167        case BPF_JMP | BPF_JLE | BPF_X:
1168        case BPF_JMP | BPF_JLE | BPF_K:
1169        case BPF_JMP32 | BPF_JLE | BPF_X:
1170        case BPF_JMP32 | BPF_JLE | BPF_K:
1171
1172        case BPF_JMP | BPF_JLT | BPF_X:
1173        case BPF_JMP | BPF_JLT | BPF_K:
1174        case BPF_JMP32 | BPF_JLT | BPF_X:
1175        case BPF_JMP32 | BPF_JLT | BPF_K:
1176
1177        case BPF_JMP | BPF_JGE | BPF_X:
1178        case BPF_JMP | BPF_JGE | BPF_K:
1179        case BPF_JMP32 | BPF_JGE | BPF_X:
1180        case BPF_JMP32 | BPF_JGE | BPF_K:
1181
1182        case BPF_JMP | BPF_JGT | BPF_X:
1183        case BPF_JMP | BPF_JGT | BPF_K:
1184        case BPF_JMP32 | BPF_JGT | BPF_X:
1185        case BPF_JMP32 | BPF_JGT | BPF_K:
1186
1187        case BPF_JMP | BPF_JSLE | BPF_X:
1188        case BPF_JMP | BPF_JSLE | BPF_K:
1189        case BPF_JMP32 | BPF_JSLE | BPF_X:
1190        case BPF_JMP32 | BPF_JSLE | BPF_K:
1191
1192        case BPF_JMP | BPF_JSLT | BPF_X:
1193        case BPF_JMP | BPF_JSLT | BPF_K:
1194        case BPF_JMP32 | BPF_JSLT | BPF_X:
1195        case BPF_JMP32 | BPF_JSLT | BPF_K:
1196
1197        case BPF_JMP | BPF_JSGE | BPF_X:
1198        case BPF_JMP | BPF_JSGE | BPF_K:
1199        case BPF_JMP32 | BPF_JSGE | BPF_X:
1200        case BPF_JMP32 | BPF_JSGE | BPF_K:
1201
1202        case BPF_JMP | BPF_JSGT | BPF_X:
1203        case BPF_JMP | BPF_JSGT | BPF_K:
1204        case BPF_JMP32 | BPF_JSGT | BPF_X:
1205        case BPF_JMP32 | BPF_JSGT | BPF_K:
1206
1207        case BPF_JMP | BPF_JSET | BPF_X:
1208        case BPF_JMP | BPF_JSET | BPF_K:
1209        case BPF_JMP32 | BPF_JSET | BPF_X:
1210        case BPF_JMP32 | BPF_JSET | BPF_K:
1211                rvoff = rv_offset(i, off, ctx);
1212                if (BPF_SRC(code) == BPF_K) {
1213                        s = ctx->ninsns;
1214                        emit_imm32(tmp2, imm, ctx);
1215                        src = tmp2;
1216                        e = ctx->ninsns;
1217                        rvoff -= ninsns_rvoff(e - s);
1218                }
1219
1220                if (is64)
1221                        emit_branch_r64(dst, src, rvoff, ctx, BPF_OP(code));
1222                else
1223                        emit_branch_r32(dst, src, rvoff, ctx, BPF_OP(code));
1224                break;
1225
1226        case BPF_JMP | BPF_EXIT:
1227                if (i == ctx->prog->len - 1)
1228                        break;
1229
1230                rvoff = epilogue_offset(ctx);
1231                emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
1232                break;
1233
1234        case BPF_LD | BPF_IMM | BPF_DW:
1235        {
1236                struct bpf_insn insn1 = insn[1];
1237                s32 imm_lo = imm;
1238                s32 imm_hi = insn1.imm;
1239                const s8 *rd = bpf_get_reg64(dst, tmp1, ctx);
1240
1241                emit_imm64(rd, imm_hi, imm_lo, ctx);
1242                bpf_put_reg64(dst, rd, ctx);
1243                return 1;
1244        }
1245
1246        case BPF_LDX | BPF_MEM | BPF_B:
1247        case BPF_LDX | BPF_MEM | BPF_H:
1248        case BPF_LDX | BPF_MEM | BPF_W:
1249        case BPF_LDX | BPF_MEM | BPF_DW:
1250                if (emit_load_r64(dst, src, off, ctx, BPF_SIZE(code)))
1251                        return -1;
1252                break;
1253
1254        /* speculation barrier */
1255        case BPF_ST | BPF_NOSPEC:
1256                break;
1257
1258        case BPF_ST | BPF_MEM | BPF_B:
1259        case BPF_ST | BPF_MEM | BPF_H:
1260        case BPF_ST | BPF_MEM | BPF_W:
1261        case BPF_ST | BPF_MEM | BPF_DW:
1262
1263        case BPF_STX | BPF_MEM | BPF_B:
1264        case BPF_STX | BPF_MEM | BPF_H:
1265        case BPF_STX | BPF_MEM | BPF_W:
1266        case BPF_STX | BPF_MEM | BPF_DW:
1267                if (BPF_CLASS(code) == BPF_ST) {
1268                        emit_imm32(tmp2, imm, ctx);
1269                        src = tmp2;
1270                }
1271
1272                if (emit_store_r64(dst, src, off, ctx, BPF_SIZE(code),
1273                                   BPF_MODE(code)))
1274                        return -1;
1275                break;
1276
1277        case BPF_STX | BPF_ATOMIC | BPF_W:
1278                if (insn->imm != BPF_ADD) {
1279                        pr_info_once(
1280                                "bpf-jit: not supported: atomic operation %02x ***\n",
1281                                insn->imm);
1282                        return -EFAULT;
1283                }
1284
1285                if (emit_store_r64(dst, src, off, ctx, BPF_SIZE(code),
1286                                   BPF_MODE(code)))
1287                        return -1;
1288                break;
1289
1290        /* No hardware support for 8-byte atomics in RV32. */
1291        case BPF_STX | BPF_ATOMIC | BPF_DW:
1292                /* Fallthrough. */
1293
1294notsupported:
1295                pr_info_once("bpf-jit: not supported: opcode %02x ***\n", code);
1296                return -EFAULT;
1297
1298        default:
1299                pr_err("bpf-jit: unknown opcode %02x\n", code);
1300                return -EINVAL;
1301        }
1302
1303        return 0;
1304}
1305
1306void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1307{
1308        const s8 *fp = bpf2rv32[BPF_REG_FP];
1309        const s8 *r1 = bpf2rv32[BPF_REG_1];
1310        int stack_adjust = 0;
1311        int bpf_stack_adjust =
1312                round_up(ctx->prog->aux->stack_depth, STACK_ALIGN);
1313
1314        /* Make space for callee-saved registers. */
1315        stack_adjust += NR_SAVED_REGISTERS * sizeof(u32);
1316        /* Make space for BPF registers on stack. */
1317        stack_adjust += BPF_JIT_SCRATCH_REGS * sizeof(u32);
1318        /* Make space for BPF stack. */
1319        stack_adjust += bpf_stack_adjust;
1320        /* Round up for stack alignment. */
1321        stack_adjust = round_up(stack_adjust, STACK_ALIGN);
1322
1323        /*
1324         * The first instruction sets the tail-call-counter (TCC) register.
1325         * This instruction is skipped by tail calls.
1326         */
1327        emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1328
1329        emit(rv_addi(RV_REG_SP, RV_REG_SP, -stack_adjust), ctx);
1330
1331        /* Save callee-save registers. */
1332        emit(rv_sw(RV_REG_SP, stack_adjust - 4, RV_REG_RA), ctx);
1333        emit(rv_sw(RV_REG_SP, stack_adjust - 8, RV_REG_FP), ctx);
1334        emit(rv_sw(RV_REG_SP, stack_adjust - 12, RV_REG_S1), ctx);
1335        emit(rv_sw(RV_REG_SP, stack_adjust - 16, RV_REG_S2), ctx);
1336        emit(rv_sw(RV_REG_SP, stack_adjust - 20, RV_REG_S3), ctx);
1337        emit(rv_sw(RV_REG_SP, stack_adjust - 24, RV_REG_S4), ctx);
1338        emit(rv_sw(RV_REG_SP, stack_adjust - 28, RV_REG_S5), ctx);
1339        emit(rv_sw(RV_REG_SP, stack_adjust - 32, RV_REG_S6), ctx);
1340        emit(rv_sw(RV_REG_SP, stack_adjust - 36, RV_REG_S7), ctx);
1341
1342        /* Set fp: used as the base address for stacked BPF registers. */
1343        emit(rv_addi(RV_REG_FP, RV_REG_SP, stack_adjust), ctx);
1344
1345        /* Set up BPF frame pointer. */
1346        emit(rv_addi(lo(fp), RV_REG_SP, bpf_stack_adjust), ctx);
1347        emit(rv_addi(hi(fp), RV_REG_ZERO, 0), ctx);
1348
1349        /* Set up BPF context pointer. */
1350        emit(rv_addi(lo(r1), RV_REG_A0, 0), ctx);
1351        emit(rv_addi(hi(r1), RV_REG_ZERO, 0), ctx);
1352
1353        ctx->stack_size = stack_adjust;
1354}
1355
1356void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1357{
1358        __build_epilogue(false, ctx);
1359}
1360