LoongArch: BPF: Fix the tailcall hierarchy
authorHaoran Jiang <jianghaoran@kylinos.cn>
Tue, 5 Aug 2025 11:00:22 +0000 (19:00 +0800)
committerHuacai Chen <chenhuacai@loongson.cn>
Tue, 5 Aug 2025 11:00:22 +0000 (19:00 +0800)
In specific use cases combining tailcalls and BPF-to-BPF calls,
MAX_TAIL_CALL_CNT won't work because of missing tail_call_cnt
back-propagation from callee to caller. This patch fixes this
tailcall issue caused by abusing the tailcall in bpf2bpf feature
on LoongArch like the way of "bpf, x64: Fix tailcall hierarchy".

Push tail_call_cnt_ptr and tail_call_cnt into the stack,
tail_call_cnt_ptr is passed between tailcall and bpf2bpf,
uses tail_call_cnt_ptr to increment tail_call_cnt.

Fixes: bb035ef0cc91 ("LoongArch: BPF: Support mixing bpf2bpf and tailcalls")
Reviewed-by: Geliang Tang <geliang@kernel.org>
Reviewed-by: Hengqi Chen <hengqi.chen@gmail.com>
Signed-off-by: Haoran Jiang <jianghaoran@kylinos.cn>
Signed-off-by: Huacai Chen <chenhuacai@loongson.cn>
arch/loongarch/net/bpf_jit.c

index f4f12ed16d2f21e62bef32985bf040ba7434d5f7..4ea8ae4cf0ca0a523b952c4275d9eff215cdcb49 100644 (file)
 #define LOONGARCH_BPF_FENTRY_NBYTES (LOONGARCH_LONG_JUMP_NINSNS * 4)
 
 #define REG_TCC                LOONGARCH_GPR_A6
-#define TCC_SAVED      LOONGARCH_GPR_S5
-
-#define SAVE_RA                BIT(0)
-#define SAVE_TCC       BIT(1)
+#define BPF_TAIL_CALL_CNT_PTR_STACK_OFF(stack) (round_up(stack, 16) - 80)
 
 static const int regmap[] = {
        /* return value from in-kernel function, and exit value for eBPF program */
@@ -42,32 +39,57 @@ static const int regmap[] = {
        [BPF_REG_AX] = LOONGARCH_GPR_T0,
 };
 
-static void mark_call(struct jit_ctx *ctx)
+static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx, int *store_offset)
 {
-       ctx->flags |= SAVE_RA;
-}
+       const struct bpf_prog *prog = ctx->prog;
+       const bool is_main_prog = !bpf_is_subprog(prog);
 
-static void mark_tail_call(struct jit_ctx *ctx)
-{
-       ctx->flags |= SAVE_TCC;
-}
+       if (is_main_prog) {
+               /*
+                * LOONGARCH_GPR_T3 = MAX_TAIL_CALL_CNT
+                * if (REG_TCC > T3 )
+                *      std REG_TCC -> LOONGARCH_GPR_SP + store_offset
+                * else
+                *      std REG_TCC -> LOONGARCH_GPR_SP + store_offset
+                *      REG_TCC = LOONGARCH_GPR_SP + store_offset
+                *
+                * std REG_TCC -> LOONGARCH_GPR_SP + store_offset
+                *
+                * The purpose of this code is to first push the TCC into stack,
+                * and then push the address of TCC into stack.
+                * In cases where bpf2bpf and tailcall are used in combination,
+                * the value in REG_TCC may be a count or an address,
+                * these two cases need to be judged and handled separately.
+                */
+               emit_insn(ctx, addid, LOONGARCH_GPR_T3, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
+               *store_offset -= sizeof(long);
 
-static bool seen_call(struct jit_ctx *ctx)
-{
-       return (ctx->flags & SAVE_RA);
-}
+               emit_cond_jmp(ctx, BPF_JGT, REG_TCC, LOONGARCH_GPR_T3, 4);
 
-static bool seen_tail_call(struct jit_ctx *ctx)
-{
-       return (ctx->flags & SAVE_TCC);
-}
+               /*
+                * If REG_TCC < MAX_TAIL_CALL_CNT, the value in REG_TCC is a count,
+                * push tcc into stack
+                */
+               emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
 
-static u8 tail_call_reg(struct jit_ctx *ctx)
-{
-       if (seen_call(ctx))
-               return TCC_SAVED;
+               /* Push the address of TCC into the REG_TCC */
+               emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
 
-       return REG_TCC;
+               emit_uncond_jmp(ctx, 2);
+
+               /*
+                * If REG_TCC > MAX_TAIL_CALL_CNT, the value in REG_TCC is an address,
+                * push tcc_ptr into stack
+                */
+               emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
+       } else {
+               *store_offset -= sizeof(long);
+               emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
+       }
+
+       /* Push tcc_ptr into stack */
+       *store_offset -= sizeof(long);
+       emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_SP, *store_offset);
 }
 
 /*
@@ -90,6 +112,10 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
  *                            |           $s4           |
  *                            +-------------------------+
  *                            |           $s5           |
+ *                            +-------------------------+
+ *                            |           tcc           |
+ *                            +-------------------------+
+ *                            |           tcc_ptr       |
  *                            +-------------------------+ <--BPF_REG_FP
  *                            |  prog->aux->stack_depth |
  *                            |        (optional)       |
@@ -99,12 +125,17 @@ static u8 tail_call_reg(struct jit_ctx *ctx)
 static void build_prologue(struct jit_ctx *ctx)
 {
        int i, stack_adjust = 0, store_offset, bpf_stack_adjust;
+       const struct bpf_prog *prog = ctx->prog;
+       const bool is_main_prog = !bpf_is_subprog(prog);
 
        bpf_stack_adjust = round_up(ctx->prog->aux->stack_depth, 16);
 
-       /* To store ra, fp, s0, s1, s2, s3, s4 and s5. */
+       /* To store ra, fp, s0, s1, s2, s3, s4, s5 */
        stack_adjust += sizeof(long) * 8;
 
