bpf: Refactor x86 JIT into helpers
[linux-block.git] / arch / x86 / net / bpf_jit_comp.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * bpf_jit_comp.c: BPF JIT compiler
4  *
5  * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
6  * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
7  */
8 #include <linux/netdevice.h>
9 #include <linux/filter.h>
10 #include <linux/if_vlan.h>
11 #include <linux/bpf.h>
12 #include <asm/extable.h>
13 #include <asm/set_memory.h>
14 #include <asm/nospec-branch.h>
15
16 static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
17 {
18         if (len == 1)
19                 *ptr = bytes;
20         else if (len == 2)
21                 *(u16 *)ptr = bytes;
22         else {
23                 *(u32 *)ptr = bytes;
24                 barrier();
25         }
26         return ptr + len;
27 }
28
29 #define EMIT(bytes, len) \
30         do { prog = emit_code(prog, bytes, len); cnt += len; } while (0)
31
32 #define EMIT1(b1)               EMIT(b1, 1)
33 #define EMIT2(b1, b2)           EMIT((b1) + ((b2) << 8), 2)
34 #define EMIT3(b1, b2, b3)       EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
35 #define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
36
37 #define EMIT1_off32(b1, off) \
38         do { EMIT1(b1); EMIT(off, 4); } while (0)
39 #define EMIT2_off32(b1, b2, off) \
40         do { EMIT2(b1, b2); EMIT(off, 4); } while (0)
41 #define EMIT3_off32(b1, b2, b3, off) \
42         do { EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
43 #define EMIT4_off32(b1, b2, b3, b4, off) \
44         do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
45
46 static bool is_imm8(int value)
47 {
48         return value <= 127 && value >= -128;
49 }
50
51 static bool is_simm32(s64 value)
52 {
53         return value == (s64)(s32)value;
54 }
55
56 static bool is_uimm32(u64 value)
57 {
58         return value == (u64)(u32)value;
59 }
60
61 /* mov dst, src */
62 #define EMIT_mov(DST, SRC)                                                               \
63         do {                                                                             \
64                 if (DST != SRC)                                                          \
65                         EMIT3(add_2mod(0x48, DST, SRC), 0x89, add_2reg(0xC0, DST, SRC)); \
66         } while (0)
67
68 static int bpf_size_to_x86_bytes(int bpf_size)
69 {
70         if (bpf_size == BPF_W)
71                 return 4;
72         else if (bpf_size == BPF_H)
73                 return 2;
74         else if (bpf_size == BPF_B)
75                 return 1;
76         else if (bpf_size == BPF_DW)
77                 return 4; /* imm32 */
78         else
79                 return 0;
80 }
81
82 /*
83  * List of x86 cond jumps opcodes (. + s8)
84  * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
85  */
86 #define X86_JB  0x72
87 #define X86_JAE 0x73
88 #define X86_JE  0x74
89 #define X86_JNE 0x75
90 #define X86_JBE 0x76
91 #define X86_JA  0x77
92 #define X86_JL  0x7C
93 #define X86_JGE 0x7D
94 #define X86_JLE 0x7E
95 #define X86_JG  0x7F
96
97 /* Pick a register outside of BPF range for JIT internal work */
98 #define AUX_REG (MAX_BPF_JIT_REG + 1)
99
100 /*
101  * The following table maps BPF registers to x86-64 registers.
102  *
103  * x86-64 register R12 is unused, since if used as base address
104  * register in load/store instructions, it always needs an
105  * extra byte of encoding and is callee saved.
106  *
107  * Also x86-64 register R9 is unused. x86-64 register R10 is
108  * used for blinding (if enabled).
109  */
110 static const int reg2hex[] = {
111         [BPF_REG_0] = 0,  /* RAX */
112         [BPF_REG_1] = 7,  /* RDI */
113         [BPF_REG_2] = 6,  /* RSI */
114         [BPF_REG_3] = 2,  /* RDX */
115         [BPF_REG_4] = 1,  /* RCX */
116         [BPF_REG_5] = 0,  /* R8  */
117         [BPF_REG_6] = 3,  /* RBX callee saved */
118         [BPF_REG_7] = 5,  /* R13 callee saved */
119         [BPF_REG_8] = 6,  /* R14 callee saved */
120         [BPF_REG_9] = 7,  /* R15 callee saved */
121         [BPF_REG_FP] = 5, /* RBP readonly */
122         [BPF_REG_AX] = 2, /* R10 temp register */
123         [AUX_REG] = 3,    /* R11 temp register */
124 };
125
126 static const int reg2pt_regs[] = {
127         [BPF_REG_0] = offsetof(struct pt_regs, ax),
128         [BPF_REG_1] = offsetof(struct pt_regs, di),
129         [BPF_REG_2] = offsetof(struct pt_regs, si),
130         [BPF_REG_3] = offsetof(struct pt_regs, dx),
131         [BPF_REG_4] = offsetof(struct pt_regs, cx),
132         [BPF_REG_5] = offsetof(struct pt_regs, r8),
133         [BPF_REG_6] = offsetof(struct pt_regs, bx),
134         [BPF_REG_7] = offsetof(struct pt_regs, r13),
135         [BPF_REG_8] = offsetof(struct pt_regs, r14),
136         [BPF_REG_9] = offsetof(struct pt_regs, r15),
137 };
138
139 /*
140  * is_ereg() == true if BPF register 'reg' maps to x86-64 r8..r15
141  * which need extra byte of encoding.
142  * rax,rcx,...,rbp have simpler encoding
143  */
144 static bool is_ereg(u32 reg)
145 {
146         return (1 << reg) & (BIT(BPF_REG_5) |
147                              BIT(AUX_REG) |
148                              BIT(BPF_REG_7) |
149                              BIT(BPF_REG_8) |
150                              BIT(BPF_REG_9) |
151                              BIT(BPF_REG_AX));
152 }
153
154 static bool is_axreg(u32 reg)
155 {
156         return reg == BPF_REG_0;
157 }
158
159 /* Add modifiers if 'reg' maps to x86-64 registers R8..R15 */
160 static u8 add_1mod(u8 byte, u32 reg)
161 {
162         if (is_ereg(reg))
163                 byte |= 1;
164         return byte;
165 }
166
167 static u8 add_2mod(u8 byte, u32 r1, u32 r2)
168 {
169         if (is_ereg(r1))
170                 byte |= 1;
171         if (is_ereg(r2))
172                 byte |= 4;
173         return byte;
174 }
175
176 /* Encode 'dst_reg' register into x86-64 opcode 'byte' */
177 static u8 add_1reg(u8 byte, u32 dst_reg)
178 {
179         return byte + reg2hex[dst_reg];
180 }
181
182 /* Encode 'dst_reg' and 'src_reg' registers into x86-64 opcode 'byte' */
183 static u8 add_2reg(u8 byte, u32 dst_reg, u32 src_reg)
184 {
185         return byte + reg2hex[dst_reg] + (reg2hex[src_reg] << 3);
186 }
187
188 static void jit_fill_hole(void *area, unsigned int size)
189 {
190         /* Fill whole space with INT3 instructions */
191         memset(area, 0xcc, size);
192 }
193
194 struct jit_context {
195         int cleanup_addr; /* Epilogue code offset */
196 };
197
198 /* Maximum number of bytes emitted while JITing one eBPF insn */
199 #define BPF_MAX_INSN_SIZE       128
200 #define BPF_INSN_SAFETY         64
201 /* number of bytes emit_call() needs to generate call instruction */
202 #define X86_CALL_SIZE           5
203
204 #define PROLOGUE_SIZE           20
205
206 /*
207  * Emit x86-64 prologue code for BPF program and check its size.
208  * bpf_tail_call helper will skip it while jumping into another program
209  */
210 static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf)
211 {
212         u8 *prog = *pprog;
213         int cnt = 0;
214
215         EMIT1(0x55);             /* push rbp */
216         EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */
217         /* sub rsp, rounded_stack_depth */
218         EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8));
219         EMIT1(0x53);             /* push rbx */
220         EMIT2(0x41, 0x55);       /* push r13 */
221         EMIT2(0x41, 0x56);       /* push r14 */
222         EMIT2(0x41, 0x57);       /* push r15 */
223         if (!ebpf_from_cbpf) {
224                 /* zero init tail_call_cnt */
225                 EMIT2(0x6a, 0x00);
226                 BUILD_BUG_ON(cnt != PROLOGUE_SIZE);
227         }
228         *pprog = prog;
229 }
230
231 /*
232  * Generate the following code:
233  *
234  * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ...
235  *   if (index >= array->map.max_entries)
236  *     goto out;
237  *   if (++tail_call_cnt > MAX_TAIL_CALL_CNT)
238  *     goto out;
239  *   prog = array->ptrs[index];
240  *   if (prog == NULL)
241  *     goto out;
242  *   goto *(prog->bpf_func + prologue_size);
243  * out:
244  */
245 static void emit_bpf_tail_call(u8 **pprog)
246 {
247         u8 *prog = *pprog;
248         int label1, label2, label3;
249         int cnt = 0;
250
251         /*
252          * rdi - pointer to ctx
253          * rsi - pointer to bpf_array
254          * rdx - index in bpf_array
255          */
256
257         /*
258          * if (index >= array->map.max_entries)
259          *      goto out;
260          */
261         EMIT2(0x89, 0xD2);                        /* mov edx, edx */
262         EMIT3(0x39, 0x56,                         /* cmp dword ptr [rsi + 16], edx */
263               offsetof(struct bpf_array, map.max_entries));
264 #define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* Number of bytes to jump */
265         EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
266         label1 = cnt;
267
268         /*
269          * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
270          *      goto out;
271          */
272         EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */
273         EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
274 #define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE)
275         EMIT2(X86_JA, OFFSET2);                   /* ja out */
276         label2 = cnt;
277         EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
278         EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */
279
280         /* prog = array->ptrs[index]; */
281         EMIT4_off32(0x48, 0x8B, 0x84, 0xD6,       /* mov rax, [rsi + rdx * 8 + offsetof(...)] */
282                     offsetof(struct bpf_array, ptrs));
283
284         /*
285          * if (prog == NULL)
286          *      goto out;
287          */
288         EMIT3(0x48, 0x85, 0xC0);                  /* test rax,rax */
289 #define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE)
290         EMIT2(X86_JE, OFFSET3);                   /* je out */
291         label3 = cnt;
292
293         /* goto *(prog->bpf_func + prologue_size); */
294         EMIT4(0x48, 0x8B, 0x40,                   /* mov rax, qword ptr [rax + 32] */
295               offsetof(struct bpf_prog, bpf_func));
296         EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE);   /* add rax, prologue_size */
297
298         /*
299          * Wow we're ready to jump into next BPF program
300          * rdi == ctx (1st arg)
301          * rax == prog->bpf_func + prologue_size
302          */
303         RETPOLINE_RAX_BPF_JIT();
304
305         /* out: */
306         BUILD_BUG_ON(cnt - label1 != OFFSET1);
307         BUILD_BUG_ON(cnt - label2 != OFFSET2);
308         BUILD_BUG_ON(cnt - label3 != OFFSET3);
309         *pprog = prog;
310 }
311
312 static void emit_mov_imm32(u8 **pprog, bool sign_propagate,
313                            u32 dst_reg, const u32 imm32)
314 {
315         u8 *prog = *pprog;
316         u8 b1, b2, b3;
317         int cnt = 0;
318
319         /*
320          * Optimization: if imm32 is positive, use 'mov %eax, imm32'
321          * (which zero-extends imm32) to save 2 bytes.
322          */
323         if (sign_propagate && (s32)imm32 < 0) {
324                 /* 'mov %rax, imm32' sign extends imm32 */
325                 b1 = add_1mod(0x48, dst_reg);
326                 b2 = 0xC7;
327                 b3 = 0xC0;
328                 EMIT3_off32(b1, b2, add_1reg(b3, dst_reg), imm32);
329                 goto done;
330         }
331
332         /*
333          * Optimization: if imm32 is zero, use 'xor %eax, %eax'
334          * to save 3 bytes.
335          */
336         if (imm32 == 0) {
337                 if (is_ereg(dst_reg))
338                         EMIT1(add_2mod(0x40, dst_reg, dst_reg));
339                 b2 = 0x31; /* xor */
340                 b3 = 0xC0;
341                 EMIT2(b2, add_2reg(b3, dst_reg, dst_reg));
342                 goto done;
343         }
344
345         /* mov %eax, imm32 */
346         if (is_ereg(dst_reg))
347                 EMIT1(add_1mod(0x40, dst_reg));
348         EMIT1_off32(add_1reg(0xB8, dst_reg), imm32);
349 done:
350         *pprog = prog;
351 }
352
353 static void emit_mov_imm64(u8 **pprog, u32 dst_reg,
354                            const u32 imm32_hi, const u32 imm32_lo)
355 {
356         u8 *prog = *pprog;
357         int cnt = 0;
358
359         if (is_uimm32(((u64)imm32_hi << 32) | (u32)imm32_lo)) {
360                 /*
361                  * For emitting plain u32, where sign bit must not be
362                  * propagated LLVM tends to load imm64 over mov32
363                  * directly, so save couple of bytes by just doing
364                  * 'mov %eax, imm32' instead.
365                  */
366                 emit_mov_imm32(&prog, false, dst_reg, imm32_lo);
367         } else {
368                 /* movabsq %rax, imm64 */
369                 EMIT2(add_1mod(0x48, dst_reg), add_1reg(0xB8, dst_reg));
370                 EMIT(imm32_lo, 4);
371                 EMIT(imm32_hi, 4);
372         }
373
374         *pprog = prog;
375 }
376
377 static void emit_mov_reg(u8 **pprog, bool is64, u32 dst_reg, u32 src_reg)
378 {
379         u8 *prog = *pprog;
380         int cnt = 0;
381
382         if (is64) {
383                 /* mov dst, src */
384                 EMIT_mov(dst_reg, src_reg);
385         } else {
386                 /* mov32 dst, src */
387                 if (is_ereg(dst_reg) || is_ereg(src_reg))
388                         EMIT1(add_2mod(0x40, dst_reg, src_reg));
389                 EMIT2(0x89, add_2reg(0xC0, dst_reg, src_reg));
390         }
391
392         *pprog = prog;
393 }
394
395 /* LDX: dst_reg = *(u8*)(src_reg + off) */
396 static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
397 {
398         u8 *prog = *pprog;
399         int cnt = 0;
400
401         switch (size) {
402         case BPF_B:
403                 /* Emit 'movzx rax, byte ptr [rax + off]' */
404                 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
405                 break;
406         case BPF_H:
407                 /* Emit 'movzx rax, word ptr [rax + off]' */
408                 EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
409                 break;
410         case BPF_W:
411                 /* Emit 'mov eax, dword ptr [rax+0x14]' */
412                 if (is_ereg(dst_reg) || is_ereg(src_reg))
413                         EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
414                 else
415                         EMIT1(0x8B);
416                 break;
417         case BPF_DW:
418                 /* Emit 'mov rax, qword ptr [rax+0x14]' */
419                 EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
420                 break;
421         }
422         /*
423          * If insn->off == 0 we can save one extra byte, but
424          * special case of x86 R13 which always needs an offset
425          * is not worth the hassle
426          */
427         if (is_imm8(off))
428                 EMIT2(add_2reg(0x40, src_reg, dst_reg), off);
429         else
430                 EMIT1_off32(add_2reg(0x80, src_reg, dst_reg), off);
431         *pprog = prog;
432 }
433
434 /* STX: *(u8*)(dst_reg + off) = src_reg */
435 static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
436 {
437         u8 *prog = *pprog;
438         int cnt = 0;
439
440         switch (size) {
441         case BPF_B:
442                 /* Emit 'mov byte ptr [rax + off], al' */
443                 if (is_ereg(dst_reg) || is_ereg(src_reg) ||
444                     /* We have to add extra byte for x86 SIL, DIL regs */
445                     src_reg == BPF_REG_1 || src_reg == BPF_REG_2)
446                         EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x88);
447                 else
448                         EMIT1(0x88);
449                 break;
450         case BPF_H:
451                 if (is_ereg(dst_reg) || is_ereg(src_reg))
452                         EMIT3(0x66, add_2mod(0x40, dst_reg, src_reg), 0x89);
453                 else
454                         EMIT2(0x66, 0x89);
455                 break;
456         case BPF_W:
457                 if (is_ereg(dst_reg) || is_ereg(src_reg))
458                         EMIT2(add_2mod(0x40, dst_reg, src_reg), 0x89);
459                 else
460                         EMIT1(0x89);
461                 break;
462         case BPF_DW:
463                 EMIT2(add_2mod(0x48, dst_reg, src_reg), 0x89);
464                 break;
465         }
466         if (is_imm8(off))
467                 EMIT2(add_2reg(0x40, dst_reg, src_reg), off);
468         else
469                 EMIT1_off32(add_2reg(0x80, dst_reg, src_reg), off);
470         *pprog = prog;
471 }
472
473 static int emit_call(u8 **pprog, void *func, void *ip)
474 {
475         u8 *prog = *pprog;
476         int cnt = 0;
477         s64 offset;
478
479         offset = func - (ip + X86_CALL_SIZE);
480         if (!is_simm32(offset)) {
481                 pr_err("Target call %p is out of range\n", func);
482                 return -EINVAL;
483         }
484         EMIT1_off32(0xE8, offset);
485         *pprog = prog;
486         return 0;
487 }
488
489 static bool ex_handler_bpf(const struct exception_table_entry *x,
490                            struct pt_regs *regs, int trapnr,
491                            unsigned long error_code, unsigned long fault_addr)
492 {
493         u32 reg = x->fixup >> 8;
494
495         /* jump over faulting load and clear dest register */
496         *(unsigned long *)((void *)regs + reg) = 0;
497         regs->ip += x->fixup & 0xff;
498         return true;
499 }
500
501 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
502                   int oldproglen, struct jit_context *ctx)
503 {
504         struct bpf_insn *insn = bpf_prog->insnsi;
505         int insn_cnt = bpf_prog->len;
506         bool seen_exit = false;
507         u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
508         int i, cnt = 0, excnt = 0;
509         int proglen = 0;
510         u8 *prog = temp;
511
512         emit_prologue(&prog, bpf_prog->aux->stack_depth,
513                       bpf_prog_was_classic(bpf_prog));
514         addrs[0] = prog - temp;
515
516         for (i = 1; i <= insn_cnt; i++, insn++) {
517                 const s32 imm32 = insn->imm;
518                 u32 dst_reg = insn->dst_reg;
519                 u32 src_reg = insn->src_reg;
520                 u8 b2 = 0, b3 = 0;
521                 s64 jmp_offset;
522                 u8 jmp_cond;
523                 int ilen;
524                 u8 *func;
525
526                 switch (insn->code) {
527                         /* ALU */
528                 case BPF_ALU | BPF_ADD | BPF_X:
529                 case BPF_ALU | BPF_SUB | BPF_X:
530                 case BPF_ALU | BPF_AND | BPF_X:
531                 case BPF_ALU | BPF_OR | BPF_X:
532                 case BPF_ALU | BPF_XOR | BPF_X:
533                 case BPF_ALU64 | BPF_ADD | BPF_X:
534                 case BPF_ALU64 | BPF_SUB | BPF_X:
535                 case BPF_ALU64 | BPF_AND | BPF_X:
536                 case BPF_ALU64 | BPF_OR | BPF_X:
537                 case BPF_ALU64 | BPF_XOR | BPF_X:
538                         switch (BPF_OP(insn->code)) {
539                         case BPF_ADD: b2 = 0x01; break;
540                         case BPF_SUB: b2 = 0x29; break;
541                         case BPF_AND: b2 = 0x21; break;
542                         case BPF_OR: b2 = 0x09; break;
543                         case BPF_XOR: b2 = 0x31; break;
544                         }
545                         if (BPF_CLASS(insn->code) == BPF_ALU64)
546                                 EMIT1(add_2mod(0x48, dst_reg, src_reg));
547                         else if (is_ereg(dst_reg) || is_ereg(src_reg))
548                                 EMIT1(add_2mod(0x40, dst_reg, src_reg));
549                         EMIT2(b2, add_2reg(0xC0, dst_reg, src_reg));
550                         break;
551
552                 case BPF_ALU64 | BPF_MOV | BPF_X:
553                 case BPF_ALU | BPF_MOV | BPF_X:
554                         emit_mov_reg(&prog,
555                                      BPF_CLASS(insn->code) == BPF_ALU64,
556                                      dst_reg, src_reg);
557                         break;
558
559                         /* neg dst */
560                 case BPF_ALU | BPF_NEG:
561                 case BPF_ALU64 | BPF_NEG:
562                         if (BPF_CLASS(insn->code) == BPF_ALU64)
563                                 EMIT1(add_1mod(0x48, dst_reg));
564                         else if (is_ereg(dst_reg))
565                                 EMIT1(add_1mod(0x40, dst_reg));
566                         EMIT2(0xF7, add_1reg(0xD8, dst_reg));
567                         break;
568
569                 case BPF_ALU | BPF_ADD | BPF_K:
570                 case BPF_ALU | BPF_SUB | BPF_K:
571                 case BPF_ALU | BPF_AND | BPF_K:
572                 case BPF_ALU | BPF_OR | BPF_K:
573                 case BPF_ALU | BPF_XOR | BPF_K:
574                 case BPF_ALU64 | BPF_ADD | BPF_K:
575                 case BPF_ALU64 | BPF_SUB | BPF_K:
576                 case BPF_ALU64 | BPF_AND | BPF_K:
577                 case BPF_ALU64 | BPF_OR | BPF_K:
578                 case BPF_ALU64 | BPF_XOR | BPF_K:
579                         if (BPF_CLASS(insn->code) == BPF_ALU64)
580                                 EMIT1(add_1mod(0x48, dst_reg));
581                         else if (is_ereg(dst_reg))
582                                 EMIT1(add_1mod(0x40, dst_reg));
583
584                         /*
585                          * b3 holds 'normal' opcode, b2 short form only valid
586                          * in case dst is eax/rax.
587                          */
588                         switch (BPF_OP(insn->code)) {
589                         case BPF_ADD:
590                                 b3 = 0xC0;
591                                 b2 = 0x05;
592                                 break;
593                         case BPF_SUB:
594                                 b3 = 0xE8;
595                                 b2 = 0x2D;
596                                 break;
597                         case BPF_AND:
598                                 b3 = 0xE0;
599                                 b2 = 0x25;
600                                 break;
601                         case BPF_OR:
602                                 b3 = 0xC8;
603                                 b2 = 0x0D;
604                                 break;
605                         case BPF_XOR:
606                                 b3 = 0xF0;
607                                 b2 = 0x35;
608                                 break;
609                         }
610
611                         if (is_imm8(imm32))
612                                 EMIT3(0x83, add_1reg(b3, dst_reg), imm32);
613                         else if (is_axreg(dst_reg))
614                                 EMIT1_off32(b2, imm32);
615                         else
616                                 EMIT2_off32(0x81, add_1reg(b3, dst_reg), imm32);
617                         break;
618
619                 case BPF_ALU64 | BPF_MOV | BPF_K:
620                 case BPF_ALU | BPF_MOV | BPF_K:
621                         emit_mov_imm32(&prog, BPF_CLASS(insn->code) == BPF_ALU64,
622                                        dst_reg, imm32);
623                         break;
624
625                 case BPF_LD | BPF_IMM | BPF_DW:
626                         emit_mov_imm64(&prog, dst_reg, insn[1].imm, insn[0].imm);
627                         insn++;
628                         i++;
629                         break;
630
631                         /* dst %= src, dst /= src, dst %= imm32, dst /= imm32 */
632                 case BPF_ALU | BPF_MOD | BPF_X:
633                 case BPF_ALU | BPF_DIV | BPF_X:
634                 case BPF_ALU | BPF_MOD | BPF_K:
635                 case BPF_ALU | BPF_DIV | BPF_K:
636                 case BPF_ALU64 | BPF_MOD | BPF_X:
637                 case BPF_ALU64 | BPF_DIV | BPF_X:
638                 case BPF_ALU64 | BPF_MOD | BPF_K:
639                 case BPF_ALU64 | BPF_DIV | BPF_K:
640                         EMIT1(0x50); /* push rax */
641                         EMIT1(0x52); /* push rdx */
642
643                         if (BPF_SRC(insn->code) == BPF_X)
644                                 /* mov r11, src_reg */
645                                 EMIT_mov(AUX_REG, src_reg);
646                         else
647                                 /* mov r11, imm32 */
648                                 EMIT3_off32(0x49, 0xC7, 0xC3, imm32);
649
650                         /* mov rax, dst_reg */
651                         EMIT_mov(BPF_REG_0, dst_reg);
652
653                         /*
654                          * xor edx, edx
655                          * equivalent to 'xor rdx, rdx', but one byte less
656                          */
657                         EMIT2(0x31, 0xd2);
658
659                         if (BPF_CLASS(insn->code) == BPF_ALU64)
660                                 /* div r11 */
661                                 EMIT3(0x49, 0xF7, 0xF3);
662                         else
663                                 /* div r11d */
664                                 EMIT3(0x41, 0xF7, 0xF3);
665
666                         if (BPF_OP(insn->code) == BPF_MOD)
667                                 /* mov r11, rdx */
668                                 EMIT3(0x49, 0x89, 0xD3);
669                         else
670                                 /* mov r11, rax */
671                                 EMIT3(0x49, 0x89, 0xC3);
672
673                         EMIT1(0x5A); /* pop rdx */
674                         EMIT1(0x58); /* pop rax */
675
676                         /* mov dst_reg, r11 */
677                         EMIT_mov(dst_reg, AUX_REG);
678                         break;
679
680                 case BPF_ALU | BPF_MUL | BPF_K:
681                 case BPF_ALU | BPF_MUL | BPF_X:
682                 case BPF_ALU64 | BPF_MUL | BPF_K:
683                 case BPF_ALU64 | BPF_MUL | BPF_X:
684                 {
685                         bool is64 = BPF_CLASS(insn->code) == BPF_ALU64;
686
687                         if (dst_reg != BPF_REG_0)
688                                 EMIT1(0x50); /* push rax */
689                         if (dst_reg != BPF_REG_3)
690                                 EMIT1(0x52); /* push rdx */
691
692                         /* mov r11, dst_reg */
693                         EMIT_mov(AUX_REG, dst_reg);
694
695                         if (BPF_SRC(insn->code) == BPF_X)
696                                 emit_mov_reg(&prog, is64, BPF_REG_0, src_reg);
697                         else
698                                 emit_mov_imm32(&prog, is64, BPF_REG_0, imm32);
699
700                         if (is64)
701                                 EMIT1(add_1mod(0x48, AUX_REG));
702                         else if (is_ereg(AUX_REG))
703                                 EMIT1(add_1mod(0x40, AUX_REG));
704                         /* mul(q) r11 */
705                         EMIT2(0xF7, add_1reg(0xE0, AUX_REG));
706
707                         if (dst_reg != BPF_REG_3)
708                                 EMIT1(0x5A); /* pop rdx */
709                         if (dst_reg != BPF_REG_0) {
710                                 /* mov dst_reg, rax */
711                                 EMIT_mov(dst_reg, BPF_REG_0);
712                                 EMIT1(0x58); /* pop rax */
713                         }
714                         break;
715                 }
716                         /* Shifts */
717                 case BPF_ALU | BPF_LSH | BPF_K:
718                 case BPF_ALU | BPF_RSH | BPF_K:
719                 case BPF_ALU | BPF_ARSH | BPF_K:
720                 case BPF_ALU64 | BPF_LSH | BPF_K:
721                 case BPF_ALU64 | BPF_RSH | BPF_K:
722                 case BPF_ALU64 | BPF_ARSH | BPF_K:
723                         if (BPF_CLASS(insn->code) == BPF_ALU64)
724                                 EMIT1(add_1mod(0x48, dst_reg));
725                         else if (is_ereg(dst_reg))
726                                 EMIT1(add_1mod(0x40, dst_reg));
727
728                         switch (BPF_OP(insn->code)) {
729                         case BPF_LSH: b3 = 0xE0; break;
730                         case BPF_RSH: b3 = 0xE8; break;
731                         case BPF_ARSH: b3 = 0xF8; break;
732                         }
733
734                         if (imm32 == 1)
735                                 EMIT2(0xD1, add_1reg(b3, dst_reg));
736                         else
737                                 EMIT3(0xC1, add_1reg(b3, dst_reg), imm32);
738                         break;
739
740                 case BPF_ALU | BPF_LSH | BPF_X:
741                 case BPF_ALU | BPF_RSH | BPF_X:
742                 case BPF_ALU | BPF_ARSH | BPF_X:
743                 case BPF_ALU64 | BPF_LSH | BPF_X:
744                 case BPF_ALU64 | BPF_RSH | BPF_X:
745                 case BPF_ALU64 | BPF_ARSH | BPF_X:
746
747                         /* Check for bad case when dst_reg == rcx */
748                         if (dst_reg == BPF_REG_4) {
749                                 /* mov r11, dst_reg */
750                                 EMIT_mov(AUX_REG, dst_reg);
751                                 dst_reg = AUX_REG;
752                         }
753
754                         if (src_reg != BPF_REG_4) { /* common case */
755                                 EMIT1(0x51); /* push rcx */
756
757                                 /* mov rcx, src_reg */
758                                 EMIT_mov(BPF_REG_4, src_reg);
759                         }
760
761                         /* shl %rax, %cl | shr %rax, %cl | sar %rax, %cl */
762                         if (BPF_CLASS(insn->code) == BPF_ALU64)
763                                 EMIT1(add_1mod(0x48, dst_reg));
764                         else if (is_ereg(dst_reg))
765                                 EMIT1(add_1mod(0x40, dst_reg));
766
767                         switch (BPF_OP(insn->code)) {
768                         case BPF_LSH: b3 = 0xE0; break;
769                         case BPF_RSH: b3 = 0xE8; break;
770                         case BPF_ARSH: b3 = 0xF8; break;
771                         }
772                         EMIT2(0xD3, add_1reg(b3, dst_reg));
773
774                         if (src_reg != BPF_REG_4)
775                                 EMIT1(0x59); /* pop rcx */
776
777                         if (insn->dst_reg == BPF_REG_4)
778                                 /* mov dst_reg, r11 */
779                                 EMIT_mov(insn->dst_reg, AUX_REG);
780                         break;
781
782                 case BPF_ALU | BPF_END | BPF_FROM_BE:
783                         switch (imm32) {
784                         case 16:
785                                 /* Emit 'ror %ax, 8' to swap lower 2 bytes */
786                                 EMIT1(0x66);
787                                 if (is_ereg(dst_reg))
788                                         EMIT1(0x41);
789                                 EMIT3(0xC1, add_1reg(0xC8, dst_reg), 8);
790
791                                 /* Emit 'movzwl eax, ax' */
792                                 if (is_ereg(dst_reg))
793                                         EMIT3(0x45, 0x0F, 0xB7);
794                                 else
795                                         EMIT2(0x0F, 0xB7);
796                                 EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
797                                 break;
798                         case 32:
799                                 /* Emit 'bswap eax' to swap lower 4 bytes */
800                                 if (is_ereg(dst_reg))
801                                         EMIT2(0x41, 0x0F);
802                                 else
803                                         EMIT1(0x0F);
804                                 EMIT1(add_1reg(0xC8, dst_reg));
805                                 break;
806                         case 64:
807                                 /* Emit 'bswap rax' to swap 8 bytes */
808                                 EMIT3(add_1mod(0x48, dst_reg), 0x0F,
809                                       add_1reg(0xC8, dst_reg));
810                                 break;
811                         }
812                         break;
813
814                 case BPF_ALU | BPF_END | BPF_FROM_LE:
815                         switch (imm32) {
816                         case 16:
817                                 /*
818                                  * Emit 'movzwl eax, ax' to zero extend 16-bit
819                                  * into 64 bit
820                                  */
821                                 if (is_ereg(dst_reg))
822                                         EMIT3(0x45, 0x0F, 0xB7);
823                                 else
824                                         EMIT2(0x0F, 0xB7);
825                                 EMIT1(add_2reg(0xC0, dst_reg, dst_reg));
826                                 break;
827                         case 32:
828                                 /* Emit 'mov eax, eax' to clear upper 32-bits */
829                                 if (is_ereg(dst_reg))
830                                         EMIT1(0x45);
831                                 EMIT2(0x89, add_2reg(0xC0, dst_reg, dst_reg));
832                                 break;
833                         case 64:
834                                 /* nop */
835                                 break;
836                         }
837                         break;
838
839                         /* ST: *(u8*)(dst_reg + off) = imm */
840                 case BPF_ST | BPF_MEM | BPF_B:
841                         if (is_ereg(dst_reg))
842                                 EMIT2(0x41, 0xC6);
843                         else
844                                 EMIT1(0xC6);
845                         goto st;
846                 case BPF_ST | BPF_MEM | BPF_H:
847                         if (is_ereg(dst_reg))
848                                 EMIT3(0x66, 0x41, 0xC7);
849                         else
850                                 EMIT2(0x66, 0xC7);
851                         goto st;
852                 case BPF_ST | BPF_MEM | BPF_W:
853                         if (is_ereg(dst_reg))
854                                 EMIT2(0x41, 0xC7);
855                         else
856                                 EMIT1(0xC7);
857                         goto st;
858                 case BPF_ST | BPF_MEM | BPF_DW:
859                         EMIT2(add_1mod(0x48, dst_reg), 0xC7);
860
861 st:                     if (is_imm8(insn->off))
862                                 EMIT2(add_1reg(0x40, dst_reg), insn->off);
863                         else
864                                 EMIT1_off32(add_1reg(0x80, dst_reg), insn->off);
865
866                         EMIT(imm32, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
867                         break;
868
869                         /* STX: *(u8*)(dst_reg + off) = src_reg */
870                 case BPF_STX | BPF_MEM | BPF_B:
871                 case BPF_STX | BPF_MEM | BPF_H:
872                 case BPF_STX | BPF_MEM | BPF_W:
873                 case BPF_STX | BPF_MEM | BPF_DW:
874                         emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
875                         break;
876
877                         /* LDX: dst_reg = *(u8*)(src_reg + off) */
878                 case BPF_LDX | BPF_MEM | BPF_B:
879                 case BPF_LDX | BPF_PROBE_MEM | BPF_B:
880                 case BPF_LDX | BPF_MEM | BPF_H:
881                 case BPF_LDX | BPF_PROBE_MEM | BPF_H:
882                 case BPF_LDX | BPF_MEM | BPF_W:
883                 case BPF_LDX | BPF_PROBE_MEM | BPF_W:
884                 case BPF_LDX | BPF_MEM | BPF_DW:
885                 case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
886                         emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn->off);
887                         if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
888                                 struct exception_table_entry *ex;
889                                 u8 *_insn = image + proglen;
890                                 s64 delta;
891
892                                 if (!bpf_prog->aux->extable)
893                                         break;
894
895                                 if (excnt >= bpf_prog->aux->num_exentries) {
896                                         pr_err("ex gen bug\n");
897                                         return -EFAULT;
898                                 }
899                                 ex = &bpf_prog->aux->extable[excnt++];
900
901                                 delta = _insn - (u8 *)&ex->insn;
902                                 if (!is_simm32(delta)) {
903                                         pr_err("extable->insn doesn't fit into 32-bit\n");
904                                         return -EFAULT;
905                                 }
906                                 ex->insn = delta;
907
908                                 delta = (u8 *)ex_handler_bpf - (u8 *)&ex->handler;
909                                 if (!is_simm32(delta)) {
910                                         pr_err("extable->handler doesn't fit into 32-bit\n");
911                                         return -EFAULT;
912                                 }
913                                 ex->handler = delta;
914
915                                 if (dst_reg > BPF_REG_9) {
916                                         pr_err("verifier error\n");
917                                         return -EFAULT;
918                                 }
919                                 /*
920                                  * Compute size of x86 insn and its target dest x86 register.
921                                  * ex_handler_bpf() will use lower 8 bits to adjust
922                                  * pt_regs->ip to jump over this x86 instruction
923                                  * and upper bits to figure out which pt_regs to zero out.
924                                  * End result: x86 insn "mov rbx, qword ptr [rax+0x14]"
925                                  * of 4 bytes will be ignored and rbx will be zero inited.
926                                  */
927                                 ex->fixup = (prog - temp) | (reg2pt_regs[dst_reg] << 8);
928                         }
929                         break;
930
931                         /* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
932                 case BPF_STX | BPF_XADD | BPF_W:
933                         /* Emit 'lock add dword ptr [rax + off], eax' */
934                         if (is_ereg(dst_reg) || is_ereg(src_reg))
935                                 EMIT3(0xF0, add_2mod(0x40, dst_reg, src_reg), 0x01);
936                         else
937                                 EMIT2(0xF0, 0x01);
938                         goto xadd;
939                 case BPF_STX | BPF_XADD | BPF_DW:
940                         EMIT3(0xF0, add_2mod(0x48, dst_reg, src_reg), 0x01);
941 xadd:                   if (is_imm8(insn->off))
942                                 EMIT2(add_2reg(0x40, dst_reg, src_reg), insn->off);
943                         else
944                                 EMIT1_off32(add_2reg(0x80, dst_reg, src_reg),
945                                             insn->off);
946                         break;
947
948                         /* call */
949                 case BPF_JMP | BPF_CALL:
950                         func = (u8 *) __bpf_call_base + imm32;
951                         if (!imm32 || emit_call(&prog, func, image + addrs[i - 1]))
952                                 return -EINVAL;
953                         break;
954
955                 case BPF_JMP | BPF_TAIL_CALL:
956                         emit_bpf_tail_call(&prog);
957                         break;
958
959                         /* cond jump */
960                 case BPF_JMP | BPF_JEQ | BPF_X:
961                 case BPF_JMP | BPF_JNE | BPF_X:
962                 case BPF_JMP | BPF_JGT | BPF_X:
963                 case BPF_JMP | BPF_JLT | BPF_X:
964                 case BPF_JMP | BPF_JGE | BPF_X:
965                 case BPF_JMP | BPF_JLE | BPF_X:
966                 case BPF_JMP | BPF_JSGT | BPF_X:
967                 case BPF_JMP | BPF_JSLT | BPF_X:
968                 case BPF_JMP | BPF_JSGE | BPF_X:
969                 case BPF_JMP | BPF_JSLE | BPF_X:
970                 case BPF_JMP32 | BPF_JEQ | BPF_X:
971                 case BPF_JMP32 | BPF_JNE | BPF_X:
972                 case BPF_JMP32 | BPF_JGT | BPF_X:
973                 case BPF_JMP32 | BPF_JLT | BPF_X:
974                 case BPF_JMP32 | BPF_JGE | BPF_X:
975                 case BPF_JMP32 | BPF_JLE | BPF_X:
976                 case BPF_JMP32 | BPF_JSGT | BPF_X:
977                 case BPF_JMP32 | BPF_JSLT | BPF_X:
978                 case BPF_JMP32 | BPF_JSGE | BPF_X:
979                 case BPF_JMP32 | BPF_JSLE | BPF_X:
980                         /* cmp dst_reg, src_reg */
981                         if (BPF_CLASS(insn->code) == BPF_JMP)
982                                 EMIT1(add_2mod(0x48, dst_reg, src_reg));
983                         else if (is_ereg(dst_reg) || is_ereg(src_reg))
984                                 EMIT1(add_2mod(0x40, dst_reg, src_reg));
985                         EMIT2(0x39, add_2reg(0xC0, dst_reg, src_reg));
986                         goto emit_cond_jmp;
987
988                 case BPF_JMP | BPF_JSET | BPF_X:
989                 case BPF_JMP32 | BPF_JSET | BPF_X:
990                         /* test dst_reg, src_reg */
991                         if (BPF_CLASS(insn->code) == BPF_JMP)
992                                 EMIT1(add_2mod(0x48, dst_reg, src_reg));
993                         else if (is_ereg(dst_reg) || is_ereg(src_reg))
994                                 EMIT1(add_2mod(0x40, dst_reg, src_reg));
995                         EMIT2(0x85, add_2reg(0xC0, dst_reg, src_reg));
996                         goto emit_cond_jmp;
997
998                 case BPF_JMP | BPF_JSET | BPF_K:
999                 case BPF_JMP32 | BPF_JSET | BPF_K:
1000                         /* test dst_reg, imm32 */
1001                         if (BPF_CLASS(insn->code) == BPF_JMP)
1002                                 EMIT1(add_1mod(0x48, dst_reg));
1003                         else if (is_ereg(dst_reg))
1004                                 EMIT1(add_1mod(0x40, dst_reg));
1005                         EMIT2_off32(0xF7, add_1reg(0xC0, dst_reg), imm32);
1006                         goto emit_cond_jmp;
1007
1008                 case BPF_JMP | BPF_JEQ | BPF_K:
1009                 case BPF_JMP | BPF_JNE | BPF_K:
1010                 case BPF_JMP | BPF_JGT | BPF_K:
1011                 case BPF_JMP | BPF_JLT | BPF_K:
1012                 case BPF_JMP | BPF_JGE | BPF_K:
1013                 case BPF_JMP | BPF_JLE | BPF_K:
1014                 case BPF_JMP | BPF_JSGT | BPF_K:
1015                 case BPF_JMP | BPF_JSLT | BPF_K:
1016                 case BPF_JMP | BPF_JSGE | BPF_K:
1017                 case BPF_JMP | BPF_JSLE | BPF_K:
1018                 case BPF_JMP32 | BPF_JEQ | BPF_K:
1019                 case BPF_JMP32 | BPF_JNE | BPF_K:
1020                 case BPF_JMP32 | BPF_JGT | BPF_K:
1021                 case BPF_JMP32 | BPF_JLT | BPF_K:
1022                 case BPF_JMP32 | BPF_JGE | BPF_K:
1023                 case BPF_JMP32 | BPF_JLE | BPF_K:
1024                 case BPF_JMP32 | BPF_JSGT | BPF_K:
1025                 case BPF_JMP32 | BPF_JSLT | BPF_K:
1026                 case BPF_JMP32 | BPF_JSGE | BPF_K:
1027                 case BPF_JMP32 | BPF_JSLE | BPF_K:
1028                         /* test dst_reg, dst_reg to save one extra byte */
1029                         if (imm32 == 0) {
1030                                 if (BPF_CLASS(insn->code) == BPF_JMP)
1031                                         EMIT1(add_2mod(0x48, dst_reg, dst_reg));
1032                                 else if (is_ereg(dst_reg))
1033                                         EMIT1(add_2mod(0x40, dst_reg, dst_reg));
1034                                 EMIT2(0x85, add_2reg(0xC0, dst_reg, dst_reg));
1035                                 goto emit_cond_jmp;
1036                         }
1037
1038                         /* cmp dst_reg, imm8/32 */
1039                         if (BPF_CLASS(insn->code) == BPF_JMP)
1040                                 EMIT1(add_1mod(0x48, dst_reg));
1041                         else if (is_ereg(dst_reg))
1042                                 EMIT1(add_1mod(0x40, dst_reg));
1043
1044                         if (is_imm8(imm32))
1045                                 EMIT3(0x83, add_1reg(0xF8, dst_reg), imm32);
1046                         else
1047                                 EMIT2_off32(0x81, add_1reg(0xF8, dst_reg), imm32);
1048
1049 emit_cond_jmp:          /* Convert BPF opcode to x86 */
1050                         switch (BPF_OP(insn->code)) {
1051                         case BPF_JEQ:
1052                                 jmp_cond = X86_JE;
1053                                 break;
1054                         case BPF_JSET:
1055                         case BPF_JNE:
1056                                 jmp_cond = X86_JNE;
1057                                 break;
1058                         case BPF_JGT:
1059                                 /* GT is unsigned '>', JA in x86 */
1060                                 jmp_cond = X86_JA;
1061                                 break;
1062                         case BPF_JLT:
1063                                 /* LT is unsigned '<', JB in x86 */
1064                                 jmp_cond = X86_JB;
1065                                 break;
1066                         case BPF_JGE:
1067                                 /* GE is unsigned '>=', JAE in x86 */
1068                                 jmp_cond = X86_JAE;
1069                                 break;
1070                         case BPF_JLE:
1071                                 /* LE is unsigned '<=', JBE in x86 */
1072                                 jmp_cond = X86_JBE;
1073                                 break;
1074                         case BPF_JSGT:
1075                                 /* Signed '>', GT in x86 */
1076                                 jmp_cond = X86_JG;
1077                                 break;
1078                         case BPF_JSLT:
1079                                 /* Signed '<', LT in x86 */
1080                                 jmp_cond = X86_JL;
1081                                 break;
1082                         case BPF_JSGE:
1083                                 /* Signed '>=', GE in x86 */
1084                                 jmp_cond = X86_JGE;
1085                                 break;
1086                         case BPF_JSLE:
1087                                 /* Signed '<=', LE in x86 */
1088                                 jmp_cond = X86_JLE;
1089                                 break;
1090                         default: /* to silence GCC warning */
1091                                 return -EFAULT;
1092                         }
1093                         jmp_offset = addrs[i + insn->off] - addrs[i];
1094                         if (is_imm8(jmp_offset)) {
1095                                 EMIT2(jmp_cond, jmp_offset);
1096                         } else if (is_simm32(jmp_offset)) {
1097                                 EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
1098                         } else {
1099                                 pr_err("cond_jmp gen bug %llx\n", jmp_offset);
1100                                 return -EFAULT;
1101                         }
1102
1103                         break;
1104
1105                 case BPF_JMP | BPF_JA:
1106                         if (insn->off == -1)
1107                                 /* -1 jmp instructions will always jump
1108                                  * backwards two bytes. Explicitly handling
1109                                  * this case avoids wasting too many passes
1110                                  * when there are long sequences of replaced
1111                                  * dead code.
1112                                  */
1113                                 jmp_offset = -2;
1114                         else
1115                                 jmp_offset = addrs[i + insn->off] - addrs[i];
1116
1117                         if (!jmp_offset)
1118                                 /* Optimize out nop jumps */
1119                                 break;
1120 emit_jmp:
1121                         if (is_imm8(jmp_offset)) {
1122                                 EMIT2(0xEB, jmp_offset);
1123                         } else if (is_simm32(jmp_offset)) {
1124                                 EMIT1_off32(0xE9, jmp_offset);
1125                         } else {
1126                                 pr_err("jmp gen bug %llx\n", jmp_offset);
1127                                 return -EFAULT;
1128                         }
1129                         break;
1130
1131                 case BPF_JMP | BPF_EXIT:
1132                         if (seen_exit) {
1133                                 jmp_offset = ctx->cleanup_addr - addrs[i];
1134                                 goto emit_jmp;
1135                         }
1136                         seen_exit = true;
1137                         /* Update cleanup_addr */
1138                         ctx->cleanup_addr = proglen;
1139                         if (!bpf_prog_was_classic(bpf_prog))
1140                                 EMIT1(0x5B); /* get rid of tail_call_cnt */
1141                         EMIT2(0x41, 0x5F);   /* pop r15 */
1142                         EMIT2(0x41, 0x5E);   /* pop r14 */
1143                         EMIT2(0x41, 0x5D);   /* pop r13 */
1144                         EMIT1(0x5B);         /* pop rbx */
1145                         EMIT1(0xC9);         /* leave */
1146                         EMIT1(0xC3);         /* ret */
1147                         break;
1148
1149                 default:
1150                         /*
1151                          * By design x86-64 JIT should support all BPF instructions.
1152                          * This error will be seen if new instruction was added
1153                          * to the interpreter, but not to the JIT, or if there is
1154                          * junk in bpf_prog.
1155                          */
1156                         pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
1157                         return -EINVAL;
1158                 }
1159
1160                 ilen = prog - temp;
1161                 if (ilen > BPF_MAX_INSN_SIZE) {
1162                         pr_err("bpf_jit: fatal insn size error\n");
1163                         return -EFAULT;
1164                 }
1165
1166                 if (image) {
1167                         if (unlikely(proglen + ilen > oldproglen)) {
1168                                 pr_err("bpf_jit: fatal error\n");
1169                                 return -EFAULT;
1170                         }
1171                         memcpy(image + proglen, temp, ilen);
1172                 }
1173                 proglen += ilen;
1174                 addrs[i] = proglen;
1175                 prog = temp;
1176         }
1177
1178         if (image && excnt != bpf_prog->aux->num_exentries) {
1179                 pr_err("extable is not populated\n");
1180                 return -EFAULT;
1181         }
1182         return proglen;
1183 }
1184
1185 struct x64_jit_data {
1186         struct bpf_binary_header *header;
1187         int *addrs;
1188         u8 *image;
1189         int proglen;
1190         struct jit_context ctx;
1191 };
1192
1193 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
1194 {
1195         struct bpf_binary_header *header = NULL;
1196         struct bpf_prog *tmp, *orig_prog = prog;
1197         struct x64_jit_data *jit_data;
1198         int proglen, oldproglen = 0;
1199         struct jit_context ctx = {};
1200         bool tmp_blinded = false;
1201         bool extra_pass = false;
1202         u8 *image = NULL;
1203         int *addrs;
1204         int pass;
1205         int i;
1206
1207         if (!prog->jit_requested)
1208                 return orig_prog;
1209
1210         tmp = bpf_jit_blind_constants(prog);
1211         /*
1212          * If blinding was requested and we failed during blinding,
1213          * we must fall back to the interpreter.
1214          */
1215         if (IS_ERR(tmp))
1216                 return orig_prog;
1217         if (tmp != prog) {
1218                 tmp_blinded = true;
1219                 prog = tmp;
1220         }
1221
1222         jit_data = prog->aux->jit_data;
1223         if (!jit_data) {
1224                 jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
1225                 if (!jit_data) {
1226                         prog = orig_prog;
1227                         goto out;
1228                 }
1229                 prog->aux->jit_data = jit_data;
1230         }
1231         addrs = jit_data->addrs;
1232         if (addrs) {
1233                 ctx = jit_data->ctx;
1234                 oldproglen = jit_data->proglen;
1235                 image = jit_data->image;
1236                 header = jit_data->header;
1237                 extra_pass = true;
1238                 goto skip_init_addrs;
1239         }
1240         addrs = kmalloc_array(prog->len + 1, sizeof(*addrs), GFP_KERNEL);
1241         if (!addrs) {
1242                 prog = orig_prog;
1243                 goto out_addrs;
1244         }
1245
1246         /*
1247          * Before first pass, make a rough estimation of addrs[]
1248          * each BPF instruction is translated to less than 64 bytes
1249          */
1250         for (proglen = 0, i = 0; i <= prog->len; i++) {
1251                 proglen += 64;
1252                 addrs[i] = proglen;
1253         }
1254         ctx.cleanup_addr = proglen;
1255 skip_init_addrs:
1256
1257         /*
1258          * JITed image shrinks with every pass and the loop iterates
1259          * until the image stops shrinking. Very large BPF programs
1260          * may converge on the last pass. In such case do one more
1261          * pass to emit the final image.
1262          */
1263         for (pass = 0; pass < 20 || image; pass++) {
1264                 proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
1265                 if (proglen <= 0) {
1266 out_image:
1267                         image = NULL;
1268                         if (header)
1269                                 bpf_jit_binary_free(header);
1270                         prog = orig_prog;
1271                         goto out_addrs;
1272                 }
1273                 if (image) {
1274                         if (proglen != oldproglen) {
1275                                 pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
1276                                        proglen, oldproglen);
1277                                 goto out_image;
1278                         }
1279                         break;
1280                 }
1281                 if (proglen == oldproglen) {
1282                         /*
1283                          * The number of entries in extable is the number of BPF_LDX
1284                          * insns that access kernel memory via "pointer to BTF type".
1285                          * The verifier changed their opcode from LDX|MEM|size
1286                          * to LDX|PROBE_MEM|size to make JITing easier.
1287                          */
1288                         u32 align = __alignof__(struct exception_table_entry);
1289                         u32 extable_size = prog->aux->num_exentries *
1290                                 sizeof(struct exception_table_entry);
1291
1292                         /* allocate module memory for x86 insns and extable */
1293                         header = bpf_jit_binary_alloc(roundup(proglen, align) + extable_size,
1294                                                       &image, align, jit_fill_hole);
1295                         if (!header) {
1296                                 prog = orig_prog;
1297                                 goto out_addrs;
1298                         }
1299                         prog->aux->extable = (void *) image + roundup(proglen, align);
1300                 }
1301                 oldproglen = proglen;
1302                 cond_resched();
1303         }
1304
1305         if (bpf_jit_enable > 1)
1306                 bpf_jit_dump(prog->len, proglen, pass + 1, image);
1307
1308         if (image) {
1309                 if (!prog->is_func || extra_pass) {
1310                         bpf_jit_binary_lock_ro(header);
1311                 } else {
1312                         jit_data->addrs = addrs;
1313                         jit_data->ctx = ctx;
1314                         jit_data->proglen = proglen;
1315                         jit_data->image = image;
1316                         jit_data->header = header;
1317                 }
1318                 prog->bpf_func = (void *)image;
1319                 prog->jited = 1;
1320                 prog->jited_len = proglen;
1321         } else {
1322                 prog = orig_prog;
1323         }
1324
1325         if (!image || !prog->is_func || extra_pass) {
1326                 if (image)
1327                         bpf_prog_fill_jited_linfo(prog, addrs + 1);
1328 out_addrs:
1329                 kfree(addrs);
1330                 kfree(jit_data);
1331                 prog->aux->jit_data = NULL;
1332         }
1333 out:
1334         if (tmp_blinded)
1335                 bpf_jit_prog_release_other(prog, prog == orig_prog ?
1336                                            tmp : orig_prog);
1337         return prog;
1338 }