riscv, bpf: Fix patch_text implicit declaration
[linux-block.git] / arch / riscv / net / bpf_jit_comp64.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* BPF JIT compiler for RV64G
3  *
4  * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5  *
6  */
7
8 #include <linux/bitfield.h>
9 #include <linux/bpf.h>
10 #include <linux/filter.h>
11 #include <linux/memory.h>
12 #include <linux/stop_machine.h>
13 #include <asm/patch.h>
14 #include "bpf_jit.h"
15
16 #define RV_REG_TCC RV_REG_A6
17 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
18
19 static const int regmap[] = {
20         [BPF_REG_0] =   RV_REG_A5,
21         [BPF_REG_1] =   RV_REG_A0,
22         [BPF_REG_2] =   RV_REG_A1,
23         [BPF_REG_3] =   RV_REG_A2,
24         [BPF_REG_4] =   RV_REG_A3,
25         [BPF_REG_5] =   RV_REG_A4,
26         [BPF_REG_6] =   RV_REG_S1,
27         [BPF_REG_7] =   RV_REG_S2,
28         [BPF_REG_8] =   RV_REG_S3,
29         [BPF_REG_9] =   RV_REG_S4,
30         [BPF_REG_FP] =  RV_REG_S5,
31         [BPF_REG_AX] =  RV_REG_T0,
32 };
33
34 static const int pt_regmap[] = {
35         [RV_REG_A0] = offsetof(struct pt_regs, a0),
36         [RV_REG_A1] = offsetof(struct pt_regs, a1),
37         [RV_REG_A2] = offsetof(struct pt_regs, a2),
38         [RV_REG_A3] = offsetof(struct pt_regs, a3),
39         [RV_REG_A4] = offsetof(struct pt_regs, a4),
40         [RV_REG_A5] = offsetof(struct pt_regs, a5),
41         [RV_REG_S1] = offsetof(struct pt_regs, s1),
42         [RV_REG_S2] = offsetof(struct pt_regs, s2),
43         [RV_REG_S3] = offsetof(struct pt_regs, s3),
44         [RV_REG_S4] = offsetof(struct pt_regs, s4),
45         [RV_REG_S5] = offsetof(struct pt_regs, s5),
46         [RV_REG_T0] = offsetof(struct pt_regs, t0),
47 };
48
49 enum {
50         RV_CTX_F_SEEN_TAIL_CALL =       0,
51         RV_CTX_F_SEEN_CALL =            RV_REG_RA,
52         RV_CTX_F_SEEN_S1 =              RV_REG_S1,
53         RV_CTX_F_SEEN_S2 =              RV_REG_S2,
54         RV_CTX_F_SEEN_S3 =              RV_REG_S3,
55         RV_CTX_F_SEEN_S4 =              RV_REG_S4,
56         RV_CTX_F_SEEN_S5 =              RV_REG_S5,
57         RV_CTX_F_SEEN_S6 =              RV_REG_S6,
58 };
59
60 static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
61 {
62         u8 reg = regmap[bpf_reg];
63
64         switch (reg) {
65         case RV_CTX_F_SEEN_S1:
66         case RV_CTX_F_SEEN_S2:
67         case RV_CTX_F_SEEN_S3:
68         case RV_CTX_F_SEEN_S4:
69         case RV_CTX_F_SEEN_S5:
70         case RV_CTX_F_SEEN_S6:
71                 __set_bit(reg, &ctx->flags);
72         }
73         return reg;
74 };
75
76 static bool seen_reg(int reg, struct rv_jit_context *ctx)
77 {
78         switch (reg) {
79         case RV_CTX_F_SEEN_CALL:
80         case RV_CTX_F_SEEN_S1:
81         case RV_CTX_F_SEEN_S2:
82         case RV_CTX_F_SEEN_S3:
83         case RV_CTX_F_SEEN_S4:
84         case RV_CTX_F_SEEN_S5:
85         case RV_CTX_F_SEEN_S6:
86                 return test_bit(reg, &ctx->flags);
87         }
88         return false;
89 }
90
91 static void mark_fp(struct rv_jit_context *ctx)
92 {
93         __set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
94 }
95
96 static void mark_call(struct rv_jit_context *ctx)
97 {
98         __set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
99 }
100
101 static bool seen_call(struct rv_jit_context *ctx)
102 {
103         return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
104 }
105
106 static void mark_tail_call(struct rv_jit_context *ctx)
107 {
108         __set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
109 }
110
111 static bool seen_tail_call(struct rv_jit_context *ctx)
112 {
113         return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
114 }
115
116 static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
117 {
118         mark_tail_call(ctx);
119
120         if (seen_call(ctx)) {
121                 __set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
122                 return RV_REG_S6;
123         }
124         return RV_REG_A6;
125 }
126
127 static bool is_32b_int(s64 val)
128 {
129         return -(1L << 31) <= val && val < (1L << 31);
130 }
131
132 static bool in_auipc_jalr_range(s64 val)
133 {
134         /*
135          * auipc+jalr can reach any signed PC-relative offset in the range
136          * [-2^31 - 2^11, 2^31 - 2^11).
137          */
138         return (-(1L << 31) - (1L << 11)) <= val &&
139                 val < ((1L << 31) - (1L << 11));
140 }
141
142 /* Emit fixed-length instructions for address */
143 static int emit_addr(u8 rd, u64 addr, bool extra_pass, struct rv_jit_context *ctx)
144 {
145         u64 ip = (u64)(ctx->insns + ctx->ninsns);
146         s64 off = addr - ip;
147         s64 upper = (off + (1 << 11)) >> 12;
148         s64 lower = off & 0xfff;
149
150         if (extra_pass && !in_auipc_jalr_range(off)) {
151                 pr_err("bpf-jit: target offset 0x%llx is out of range\n", off);
152                 return -ERANGE;
153         }
154
155         emit(rv_auipc(rd, upper), ctx);
156         emit(rv_addi(rd, rd, lower), ctx);
157         return 0;
158 }
159
160 /* Emit variable-length instructions for 32-bit and 64-bit imm */
161 static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
162 {
163         /* Note that the immediate from the add is sign-extended,
164          * which means that we need to compensate this by adding 2^12,
165          * when the 12th bit is set. A simpler way of doing this, and
166          * getting rid of the check, is to just add 2**11 before the
167          * shift. The "Loading a 32-Bit constant" example from the
168          * "Computer Organization and Design, RISC-V edition" book by
169          * Patterson/Hennessy highlights this fact.
170          *
171          * This also means that we need to process LSB to MSB.
172          */
173         s64 upper = (val + (1 << 11)) >> 12;
174         /* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
175          * and addi are signed and RVC checks will perform signed comparisons.
176          */
177         s64 lower = ((val & 0xfff) << 52) >> 52;
178         int shift;
179
180         if (is_32b_int(val)) {
181                 if (upper)
182                         emit_lui(rd, upper, ctx);
183
184                 if (!upper) {
185                         emit_li(rd, lower, ctx);
186                         return;
187                 }
188
189                 emit_addiw(rd, rd, lower, ctx);
190                 return;
191         }
192
193         shift = __ffs(upper);
194         upper >>= shift;
195         shift += 12;
196
197         emit_imm(rd, upper, ctx);
198
199         emit_slli(rd, rd, shift, ctx);
200         if (lower)
201                 emit_addi(rd, rd, lower, ctx);
202 }
203
204 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
205 {
206         int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
207
208         if (seen_reg(RV_REG_RA, ctx)) {
209                 emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
210                 store_offset -= 8;
211         }
212         emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
213         store_offset -= 8;
214         if (seen_reg(RV_REG_S1, ctx)) {
215                 emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
216                 store_offset -= 8;
217         }
218         if (seen_reg(RV_REG_S2, ctx)) {
219                 emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
220                 store_offset -= 8;
221         }
222         if (seen_reg(RV_REG_S3, ctx)) {
223                 emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
224                 store_offset -= 8;
225         }
226         if (seen_reg(RV_REG_S4, ctx)) {
227                 emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
228                 store_offset -= 8;
229         }
230         if (seen_reg(RV_REG_S5, ctx)) {
231                 emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
232                 store_offset -= 8;
233         }
234         if (seen_reg(RV_REG_S6, ctx)) {
235                 emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
236                 store_offset -= 8;
237         }
238
239         emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
240         /* Set return value. */
241         if (!is_tail_call)
242                 emit_mv(RV_REG_A0, RV_REG_A5, ctx);
243         emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
244                   is_tail_call ? 20 : 0, /* skip reserved nops and TCC init */
245                   ctx);
246 }
247
248 static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
249                      struct rv_jit_context *ctx)
250 {
251         switch (cond) {
252         case BPF_JEQ:
253                 emit(rv_beq(rd, rs, rvoff >> 1), ctx);
254                 return;
255         case BPF_JGT:
256                 emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
257                 return;
258         case BPF_JLT:
259                 emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
260                 return;
261         case BPF_JGE:
262                 emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
263                 return;
264         case BPF_JLE:
265                 emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
266                 return;
267         case BPF_JNE:
268                 emit(rv_bne(rd, rs, rvoff >> 1), ctx);
269                 return;
270         case BPF_JSGT:
271                 emit(rv_blt(rs, rd, rvoff >> 1), ctx);
272                 return;
273         case BPF_JSLT:
274                 emit(rv_blt(rd, rs, rvoff >> 1), ctx);
275                 return;
276         case BPF_JSGE:
277                 emit(rv_bge(rd, rs, rvoff >> 1), ctx);
278                 return;
279         case BPF_JSLE:
280                 emit(rv_bge(rs, rd, rvoff >> 1), ctx);
281         }
282 }
283
284 static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
285                         struct rv_jit_context *ctx)
286 {
287         s64 upper, lower;
288
289         if (is_13b_int(rvoff)) {
290                 emit_bcc(cond, rd, rs, rvoff, ctx);
291                 return;
292         }
293
294         /* Adjust for jal */
295         rvoff -= 4;
296
297         /* Transform, e.g.:
298          *   bne rd,rs,foo
299          * to
300          *   beq rd,rs,<.L1>
301          *   (auipc foo)
302          *   jal(r) foo
303          * .L1
304          */
305         cond = invert_bpf_cond(cond);
306         if (is_21b_int(rvoff)) {
307                 emit_bcc(cond, rd, rs, 8, ctx);
308                 emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
309                 return;
310         }
311
312         /* 32b No need for an additional rvoff adjustment, since we
313          * get that from the auipc at PC', where PC = PC' + 4.
314          */
315         upper = (rvoff + (1 << 11)) >> 12;
316         lower = rvoff & 0xfff;
317
318         emit_bcc(cond, rd, rs, 12, ctx);
319         emit(rv_auipc(RV_REG_T1, upper), ctx);
320         emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
321 }
322
323 static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
324 {
325         emit_slli(reg, reg, 32, ctx);
326         emit_srli(reg, reg, 32, ctx);
327 }
328
329 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
330 {
331         int tc_ninsn, off, start_insn = ctx->ninsns;
332         u8 tcc = rv_tail_call_reg(ctx);
333
334         /* a0: &ctx
335          * a1: &array
336          * a2: index
337          *
338          * if (index >= array->map.max_entries)
339          *      goto out;
340          */
341         tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
342                    ctx->offset[0];
343         emit_zext_32(RV_REG_A2, ctx);
344
345         off = offsetof(struct bpf_array, map.max_entries);
346         if (is_12b_check(off, insn))
347                 return -1;
348         emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
349         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
350         emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
351
352         /* if (--TCC < 0)
353          *     goto out;
354          */
355         emit_addi(RV_REG_TCC, tcc, -1, ctx);
356         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
357         emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
358
359         /* prog = array->ptrs[index];
360          * if (!prog)
361          *     goto out;
362          */
363         emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
364         emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
365         off = offsetof(struct bpf_array, ptrs);
366         if (is_12b_check(off, insn))
367                 return -1;
368         emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
369         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
370         emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
371
372         /* goto *(prog->bpf_func + 4); */
373         off = offsetof(struct bpf_prog, bpf_func);
374         if (is_12b_check(off, insn))
375                 return -1;
376         emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
377         __build_epilogue(true, ctx);
378         return 0;
379 }
380
381 static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
382                       struct rv_jit_context *ctx)
383 {
384         u8 code = insn->code;
385
386         switch (code) {
387         case BPF_JMP | BPF_JA:
388         case BPF_JMP | BPF_CALL:
389         case BPF_JMP | BPF_EXIT:
390         case BPF_JMP | BPF_TAIL_CALL:
391                 break;
392         default:
393                 *rd = bpf_to_rv_reg(insn->dst_reg, ctx);
394         }
395
396         if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
397             code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
398             code & BPF_LDX || code & BPF_STX)
399                 *rs = bpf_to_rv_reg(insn->src_reg, ctx);
400 }
401
402 static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
403 {
404         emit_mv(RV_REG_T2, *rd, ctx);
405         emit_zext_32(RV_REG_T2, ctx);
406         emit_mv(RV_REG_T1, *rs, ctx);
407         emit_zext_32(RV_REG_T1, ctx);
408         *rd = RV_REG_T2;
409         *rs = RV_REG_T1;
410 }
411
412 static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
413 {
414         emit_addiw(RV_REG_T2, *rd, 0, ctx);
415         emit_addiw(RV_REG_T1, *rs, 0, ctx);
416         *rd = RV_REG_T2;
417         *rs = RV_REG_T1;
418 }
419
420 static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx)
421 {
422         emit_mv(RV_REG_T2, *rd, ctx);
423         emit_zext_32(RV_REG_T2, ctx);
424         emit_zext_32(RV_REG_T1, ctx);
425         *rd = RV_REG_T2;
426 }
427
428 static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
429 {
430         emit_addiw(RV_REG_T2, *rd, 0, ctx);
431         *rd = RV_REG_T2;
432 }
433
434 static int emit_jump_and_link(u8 rd, s64 rvoff, bool fixed_addr,
435                               struct rv_jit_context *ctx)
436 {
437         s64 upper, lower;
438
439         if (rvoff && fixed_addr && is_21b_int(rvoff)) {
440                 emit(rv_jal(rd, rvoff >> 1), ctx);
441                 return 0;
442         } else if (in_auipc_jalr_range(rvoff)) {
443                 upper = (rvoff + (1 << 11)) >> 12;
444                 lower = rvoff & 0xfff;
445                 emit(rv_auipc(RV_REG_T1, upper), ctx);
446                 emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
447                 return 0;
448         }
449
450         pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
451         return -ERANGE;
452 }
453
454 static bool is_signed_bpf_cond(u8 cond)
455 {
456         return cond == BPF_JSGT || cond == BPF_JSLT ||
457                 cond == BPF_JSGE || cond == BPF_JSLE;
458 }
459
460 static int emit_call(u64 addr, bool fixed_addr, struct rv_jit_context *ctx)
461 {
462         s64 off = 0;
463         u64 ip;
464
465         if (addr && ctx->insns) {
466                 ip = (u64)(long)(ctx->insns + ctx->ninsns);
467                 off = addr - ip;
468         }
469
470         return emit_jump_and_link(RV_REG_RA, off, fixed_addr, ctx);
471 }
472
473 static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
474                         struct rv_jit_context *ctx)
475 {
476         u8 r0;
477         int jmp_offset;
478
479         if (off) {
480                 if (is_12b_int(off)) {
481                         emit_addi(RV_REG_T1, rd, off, ctx);
482                 } else {
483                         emit_imm(RV_REG_T1, off, ctx);
484                         emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
485                 }
486                 rd = RV_REG_T1;
487         }
488
489         switch (imm) {
490         /* lock *(u32/u64 *)(dst_reg + off16) <op>= src_reg */
491         case BPF_ADD:
492                 emit(is64 ? rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0) :
493                      rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
494                 break;
495         case BPF_AND:
496                 emit(is64 ? rv_amoand_d(RV_REG_ZERO, rs, rd, 0, 0) :
497                      rv_amoand_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
498                 break;
499         case BPF_OR:
500                 emit(is64 ? rv_amoor_d(RV_REG_ZERO, rs, rd, 0, 0) :
501                      rv_amoor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
502                 break;
503         case BPF_XOR:
504                 emit(is64 ? rv_amoxor_d(RV_REG_ZERO, rs, rd, 0, 0) :
505                      rv_amoxor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
506                 break;
507         /* src_reg = atomic_fetch_<op>(dst_reg + off16, src_reg) */
508         case BPF_ADD | BPF_FETCH:
509                 emit(is64 ? rv_amoadd_d(rs, rs, rd, 0, 0) :
510                      rv_amoadd_w(rs, rs, rd, 0, 0), ctx);
511                 if (!is64)
512                         emit_zext_32(rs, ctx);
513                 break;
514         case BPF_AND | BPF_FETCH:
515                 emit(is64 ? rv_amoand_d(rs, rs, rd, 0, 0) :
516                      rv_amoand_w(rs, rs, rd, 0, 0), ctx);
517                 if (!is64)
518                         emit_zext_32(rs, ctx);
519                 break;
520         case BPF_OR | BPF_FETCH:
521                 emit(is64 ? rv_amoor_d(rs, rs, rd, 0, 0) :
522                      rv_amoor_w(rs, rs, rd, 0, 0), ctx);
523                 if (!is64)
524                         emit_zext_32(rs, ctx);
525                 break;
526         case BPF_XOR | BPF_FETCH:
527                 emit(is64 ? rv_amoxor_d(rs, rs, rd, 0, 0) :
528                      rv_amoxor_w(rs, rs, rd, 0, 0), ctx);
529                 if (!is64)
530                         emit_zext_32(rs, ctx);
531                 break;
532         /* src_reg = atomic_xchg(dst_reg + off16, src_reg); */
533         case BPF_XCHG:
534                 emit(is64 ? rv_amoswap_d(rs, rs, rd, 0, 0) :
535                      rv_amoswap_w(rs, rs, rd, 0, 0), ctx);
536                 if (!is64)
537                         emit_zext_32(rs, ctx);
538                 break;
539         /* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */
540         case BPF_CMPXCHG:
541                 r0 = bpf_to_rv_reg(BPF_REG_0, ctx);
542                 emit(is64 ? rv_addi(RV_REG_T2, r0, 0) :
543                      rv_addiw(RV_REG_T2, r0, 0), ctx);
544                 emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) :
545                      rv_lr_w(r0, 0, rd, 0, 0), ctx);
546                 jmp_offset = ninsns_rvoff(8);
547                 emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx);
548                 emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 0) :
549                      rv_sc_w(RV_REG_T3, rs, rd, 0, 0), ctx);
550                 jmp_offset = ninsns_rvoff(-6);
551                 emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx);
552                 emit(rv_fence(0x3, 0x3), ctx);
553                 break;
554         }
555 }
556
557 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
558 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
559
560 bool ex_handler_bpf(const struct exception_table_entry *ex,
561                     struct pt_regs *regs)
562 {
563         off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
564         int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
565
566         *(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
567         regs->epc = (unsigned long)&ex->fixup - offset;
568
569         return true;
570 }
571
572 /* For accesses to BTF pointers, add an entry to the exception table */
573 static int add_exception_handler(const struct bpf_insn *insn,
574                                  struct rv_jit_context *ctx,
575                                  int dst_reg, int insn_len)
576 {
577         struct exception_table_entry *ex;
578         unsigned long pc;
579         off_t offset;
580
581         if (!ctx->insns || !ctx->prog->aux->extable || BPF_MODE(insn->code) != BPF_PROBE_MEM)
582                 return 0;
583
584         if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
585                 return -EINVAL;
586
587         if (WARN_ON_ONCE(insn_len > ctx->ninsns))
588                 return -EINVAL;
589
590         if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
591                 return -EINVAL;
592
593         ex = &ctx->prog->aux->extable[ctx->nexentries];
594         pc = (unsigned long)&ctx->insns[ctx->ninsns - insn_len];
595
596         offset = pc - (long)&ex->insn;
597         if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
598                 return -ERANGE;
599         ex->insn = offset;
600
601         /*
602          * Since the extable follows the program, the fixup offset is always
603          * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
604          * to keep things simple, and put the destination register in the upper
605          * bits. We don't need to worry about buildtime or runtime sort
606          * modifying the upper bits because the table is already sorted, and
607          * isn't part of the main exception table.
608          */
609         offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
610         if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
611                 return -ERANGE;
612
613         ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
614                 FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
615         ex->type = EX_TYPE_BPF;
616
617         ctx->nexentries++;
618         return 0;
619 }
620
621 static int gen_call_or_nops(void *target, void *ip, u32 *insns)
622 {
623         s64 rvoff;
624         int i, ret;
625         struct rv_jit_context ctx;
626
627         ctx.ninsns = 0;
628         ctx.insns = (u16 *)insns;
629
630         if (!target) {
631                 for (i = 0; i < 4; i++)
632                         emit(rv_nop(), &ctx);
633                 return 0;
634         }
635
636         rvoff = (s64)(target - (ip + 4));
637         emit(rv_sd(RV_REG_SP, -8, RV_REG_RA), &ctx);
638         ret = emit_jump_and_link(RV_REG_RA, rvoff, false, &ctx);
639         if (ret)
640                 return ret;
641         emit(rv_ld(RV_REG_RA, -8, RV_REG_SP), &ctx);
642
643         return 0;
644 }
645
646 static int gen_jump_or_nops(void *target, void *ip, u32 *insns)
647 {
648         s64 rvoff;
649         struct rv_jit_context ctx;
650
651         ctx.ninsns = 0;
652         ctx.insns = (u16 *)insns;
653
654         if (!target) {
655                 emit(rv_nop(), &ctx);
656                 emit(rv_nop(), &ctx);
657                 return 0;
658         }
659
660         rvoff = (s64)(target - ip);
661         return emit_jump_and_link(RV_REG_ZERO, rvoff, false, &ctx);
662 }
663
664 int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
665                        void *old_addr, void *new_addr)
666 {
667         u32 old_insns[4], new_insns[4];
668         bool is_call = poke_type == BPF_MOD_CALL;
669         int (*gen_insns)(void *target, void *ip, u32 *insns);
670         int ninsns = is_call ? 4 : 2;
671         int ret;
672
673         if (!is_bpf_text_address((unsigned long)ip))
674                 return -ENOTSUPP;
675
676         gen_insns = is_call ? gen_call_or_nops : gen_jump_or_nops;
677
678         ret = gen_insns(old_addr, ip, old_insns);
679         if (ret)
680                 return ret;
681
682         if (memcmp(ip, old_insns, ninsns * 4))
683                 return -EFAULT;
684
685         ret = gen_insns(new_addr, ip, new_insns);
686         if (ret)
687                 return ret;
688
689         cpus_read_lock();
690         mutex_lock(&text_mutex);
691         if (memcmp(ip, new_insns, ninsns * 4))
692                 ret = patch_text(ip, new_insns, ninsns);
693         mutex_unlock(&text_mutex);
694         cpus_read_unlock();
695
696         return ret;
697 }
698
699 static void store_args(int nregs, int args_off, struct rv_jit_context *ctx)
700 {
701         int i;
702
703         for (i = 0; i < nregs; i++) {
704                 emit_sd(RV_REG_FP, -args_off, RV_REG_A0 + i, ctx);
705                 args_off -= 8;
706         }
707 }
708
709 static void restore_args(int nregs, int args_off, struct rv_jit_context *ctx)
710 {
711         int i;
712
713         for (i = 0; i < nregs; i++) {
714                 emit_ld(RV_REG_A0 + i, -args_off, RV_REG_FP, ctx);
715                 args_off -= 8;
716         }
717 }
718
719 static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_off,
720                            int run_ctx_off, bool save_ret, struct rv_jit_context *ctx)
721 {
722         int ret, branch_off;
723         struct bpf_prog *p = l->link.prog;
724         int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
725
726         if (l->cookie) {
727                 emit_imm(RV_REG_T1, l->cookie, ctx);
728                 emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_T1, ctx);
729         } else {
730                 emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_ZERO, ctx);
731         }
732
733         /* arg1: prog */
734         emit_imm(RV_REG_A0, (const s64)p, ctx);
735         /* arg2: &run_ctx */
736         emit_addi(RV_REG_A1, RV_REG_FP, -run_ctx_off, ctx);
737         ret = emit_call((const u64)bpf_trampoline_enter(p), true, ctx);
738         if (ret)
739                 return ret;
740
741         /* if (__bpf_prog_enter(prog) == 0)
742          *      goto skip_exec_of_prog;
743          */
744         branch_off = ctx->ninsns;
745         /* nop reserved for conditional jump */
746         emit(rv_nop(), ctx);
747
748         /* store prog start time */
749         emit_mv(RV_REG_S1, RV_REG_A0, ctx);
750
751         /* arg1: &args_off */
752         emit_addi(RV_REG_A0, RV_REG_FP, -args_off, ctx);
753         if (!p->jited)
754                 /* arg2: progs[i]->insnsi for interpreter */
755                 emit_imm(RV_REG_A1, (const s64)p->insnsi, ctx);
756         ret = emit_call((const u64)p->bpf_func, true, ctx);
757         if (ret)
758                 return ret;
759
760         if (save_ret)
761                 emit_sd(RV_REG_FP, -retval_off, regmap[BPF_REG_0], ctx);
762
763         /* update branch with beqz */
764         if (ctx->insns) {
765                 int offset = ninsns_rvoff(ctx->ninsns - branch_off);
766                 u32 insn = rv_beq(RV_REG_A0, RV_REG_ZERO, offset >> 1);
767                 *(u32 *)(ctx->insns + branch_off) = insn;
768         }
769
770         /* arg1: prog */
771         emit_imm(RV_REG_A0, (const s64)p, ctx);
772         /* arg2: prog start time */
773         emit_mv(RV_REG_A1, RV_REG_S1, ctx);
774         /* arg3: &run_ctx */
775         emit_addi(RV_REG_A2, RV_REG_FP, -run_ctx_off, ctx);
776         ret = emit_call((const u64)bpf_trampoline_exit(p), true, ctx);
777
778         return ret;
779 }
780
781 static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
782                                          const struct btf_func_model *m,
783                                          struct bpf_tramp_links *tlinks,
784                                          void *func_addr, u32 flags,
785                                          struct rv_jit_context *ctx)
786 {
787         int i, ret, offset;
788         int *branches_off = NULL;
789         int stack_size = 0, nregs = m->nr_args;
790         int retaddr_off, fp_off, retval_off, args_off;
791         int nregs_off, ip_off, run_ctx_off, sreg_off;
792         struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
793         struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
794         struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
795         void *orig_call = func_addr;
796         bool save_ret;
797         u32 insn;
798
799         /* Generated trampoline stack layout:
800          *
801          * FP - 8           [ RA of parent func ] return address of parent
802          *                                        function
803          * FP - retaddr_off [ RA of traced func ] return address of traced
804          *                                        function
805          * FP - fp_off      [ FP of parent func ]
806          *
807          * FP - retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
808          *                                        BPF_TRAMP_F_RET_FENTRY_RET
809          *                  [ argN              ]
810          *                  [ ...               ]
811          * FP - args_off    [ arg1              ]
812          *
813          * FP - nregs_off   [ regs count        ]
814          *
815          * FP - ip_off      [ traced func       ] BPF_TRAMP_F_IP_ARG
816          *
817          * FP - run_ctx_off [ bpf_tramp_run_ctx ]
818          *
819          * FP - sreg_off    [ callee saved reg  ]
820          *
821          *                  [ pads              ] pads for 16 bytes alignment
822          */
823
824         if (flags & (BPF_TRAMP_F_ORIG_STACK | BPF_TRAMP_F_SHARE_IPMODIFY))
825                 return -ENOTSUPP;
826
827         /* extra regiters for struct arguments */
828         for (i = 0; i < m->nr_args; i++)
829                 if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
830                         nregs += round_up(m->arg_size[i], 8) / 8 - 1;
831
832         /* 8 arguments passed by registers */
833         if (nregs > 8)
834                 return -ENOTSUPP;
835
836         /* room for parent function return address */
837         stack_size += 8;
838
839         stack_size += 8;
840         retaddr_off = stack_size;
841
842         stack_size += 8;
843         fp_off = stack_size;
844
845         save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
846         if (save_ret) {
847                 stack_size += 8;
848                 retval_off = stack_size;
849         }
850
851         stack_size += nregs * 8;
852         args_off = stack_size;
853
854         stack_size += 8;
855         nregs_off = stack_size;
856
857         if (flags & BPF_TRAMP_F_IP_ARG) {
858                 stack_size += 8;
859                 ip_off = stack_size;
860         }
861
862         stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
863         run_ctx_off = stack_size;
864
865         stack_size += 8;
866         sreg_off = stack_size;
867
868         stack_size = round_up(stack_size, 16);
869
870         emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
871
872         emit_sd(RV_REG_SP, stack_size - retaddr_off, RV_REG_RA, ctx);
873         emit_sd(RV_REG_SP, stack_size - fp_off, RV_REG_FP, ctx);
874
875         emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
876
877         /* callee saved register S1 to pass start time */
878         emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
879
880         /* store ip address of the traced function */
881         if (flags & BPF_TRAMP_F_IP_ARG) {
882                 emit_imm(RV_REG_T1, (const s64)func_addr, ctx);
883                 emit_sd(RV_REG_FP, -ip_off, RV_REG_T1, ctx);
884         }
885
886         emit_li(RV_REG_T1, nregs, ctx);
887         emit_sd(RV_REG_FP, -nregs_off, RV_REG_T1, ctx);
888
889         store_args(nregs, args_off, ctx);
890
891         /* skip to actual body of traced function */
892         if (flags & BPF_TRAMP_F_SKIP_FRAME)
893                 orig_call += 16;
894
895         if (flags & BPF_TRAMP_F_CALL_ORIG) {
896                 emit_imm(RV_REG_A0, (const s64)im, ctx);
897                 ret = emit_call((const u64)__bpf_tramp_enter, true, ctx);
898                 if (ret)
899                         return ret;
900         }
901
902         for (i = 0; i < fentry->nr_links; i++) {
903                 ret = invoke_bpf_prog(fentry->links[i], args_off, retval_off, run_ctx_off,
904                                       flags & BPF_TRAMP_F_RET_FENTRY_RET, ctx);
905                 if (ret)
906                         return ret;
907         }
908
909         if (fmod_ret->nr_links) {
910                 branches_off = kcalloc(fmod_ret->nr_links, sizeof(int), GFP_KERNEL);
911                 if (!branches_off)
912                         return -ENOMEM;
913
914                 /* cleanup to avoid garbage return value confusion */
915                 emit_sd(RV_REG_FP, -retval_off, RV_REG_ZERO, ctx);
916                 for (i = 0; i < fmod_ret->nr_links; i++) {
917                         ret = invoke_bpf_prog(fmod_ret->links[i], args_off, retval_off,
918                                               run_ctx_off, true, ctx);
919                         if (ret)
920                                 goto out;
921                         emit_ld(RV_REG_T1, -retval_off, RV_REG_FP, ctx);
922                         branches_off[i] = ctx->ninsns;
923                         /* nop reserved for conditional jump */
924                         emit(rv_nop(), ctx);
925                 }
926         }
927
928         if (flags & BPF_TRAMP_F_CALL_ORIG) {
929                 restore_args(nregs, args_off, ctx);
930                 ret = emit_call((const u64)orig_call, true, ctx);
931                 if (ret)
932                         goto out;
933                 emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
934                 im->ip_after_call = ctx->insns + ctx->ninsns;
935                 /* 2 nops reserved for auipc+jalr pair */
936                 emit(rv_nop(), ctx);
937                 emit(rv_nop(), ctx);
938         }
939
940         /* update branches saved in invoke_bpf_mod_ret with bnez */
941         for (i = 0; ctx->insns && i < fmod_ret->nr_links; i++) {
942                 offset = ninsns_rvoff(ctx->ninsns - branches_off[i]);
943                 insn = rv_bne(RV_REG_T1, RV_REG_ZERO, offset >> 1);
944                 *(u32 *)(ctx->insns + branches_off[i]) = insn;
945         }
946
947         for (i = 0; i < fexit->nr_links; i++) {
948                 ret = invoke_bpf_prog(fexit->links[i], args_off, retval_off,
949                                       run_ctx_off, false, ctx);
950                 if (ret)
951                         goto out;
952         }
953
954         if (flags & BPF_TRAMP_F_CALL_ORIG) {
955                 im->ip_epilogue = ctx->insns + ctx->ninsns;
956                 emit_imm(RV_REG_A0, (const s64)im, ctx);
957                 ret = emit_call((const u64)__bpf_tramp_exit, true, ctx);
958                 if (ret)
959                         goto out;
960         }
961
962         if (flags & BPF_TRAMP_F_RESTORE_REGS)
963                 restore_args(nregs, args_off, ctx);
964
965         if (save_ret)
966                 emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx);
967
968         emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
969
970         if (flags & BPF_TRAMP_F_SKIP_FRAME)
971                 /* return address of parent function */
972                 emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
973         else
974                 /* return address of traced function */
975                 emit_ld(RV_REG_RA, stack_size - retaddr_off, RV_REG_SP, ctx);
976
977         emit_ld(RV_REG_FP, stack_size - fp_off, RV_REG_SP, ctx);
978         emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
979
980         emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
981
982         ret = ctx->ninsns;
983 out:
984         kfree(branches_off);
985         return ret;
986 }
987
988 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
989                                 void *image_end, const struct btf_func_model *m,
990                                 u32 flags, struct bpf_tramp_links *tlinks,
991                                 void *func_addr)
992 {
993         int ret;
994         struct rv_jit_context ctx;
995
996         ctx.ninsns = 0;
997         ctx.insns = NULL;
998         ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
999         if (ret < 0)
1000                 return ret;
1001
1002         if (ninsns_rvoff(ret) > (long)image_end - (long)image)
1003                 return -EFBIG;
1004
1005         ctx.ninsns = 0;
1006         ctx.insns = image;
1007         ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
1008         if (ret < 0)
1009                 return ret;
1010
1011         bpf_flush_icache(ctx.insns, ctx.insns + ctx.ninsns);
1012
1013         return ninsns_rvoff(ret);
1014 }
1015
1016 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
1017                       bool extra_pass)
1018 {
1019         bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
1020                     BPF_CLASS(insn->code) == BPF_JMP;
1021         int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
1022         struct bpf_prog_aux *aux = ctx->prog->aux;
1023         u8 rd = -1, rs = -1, code = insn->code;
1024         s16 off = insn->off;
1025         s32 imm = insn->imm;
1026
1027         init_regs(&rd, &rs, insn, ctx);
1028
1029         switch (code) {
1030         /* dst = src */
1031         case BPF_ALU | BPF_MOV | BPF_X:
1032         case BPF_ALU64 | BPF_MOV | BPF_X:
1033                 if (imm == 1) {
1034                         /* Special mov32 for zext */
1035                         emit_zext_32(rd, ctx);
1036                         break;
1037                 }
1038                 emit_mv(rd, rs, ctx);
1039                 if (!is64 && !aux->verifier_zext)
1040                         emit_zext_32(rd, ctx);
1041                 break;
1042
1043         /* dst = dst OP src */
1044         case BPF_ALU | BPF_ADD | BPF_X:
1045         case BPF_ALU64 | BPF_ADD | BPF_X:
1046                 emit_add(rd, rd, rs, ctx);
1047                 if (!is64 && !aux->verifier_zext)
1048                         emit_zext_32(rd, ctx);
1049                 break;
1050         case BPF_ALU | BPF_SUB | BPF_X:
1051         case BPF_ALU64 | BPF_SUB | BPF_X:
1052                 if (is64)
1053                         emit_sub(rd, rd, rs, ctx);
1054                 else
1055                         emit_subw(rd, rd, rs, ctx);
1056
1057                 if (!is64 && !aux->verifier_zext)
1058                         emit_zext_32(rd, ctx);
1059                 break;
1060         case BPF_ALU | BPF_AND | BPF_X:
1061         case BPF_ALU64 | BPF_AND | BPF_X:
1062                 emit_and(rd, rd, rs, ctx);
1063                 if (!is64 && !aux->verifier_zext)
1064                         emit_zext_32(rd, ctx);
1065                 break;
1066         case BPF_ALU | BPF_OR | BPF_X:
1067         case BPF_ALU64 | BPF_OR | BPF_X:
1068                 emit_or(rd, rd, rs, ctx);
1069                 if (!is64 && !aux->verifier_zext)
1070                         emit_zext_32(rd, ctx);
1071                 break;
1072         case BPF_ALU | BPF_XOR | BPF_X:
1073         case BPF_ALU64 | BPF_XOR | BPF_X:
1074                 emit_xor(rd, rd, rs, ctx);
1075                 if (!is64 && !aux->verifier_zext)
1076                         emit_zext_32(rd, ctx);
1077                 break;
1078         case BPF_ALU | BPF_MUL | BPF_X:
1079         case BPF_ALU64 | BPF_MUL | BPF_X:
1080                 emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
1081                 if (!is64 && !aux->verifier_zext)
1082                         emit_zext_32(rd, ctx);
1083                 break;
1084         case BPF_ALU | BPF_DIV | BPF_X:
1085         case BPF_ALU64 | BPF_DIV | BPF_X:
1086                 emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
1087                 if (!is64 && !aux->verifier_zext)
1088                         emit_zext_32(rd, ctx);
1089                 break;
1090         case BPF_ALU | BPF_MOD | BPF_X:
1091         case BPF_ALU64 | BPF_MOD | BPF_X:
1092                 emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
1093                 if (!is64 && !aux->verifier_zext)
1094                         emit_zext_32(rd, ctx);
1095                 break;
1096         case BPF_ALU | BPF_LSH | BPF_X:
1097         case BPF_ALU64 | BPF_LSH | BPF_X:
1098                 emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
1099                 if (!is64 && !aux->verifier_zext)
1100                         emit_zext_32(rd, ctx);
1101                 break;
1102         case BPF_ALU | BPF_RSH | BPF_X:
1103         case BPF_ALU64 | BPF_RSH | BPF_X:
1104                 emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
1105                 if (!is64 && !aux->verifier_zext)
1106                         emit_zext_32(rd, ctx);
1107                 break;
1108         case BPF_ALU | BPF_ARSH | BPF_X:
1109         case BPF_ALU64 | BPF_ARSH | BPF_X:
1110                 emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
1111                 if (!is64 && !aux->verifier_zext)
1112                         emit_zext_32(rd, ctx);
1113                 break;
1114
1115         /* dst = -dst */
1116         case BPF_ALU | BPF_NEG:
1117         case BPF_ALU64 | BPF_NEG:
1118                 emit_sub(rd, RV_REG_ZERO, rd, ctx);
1119                 if (!is64 && !aux->verifier_zext)
1120                         emit_zext_32(rd, ctx);
1121                 break;
1122
1123         /* dst = BSWAP##imm(dst) */
1124         case BPF_ALU | BPF_END | BPF_FROM_LE:
1125                 switch (imm) {
1126                 case 16:
1127                         emit_slli(rd, rd, 48, ctx);
1128                         emit_srli(rd, rd, 48, ctx);
1129                         break;
1130                 case 32:
1131                         if (!aux->verifier_zext)
1132                                 emit_zext_32(rd, ctx);
1133                         break;
1134                 case 64:
1135                         /* Do nothing */
1136                         break;
1137                 }
1138                 break;
1139
1140         case BPF_ALU | BPF_END | BPF_FROM_BE:
1141                 emit_li(RV_REG_T2, 0, ctx);
1142
1143                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1144                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1145                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1146                 emit_srli(rd, rd, 8, ctx);
1147                 if (imm == 16)
1148                         goto out_be;
1149
1150                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1151                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1152                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1153                 emit_srli(rd, rd, 8, ctx);
1154
1155                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1156                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1157                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1158                 emit_srli(rd, rd, 8, ctx);
1159                 if (imm == 32)
1160                         goto out_be;
1161
1162                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1163                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1164                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1165                 emit_srli(rd, rd, 8, ctx);
1166
1167                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1168                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1169                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1170                 emit_srli(rd, rd, 8, ctx);
1171
1172                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1173                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1174                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1175                 emit_srli(rd, rd, 8, ctx);
1176
1177                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1178                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1179                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1180                 emit_srli(rd, rd, 8, ctx);
1181 out_be:
1182                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1183                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1184
1185                 emit_mv(rd, RV_REG_T2, ctx);
1186                 break;
1187
1188         /* dst = imm */
1189         case BPF_ALU | BPF_MOV | BPF_K:
1190         case BPF_ALU64 | BPF_MOV | BPF_K:
1191                 emit_imm(rd, imm, ctx);
1192                 if (!is64 && !aux->verifier_zext)
1193                         emit_zext_32(rd, ctx);
1194                 break;
1195
1196         /* dst = dst OP imm */
1197         case BPF_ALU | BPF_ADD | BPF_K:
1198         case BPF_ALU64 | BPF_ADD | BPF_K:
1199                 if (is_12b_int(imm)) {
1200                         emit_addi(rd, rd, imm, ctx);
1201                 } else {
1202                         emit_imm(RV_REG_T1, imm, ctx);
1203                         emit_add(rd, rd, RV_REG_T1, ctx);
1204                 }
1205                 if (!is64 && !aux->verifier_zext)
1206                         emit_zext_32(rd, ctx);
1207                 break;
1208         case BPF_ALU | BPF_SUB | BPF_K:
1209         case BPF_ALU64 | BPF_SUB | BPF_K:
1210                 if (is_12b_int(-imm)) {
1211                         emit_addi(rd, rd, -imm, ctx);
1212                 } else {
1213                         emit_imm(RV_REG_T1, imm, ctx);
1214                         emit_sub(rd, rd, RV_REG_T1, ctx);
1215                 }
1216                 if (!is64 && !aux->verifier_zext)
1217                         emit_zext_32(rd, ctx);
1218                 break;
1219         case BPF_ALU | BPF_AND | BPF_K:
1220         case BPF_ALU64 | BPF_AND | BPF_K:
1221                 if (is_12b_int(imm)) {
1222                         emit_andi(rd, rd, imm, ctx);
1223                 } else {
1224                         emit_imm(RV_REG_T1, imm, ctx);
1225                         emit_and(rd, rd, RV_REG_T1, ctx);
1226                 }
1227                 if (!is64 && !aux->verifier_zext)
1228                         emit_zext_32(rd, ctx);
1229                 break;
1230         case BPF_ALU | BPF_OR | BPF_K:
1231         case BPF_ALU64 | BPF_OR | BPF_K:
1232                 if (is_12b_int(imm)) {
1233                         emit(rv_ori(rd, rd, imm), ctx);
1234                 } else {
1235                         emit_imm(RV_REG_T1, imm, ctx);
1236                         emit_or(rd, rd, RV_REG_T1, ctx);
1237                 }
1238                 if (!is64 && !aux->verifier_zext)
1239                         emit_zext_32(rd, ctx);
1240                 break;
1241         case BPF_ALU | BPF_XOR | BPF_K:
1242         case BPF_ALU64 | BPF_XOR | BPF_K:
1243                 if (is_12b_int(imm)) {
1244                         emit(rv_xori(rd, rd, imm), ctx);
1245                 } else {
1246                         emit_imm(RV_REG_T1, imm, ctx);
1247                         emit_xor(rd, rd, RV_REG_T1, ctx);
1248                 }
1249                 if (!is64 && !aux->verifier_zext)
1250                         emit_zext_32(rd, ctx);
1251                 break;
1252         case BPF_ALU | BPF_MUL | BPF_K:
1253         case BPF_ALU64 | BPF_MUL | BPF_K:
1254                 emit_imm(RV_REG_T1, imm, ctx);
1255                 emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
1256                      rv_mulw(rd, rd, RV_REG_T1), ctx);
1257                 if (!is64 && !aux->verifier_zext)
1258                         emit_zext_32(rd, ctx);
1259                 break;
1260         case BPF_ALU | BPF_DIV | BPF_K:
1261         case BPF_ALU64 | BPF_DIV | BPF_K:
1262                 emit_imm(RV_REG_T1, imm, ctx);
1263                 emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
1264                      rv_divuw(rd, rd, RV_REG_T1), ctx);
1265                 if (!is64 && !aux->verifier_zext)
1266                         emit_zext_32(rd, ctx);
1267                 break;
1268         case BPF_ALU | BPF_MOD | BPF_K:
1269         case BPF_ALU64 | BPF_MOD | BPF_K:
1270                 emit_imm(RV_REG_T1, imm, ctx);
1271                 emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
1272                      rv_remuw(rd, rd, RV_REG_T1), ctx);
1273                 if (!is64 && !aux->verifier_zext)
1274                         emit_zext_32(rd, ctx);
1275                 break;
1276         case BPF_ALU | BPF_LSH | BPF_K:
1277         case BPF_ALU64 | BPF_LSH | BPF_K:
1278                 emit_slli(rd, rd, imm, ctx);
1279
1280                 if (!is64 && !aux->verifier_zext)
1281                         emit_zext_32(rd, ctx);
1282                 break;
1283         case BPF_ALU | BPF_RSH | BPF_K:
1284         case BPF_ALU64 | BPF_RSH | BPF_K:
1285                 if (is64)
1286                         emit_srli(rd, rd, imm, ctx);
1287                 else
1288                         emit(rv_srliw(rd, rd, imm), ctx);
1289
1290                 if (!is64 && !aux->verifier_zext)
1291                         emit_zext_32(rd, ctx);
1292                 break;
1293         case BPF_ALU | BPF_ARSH | BPF_K:
1294         case BPF_ALU64 | BPF_ARSH | BPF_K:
1295                 if (is64)
1296                         emit_srai(rd, rd, imm, ctx);
1297                 else
1298                         emit(rv_sraiw(rd, rd, imm), ctx);
1299
1300                 if (!is64 && !aux->verifier_zext)
1301                         emit_zext_32(rd, ctx);
1302                 break;
1303
1304         /* JUMP off */
1305         case BPF_JMP | BPF_JA:
1306                 rvoff = rv_offset(i, off, ctx);
1307                 ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1308                 if (ret)
1309                         return ret;
1310                 break;
1311
1312         /* IF (dst COND src) JUMP off */
1313         case BPF_JMP | BPF_JEQ | BPF_X:
1314         case BPF_JMP32 | BPF_JEQ | BPF_X:
1315         case BPF_JMP | BPF_JGT | BPF_X:
1316         case BPF_JMP32 | BPF_JGT | BPF_X:
1317         case BPF_JMP | BPF_JLT | BPF_X:
1318         case BPF_JMP32 | BPF_JLT | BPF_X:
1319         case BPF_JMP | BPF_JGE | BPF_X:
1320         case BPF_JMP32 | BPF_JGE | BPF_X:
1321         case BPF_JMP | BPF_JLE | BPF_X:
1322         case BPF_JMP32 | BPF_JLE | BPF_X:
1323         case BPF_JMP | BPF_JNE | BPF_X:
1324         case BPF_JMP32 | BPF_JNE | BPF_X:
1325         case BPF_JMP | BPF_JSGT | BPF_X:
1326         case BPF_JMP32 | BPF_JSGT | BPF_X:
1327         case BPF_JMP | BPF_JSLT | BPF_X:
1328         case BPF_JMP32 | BPF_JSLT | BPF_X:
1329         case BPF_JMP | BPF_JSGE | BPF_X:
1330         case BPF_JMP32 | BPF_JSGE | BPF_X:
1331         case BPF_JMP | BPF_JSLE | BPF_X:
1332         case BPF_JMP32 | BPF_JSLE | BPF_X:
1333         case BPF_JMP | BPF_JSET | BPF_X:
1334         case BPF_JMP32 | BPF_JSET | BPF_X:
1335                 rvoff = rv_offset(i, off, ctx);
1336                 if (!is64) {
1337                         s = ctx->ninsns;
1338                         if (is_signed_bpf_cond(BPF_OP(code)))
1339                                 emit_sext_32_rd_rs(&rd, &rs, ctx);
1340                         else
1341                                 emit_zext_32_rd_rs(&rd, &rs, ctx);
1342                         e = ctx->ninsns;
1343
1344                         /* Adjust for extra insns */
1345                         rvoff -= ninsns_rvoff(e - s);
1346                 }
1347
1348                 if (BPF_OP(code) == BPF_JSET) {
1349                         /* Adjust for and */
1350                         rvoff -= 4;
1351                         emit_and(RV_REG_T1, rd, rs, ctx);
1352                         emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
1353                                     ctx);
1354                 } else {
1355                         emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1356                 }
1357                 break;
1358
1359         /* IF (dst COND imm) JUMP off */
1360         case BPF_JMP | BPF_JEQ | BPF_K:
1361         case BPF_JMP32 | BPF_JEQ | BPF_K:
1362         case BPF_JMP | BPF_JGT | BPF_K:
1363         case BPF_JMP32 | BPF_JGT | BPF_K:
1364         case BPF_JMP | BPF_JLT | BPF_K:
1365         case BPF_JMP32 | BPF_JLT | BPF_K:
1366         case BPF_JMP | BPF_JGE | BPF_K:
1367         case BPF_JMP32 | BPF_JGE | BPF_K:
1368         case BPF_JMP | BPF_JLE | BPF_K:
1369         case BPF_JMP32 | BPF_JLE | BPF_K:
1370         case BPF_JMP | BPF_JNE | BPF_K:
1371         case BPF_JMP32 | BPF_JNE | BPF_K:
1372         case BPF_JMP | BPF_JSGT | BPF_K:
1373         case BPF_JMP32 | BPF_JSGT | BPF_K:
1374         case BPF_JMP | BPF_JSLT | BPF_K:
1375         case BPF_JMP32 | BPF_JSLT | BPF_K:
1376         case BPF_JMP | BPF_JSGE | BPF_K:
1377         case BPF_JMP32 | BPF_JSGE | BPF_K:
1378         case BPF_JMP | BPF_JSLE | BPF_K:
1379         case BPF_JMP32 | BPF_JSLE | BPF_K:
1380                 rvoff = rv_offset(i, off, ctx);
1381                 s = ctx->ninsns;
1382                 if (imm) {
1383                         emit_imm(RV_REG_T1, imm, ctx);
1384                         rs = RV_REG_T1;
1385                 } else {
1386                         /* If imm is 0, simply use zero register. */
1387                         rs = RV_REG_ZERO;
1388                 }
1389                 if (!is64) {
1390                         if (is_signed_bpf_cond(BPF_OP(code)))
1391                                 emit_sext_32_rd(&rd, ctx);
1392                         else
1393                                 emit_zext_32_rd_t1(&rd, ctx);
1394                 }
1395                 e = ctx->ninsns;
1396
1397                 /* Adjust for extra insns */
1398                 rvoff -= ninsns_rvoff(e - s);
1399                 emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1400                 break;
1401
1402         case BPF_JMP | BPF_JSET | BPF_K:
1403         case BPF_JMP32 | BPF_JSET | BPF_K:
1404                 rvoff = rv_offset(i, off, ctx);
1405                 s = ctx->ninsns;
1406                 if (is_12b_int(imm)) {
1407                         emit_andi(RV_REG_T1, rd, imm, ctx);
1408                 } else {
1409                         emit_imm(RV_REG_T1, imm, ctx);
1410                         emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
1411                 }
1412                 /* For jset32, we should clear the upper 32 bits of t1, but
1413                  * sign-extension is sufficient here and saves one instruction,
1414                  * as t1 is used only in comparison against zero.
1415                  */
1416                 if (!is64 && imm < 0)
1417                         emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx);
1418                 e = ctx->ninsns;
1419                 rvoff -= ninsns_rvoff(e - s);
1420                 emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1421                 break;
1422
1423         /* function call */
1424         case BPF_JMP | BPF_CALL:
1425         {
1426                 bool fixed_addr;
1427                 u64 addr;
1428
1429                 mark_call(ctx);
1430                 ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1431                                             &addr, &fixed_addr);
1432                 if (ret < 0)
1433                         return ret;
1434
1435                 ret = emit_call(addr, fixed_addr, ctx);
1436                 if (ret)
1437                         return ret;
1438
1439                 emit_mv(bpf_to_rv_reg(BPF_REG_0, ctx), RV_REG_A0, ctx);
1440                 break;
1441         }
1442         /* tail call */
1443         case BPF_JMP | BPF_TAIL_CALL:
1444                 if (emit_bpf_tail_call(i, ctx))
1445                         return -1;
1446                 break;
1447
1448         /* function return */
1449         case BPF_JMP | BPF_EXIT:
1450                 if (i == ctx->prog->len - 1)
1451                         break;
1452
1453                 rvoff = epilogue_offset(ctx);
1454                 ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1455                 if (ret)
1456                         return ret;
1457                 break;
1458
1459         /* dst = imm64 */
1460         case BPF_LD | BPF_IMM | BPF_DW:
1461         {
1462                 struct bpf_insn insn1 = insn[1];
1463                 u64 imm64;
1464
1465                 imm64 = (u64)insn1.imm << 32 | (u32)imm;
1466                 if (bpf_pseudo_func(insn)) {
1467                         /* fixed-length insns for extra jit pass */
1468                         ret = emit_addr(rd, imm64, extra_pass, ctx);
1469                         if (ret)
1470                                 return ret;
1471                 } else {
1472                         emit_imm(rd, imm64, ctx);
1473                 }
1474
1475                 return 1;
1476         }
1477
1478         /* LDX: dst = *(size *)(src + off) */
1479         case BPF_LDX | BPF_MEM | BPF_B:
1480         case BPF_LDX | BPF_MEM | BPF_H:
1481         case BPF_LDX | BPF_MEM | BPF_W:
1482         case BPF_LDX | BPF_MEM | BPF_DW:
1483         case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1484         case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1485         case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1486         case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1487         {
1488                 int insn_len, insns_start;
1489
1490                 switch (BPF_SIZE(code)) {
1491                 case BPF_B:
1492                         if (is_12b_int(off)) {
1493                                 insns_start = ctx->ninsns;
1494                                 emit(rv_lbu(rd, off, rs), ctx);
1495                                 insn_len = ctx->ninsns - insns_start;
1496                                 break;
1497                         }
1498
1499                         emit_imm(RV_REG_T1, off, ctx);
1500                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1501                         insns_start = ctx->ninsns;
1502                         emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
1503                         insn_len = ctx->ninsns - insns_start;
1504                         if (insn_is_zext(&insn[1]))
1505                                 return 1;
1506                         break;
1507                 case BPF_H:
1508                         if (is_12b_int(off)) {
1509                                 insns_start = ctx->ninsns;
1510                                 emit(rv_lhu(rd, off, rs), ctx);
1511                                 insn_len = ctx->ninsns - insns_start;
1512                                 break;
1513                         }
1514
1515                         emit_imm(RV_REG_T1, off, ctx);
1516                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1517                         insns_start = ctx->ninsns;
1518                         emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1519                         insn_len = ctx->ninsns - insns_start;
1520                         if (insn_is_zext(&insn[1]))
1521                                 return 1;
1522                         break;
1523                 case BPF_W:
1524                         if (is_12b_int(off)) {
1525                                 insns_start = ctx->ninsns;
1526                                 emit(rv_lwu(rd, off, rs), ctx);
1527                                 insn_len = ctx->ninsns - insns_start;
1528                                 break;
1529                         }
1530
1531                         emit_imm(RV_REG_T1, off, ctx);
1532                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1533                         insns_start = ctx->ninsns;
1534                         emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1535                         insn_len = ctx->ninsns - insns_start;
1536                         if (insn_is_zext(&insn[1]))
1537                                 return 1;
1538                         break;
1539                 case BPF_DW:
1540                         if (is_12b_int(off)) {
1541                                 insns_start = ctx->ninsns;
1542                                 emit_ld(rd, off, rs, ctx);
1543                                 insn_len = ctx->ninsns - insns_start;
1544                                 break;
1545                         }
1546
1547                         emit_imm(RV_REG_T1, off, ctx);
1548                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1549                         insns_start = ctx->ninsns;
1550                         emit_ld(rd, 0, RV_REG_T1, ctx);
1551                         insn_len = ctx->ninsns - insns_start;
1552                         break;
1553                 }
1554
1555                 ret = add_exception_handler(insn, ctx, rd, insn_len);
1556                 if (ret)
1557                         return ret;
1558                 break;
1559         }
1560         /* speculation barrier */
1561         case BPF_ST | BPF_NOSPEC:
1562                 break;
1563
1564         /* ST: *(size *)(dst + off) = imm */
1565         case BPF_ST | BPF_MEM | BPF_B:
1566                 emit_imm(RV_REG_T1, imm, ctx);
1567                 if (is_12b_int(off)) {
1568                         emit(rv_sb(rd, off, RV_REG_T1), ctx);
1569                         break;
1570                 }
1571
1572                 emit_imm(RV_REG_T2, off, ctx);
1573                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1574                 emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1575                 break;
1576
1577         case BPF_ST | BPF_MEM | BPF_H:
1578                 emit_imm(RV_REG_T1, imm, ctx);
1579                 if (is_12b_int(off)) {
1580                         emit(rv_sh(rd, off, RV_REG_T1), ctx);
1581                         break;
1582                 }
1583
1584                 emit_imm(RV_REG_T2, off, ctx);
1585                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1586                 emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1587                 break;
1588         case BPF_ST | BPF_MEM | BPF_W:
1589                 emit_imm(RV_REG_T1, imm, ctx);
1590                 if (is_12b_int(off)) {
1591                         emit_sw(rd, off, RV_REG_T1, ctx);
1592                         break;
1593                 }
1594
1595                 emit_imm(RV_REG_T2, off, ctx);
1596                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1597                 emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1598                 break;
1599         case BPF_ST | BPF_MEM | BPF_DW:
1600                 emit_imm(RV_REG_T1, imm, ctx);
1601                 if (is_12b_int(off)) {
1602                         emit_sd(rd, off, RV_REG_T1, ctx);
1603                         break;
1604                 }
1605
1606                 emit_imm(RV_REG_T2, off, ctx);
1607                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1608                 emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1609                 break;
1610
1611         /* STX: *(size *)(dst + off) = src */
1612         case BPF_STX | BPF_MEM | BPF_B:
1613                 if (is_12b_int(off)) {
1614                         emit(rv_sb(rd, off, rs), ctx);
1615                         break;
1616                 }
1617
1618                 emit_imm(RV_REG_T1, off, ctx);
1619                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1620                 emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1621                 break;
1622         case BPF_STX | BPF_MEM | BPF_H:
1623                 if (is_12b_int(off)) {
1624                         emit(rv_sh(rd, off, rs), ctx);
1625                         break;
1626                 }
1627
1628                 emit_imm(RV_REG_T1, off, ctx);
1629                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1630                 emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1631                 break;
1632         case BPF_STX | BPF_MEM | BPF_W:
1633                 if (is_12b_int(off)) {
1634                         emit_sw(rd, off, rs, ctx);
1635                         break;
1636                 }
1637
1638                 emit_imm(RV_REG_T1, off, ctx);
1639                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1640                 emit_sw(RV_REG_T1, 0, rs, ctx);
1641                 break;
1642         case BPF_STX | BPF_MEM | BPF_DW:
1643                 if (is_12b_int(off)) {
1644                         emit_sd(rd, off, rs, ctx);
1645                         break;
1646                 }
1647
1648                 emit_imm(RV_REG_T1, off, ctx);
1649                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1650                 emit_sd(RV_REG_T1, 0, rs, ctx);
1651                 break;
1652         case BPF_STX | BPF_ATOMIC | BPF_W:
1653         case BPF_STX | BPF_ATOMIC | BPF_DW:
1654                 emit_atomic(rd, rs, off, imm,
1655                             BPF_SIZE(code) == BPF_DW, ctx);
1656                 break;
1657         default:
1658                 pr_err("bpf-jit: unknown opcode %02x\n", code);
1659                 return -EINVAL;
1660         }
1661
1662         return 0;
1663 }
1664
1665 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1666 {
1667         int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
1668
1669         bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1670         if (bpf_stack_adjust)
1671                 mark_fp(ctx);
1672
1673         if (seen_reg(RV_REG_RA, ctx))
1674                 stack_adjust += 8;
1675         stack_adjust += 8; /* RV_REG_FP */
1676         if (seen_reg(RV_REG_S1, ctx))
1677                 stack_adjust += 8;
1678         if (seen_reg(RV_REG_S2, ctx))
1679                 stack_adjust += 8;
1680         if (seen_reg(RV_REG_S3, ctx))
1681                 stack_adjust += 8;
1682         if (seen_reg(RV_REG_S4, ctx))
1683                 stack_adjust += 8;
1684         if (seen_reg(RV_REG_S5, ctx))
1685                 stack_adjust += 8;
1686         if (seen_reg(RV_REG_S6, ctx))
1687                 stack_adjust += 8;
1688
1689         stack_adjust = round_up(stack_adjust, 16);
1690         stack_adjust += bpf_stack_adjust;
1691
1692         store_offset = stack_adjust - 8;
1693
1694         /* reserve 4 nop insns */
1695         for (i = 0; i < 4; i++)
1696                 emit(rv_nop(), ctx);
1697
1698         /* First instruction is always setting the tail-call-counter
1699          * (TCC) register. This instruction is skipped for tail calls.
1700          * Force using a 4-byte (non-compressed) instruction.
1701          */
1702         emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1703
1704         emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1705
1706         if (seen_reg(RV_REG_RA, ctx)) {
1707                 emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1708                 store_offset -= 8;
1709         }
1710         emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1711         store_offset -= 8;
1712         if (seen_reg(RV_REG_S1, ctx)) {
1713                 emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1714                 store_offset -= 8;
1715         }
1716         if (seen_reg(RV_REG_S2, ctx)) {
1717                 emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1718                 store_offset -= 8;
1719         }
1720         if (seen_reg(RV_REG_S3, ctx)) {
1721                 emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1722                 store_offset -= 8;
1723         }
1724         if (seen_reg(RV_REG_S4, ctx)) {
1725                 emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1726                 store_offset -= 8;
1727         }
1728         if (seen_reg(RV_REG_S5, ctx)) {
1729                 emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1730                 store_offset -= 8;
1731         }
1732         if (seen_reg(RV_REG_S6, ctx)) {
1733                 emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1734                 store_offset -= 8;
1735         }
1736
1737         emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1738
1739         if (bpf_stack_adjust)
1740                 emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1741
1742         /* Program contains calls and tail calls, so RV_REG_TCC need
1743          * to be saved across calls.
1744          */
1745         if (seen_tail_call(ctx) && seen_call(ctx))
1746                 emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1747
1748         ctx->stack_size = stack_adjust;
1749 }
1750
1751 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1752 {
1753         __build_epilogue(false, ctx);
1754 }