+       /* To store tcc and tcc_ptr */
+       stack_adjust += sizeof(long) * 2;
+
        stack_adjust = round_up(stack_adjust, 16);
        stack_adjust += bpf_stack_adjust;
 
@@ -113,11 +144,12 @@ static void build_prologue(struct jit_ctx *ctx)
                emit_insn(ctx, nop);
 
        /*
-        * First instruction initializes the tail call count (TCC).
-        * On tail call we skip this instruction, and the TCC is
-        * passed in REG_TCC from the caller.
+        * First instruction initializes the tail call count (TCC)
+        * register to zero. On tail call we skip this instruction,
+        * and the TCC is passed in REG_TCC from the caller.
         */
-       emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
+       if (is_main_prog)
+               emit_insn(ctx, addid, REG_TCC, LOONGARCH_GPR_ZERO, 0);
 
        emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, -stack_adjust);
 
@@ -145,20 +177,13 @@ static void build_prologue(struct jit_ctx *ctx)
        store_offset -= sizeof(long);
        emit_insn(ctx, std, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, store_offset);
 
+       prepare_bpf_tail_call_cnt(ctx, &store_offset);
+
        emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_adjust);
 
        if (bpf_stack_adjust)
                emit_insn(ctx, addid, regmap[BPF_REG_FP], LOONGARCH_GPR_SP, bpf_stack_adjust);
 
-       /*
-        * Program contains calls and tail calls, so REG_TCC need
-        * to be saved across calls.
-        */
-       if (seen_tail_call(ctx) && seen_call(ctx))
-               move_reg(ctx, TCC_SAVED, REG_TCC);
-       else
-               emit_insn(ctx, nop);
-
        ctx->stack_size = stack_adjust;
 }
 
@@ -191,6 +216,16 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
        load_offset -= sizeof(long);
        emit_insn(ctx, ldd, LOONGARCH_GPR_S5, LOONGARCH_GPR_SP, load_offset);
 
+       /*
+        * When push into the stack, follow the order of tcc then tcc_ptr.
+        * When pop from the stack, first pop tcc_ptr then followed by tcc.
+        */
+       load_offset -= 2 * sizeof(long);
+       emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
+
+       load_offset += sizeof(long);
+       emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, load_offset);
+
        emit_insn(ctx, addid, LOONGARCH_GPR_SP, LOONGARCH_GPR_SP, stack_adjust);
 
        if (!is_tail_call) {
@@ -203,7 +238,7 @@ static void __build_epilogue(struct jit_ctx *ctx, bool is_tail_call)
                 * Call the next bpf prog and skip the first instruction
                 * of TCC initialization.
                 */
-               emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 1);
+               emit_insn(ctx, jirl, LOONGARCH_GPR_ZERO, LOONGARCH_GPR_T3, 6);
        }
 }
 
@@ -225,7 +260,7 @@ bool bpf_jit_supports_far_kfunc_call(void)
 static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
 {
        int off, tc_ninsn = 0;
-       u8 tcc = tail_call_reg(ctx);
+       int tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
        u8 a1 = LOONGARCH_GPR_A1;
        u8 a2 = LOONGARCH_GPR_A2;
        u8 t1 = LOONGARCH_GPR_T1;
@@ -252,11 +287,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx, int insn)
                goto toofar;
 
        /*
-        * if (--TCC < 0)
-        *       goto out;
+        * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT)
+        *      goto out;
         */
