Documentation: litmus-tests: Correct spelling
[linux-block.git] / arch / riscv / net / bpf_jit_comp64.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* BPF JIT compiler for RV64G
3  *
4  * Copyright(c) 2019 Björn Töpel <bjorn.topel@gmail.com>
5  *
6  */
7
8 #include <linux/bitfield.h>
9 #include <linux/bpf.h>
10 #include <linux/filter.h>
11 #include <linux/memory.h>
12 #include <linux/stop_machine.h>
13 #include "bpf_jit.h"
14
15 #define RV_REG_TCC RV_REG_A6
16 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
17
18 static const int regmap[] = {
19         [BPF_REG_0] =   RV_REG_A5,
20         [BPF_REG_1] =   RV_REG_A0,
21         [BPF_REG_2] =   RV_REG_A1,
22         [BPF_REG_3] =   RV_REG_A2,
23         [BPF_REG_4] =   RV_REG_A3,
24         [BPF_REG_5] =   RV_REG_A4,
25         [BPF_REG_6] =   RV_REG_S1,
26         [BPF_REG_7] =   RV_REG_S2,
27         [BPF_REG_8] =   RV_REG_S3,
28         [BPF_REG_9] =   RV_REG_S4,
29         [BPF_REG_FP] =  RV_REG_S5,
30         [BPF_REG_AX] =  RV_REG_T0,
31 };
32
33 static const int pt_regmap[] = {
34         [RV_REG_A0] = offsetof(struct pt_regs, a0),
35         [RV_REG_A1] = offsetof(struct pt_regs, a1),
36         [RV_REG_A2] = offsetof(struct pt_regs, a2),
37         [RV_REG_A3] = offsetof(struct pt_regs, a3),
38         [RV_REG_A4] = offsetof(struct pt_regs, a4),
39         [RV_REG_A5] = offsetof(struct pt_regs, a5),
40         [RV_REG_S1] = offsetof(struct pt_regs, s1),
41         [RV_REG_S2] = offsetof(struct pt_regs, s2),
42         [RV_REG_S3] = offsetof(struct pt_regs, s3),
43         [RV_REG_S4] = offsetof(struct pt_regs, s4),
44         [RV_REG_S5] = offsetof(struct pt_regs, s5),
45         [RV_REG_T0] = offsetof(struct pt_regs, t0),
46 };
47
48 enum {
49         RV_CTX_F_SEEN_TAIL_CALL =       0,
50         RV_CTX_F_SEEN_CALL =            RV_REG_RA,
51         RV_CTX_F_SEEN_S1 =              RV_REG_S1,
52         RV_CTX_F_SEEN_S2 =              RV_REG_S2,
53         RV_CTX_F_SEEN_S3 =              RV_REG_S3,
54         RV_CTX_F_SEEN_S4 =              RV_REG_S4,
55         RV_CTX_F_SEEN_S5 =              RV_REG_S5,
56         RV_CTX_F_SEEN_S6 =              RV_REG_S6,
57 };
58
59 static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
60 {
61         u8 reg = regmap[bpf_reg];
62
63         switch (reg) {
64         case RV_CTX_F_SEEN_S1:
65         case RV_CTX_F_SEEN_S2:
66         case RV_CTX_F_SEEN_S3:
67         case RV_CTX_F_SEEN_S4:
68         case RV_CTX_F_SEEN_S5:
69         case RV_CTX_F_SEEN_S6:
70                 __set_bit(reg, &ctx->flags);
71         }
72         return reg;
73 };
74
75 static bool seen_reg(int reg, struct rv_jit_context *ctx)
76 {
77         switch (reg) {
78         case RV_CTX_F_SEEN_CALL:
79         case RV_CTX_F_SEEN_S1:
80         case RV_CTX_F_SEEN_S2:
81         case RV_CTX_F_SEEN_S3:
82         case RV_CTX_F_SEEN_S4:
83         case RV_CTX_F_SEEN_S5:
84         case RV_CTX_F_SEEN_S6:
85                 return test_bit(reg, &ctx->flags);
86         }
87         return false;
88 }
89
90 static void mark_fp(struct rv_jit_context *ctx)
91 {
92         __set_bit(RV_CTX_F_SEEN_S5, &ctx->flags);
93 }
94
95 static void mark_call(struct rv_jit_context *ctx)
96 {
97         __set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
98 }
99
100 static bool seen_call(struct rv_jit_context *ctx)
101 {
102         return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
103 }
104
105 static void mark_tail_call(struct rv_jit_context *ctx)
106 {
107         __set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
108 }
109
110 static bool seen_tail_call(struct rv_jit_context *ctx)
111 {
112         return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
113 }
114
115 static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
116 {
117         mark_tail_call(ctx);
118
119         if (seen_call(ctx)) {
120                 __set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
121                 return RV_REG_S6;
122         }
123         return RV_REG_A6;
124 }
125
126 static bool is_32b_int(s64 val)
127 {
128         return -(1L << 31) <= val && val < (1L << 31);
129 }
130
131 static bool in_auipc_jalr_range(s64 val)
132 {
133         /*
134          * auipc+jalr can reach any signed PC-relative offset in the range
135          * [-2^31 - 2^11, 2^31 - 2^11).
136          */
137         return (-(1L << 31) - (1L << 11)) <= val &&
138                 val < ((1L << 31) - (1L << 11));
139 }
140
141 /* Emit fixed-length instructions for address */
142 static int emit_addr(u8 rd, u64 addr, bool extra_pass, struct rv_jit_context *ctx)
143 {
144         u64 ip = (u64)(ctx->insns + ctx->ninsns);
145         s64 off = addr - ip;
146         s64 upper = (off + (1 << 11)) >> 12;
147         s64 lower = off & 0xfff;
148
149         if (extra_pass && !in_auipc_jalr_range(off)) {
150                 pr_err("bpf-jit: target offset 0x%llx is out of range\n", off);
151                 return -ERANGE;
152         }
153
154         emit(rv_auipc(rd, upper), ctx);
155         emit(rv_addi(rd, rd, lower), ctx);
156         return 0;
157 }
158
159 /* Emit variable-length instructions for 32-bit and 64-bit imm */
160 static void emit_imm(u8 rd, s64 val, struct rv_jit_context *ctx)
161 {
162         /* Note that the immediate from the add is sign-extended,
163          * which means that we need to compensate this by adding 2^12,
164          * when the 12th bit is set. A simpler way of doing this, and
165          * getting rid of the check, is to just add 2**11 before the
166          * shift. The "Loading a 32-Bit constant" example from the
167          * "Computer Organization and Design, RISC-V edition" book by
168          * Patterson/Hennessy highlights this fact.
169          *
170          * This also means that we need to process LSB to MSB.
171          */
172         s64 upper = (val + (1 << 11)) >> 12;
173         /* Sign-extend lower 12 bits to 64 bits since immediates for li, addiw,
174          * and addi are signed and RVC checks will perform signed comparisons.
175          */
176         s64 lower = ((val & 0xfff) << 52) >> 52;
177         int shift;
178
179         if (is_32b_int(val)) {
180                 if (upper)
181                         emit_lui(rd, upper, ctx);
182
183                 if (!upper) {
184                         emit_li(rd, lower, ctx);
185                         return;
186                 }
187
188                 emit_addiw(rd, rd, lower, ctx);
189                 return;
190         }
191
192         shift = __ffs(upper);
193         upper >>= shift;
194         shift += 12;
195
196         emit_imm(rd, upper, ctx);
197
198         emit_slli(rd, rd, shift, ctx);
199         if (lower)
200                 emit_addi(rd, rd, lower, ctx);
201 }
202
203 static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
204 {
205         int stack_adjust = ctx->stack_size, store_offset = stack_adjust - 8;
206
207         if (seen_reg(RV_REG_RA, ctx)) {
208                 emit_ld(RV_REG_RA, store_offset, RV_REG_SP, ctx);
209                 store_offset -= 8;
210         }
211         emit_ld(RV_REG_FP, store_offset, RV_REG_SP, ctx);
212         store_offset -= 8;
213         if (seen_reg(RV_REG_S1, ctx)) {
214                 emit_ld(RV_REG_S1, store_offset, RV_REG_SP, ctx);
215                 store_offset -= 8;
216         }
217         if (seen_reg(RV_REG_S2, ctx)) {
218                 emit_ld(RV_REG_S2, store_offset, RV_REG_SP, ctx);
219                 store_offset -= 8;
220         }
221         if (seen_reg(RV_REG_S3, ctx)) {
222                 emit_ld(RV_REG_S3, store_offset, RV_REG_SP, ctx);
223                 store_offset -= 8;
224         }
225         if (seen_reg(RV_REG_S4, ctx)) {
226                 emit_ld(RV_REG_S4, store_offset, RV_REG_SP, ctx);
227                 store_offset -= 8;
228         }
229         if (seen_reg(RV_REG_S5, ctx)) {
230                 emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
231                 store_offset -= 8;
232         }
233         if (seen_reg(RV_REG_S6, ctx)) {
234                 emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
235                 store_offset -= 8;
236         }
237
238         emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
239         /* Set return value. */
240         if (!is_tail_call)
241                 emit_mv(RV_REG_A0, RV_REG_A5, ctx);
242         emit_jalr(RV_REG_ZERO, is_tail_call ? RV_REG_T3 : RV_REG_RA,
243                   is_tail_call ? 20 : 0, /* skip reserved nops and TCC init */
244                   ctx);
245 }
246
247 static void emit_bcc(u8 cond, u8 rd, u8 rs, int rvoff,
248                      struct rv_jit_context *ctx)
249 {
250         switch (cond) {
251         case BPF_JEQ:
252                 emit(rv_beq(rd, rs, rvoff >> 1), ctx);
253                 return;
254         case BPF_JGT:
255                 emit(rv_bltu(rs, rd, rvoff >> 1), ctx);
256                 return;
257         case BPF_JLT:
258                 emit(rv_bltu(rd, rs, rvoff >> 1), ctx);
259                 return;
260         case BPF_JGE:
261                 emit(rv_bgeu(rd, rs, rvoff >> 1), ctx);
262                 return;
263         case BPF_JLE:
264                 emit(rv_bgeu(rs, rd, rvoff >> 1), ctx);
265                 return;
266         case BPF_JNE:
267                 emit(rv_bne(rd, rs, rvoff >> 1), ctx);
268                 return;
269         case BPF_JSGT:
270                 emit(rv_blt(rs, rd, rvoff >> 1), ctx);
271                 return;
272         case BPF_JSLT:
273                 emit(rv_blt(rd, rs, rvoff >> 1), ctx);
274                 return;
275         case BPF_JSGE:
276                 emit(rv_bge(rd, rs, rvoff >> 1), ctx);
277                 return;
278         case BPF_JSLE:
279                 emit(rv_bge(rs, rd, rvoff >> 1), ctx);
280         }
281 }
282
283 static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
284                         struct rv_jit_context *ctx)
285 {
286         s64 upper, lower;
287
288         if (is_13b_int(rvoff)) {
289                 emit_bcc(cond, rd, rs, rvoff, ctx);
290                 return;
291         }
292
293         /* Adjust for jal */
294         rvoff -= 4;
295
296         /* Transform, e.g.:
297          *   bne rd,rs,foo
298          * to
299          *   beq rd,rs,<.L1>
300          *   (auipc foo)
301          *   jal(r) foo
302          * .L1
303          */
304         cond = invert_bpf_cond(cond);
305         if (is_21b_int(rvoff)) {
306                 emit_bcc(cond, rd, rs, 8, ctx);
307                 emit(rv_jal(RV_REG_ZERO, rvoff >> 1), ctx);
308                 return;
309         }
310
311         /* 32b No need for an additional rvoff adjustment, since we
312          * get that from the auipc at PC', where PC = PC' + 4.
313          */
314         upper = (rvoff + (1 << 11)) >> 12;
315         lower = rvoff & 0xfff;
316
317         emit_bcc(cond, rd, rs, 12, ctx);
318         emit(rv_auipc(RV_REG_T1, upper), ctx);
319         emit(rv_jalr(RV_REG_ZERO, RV_REG_T1, lower), ctx);
320 }
321
322 static void emit_zext_32(u8 reg, struct rv_jit_context *ctx)
323 {
324         emit_slli(reg, reg, 32, ctx);
325         emit_srli(reg, reg, 32, ctx);
326 }
327
328 static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
329 {
330         int tc_ninsn, off, start_insn = ctx->ninsns;
331         u8 tcc = rv_tail_call_reg(ctx);
332
333         /* a0: &ctx
334          * a1: &array
335          * a2: index
336          *
337          * if (index >= array->map.max_entries)
338          *      goto out;
339          */
340         tc_ninsn = insn ? ctx->offset[insn] - ctx->offset[insn - 1] :
341                    ctx->offset[0];
342         emit_zext_32(RV_REG_A2, ctx);
343
344         off = offsetof(struct bpf_array, map.max_entries);
345         if (is_12b_check(off, insn))
346                 return -1;
347         emit(rv_lwu(RV_REG_T1, off, RV_REG_A1), ctx);
348         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
349         emit_branch(BPF_JGE, RV_REG_A2, RV_REG_T1, off, ctx);
350
351         /* if (--TCC < 0)
352          *     goto out;
353          */
354         emit_addi(RV_REG_TCC, tcc, -1, ctx);
355         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
356         emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
357
358         /* prog = array->ptrs[index];
359          * if (!prog)
360          *     goto out;
361          */
362         emit_slli(RV_REG_T2, RV_REG_A2, 3, ctx);
363         emit_add(RV_REG_T2, RV_REG_T2, RV_REG_A1, ctx);
364         off = offsetof(struct bpf_array, ptrs);
365         if (is_12b_check(off, insn))
366                 return -1;
367         emit_ld(RV_REG_T2, off, RV_REG_T2, ctx);
368         off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
369         emit_branch(BPF_JEQ, RV_REG_T2, RV_REG_ZERO, off, ctx);
370
371         /* goto *(prog->bpf_func + 4); */
372         off = offsetof(struct bpf_prog, bpf_func);
373         if (is_12b_check(off, insn))
374                 return -1;
375         emit_ld(RV_REG_T3, off, RV_REG_T2, ctx);
376         __build_epilogue(true, ctx);
377         return 0;
378 }
379
380 static void init_regs(u8 *rd, u8 *rs, const struct bpf_insn *insn,
381                       struct rv_jit_context *ctx)
382 {
383         u8 code = insn->code;
384
385         switch (code) {
386         case BPF_JMP | BPF_JA:
387         case BPF_JMP | BPF_CALL:
388         case BPF_JMP | BPF_EXIT:
389         case BPF_JMP | BPF_TAIL_CALL:
390                 break;
391         default:
392                 *rd = bpf_to_rv_reg(insn->dst_reg, ctx);
393         }
394
395         if (code & (BPF_ALU | BPF_X) || code & (BPF_ALU64 | BPF_X) ||
396             code & (BPF_JMP | BPF_X) || code & (BPF_JMP32 | BPF_X) ||
397             code & BPF_LDX || code & BPF_STX)
398                 *rs = bpf_to_rv_reg(insn->src_reg, ctx);
399 }
400
401 static void emit_zext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
402 {
403         emit_mv(RV_REG_T2, *rd, ctx);
404         emit_zext_32(RV_REG_T2, ctx);
405         emit_mv(RV_REG_T1, *rs, ctx);
406         emit_zext_32(RV_REG_T1, ctx);
407         *rd = RV_REG_T2;
408         *rs = RV_REG_T1;
409 }
410
411 static void emit_sext_32_rd_rs(u8 *rd, u8 *rs, struct rv_jit_context *ctx)
412 {
413         emit_addiw(RV_REG_T2, *rd, 0, ctx);
414         emit_addiw(RV_REG_T1, *rs, 0, ctx);
415         *rd = RV_REG_T2;
416         *rs = RV_REG_T1;
417 }
418
419 static void emit_zext_32_rd_t1(u8 *rd, struct rv_jit_context *ctx)
420 {
421         emit_mv(RV_REG_T2, *rd, ctx);
422         emit_zext_32(RV_REG_T2, ctx);
423         emit_zext_32(RV_REG_T1, ctx);
424         *rd = RV_REG_T2;
425 }
426
427 static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
428 {
429         emit_addiw(RV_REG_T2, *rd, 0, ctx);
430         *rd = RV_REG_T2;
431 }
432
433 static int emit_jump_and_link(u8 rd, s64 rvoff, bool fixed_addr,
434                               struct rv_jit_context *ctx)
435 {
436         s64 upper, lower;
437
438         if (rvoff && fixed_addr && is_21b_int(rvoff)) {
439                 emit(rv_jal(rd, rvoff >> 1), ctx);
440                 return 0;
441         } else if (in_auipc_jalr_range(rvoff)) {
442                 upper = (rvoff + (1 << 11)) >> 12;
443                 lower = rvoff & 0xfff;
444                 emit(rv_auipc(RV_REG_T1, upper), ctx);
445                 emit(rv_jalr(rd, RV_REG_T1, lower), ctx);
446                 return 0;
447         }
448
449         pr_err("bpf-jit: target offset 0x%llx is out of range\n", rvoff);
450         return -ERANGE;
451 }
452
453 static bool is_signed_bpf_cond(u8 cond)
454 {
455         return cond == BPF_JSGT || cond == BPF_JSLT ||
456                 cond == BPF_JSGE || cond == BPF_JSLE;
457 }
458
459 static int emit_call(u64 addr, bool fixed_addr, struct rv_jit_context *ctx)
460 {
461         s64 off = 0;
462         u64 ip;
463
464         if (addr && ctx->insns) {
465                 ip = (u64)(long)(ctx->insns + ctx->ninsns);
466                 off = addr - ip;
467         }
468
469         return emit_jump_and_link(RV_REG_RA, off, fixed_addr, ctx);
470 }
471
472 static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
473                         struct rv_jit_context *ctx)
474 {
475         u8 r0;
476         int jmp_offset;
477
478         if (off) {
479                 if (is_12b_int(off)) {
480                         emit_addi(RV_REG_T1, rd, off, ctx);
481                 } else {
482                         emit_imm(RV_REG_T1, off, ctx);
483                         emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
484                 }
485                 rd = RV_REG_T1;
486         }
487
488         switch (imm) {
489         /* lock *(u32/u64 *)(dst_reg + off16) <op>= src_reg */
490         case BPF_ADD:
491                 emit(is64 ? rv_amoadd_d(RV_REG_ZERO, rs, rd, 0, 0) :
492                      rv_amoadd_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
493                 break;
494         case BPF_AND:
495                 emit(is64 ? rv_amoand_d(RV_REG_ZERO, rs, rd, 0, 0) :
496                      rv_amoand_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
497                 break;
498         case BPF_OR:
499                 emit(is64 ? rv_amoor_d(RV_REG_ZERO, rs, rd, 0, 0) :
500                      rv_amoor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
501                 break;
502         case BPF_XOR:
503                 emit(is64 ? rv_amoxor_d(RV_REG_ZERO, rs, rd, 0, 0) :
504                      rv_amoxor_w(RV_REG_ZERO, rs, rd, 0, 0), ctx);
505                 break;
506         /* src_reg = atomic_fetch_<op>(dst_reg + off16, src_reg) */
507         case BPF_ADD | BPF_FETCH:
508                 emit(is64 ? rv_amoadd_d(rs, rs, rd, 0, 0) :
509                      rv_amoadd_w(rs, rs, rd, 0, 0), ctx);
510                 if (!is64)
511                         emit_zext_32(rs, ctx);
512                 break;
513         case BPF_AND | BPF_FETCH:
514                 emit(is64 ? rv_amoand_d(rs, rs, rd, 0, 0) :
515                      rv_amoand_w(rs, rs, rd, 0, 0), ctx);
516                 if (!is64)
517                         emit_zext_32(rs, ctx);
518                 break;
519         case BPF_OR | BPF_FETCH:
520                 emit(is64 ? rv_amoor_d(rs, rs, rd, 0, 0) :
521                      rv_amoor_w(rs, rs, rd, 0, 0), ctx);
522                 if (!is64)
523                         emit_zext_32(rs, ctx);
524                 break;
525         case BPF_XOR | BPF_FETCH:
526                 emit(is64 ? rv_amoxor_d(rs, rs, rd, 0, 0) :
527                      rv_amoxor_w(rs, rs, rd, 0, 0), ctx);
528                 if (!is64)
529                         emit_zext_32(rs, ctx);
530                 break;
531         /* src_reg = atomic_xchg(dst_reg + off16, src_reg); */
532         case BPF_XCHG:
533                 emit(is64 ? rv_amoswap_d(rs, rs, rd, 0, 0) :
534                      rv_amoswap_w(rs, rs, rd, 0, 0), ctx);
535                 if (!is64)
536                         emit_zext_32(rs, ctx);
537                 break;
538         /* r0 = atomic_cmpxchg(dst_reg + off16, r0, src_reg); */
539         case BPF_CMPXCHG:
540                 r0 = bpf_to_rv_reg(BPF_REG_0, ctx);
541                 emit(is64 ? rv_addi(RV_REG_T2, r0, 0) :
542                      rv_addiw(RV_REG_T2, r0, 0), ctx);
543                 emit(is64 ? rv_lr_d(r0, 0, rd, 0, 0) :
544                      rv_lr_w(r0, 0, rd, 0, 0), ctx);
545                 jmp_offset = ninsns_rvoff(8);
546                 emit(rv_bne(RV_REG_T2, r0, jmp_offset >> 1), ctx);
547                 emit(is64 ? rv_sc_d(RV_REG_T3, rs, rd, 0, 0) :
548                      rv_sc_w(RV_REG_T3, rs, rd, 0, 0), ctx);
549                 jmp_offset = ninsns_rvoff(-6);
550                 emit(rv_bne(RV_REG_T3, 0, jmp_offset >> 1), ctx);
551                 emit(rv_fence(0x3, 0x3), ctx);
552                 break;
553         }
554 }
555
556 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
557 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
558
559 bool ex_handler_bpf(const struct exception_table_entry *ex,
560                     struct pt_regs *regs)
561 {
562         off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
563         int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
564
565         *(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
566         regs->epc = (unsigned long)&ex->fixup - offset;
567
568         return true;
569 }
570
571 /* For accesses to BTF pointers, add an entry to the exception table */
572 static int add_exception_handler(const struct bpf_insn *insn,
573                                  struct rv_jit_context *ctx,
574                                  int dst_reg, int insn_len)
575 {
576         struct exception_table_entry *ex;
577         unsigned long pc;
578         off_t offset;
579
580         if (!ctx->insns || !ctx->prog->aux->extable || BPF_MODE(insn->code) != BPF_PROBE_MEM)
581                 return 0;
582
583         if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
584                 return -EINVAL;
585
586         if (WARN_ON_ONCE(insn_len > ctx->ninsns))
587                 return -EINVAL;
588
589         if (WARN_ON_ONCE(!rvc_enabled() && insn_len == 1))
590                 return -EINVAL;
591
592         ex = &ctx->prog->aux->extable[ctx->nexentries];
593         pc = (unsigned long)&ctx->insns[ctx->ninsns - insn_len];
594
595         offset = pc - (long)&ex->insn;
596         if (WARN_ON_ONCE(offset >= 0 || offset < INT_MIN))
597                 return -ERANGE;
598         ex->insn = offset;
599
600         /*
601          * Since the extable follows the program, the fixup offset is always
602          * negative and limited to BPF_JIT_REGION_SIZE. Store a positive value
603          * to keep things simple, and put the destination register in the upper
604          * bits. We don't need to worry about buildtime or runtime sort
605          * modifying the upper bits because the table is already sorted, and
606          * isn't part of the main exception table.
607          */
608         offset = (long)&ex->fixup - (pc + insn_len * sizeof(u16));
609         if (!FIELD_FIT(BPF_FIXUP_OFFSET_MASK, offset))
610                 return -ERANGE;
611
612         ex->fixup = FIELD_PREP(BPF_FIXUP_OFFSET_MASK, offset) |
613                 FIELD_PREP(BPF_FIXUP_REG_MASK, dst_reg);
614         ex->type = EX_TYPE_BPF;
615
616         ctx->nexentries++;
617         return 0;
618 }
619
620 static int gen_call_or_nops(void *target, void *ip, u32 *insns)
621 {
622         s64 rvoff;
623         int i, ret;
624         struct rv_jit_context ctx;
625
626         ctx.ninsns = 0;
627         ctx.insns = (u16 *)insns;
628
629         if (!target) {
630                 for (i = 0; i < 4; i++)
631                         emit(rv_nop(), &ctx);
632                 return 0;
633         }
634
635         rvoff = (s64)(target - (ip + 4));
636         emit(rv_sd(RV_REG_SP, -8, RV_REG_RA), &ctx);
637         ret = emit_jump_and_link(RV_REG_RA, rvoff, false, &ctx);
638         if (ret)
639                 return ret;
640         emit(rv_ld(RV_REG_RA, -8, RV_REG_SP), &ctx);
641
642         return 0;
643 }
644
645 static int gen_jump_or_nops(void *target, void *ip, u32 *insns)
646 {
647         s64 rvoff;
648         struct rv_jit_context ctx;
649
650         ctx.ninsns = 0;
651         ctx.insns = (u16 *)insns;
652
653         if (!target) {
654                 emit(rv_nop(), &ctx);
655                 emit(rv_nop(), &ctx);
656                 return 0;
657         }
658
659         rvoff = (s64)(target - ip);
660         return emit_jump_and_link(RV_REG_ZERO, rvoff, false, &ctx);
661 }
662
663 int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
664                        void *old_addr, void *new_addr)
665 {
666         u32 old_insns[4], new_insns[4];
667         bool is_call = poke_type == BPF_MOD_CALL;
668         int (*gen_insns)(void *target, void *ip, u32 *insns);
669         int ninsns = is_call ? 4 : 2;
670         int ret;
671
672         if (!is_bpf_text_address((unsigned long)ip))
673                 return -ENOTSUPP;
674
675         gen_insns = is_call ? gen_call_or_nops : gen_jump_or_nops;
676
677         ret = gen_insns(old_addr, ip, old_insns);
678         if (ret)
679                 return ret;
680
681         if (memcmp(ip, old_insns, ninsns * 4))
682                 return -EFAULT;
683
684         ret = gen_insns(new_addr, ip, new_insns);
685         if (ret)
686                 return ret;
687
688         cpus_read_lock();
689         mutex_lock(&text_mutex);
690         if (memcmp(ip, new_insns, ninsns * 4))
691                 ret = patch_text(ip, new_insns, ninsns);
692         mutex_unlock(&text_mutex);
693         cpus_read_unlock();
694
695         return ret;
696 }
697
698 static void store_args(int nregs, int args_off, struct rv_jit_context *ctx)
699 {
700         int i;
701
702         for (i = 0; i < nregs; i++) {
703                 emit_sd(RV_REG_FP, -args_off, RV_REG_A0 + i, ctx);
704                 args_off -= 8;
705         }
706 }
707
708 static void restore_args(int nregs, int args_off, struct rv_jit_context *ctx)
709 {
710         int i;
711
712         for (i = 0; i < nregs; i++) {
713                 emit_ld(RV_REG_A0 + i, -args_off, RV_REG_FP, ctx);
714                 args_off -= 8;
715         }
716 }
717
718 static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_off,
719                            int run_ctx_off, bool save_ret, struct rv_jit_context *ctx)
720 {
721         int ret, branch_off;
722         struct bpf_prog *p = l->link.prog;
723         int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
724
725         if (l->cookie) {
726                 emit_imm(RV_REG_T1, l->cookie, ctx);
727                 emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_T1, ctx);
728         } else {
729                 emit_sd(RV_REG_FP, -run_ctx_off + cookie_off, RV_REG_ZERO, ctx);
730         }
731
732         /* arg1: prog */
733         emit_imm(RV_REG_A0, (const s64)p, ctx);
734         /* arg2: &run_ctx */
735         emit_addi(RV_REG_A1, RV_REG_FP, -run_ctx_off, ctx);
736         ret = emit_call((const u64)bpf_trampoline_enter(p), true, ctx);
737         if (ret)
738                 return ret;
739
740         /* if (__bpf_prog_enter(prog) == 0)
741          *      goto skip_exec_of_prog;
742          */
743         branch_off = ctx->ninsns;
744         /* nop reserved for conditional jump */
745         emit(rv_nop(), ctx);
746
747         /* store prog start time */
748         emit_mv(RV_REG_S1, RV_REG_A0, ctx);
749
750         /* arg1: &args_off */
751         emit_addi(RV_REG_A0, RV_REG_FP, -args_off, ctx);
752         if (!p->jited)
753                 /* arg2: progs[i]->insnsi for interpreter */
754                 emit_imm(RV_REG_A1, (const s64)p->insnsi, ctx);
755         ret = emit_call((const u64)p->bpf_func, true, ctx);
756         if (ret)
757                 return ret;
758
759         if (save_ret)
760                 emit_sd(RV_REG_FP, -retval_off, regmap[BPF_REG_0], ctx);
761
762         /* update branch with beqz */
763         if (ctx->insns) {
764                 int offset = ninsns_rvoff(ctx->ninsns - branch_off);
765                 u32 insn = rv_beq(RV_REG_A0, RV_REG_ZERO, offset >> 1);
766                 *(u32 *)(ctx->insns + branch_off) = insn;
767         }
768
769         /* arg1: prog */
770         emit_imm(RV_REG_A0, (const s64)p, ctx);
771         /* arg2: prog start time */
772         emit_mv(RV_REG_A1, RV_REG_S1, ctx);
773         /* arg3: &run_ctx */
774         emit_addi(RV_REG_A2, RV_REG_FP, -run_ctx_off, ctx);
775         ret = emit_call((const u64)bpf_trampoline_exit(p), true, ctx);
776
777         return ret;
778 }
779
780 static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
781                                          const struct btf_func_model *m,
782                                          struct bpf_tramp_links *tlinks,
783                                          void *func_addr, u32 flags,
784                                          struct rv_jit_context *ctx)
785 {
786         int i, ret, offset;
787         int *branches_off = NULL;
788         int stack_size = 0, nregs = m->nr_args;
789         int retaddr_off, fp_off, retval_off, args_off;
790         int nregs_off, ip_off, run_ctx_off, sreg_off;
791         struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
792         struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
793         struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
794         void *orig_call = func_addr;
795         bool save_ret;
796         u32 insn;
797
798         /* Generated trampoline stack layout:
799          *
800          * FP - 8           [ RA of parent func ] return address of parent
801          *                                        function
802          * FP - retaddr_off [ RA of traced func ] return address of traced
803          *                                        function
804          * FP - fp_off      [ FP of parent func ]
805          *
806          * FP - retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
807          *                                        BPF_TRAMP_F_RET_FENTRY_RET
808          *                  [ argN              ]
809          *                  [ ...               ]
810          * FP - args_off    [ arg1              ]
811          *
812          * FP - nregs_off   [ regs count        ]
813          *
814          * FP - ip_off      [ traced func       ] BPF_TRAMP_F_IP_ARG
815          *
816          * FP - run_ctx_off [ bpf_tramp_run_ctx ]
817          *
818          * FP - sreg_off    [ callee saved reg  ]
819          *
820          *                  [ pads              ] pads for 16 bytes alignment
821          */
822
823         if (flags & (BPF_TRAMP_F_ORIG_STACK | BPF_TRAMP_F_SHARE_IPMODIFY))
824                 return -ENOTSUPP;
825
826         /* extra regiters for struct arguments */
827         for (i = 0; i < m->nr_args; i++)
828                 if (m->arg_flags[i] & BTF_FMODEL_STRUCT_ARG)
829                         nregs += round_up(m->arg_size[i], 8) / 8 - 1;
830
831         /* 8 arguments passed by registers */
832         if (nregs > 8)
833                 return -ENOTSUPP;
834
835         /* room for parent function return address */
836         stack_size += 8;
837
838         stack_size += 8;
839         retaddr_off = stack_size;
840
841         stack_size += 8;
842         fp_off = stack_size;
843
844         save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
845         if (save_ret) {
846                 stack_size += 8;
847                 retval_off = stack_size;
848         }
849
850         stack_size += nregs * 8;
851         args_off = stack_size;
852
853         stack_size += 8;
854         nregs_off = stack_size;
855
856         if (flags & BPF_TRAMP_F_IP_ARG) {
857                 stack_size += 8;
858                 ip_off = stack_size;
859         }
860
861         stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
862         run_ctx_off = stack_size;
863
864         stack_size += 8;
865         sreg_off = stack_size;
866
867         stack_size = round_up(stack_size, 16);
868
869         emit_addi(RV_REG_SP, RV_REG_SP, -stack_size, ctx);
870
871         emit_sd(RV_REG_SP, stack_size - retaddr_off, RV_REG_RA, ctx);
872         emit_sd(RV_REG_SP, stack_size - fp_off, RV_REG_FP, ctx);
873
874         emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
875
876         /* callee saved register S1 to pass start time */
877         emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
878
879         /* store ip address of the traced function */
880         if (flags & BPF_TRAMP_F_IP_ARG) {
881                 emit_imm(RV_REG_T1, (const s64)func_addr, ctx);
882                 emit_sd(RV_REG_FP, -ip_off, RV_REG_T1, ctx);
883         }
884
885         emit_li(RV_REG_T1, nregs, ctx);
886         emit_sd(RV_REG_FP, -nregs_off, RV_REG_T1, ctx);
887
888         store_args(nregs, args_off, ctx);
889
890         /* skip to actual body of traced function */
891         if (flags & BPF_TRAMP_F_SKIP_FRAME)
892                 orig_call += 16;
893
894         if (flags & BPF_TRAMP_F_CALL_ORIG) {
895                 emit_imm(RV_REG_A0, (const s64)im, ctx);
896                 ret = emit_call((const u64)__bpf_tramp_enter, true, ctx);
897                 if (ret)
898                         return ret;
899         }
900
901         for (i = 0; i < fentry->nr_links; i++) {
902                 ret = invoke_bpf_prog(fentry->links[i], args_off, retval_off, run_ctx_off,
903                                       flags & BPF_TRAMP_F_RET_FENTRY_RET, ctx);
904                 if (ret)
905                         return ret;
906         }
907
908         if (fmod_ret->nr_links) {
909                 branches_off = kcalloc(fmod_ret->nr_links, sizeof(int), GFP_KERNEL);
910                 if (!branches_off)
911                         return -ENOMEM;
912
913                 /* cleanup to avoid garbage return value confusion */
914                 emit_sd(RV_REG_FP, -retval_off, RV_REG_ZERO, ctx);
915                 for (i = 0; i < fmod_ret->nr_links; i++) {
916                         ret = invoke_bpf_prog(fmod_ret->links[i], args_off, retval_off,
917                                               run_ctx_off, true, ctx);
918                         if (ret)
919                                 goto out;
920                         emit_ld(RV_REG_T1, -retval_off, RV_REG_FP, ctx);
921                         branches_off[i] = ctx->ninsns;
922                         /* nop reserved for conditional jump */
923                         emit(rv_nop(), ctx);
924                 }
925         }
926
927         if (flags & BPF_TRAMP_F_CALL_ORIG) {
928                 restore_args(nregs, args_off, ctx);
929                 ret = emit_call((const u64)orig_call, true, ctx);
930                 if (ret)
931                         goto out;
932                 emit_sd(RV_REG_FP, -retval_off, RV_REG_A0, ctx);
933                 im->ip_after_call = ctx->insns + ctx->ninsns;
934                 /* 2 nops reserved for auipc+jalr pair */
935                 emit(rv_nop(), ctx);
936                 emit(rv_nop(), ctx);
937         }
938
939         /* update branches saved in invoke_bpf_mod_ret with bnez */
940         for (i = 0; ctx->insns && i < fmod_ret->nr_links; i++) {
941                 offset = ninsns_rvoff(ctx->ninsns - branches_off[i]);
942                 insn = rv_bne(RV_REG_T1, RV_REG_ZERO, offset >> 1);
943                 *(u32 *)(ctx->insns + branches_off[i]) = insn;
944         }
945
946         for (i = 0; i < fexit->nr_links; i++) {
947                 ret = invoke_bpf_prog(fexit->links[i], args_off, retval_off,
948                                       run_ctx_off, false, ctx);
949                 if (ret)
950                         goto out;
951         }
952
953         if (flags & BPF_TRAMP_F_CALL_ORIG) {
954                 im->ip_epilogue = ctx->insns + ctx->ninsns;
955                 emit_imm(RV_REG_A0, (const s64)im, ctx);
956                 ret = emit_call((const u64)__bpf_tramp_exit, true, ctx);
957                 if (ret)
958                         goto out;
959         }
960
961         if (flags & BPF_TRAMP_F_RESTORE_REGS)
962                 restore_args(nregs, args_off, ctx);
963
964         if (save_ret)
965                 emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx);
966
967         emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);
968
969         if (flags & BPF_TRAMP_F_SKIP_FRAME)
970                 /* return address of parent function */
971                 emit_ld(RV_REG_RA, stack_size - 8, RV_REG_SP, ctx);
972         else
973                 /* return address of traced function */
974                 emit_ld(RV_REG_RA, stack_size - retaddr_off, RV_REG_SP, ctx);
975
976         emit_ld(RV_REG_FP, stack_size - fp_off, RV_REG_SP, ctx);
977         emit_addi(RV_REG_SP, RV_REG_SP, stack_size, ctx);
978
979         emit_jalr(RV_REG_ZERO, RV_REG_RA, 0, ctx);
980
981         ret = ctx->ninsns;
982 out:
983         kfree(branches_off);
984         return ret;
985 }
986
987 int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
988                                 void *image_end, const struct btf_func_model *m,
989                                 u32 flags, struct bpf_tramp_links *tlinks,
990                                 void *func_addr)
991 {
992         int ret;
993         struct rv_jit_context ctx;
994
995         ctx.ninsns = 0;
996         ctx.insns = NULL;
997         ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
998         if (ret < 0)
999                 return ret;
1000
1001         if (ninsns_rvoff(ret) > (long)image_end - (long)image)
1002                 return -EFBIG;
1003
1004         ctx.ninsns = 0;
1005         ctx.insns = image;
1006         ret = __arch_prepare_bpf_trampoline(im, m, tlinks, func_addr, flags, &ctx);
1007         if (ret < 0)
1008                 return ret;
1009
1010         bpf_flush_icache(ctx.insns, ctx.insns + ctx.ninsns);
1011
1012         return ninsns_rvoff(ret);
1013 }
1014
1015 int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
1016                       bool extra_pass)
1017 {
1018         bool is64 = BPF_CLASS(insn->code) == BPF_ALU64 ||
1019                     BPF_CLASS(insn->code) == BPF_JMP;
1020         int s, e, rvoff, ret, i = insn - ctx->prog->insnsi;
1021         struct bpf_prog_aux *aux = ctx->prog->aux;
1022         u8 rd = -1, rs = -1, code = insn->code;
1023         s16 off = insn->off;
1024         s32 imm = insn->imm;
1025
1026         init_regs(&rd, &rs, insn, ctx);
1027
1028         switch (code) {
1029         /* dst = src */
1030         case BPF_ALU | BPF_MOV | BPF_X:
1031         case BPF_ALU64 | BPF_MOV | BPF_X:
1032                 if (imm == 1) {
1033                         /* Special mov32 for zext */
1034                         emit_zext_32(rd, ctx);
1035                         break;
1036                 }
1037                 emit_mv(rd, rs, ctx);
1038                 if (!is64 && !aux->verifier_zext)
1039                         emit_zext_32(rd, ctx);
1040                 break;
1041
1042         /* dst = dst OP src */
1043         case BPF_ALU | BPF_ADD | BPF_X:
1044         case BPF_ALU64 | BPF_ADD | BPF_X:
1045                 emit_add(rd, rd, rs, ctx);
1046                 if (!is64 && !aux->verifier_zext)
1047                         emit_zext_32(rd, ctx);
1048                 break;
1049         case BPF_ALU | BPF_SUB | BPF_X:
1050         case BPF_ALU64 | BPF_SUB | BPF_X:
1051                 if (is64)
1052                         emit_sub(rd, rd, rs, ctx);
1053                 else
1054                         emit_subw(rd, rd, rs, ctx);
1055
1056                 if (!is64 && !aux->verifier_zext)
1057                         emit_zext_32(rd, ctx);
1058                 break;
1059         case BPF_ALU | BPF_AND | BPF_X:
1060         case BPF_ALU64 | BPF_AND | BPF_X:
1061                 emit_and(rd, rd, rs, ctx);
1062                 if (!is64 && !aux->verifier_zext)
1063                         emit_zext_32(rd, ctx);
1064                 break;
1065         case BPF_ALU | BPF_OR | BPF_X:
1066         case BPF_ALU64 | BPF_OR | BPF_X:
1067                 emit_or(rd, rd, rs, ctx);
1068                 if (!is64 && !aux->verifier_zext)
1069                         emit_zext_32(rd, ctx);
1070                 break;
1071         case BPF_ALU | BPF_XOR | BPF_X:
1072         case BPF_ALU64 | BPF_XOR | BPF_X:
1073                 emit_xor(rd, rd, rs, ctx);
1074                 if (!is64 && !aux->verifier_zext)
1075                         emit_zext_32(rd, ctx);
1076                 break;
1077         case BPF_ALU | BPF_MUL | BPF_X:
1078         case BPF_ALU64 | BPF_MUL | BPF_X:
1079                 emit(is64 ? rv_mul(rd, rd, rs) : rv_mulw(rd, rd, rs), ctx);
1080                 if (!is64 && !aux->verifier_zext)
1081                         emit_zext_32(rd, ctx);
1082                 break;
1083         case BPF_ALU | BPF_DIV | BPF_X:
1084         case BPF_ALU64 | BPF_DIV | BPF_X:
1085                 emit(is64 ? rv_divu(rd, rd, rs) : rv_divuw(rd, rd, rs), ctx);
1086                 if (!is64 && !aux->verifier_zext)
1087                         emit_zext_32(rd, ctx);
1088                 break;
1089         case BPF_ALU | BPF_MOD | BPF_X:
1090         case BPF_ALU64 | BPF_MOD | BPF_X:
1091                 emit(is64 ? rv_remu(rd, rd, rs) : rv_remuw(rd, rd, rs), ctx);
1092                 if (!is64 && !aux->verifier_zext)
1093                         emit_zext_32(rd, ctx);
1094                 break;
1095         case BPF_ALU | BPF_LSH | BPF_X:
1096         case BPF_ALU64 | BPF_LSH | BPF_X:
1097                 emit(is64 ? rv_sll(rd, rd, rs) : rv_sllw(rd, rd, rs), ctx);
1098                 if (!is64 && !aux->verifier_zext)
1099                         emit_zext_32(rd, ctx);
1100                 break;
1101         case BPF_ALU | BPF_RSH | BPF_X:
1102         case BPF_ALU64 | BPF_RSH | BPF_X:
1103                 emit(is64 ? rv_srl(rd, rd, rs) : rv_srlw(rd, rd, rs), ctx);
1104                 if (!is64 && !aux->verifier_zext)
1105                         emit_zext_32(rd, ctx);
1106                 break;
1107         case BPF_ALU | BPF_ARSH | BPF_X:
1108         case BPF_ALU64 | BPF_ARSH | BPF_X:
1109                 emit(is64 ? rv_sra(rd, rd, rs) : rv_sraw(rd, rd, rs), ctx);
1110                 if (!is64 && !aux->verifier_zext)
1111                         emit_zext_32(rd, ctx);
1112                 break;
1113
1114         /* dst = -dst */
1115         case BPF_ALU | BPF_NEG:
1116         case BPF_ALU64 | BPF_NEG:
1117                 emit_sub(rd, RV_REG_ZERO, rd, ctx);
1118                 if (!is64 && !aux->verifier_zext)
1119                         emit_zext_32(rd, ctx);
1120                 break;
1121
1122         /* dst = BSWAP##imm(dst) */
1123         case BPF_ALU | BPF_END | BPF_FROM_LE:
1124                 switch (imm) {
1125                 case 16:
1126                         emit_slli(rd, rd, 48, ctx);
1127                         emit_srli(rd, rd, 48, ctx);
1128                         break;
1129                 case 32:
1130                         if (!aux->verifier_zext)
1131                                 emit_zext_32(rd, ctx);
1132                         break;
1133                 case 64:
1134                         /* Do nothing */
1135                         break;
1136                 }
1137                 break;
1138
1139         case BPF_ALU | BPF_END | BPF_FROM_BE:
1140                 emit_li(RV_REG_T2, 0, ctx);
1141
1142                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1143                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1144                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1145                 emit_srli(rd, rd, 8, ctx);
1146                 if (imm == 16)
1147                         goto out_be;
1148
1149                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1150                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1151                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1152                 emit_srli(rd, rd, 8, ctx);
1153
1154                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1155                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1156                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1157                 emit_srli(rd, rd, 8, ctx);
1158                 if (imm == 32)
1159                         goto out_be;
1160
1161                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1162                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1163                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1164                 emit_srli(rd, rd, 8, ctx);
1165
1166                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1167                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1168                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1169                 emit_srli(rd, rd, 8, ctx);
1170
1171                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1172                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1173                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1174                 emit_srli(rd, rd, 8, ctx);
1175
1176                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1177                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1178                 emit_slli(RV_REG_T2, RV_REG_T2, 8, ctx);
1179                 emit_srli(rd, rd, 8, ctx);
1180 out_be:
1181                 emit_andi(RV_REG_T1, rd, 0xff, ctx);
1182                 emit_add(RV_REG_T2, RV_REG_T2, RV_REG_T1, ctx);
1183
1184                 emit_mv(rd, RV_REG_T2, ctx);
1185                 break;
1186
1187         /* dst = imm */
1188         case BPF_ALU | BPF_MOV | BPF_K:
1189         case BPF_ALU64 | BPF_MOV | BPF_K:
1190                 emit_imm(rd, imm, ctx);
1191                 if (!is64 && !aux->verifier_zext)
1192                         emit_zext_32(rd, ctx);
1193                 break;
1194
1195         /* dst = dst OP imm */
1196         case BPF_ALU | BPF_ADD | BPF_K:
1197         case BPF_ALU64 | BPF_ADD | BPF_K:
1198                 if (is_12b_int(imm)) {
1199                         emit_addi(rd, rd, imm, ctx);
1200                 } else {
1201                         emit_imm(RV_REG_T1, imm, ctx);
1202                         emit_add(rd, rd, RV_REG_T1, ctx);
1203                 }
1204                 if (!is64 && !aux->verifier_zext)
1205                         emit_zext_32(rd, ctx);
1206                 break;
1207         case BPF_ALU | BPF_SUB | BPF_K:
1208         case BPF_ALU64 | BPF_SUB | BPF_K:
1209                 if (is_12b_int(-imm)) {
1210                         emit_addi(rd, rd, -imm, ctx);
1211                 } else {
1212                         emit_imm(RV_REG_T1, imm, ctx);
1213                         emit_sub(rd, rd, RV_REG_T1, ctx);
1214                 }
1215                 if (!is64 && !aux->verifier_zext)
1216                         emit_zext_32(rd, ctx);
1217                 break;
1218         case BPF_ALU | BPF_AND | BPF_K:
1219         case BPF_ALU64 | BPF_AND | BPF_K:
1220                 if (is_12b_int(imm)) {
1221                         emit_andi(rd, rd, imm, ctx);
1222                 } else {
1223                         emit_imm(RV_REG_T1, imm, ctx);
1224                         emit_and(rd, rd, RV_REG_T1, ctx);
1225                 }
1226                 if (!is64 && !aux->verifier_zext)
1227                         emit_zext_32(rd, ctx);
1228                 break;
1229         case BPF_ALU | BPF_OR | BPF_K:
1230         case BPF_ALU64 | BPF_OR | BPF_K:
1231                 if (is_12b_int(imm)) {
1232                         emit(rv_ori(rd, rd, imm), ctx);
1233                 } else {
1234                         emit_imm(RV_REG_T1, imm, ctx);
1235                         emit_or(rd, rd, RV_REG_T1, ctx);
1236                 }
1237                 if (!is64 && !aux->verifier_zext)
1238                         emit_zext_32(rd, ctx);
1239                 break;
1240         case BPF_ALU | BPF_XOR | BPF_K:
1241         case BPF_ALU64 | BPF_XOR | BPF_K:
1242                 if (is_12b_int(imm)) {
1243                         emit(rv_xori(rd, rd, imm), ctx);
1244                 } else {
1245                         emit_imm(RV_REG_T1, imm, ctx);
1246                         emit_xor(rd, rd, RV_REG_T1, ctx);
1247                 }
1248                 if (!is64 && !aux->verifier_zext)
1249                         emit_zext_32(rd, ctx);
1250                 break;
1251         case BPF_ALU | BPF_MUL | BPF_K:
1252         case BPF_ALU64 | BPF_MUL | BPF_K:
1253                 emit_imm(RV_REG_T1, imm, ctx);
1254                 emit(is64 ? rv_mul(rd, rd, RV_REG_T1) :
1255                      rv_mulw(rd, rd, RV_REG_T1), ctx);
1256                 if (!is64 && !aux->verifier_zext)
1257                         emit_zext_32(rd, ctx);
1258                 break;
1259         case BPF_ALU | BPF_DIV | BPF_K:
1260         case BPF_ALU64 | BPF_DIV | BPF_K:
1261                 emit_imm(RV_REG_T1, imm, ctx);
1262                 emit(is64 ? rv_divu(rd, rd, RV_REG_T1) :
1263                      rv_divuw(rd, rd, RV_REG_T1), ctx);
1264                 if (!is64 && !aux->verifier_zext)
1265                         emit_zext_32(rd, ctx);
1266                 break;
1267         case BPF_ALU | BPF_MOD | BPF_K:
1268         case BPF_ALU64 | BPF_MOD | BPF_K:
1269                 emit_imm(RV_REG_T1, imm, ctx);
1270                 emit(is64 ? rv_remu(rd, rd, RV_REG_T1) :
1271                      rv_remuw(rd, rd, RV_REG_T1), ctx);
1272                 if (!is64 && !aux->verifier_zext)
1273                         emit_zext_32(rd, ctx);
1274                 break;
1275         case BPF_ALU | BPF_LSH | BPF_K:
1276         case BPF_ALU64 | BPF_LSH | BPF_K:
1277                 emit_slli(rd, rd, imm, ctx);
1278
1279                 if (!is64 && !aux->verifier_zext)
1280                         emit_zext_32(rd, ctx);
1281                 break;
1282         case BPF_ALU | BPF_RSH | BPF_K:
1283         case BPF_ALU64 | BPF_RSH | BPF_K:
1284                 if (is64)
1285                         emit_srli(rd, rd, imm, ctx);
1286                 else
1287                         emit(rv_srliw(rd, rd, imm), ctx);
1288
1289                 if (!is64 && !aux->verifier_zext)
1290                         emit_zext_32(rd, ctx);
1291                 break;
1292         case BPF_ALU | BPF_ARSH | BPF_K:
1293         case BPF_ALU64 | BPF_ARSH | BPF_K:
1294                 if (is64)
1295                         emit_srai(rd, rd, imm, ctx);
1296                 else
1297                         emit(rv_sraiw(rd, rd, imm), ctx);
1298
1299                 if (!is64 && !aux->verifier_zext)
1300                         emit_zext_32(rd, ctx);
1301                 break;
1302
1303         /* JUMP off */
1304         case BPF_JMP | BPF_JA:
1305                 rvoff = rv_offset(i, off, ctx);
1306                 ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1307                 if (ret)
1308                         return ret;
1309                 break;
1310
1311         /* IF (dst COND src) JUMP off */
1312         case BPF_JMP | BPF_JEQ | BPF_X:
1313         case BPF_JMP32 | BPF_JEQ | BPF_X:
1314         case BPF_JMP | BPF_JGT | BPF_X:
1315         case BPF_JMP32 | BPF_JGT | BPF_X:
1316         case BPF_JMP | BPF_JLT | BPF_X:
1317         case BPF_JMP32 | BPF_JLT | BPF_X:
1318         case BPF_JMP | BPF_JGE | BPF_X:
1319         case BPF_JMP32 | BPF_JGE | BPF_X:
1320         case BPF_JMP | BPF_JLE | BPF_X:
1321         case BPF_JMP32 | BPF_JLE | BPF_X:
1322         case BPF_JMP | BPF_JNE | BPF_X:
1323         case BPF_JMP32 | BPF_JNE | BPF_X:
1324         case BPF_JMP | BPF_JSGT | BPF_X:
1325         case BPF_JMP32 | BPF_JSGT | BPF_X:
1326         case BPF_JMP | BPF_JSLT | BPF_X:
1327         case BPF_JMP32 | BPF_JSLT | BPF_X:
1328         case BPF_JMP | BPF_JSGE | BPF_X:
1329         case BPF_JMP32 | BPF_JSGE | BPF_X:
1330         case BPF_JMP | BPF_JSLE | BPF_X:
1331         case BPF_JMP32 | BPF_JSLE | BPF_X:
1332         case BPF_JMP | BPF_JSET | BPF_X:
1333         case BPF_JMP32 | BPF_JSET | BPF_X:
1334                 rvoff = rv_offset(i, off, ctx);
1335                 if (!is64) {
1336                         s = ctx->ninsns;
1337                         if (is_signed_bpf_cond(BPF_OP(code)))
1338                                 emit_sext_32_rd_rs(&rd, &rs, ctx);
1339                         else
1340                                 emit_zext_32_rd_rs(&rd, &rs, ctx);
1341                         e = ctx->ninsns;
1342
1343                         /* Adjust for extra insns */
1344                         rvoff -= ninsns_rvoff(e - s);
1345                 }
1346
1347                 if (BPF_OP(code) == BPF_JSET) {
1348                         /* Adjust for and */
1349                         rvoff -= 4;
1350                         emit_and(RV_REG_T1, rd, rs, ctx);
1351                         emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff,
1352                                     ctx);
1353                 } else {
1354                         emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1355                 }
1356                 break;
1357
1358         /* IF (dst COND imm) JUMP off */
1359         case BPF_JMP | BPF_JEQ | BPF_K:
1360         case BPF_JMP32 | BPF_JEQ | BPF_K:
1361         case BPF_JMP | BPF_JGT | BPF_K:
1362         case BPF_JMP32 | BPF_JGT | BPF_K:
1363         case BPF_JMP | BPF_JLT | BPF_K:
1364         case BPF_JMP32 | BPF_JLT | BPF_K:
1365         case BPF_JMP | BPF_JGE | BPF_K:
1366         case BPF_JMP32 | BPF_JGE | BPF_K:
1367         case BPF_JMP | BPF_JLE | BPF_K:
1368         case BPF_JMP32 | BPF_JLE | BPF_K:
1369         case BPF_JMP | BPF_JNE | BPF_K:
1370         case BPF_JMP32 | BPF_JNE | BPF_K:
1371         case BPF_JMP | BPF_JSGT | BPF_K:
1372         case BPF_JMP32 | BPF_JSGT | BPF_K:
1373         case BPF_JMP | BPF_JSLT | BPF_K:
1374         case BPF_JMP32 | BPF_JSLT | BPF_K:
1375         case BPF_JMP | BPF_JSGE | BPF_K:
1376         case BPF_JMP32 | BPF_JSGE | BPF_K:
1377         case BPF_JMP | BPF_JSLE | BPF_K:
1378         case BPF_JMP32 | BPF_JSLE | BPF_K:
1379                 rvoff = rv_offset(i, off, ctx);
1380                 s = ctx->ninsns;
1381                 if (imm) {
1382                         emit_imm(RV_REG_T1, imm, ctx);
1383                         rs = RV_REG_T1;
1384                 } else {
1385                         /* If imm is 0, simply use zero register. */
1386                         rs = RV_REG_ZERO;
1387                 }
1388                 if (!is64) {
1389                         if (is_signed_bpf_cond(BPF_OP(code)))
1390                                 emit_sext_32_rd(&rd, ctx);
1391                         else
1392                                 emit_zext_32_rd_t1(&rd, ctx);
1393                 }
1394                 e = ctx->ninsns;
1395
1396                 /* Adjust for extra insns */
1397                 rvoff -= ninsns_rvoff(e - s);
1398                 emit_branch(BPF_OP(code), rd, rs, rvoff, ctx);
1399                 break;
1400
1401         case BPF_JMP | BPF_JSET | BPF_K:
1402         case BPF_JMP32 | BPF_JSET | BPF_K:
1403                 rvoff = rv_offset(i, off, ctx);
1404                 s = ctx->ninsns;
1405                 if (is_12b_int(imm)) {
1406                         emit_andi(RV_REG_T1, rd, imm, ctx);
1407                 } else {
1408                         emit_imm(RV_REG_T1, imm, ctx);
1409                         emit_and(RV_REG_T1, rd, RV_REG_T1, ctx);
1410                 }
1411                 /* For jset32, we should clear the upper 32 bits of t1, but
1412                  * sign-extension is sufficient here and saves one instruction,
1413                  * as t1 is used only in comparison against zero.
1414                  */
1415                 if (!is64 && imm < 0)
1416                         emit_addiw(RV_REG_T1, RV_REG_T1, 0, ctx);
1417                 e = ctx->ninsns;
1418                 rvoff -= ninsns_rvoff(e - s);
1419                 emit_branch(BPF_JNE, RV_REG_T1, RV_REG_ZERO, rvoff, ctx);
1420                 break;
1421
1422         /* function call */
1423         case BPF_JMP | BPF_CALL:
1424         {
1425                 bool fixed_addr;
1426                 u64 addr;
1427
1428                 mark_call(ctx);
1429                 ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
1430                                             &addr, &fixed_addr);
1431                 if (ret < 0)
1432                         return ret;
1433
1434                 ret = emit_call(addr, fixed_addr, ctx);
1435                 if (ret)
1436                         return ret;
1437
1438                 emit_mv(bpf_to_rv_reg(BPF_REG_0, ctx), RV_REG_A0, ctx);
1439                 break;
1440         }
1441         /* tail call */
1442         case BPF_JMP | BPF_TAIL_CALL:
1443                 if (emit_bpf_tail_call(i, ctx))
1444                         return -1;
1445                 break;
1446
1447         /* function return */
1448         case BPF_JMP | BPF_EXIT:
1449                 if (i == ctx->prog->len - 1)
1450                         break;
1451
1452                 rvoff = epilogue_offset(ctx);
1453                 ret = emit_jump_and_link(RV_REG_ZERO, rvoff, true, ctx);
1454                 if (ret)
1455                         return ret;
1456                 break;
1457
1458         /* dst = imm64 */
1459         case BPF_LD | BPF_IMM | BPF_DW:
1460         {
1461                 struct bpf_insn insn1 = insn[1];
1462                 u64 imm64;
1463
1464                 imm64 = (u64)insn1.imm << 32 | (u32)imm;
1465                 if (bpf_pseudo_func(insn)) {
1466                         /* fixed-length insns for extra jit pass */
1467                         ret = emit_addr(rd, imm64, extra_pass, ctx);
1468                         if (ret)
1469                                 return ret;
1470                 } else {
1471                         emit_imm(rd, imm64, ctx);
1472                 }
1473
1474                 return 1;
1475         }
1476
1477         /* LDX: dst = *(size *)(src + off) */
1478         case BPF_LDX | BPF_MEM | BPF_B:
1479         case BPF_LDX | BPF_MEM | BPF_H:
1480         case BPF_LDX | BPF_MEM | BPF_W:
1481         case BPF_LDX | BPF_MEM | BPF_DW:
1482         case BPF_LDX | BPF_PROBE_MEM | BPF_B:
1483         case BPF_LDX | BPF_PROBE_MEM | BPF_H:
1484         case BPF_LDX | BPF_PROBE_MEM | BPF_W:
1485         case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
1486         {
1487                 int insn_len, insns_start;
1488
1489                 switch (BPF_SIZE(code)) {
1490                 case BPF_B:
1491                         if (is_12b_int(off)) {
1492                                 insns_start = ctx->ninsns;
1493                                 emit(rv_lbu(rd, off, rs), ctx);
1494                                 insn_len = ctx->ninsns - insns_start;
1495                                 break;
1496                         }
1497
1498                         emit_imm(RV_REG_T1, off, ctx);
1499                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1500                         insns_start = ctx->ninsns;
1501                         emit(rv_lbu(rd, 0, RV_REG_T1), ctx);
1502                         insn_len = ctx->ninsns - insns_start;
1503                         if (insn_is_zext(&insn[1]))
1504                                 return 1;
1505                         break;
1506                 case BPF_H:
1507                         if (is_12b_int(off)) {
1508                                 insns_start = ctx->ninsns;
1509                                 emit(rv_lhu(rd, off, rs), ctx);
1510                                 insn_len = ctx->ninsns - insns_start;
1511                                 break;
1512                         }
1513
1514                         emit_imm(RV_REG_T1, off, ctx);
1515                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1516                         insns_start = ctx->ninsns;
1517                         emit(rv_lhu(rd, 0, RV_REG_T1), ctx);
1518                         insn_len = ctx->ninsns - insns_start;
1519                         if (insn_is_zext(&insn[1]))
1520                                 return 1;
1521                         break;
1522                 case BPF_W:
1523                         if (is_12b_int(off)) {
1524                                 insns_start = ctx->ninsns;
1525                                 emit(rv_lwu(rd, off, rs), ctx);
1526                                 insn_len = ctx->ninsns - insns_start;
1527                                 break;
1528                         }
1529
1530                         emit_imm(RV_REG_T1, off, ctx);
1531                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1532                         insns_start = ctx->ninsns;
1533                         emit(rv_lwu(rd, 0, RV_REG_T1), ctx);
1534                         insn_len = ctx->ninsns - insns_start;
1535                         if (insn_is_zext(&insn[1]))
1536                                 return 1;
1537                         break;
1538                 case BPF_DW:
1539                         if (is_12b_int(off)) {
1540                                 insns_start = ctx->ninsns;
1541                                 emit_ld(rd, off, rs, ctx);
1542                                 insn_len = ctx->ninsns - insns_start;
1543                                 break;
1544                         }
1545
1546                         emit_imm(RV_REG_T1, off, ctx);
1547                         emit_add(RV_REG_T1, RV_REG_T1, rs, ctx);
1548                         insns_start = ctx->ninsns;
1549                         emit_ld(rd, 0, RV_REG_T1, ctx);
1550                         insn_len = ctx->ninsns - insns_start;
1551                         break;
1552                 }
1553
1554                 ret = add_exception_handler(insn, ctx, rd, insn_len);
1555                 if (ret)
1556                         return ret;
1557                 break;
1558         }
1559         /* speculation barrier */
1560         case BPF_ST | BPF_NOSPEC:
1561                 break;
1562
1563         /* ST: *(size *)(dst + off) = imm */
1564         case BPF_ST | BPF_MEM | BPF_B:
1565                 emit_imm(RV_REG_T1, imm, ctx);
1566                 if (is_12b_int(off)) {
1567                         emit(rv_sb(rd, off, RV_REG_T1), ctx);
1568                         break;
1569                 }
1570
1571                 emit_imm(RV_REG_T2, off, ctx);
1572                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1573                 emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
1574                 break;
1575
1576         case BPF_ST | BPF_MEM | BPF_H:
1577                 emit_imm(RV_REG_T1, imm, ctx);
1578                 if (is_12b_int(off)) {
1579                         emit(rv_sh(rd, off, RV_REG_T1), ctx);
1580                         break;
1581                 }
1582
1583                 emit_imm(RV_REG_T2, off, ctx);
1584                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1585                 emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
1586                 break;
1587         case BPF_ST | BPF_MEM | BPF_W:
1588                 emit_imm(RV_REG_T1, imm, ctx);
1589                 if (is_12b_int(off)) {
1590                         emit_sw(rd, off, RV_REG_T1, ctx);
1591                         break;
1592                 }
1593
1594                 emit_imm(RV_REG_T2, off, ctx);
1595                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1596                 emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
1597                 break;
1598         case BPF_ST | BPF_MEM | BPF_DW:
1599                 emit_imm(RV_REG_T1, imm, ctx);
1600                 if (is_12b_int(off)) {
1601                         emit_sd(rd, off, RV_REG_T1, ctx);
1602                         break;
1603                 }
1604
1605                 emit_imm(RV_REG_T2, off, ctx);
1606                 emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
1607                 emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
1608                 break;
1609
1610         /* STX: *(size *)(dst + off) = src */
1611         case BPF_STX | BPF_MEM | BPF_B:
1612                 if (is_12b_int(off)) {
1613                         emit(rv_sb(rd, off, rs), ctx);
1614                         break;
1615                 }
1616
1617                 emit_imm(RV_REG_T1, off, ctx);
1618                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1619                 emit(rv_sb(RV_REG_T1, 0, rs), ctx);
1620                 break;
1621         case BPF_STX | BPF_MEM | BPF_H:
1622                 if (is_12b_int(off)) {
1623                         emit(rv_sh(rd, off, rs), ctx);
1624                         break;
1625                 }
1626
1627                 emit_imm(RV_REG_T1, off, ctx);
1628                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1629                 emit(rv_sh(RV_REG_T1, 0, rs), ctx);
1630                 break;
1631         case BPF_STX | BPF_MEM | BPF_W:
1632                 if (is_12b_int(off)) {
1633                         emit_sw(rd, off, rs, ctx);
1634                         break;
1635                 }
1636
1637                 emit_imm(RV_REG_T1, off, ctx);
1638                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1639                 emit_sw(RV_REG_T1, 0, rs, ctx);
1640                 break;
1641         case BPF_STX | BPF_MEM | BPF_DW:
1642                 if (is_12b_int(off)) {
1643                         emit_sd(rd, off, rs, ctx);
1644                         break;
1645                 }
1646
1647                 emit_imm(RV_REG_T1, off, ctx);
1648                 emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
1649                 emit_sd(RV_REG_T1, 0, rs, ctx);
1650                 break;
1651         case BPF_STX | BPF_ATOMIC | BPF_W:
1652         case BPF_STX | BPF_ATOMIC | BPF_DW:
1653                 emit_atomic(rd, rs, off, imm,
1654                             BPF_SIZE(code) == BPF_DW, ctx);
1655                 break;
1656         default:
1657                 pr_err("bpf-jit: unknown opcode %02x\n", code);
1658                 return -EINVAL;
1659         }
1660
1661         return 0;
1662 }
1663
1664 void bpf_jit_build_prologue(struct rv_jit_context *ctx)
1665 {
1666         int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
1667
1668         bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
1669         if (bpf_stack_adjust)
1670                 mark_fp(ctx);
1671
1672         if (seen_reg(RV_REG_RA, ctx))
1673                 stack_adjust += 8;
1674         stack_adjust += 8; /* RV_REG_FP */
1675         if (seen_reg(RV_REG_S1, ctx))
1676                 stack_adjust += 8;
1677         if (seen_reg(RV_REG_S2, ctx))
1678                 stack_adjust += 8;
1679         if (seen_reg(RV_REG_S3, ctx))
1680                 stack_adjust += 8;
1681         if (seen_reg(RV_REG_S4, ctx))
1682                 stack_adjust += 8;
1683         if (seen_reg(RV_REG_S5, ctx))
1684                 stack_adjust += 8;
1685         if (seen_reg(RV_REG_S6, ctx))
1686                 stack_adjust += 8;
1687
1688         stack_adjust = round_up(stack_adjust, 16);
1689         stack_adjust += bpf_stack_adjust;
1690
1691         store_offset = stack_adjust - 8;
1692
1693         /* reserve 4 nop insns */
1694         for (i = 0; i < 4; i++)
1695                 emit(rv_nop(), ctx);
1696
1697         /* First instruction is always setting the tail-call-counter
1698          * (TCC) register. This instruction is skipped for tail calls.
1699          * Force using a 4-byte (non-compressed) instruction.
1700          */
1701         emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
1702
1703         emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
1704
1705         if (seen_reg(RV_REG_RA, ctx)) {
1706                 emit_sd(RV_REG_SP, store_offset, RV_REG_RA, ctx);
1707                 store_offset -= 8;
1708         }
1709         emit_sd(RV_REG_SP, store_offset, RV_REG_FP, ctx);
1710         store_offset -= 8;
1711         if (seen_reg(RV_REG_S1, ctx)) {
1712                 emit_sd(RV_REG_SP, store_offset, RV_REG_S1, ctx);
1713                 store_offset -= 8;
1714         }
1715         if (seen_reg(RV_REG_S2, ctx)) {
1716                 emit_sd(RV_REG_SP, store_offset, RV_REG_S2, ctx);
1717                 store_offset -= 8;
1718         }
1719         if (seen_reg(RV_REG_S3, ctx)) {
1720                 emit_sd(RV_REG_SP, store_offset, RV_REG_S3, ctx);
1721                 store_offset -= 8;
1722         }
1723         if (seen_reg(RV_REG_S4, ctx)) {
1724                 emit_sd(RV_REG_SP, store_offset, RV_REG_S4, ctx);
1725                 store_offset -= 8;
1726         }
1727         if (seen_reg(RV_REG_S5, ctx)) {
1728                 emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
1729                 store_offset -= 8;
1730         }
1731         if (seen_reg(RV_REG_S6, ctx)) {
1732                 emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
1733                 store_offset -= 8;
1734         }
1735
1736         emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
1737
1738         if (bpf_stack_adjust)
1739                 emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
1740
1741         /* Program contains calls and tail calls, so RV_REG_TCC need
1742          * to be saved across calls.
1743          */
1744         if (seen_tail_call(ctx) && seen_call(ctx))
1745                 emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
1746
1747         ctx->stack_size = stack_adjust;
1748 }
1749
1750 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
1751 {
1752         __build_epilogue(false, ctx);
1753 }