-       emit_insn(ctx, addid, REG_TCC, tcc, -1);
-       if (emit_tailcall_jmp(ctx, BPF_JSLT, REG_TCC, LOONGARCH_GPR_ZERO, jmp_offset) < 0)
+       emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
+       emit_insn(ctx, ldd, t3, REG_TCC, 0);
+       emit_insn(ctx, addid, t3, t3, 1);
+       emit_insn(ctx, std, t3, REG_TCC, 0);
+       emit_insn(ctx, addid, t2, LOONGARCH_GPR_ZERO, MAX_TAIL_CALL_CNT);
+       if (emit_tailcall_jmp(ctx, BPF_JSGT, t3, t2, jmp_offset) < 0)
                goto toofar;
 
        /*
@@ -467,7 +506,7 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
        u64 func_addr;
        bool func_addr_fixed, sign_extend;
        int i = insn - ctx->prog->insnsi;
-       int ret, jmp_offset;
+       int ret, jmp_offset, tcc_ptr_off;
        const u8 code = insn->code;
        const u8 cond = BPF_OP(code);
        const u8 t1 = LOONGARCH_GPR_T1;
@@ -903,12 +942,16 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
 
        /* function call */
        case BPF_JMP | BPF_CALL:
-               mark_call(ctx);
                ret = bpf_jit_get_func_addr(ctx->prog, insn, extra_pass,
                                            &func_addr, &func_addr_fixed);
                if (ret < 0)
                        return ret;
 
+               if (insn->src_reg == BPF_PSEUDO_CALL) {
+                       tcc_ptr_off = BPF_TAIL_CALL_CNT_PTR_STACK_OFF(ctx->stack_size);
+                       emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_SP, tcc_ptr_off);
+               }
+
                move_addr(ctx, t1, func_addr);
                emit_insn(ctx, jirl, LOONGARCH_GPR_RA, t1, 0);
 
@@ -919,7 +962,6 @@ static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx, bool ext
 
        /* tail call */
        case BPF_JMP | BPF_TAIL_CALL:
-               mark_tail_call(ctx);
                if (emit_bpf_tail_call(ctx, i) < 0)
                        return -EINVAL;
                break;
@@ -1412,7 +1454,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
 {
        int i, ret, save_ret;
        int stack_size = 0, nargs = 0;
-       int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off;
+       int retval_off, args_off, nargs_off, ip_off, run_ctx_off, sreg_off, tcc_ptr_off;
        bool is_struct_ops = flags & BPF_TRAMP_F_INDIRECT;
        void *orig_call = func_addr;
        struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
@@ -1447,6 +1489,7 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
         *
         * FP - sreg_off    [ callee saved reg  ]
         *
+        * FP - tcc_ptr_off [ tail_call_cnt_ptr ]
         */
 
        if (m->nr_args > LOONGARCH_MAX_REG_ARGS)
@@ -1489,6 +1532,12 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
        stack_size += 8;
        sreg_off = stack_size;
 
+       /* Room of trampoline frame to store tail_call_cnt_ptr */
+       if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
+               stack_size += 8;
+               tcc_ptr_off = stack_size;
+       }
+
        stack_size = round_up(stack_size, 16);
 
        if (is_struct_ops) {
@@ -1519,6 +1568,9 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
                emit_insn(ctx, addid, LOONGARCH_GPR_FP, LOONGARCH_GPR_SP, stack_size);
        }
 
+       if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+               emit_insn(ctx, std, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
+
        /* callee saved register S1 to pass start time */
        emit_insn(ctx, std, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off);
 
@@ -1565,6 +1617,10 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
 
        if (flags & BPF_TRAMP_F_CALL_ORIG) {
                restore_args(ctx, m->nr_args, args_off);
+
+               if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+                       emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
+
                ret = emit_call(ctx, (const u64)orig_call);
                if (ret)
                        goto out;
@@ -1605,6 +1661,9 @@ static int __arch_prepare_bpf_trampoline(struct jit_ctx *ctx, struct bpf_tramp_i
 
        emit_insn(ctx, ldd, LOONGARCH_GPR_S1, LOONGARCH_GPR_FP, -sreg_off);
 
+       if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+               emit_insn(ctx, ldd, REG_TCC, LOONGARCH_GPR_FP, -tcc_ptr_off);
+
        if (is_struct_ops) {
                /* trampoline called directly */
                emit_insn(ctx, ldd, LOONGARCH_GPR_RA, LOONGARCH_GPR_SP, stack_size - 8);