Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[linux-2.6-block.git] / arch / x86 / net / bpf_jit_comp.c
1 /* bpf_jit_comp.c : BPF JIT compiler
2  *
3  * Copyright (C) 2011-2013 Eric Dumazet (eric.dumazet@gmail.com)
4  * Internal BPF Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public License
8  * as published by the Free Software Foundation; version 2
9  * of the License.
10  */
11 #include <linux/moduleloader.h>
12 #include <asm/cacheflush.h>
13 #include <linux/netdevice.h>
14 #include <linux/filter.h>
15 #include <linux/if_vlan.h>
16 #include <linux/random.h>
17
18 int bpf_jit_enable __read_mostly;
19
20 /*
21  * assembly code in arch/x86/net/bpf_jit.S
22  */
23 extern u8 sk_load_word[], sk_load_half[], sk_load_byte[];
24 extern u8 sk_load_word_positive_offset[], sk_load_half_positive_offset[];
25 extern u8 sk_load_byte_positive_offset[];
26 extern u8 sk_load_word_negative_offset[], sk_load_half_negative_offset[];
27 extern u8 sk_load_byte_negative_offset[];
28
29 static inline u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
30 {
31         if (len == 1)
32                 *ptr = bytes;
33         else if (len == 2)
34                 *(u16 *)ptr = bytes;
35         else {
36                 *(u32 *)ptr = bytes;
37                 barrier();
38         }
39         return ptr + len;
40 }
41
42 #define EMIT(bytes, len)        do { prog = emit_code(prog, bytes, len); } while (0)
43
44 #define EMIT1(b1)               EMIT(b1, 1)
45 #define EMIT2(b1, b2)           EMIT((b1) + ((b2) << 8), 2)
46 #define EMIT3(b1, b2, b3)       EMIT((b1) + ((b2) << 8) + ((b3) << 16), 3)
47 #define EMIT4(b1, b2, b3, b4)   EMIT((b1) + ((b2) << 8) + ((b3) << 16) + ((b4) << 24), 4)
48 #define EMIT1_off32(b1, off) \
49         do {EMIT1(b1); EMIT(off, 4); } while (0)
50 #define EMIT2_off32(b1, b2, off) \
51         do {EMIT2(b1, b2); EMIT(off, 4); } while (0)
52 #define EMIT3_off32(b1, b2, b3, off) \
53         do {EMIT3(b1, b2, b3); EMIT(off, 4); } while (0)
54 #define EMIT4_off32(b1, b2, b3, b4, off) \
55         do {EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0)
56
57 static inline bool is_imm8(int value)
58 {
59         return value <= 127 && value >= -128;
60 }
61
62 static inline bool is_simm32(s64 value)
63 {
64         return value == (s64) (s32) value;
65 }
66
67 /* mov A, X */
68 #define EMIT_mov(A, X) \
69         do {if (A != X) \
70                 EMIT3(add_2mod(0x48, A, X), 0x89, add_2reg(0xC0, A, X)); \
71         } while (0)
72
73 static int bpf_size_to_x86_bytes(int bpf_size)
74 {
75         if (bpf_size == BPF_W)
76                 return 4;
77         else if (bpf_size == BPF_H)
78                 return 2;
79         else if (bpf_size == BPF_B)
80                 return 1;
81         else if (bpf_size == BPF_DW)
82                 return 4; /* imm32 */
83         else
84                 return 0;
85 }
86
87 /* list of x86 cond jumps opcodes (. + s8)
88  * Add 0x10 (and an extra 0x0f) to generate far jumps (. + s32)
89  */
90 #define X86_JB  0x72
91 #define X86_JAE 0x73
92 #define X86_JE  0x74
93 #define X86_JNE 0x75
94 #define X86_JBE 0x76
95 #define X86_JA  0x77
96 #define X86_JGE 0x7D
97 #define X86_JG  0x7F
98
99 static inline void bpf_flush_icache(void *start, void *end)
100 {
101         mm_segment_t old_fs = get_fs();
102
103         set_fs(KERNEL_DS);
104         smp_wmb();
105         flush_icache_range((unsigned long)start, (unsigned long)end);
106         set_fs(old_fs);
107 }
108
109 #define CHOOSE_LOAD_FUNC(K, func) \
110         ((int)K < 0 ? ((int)K >= SKF_LL_OFF ? func##_negative_offset : func) : func##_positive_offset)
111
112 struct bpf_binary_header {
113         unsigned int    pages;
114         /* Note : for security reasons, bpf code will follow a randomly
115          * sized amount of int3 instructions
116          */
117         u8              image[];
118 };
119
120 static struct bpf_binary_header *bpf_alloc_binary(unsigned int proglen,
121                                                   u8 **image_ptr)
122 {
123         unsigned int sz, hole;
124         struct bpf_binary_header *header;
125
126         /* Most of BPF filters are really small,
127          * but if some of them fill a page, allow at least
128          * 128 extra bytes to insert a random section of int3
129          */
130         sz = round_up(proglen + sizeof(*header) + 128, PAGE_SIZE);
131         header = module_alloc(sz);
132         if (!header)
133                 return NULL;
134
135         memset(header, 0xcc, sz); /* fill whole space with int3 instructions */
136
137         header->pages = sz / PAGE_SIZE;
138         hole = min(sz - (proglen + sizeof(*header)), PAGE_SIZE - sizeof(*header));
139
140         /* insert a random number of int3 instructions before BPF code */
141         *image_ptr = &header->image[prandom_u32() % hole];
142         return header;
143 }
144
145 /* pick a register outside of BPF range for JIT internal work */
146 #define AUX_REG (MAX_BPF_REG + 1)
147
148 /* the following table maps BPF registers to x64 registers.
149  * x64 register r12 is unused, since if used as base address register
150  * in load/store instructions, it always needs an extra byte of encoding
151  */
152 static const int reg2hex[] = {
153         [BPF_REG_0] = 0,  /* rax */
154         [BPF_REG_1] = 7,  /* rdi */
155         [BPF_REG_2] = 6,  /* rsi */
156         [BPF_REG_3] = 2,  /* rdx */
157         [BPF_REG_4] = 1,  /* rcx */
158         [BPF_REG_5] = 0,  /* r8 */
159         [BPF_REG_6] = 3,  /* rbx callee saved */
160         [BPF_REG_7] = 5,  /* r13 callee saved */
161         [BPF_REG_8] = 6,  /* r14 callee saved */
162         [BPF_REG_9] = 7,  /* r15 callee saved */
163         [BPF_REG_FP] = 5, /* rbp readonly */
164         [AUX_REG] = 3,    /* r11 temp register */
165 };
166
167 /* is_ereg() == true if BPF register 'reg' maps to x64 r8..r15
168  * which need extra byte of encoding.
169  * rax,rcx,...,rbp have simpler encoding
170  */
171 static inline bool is_ereg(u32 reg)
172 {
173         if (reg == BPF_REG_5 || reg == AUX_REG ||
174             (reg >= BPF_REG_7 && reg <= BPF_REG_9))
175                 return true;
176         else
177                 return false;
178 }
179
180 /* add modifiers if 'reg' maps to x64 registers r8..r15 */
181 static inline u8 add_1mod(u8 byte, u32 reg)
182 {
183         if (is_ereg(reg))
184                 byte |= 1;
185         return byte;
186 }
187
188 static inline u8 add_2mod(u8 byte, u32 r1, u32 r2)
189 {
190         if (is_ereg(r1))
191                 byte |= 1;
192         if (is_ereg(r2))
193                 byte |= 4;
194         return byte;
195 }
196
197 /* encode dest register 'a_reg' into x64 opcode 'byte' */
198 static inline u8 add_1reg(u8 byte, u32 a_reg)
199 {
200         return byte + reg2hex[a_reg];
201 }
202
203 /* encode dest 'a_reg' and src 'x_reg' registers into x64 opcode 'byte' */
204 static inline u8 add_2reg(u8 byte, u32 a_reg, u32 x_reg)
205 {
206         return byte + reg2hex[a_reg] + (reg2hex[x_reg] << 3);
207 }
208
209 struct jit_context {
210         unsigned int cleanup_addr; /* epilogue code offset */
211         bool seen_ld_abs;
212 };
213
214 static int do_jit(struct sk_filter *bpf_prog, int *addrs, u8 *image,
215                   int oldproglen, struct jit_context *ctx)
216 {
217         struct sock_filter_int *insn = bpf_prog->insnsi;
218         int insn_cnt = bpf_prog->len;
219         u8 temp[64];
220         int i;
221         int proglen = 0;
222         u8 *prog = temp;
223         int stacksize = MAX_BPF_STACK +
224                 32 /* space for rbx, r13, r14, r15 */ +
225                 8 /* space for skb_copy_bits() buffer */;
226
227         EMIT1(0x55); /* push rbp */
228         EMIT3(0x48, 0x89, 0xE5); /* mov rbp,rsp */
229
230         /* sub rsp, stacksize */
231         EMIT3_off32(0x48, 0x81, 0xEC, stacksize);
232
233         /* all classic BPF filters use R6(rbx) save it */
234
235         /* mov qword ptr [rbp-X],rbx */
236         EMIT3_off32(0x48, 0x89, 0x9D, -stacksize);
237
238         /* sk_convert_filter() maps classic BPF register X to R7 and uses R8
239          * as temporary, so all tcpdump filters need to spill/fill R7(r13) and
240          * R8(r14). R9(r15) spill could be made conditional, but there is only
241          * one 'bpf_error' return path out of helper functions inside bpf_jit.S
242          * The overhead of extra spill is negligible for any filter other
243          * than synthetic ones. Therefore not worth adding complexity.
244          */
245
246         /* mov qword ptr [rbp-X],r13 */
247         EMIT3_off32(0x4C, 0x89, 0xAD, -stacksize + 8);
248         /* mov qword ptr [rbp-X],r14 */
249         EMIT3_off32(0x4C, 0x89, 0xB5, -stacksize + 16);
250         /* mov qword ptr [rbp-X],r15 */
251         EMIT3_off32(0x4C, 0x89, 0xBD, -stacksize + 24);
252
253         /* clear A and X registers */
254         EMIT2(0x31, 0xc0); /* xor eax, eax */
255         EMIT3(0x4D, 0x31, 0xED); /* xor r13, r13 */
256
257         if (ctx->seen_ld_abs) {
258                 /* r9d : skb->len - skb->data_len (headlen)
259                  * r10 : skb->data
260                  */
261                 if (is_imm8(offsetof(struct sk_buff, len)))
262                         /* mov %r9d, off8(%rdi) */
263                         EMIT4(0x44, 0x8b, 0x4f,
264                               offsetof(struct sk_buff, len));
265                 else
266                         /* mov %r9d, off32(%rdi) */
267                         EMIT3_off32(0x44, 0x8b, 0x8f,
268                                     offsetof(struct sk_buff, len));
269
270                 if (is_imm8(offsetof(struct sk_buff, data_len)))
271                         /* sub %r9d, off8(%rdi) */
272                         EMIT4(0x44, 0x2b, 0x4f,
273                               offsetof(struct sk_buff, data_len));
274                 else
275                         EMIT3_off32(0x44, 0x2b, 0x8f,
276                                     offsetof(struct sk_buff, data_len));
277
278                 if (is_imm8(offsetof(struct sk_buff, data)))
279                         /* mov %r10, off8(%rdi) */
280                         EMIT4(0x4c, 0x8b, 0x57,
281                               offsetof(struct sk_buff, data));
282                 else
283                         /* mov %r10, off32(%rdi) */
284                         EMIT3_off32(0x4c, 0x8b, 0x97,
285                                     offsetof(struct sk_buff, data));
286         }
287
288         for (i = 0; i < insn_cnt; i++, insn++) {
289                 const s32 K = insn->imm;
290                 u32 a_reg = insn->a_reg;
291                 u32 x_reg = insn->x_reg;
292                 u8 b1 = 0, b2 = 0, b3 = 0;
293                 s64 jmp_offset;
294                 u8 jmp_cond;
295                 int ilen;
296                 u8 *func;
297
298                 switch (insn->code) {
299                         /* ALU */
300                 case BPF_ALU | BPF_ADD | BPF_X:
301                 case BPF_ALU | BPF_SUB | BPF_X:
302                 case BPF_ALU | BPF_AND | BPF_X:
303                 case BPF_ALU | BPF_OR | BPF_X:
304                 case BPF_ALU | BPF_XOR | BPF_X:
305                 case BPF_ALU64 | BPF_ADD | BPF_X:
306                 case BPF_ALU64 | BPF_SUB | BPF_X:
307                 case BPF_ALU64 | BPF_AND | BPF_X:
308                 case BPF_ALU64 | BPF_OR | BPF_X:
309                 case BPF_ALU64 | BPF_XOR | BPF_X:
310                         switch (BPF_OP(insn->code)) {
311                         case BPF_ADD: b2 = 0x01; break;
312                         case BPF_SUB: b2 = 0x29; break;
313                         case BPF_AND: b2 = 0x21; break;
314                         case BPF_OR: b2 = 0x09; break;
315                         case BPF_XOR: b2 = 0x31; break;
316                         }
317                         if (BPF_CLASS(insn->code) == BPF_ALU64)
318                                 EMIT1(add_2mod(0x48, a_reg, x_reg));
319                         else if (is_ereg(a_reg) || is_ereg(x_reg))
320                                 EMIT1(add_2mod(0x40, a_reg, x_reg));
321                         EMIT2(b2, add_2reg(0xC0, a_reg, x_reg));
322                         break;
323
324                         /* mov A, X */
325                 case BPF_ALU64 | BPF_MOV | BPF_X:
326                         EMIT_mov(a_reg, x_reg);
327                         break;
328
329                         /* mov32 A, X */
330                 case BPF_ALU | BPF_MOV | BPF_X:
331                         if (is_ereg(a_reg) || is_ereg(x_reg))
332                                 EMIT1(add_2mod(0x40, a_reg, x_reg));
333                         EMIT2(0x89, add_2reg(0xC0, a_reg, x_reg));
334                         break;
335
336                         /* neg A */
337                 case BPF_ALU | BPF_NEG:
338                 case BPF_ALU64 | BPF_NEG:
339                         if (BPF_CLASS(insn->code) == BPF_ALU64)
340                                 EMIT1(add_1mod(0x48, a_reg));
341                         else if (is_ereg(a_reg))
342                                 EMIT1(add_1mod(0x40, a_reg));
343                         EMIT2(0xF7, add_1reg(0xD8, a_reg));
344                         break;
345
346                 case BPF_ALU | BPF_ADD | BPF_K:
347                 case BPF_ALU | BPF_SUB | BPF_K:
348                 case BPF_ALU | BPF_AND | BPF_K:
349                 case BPF_ALU | BPF_OR | BPF_K:
350                 case BPF_ALU | BPF_XOR | BPF_K:
351                 case BPF_ALU64 | BPF_ADD | BPF_K:
352                 case BPF_ALU64 | BPF_SUB | BPF_K:
353                 case BPF_ALU64 | BPF_AND | BPF_K:
354                 case BPF_ALU64 | BPF_OR | BPF_K:
355                 case BPF_ALU64 | BPF_XOR | BPF_K:
356                         if (BPF_CLASS(insn->code) == BPF_ALU64)
357                                 EMIT1(add_1mod(0x48, a_reg));
358                         else if (is_ereg(a_reg))
359                                 EMIT1(add_1mod(0x40, a_reg));
360
361                         switch (BPF_OP(insn->code)) {
362                         case BPF_ADD: b3 = 0xC0; break;
363                         case BPF_SUB: b3 = 0xE8; break;
364                         case BPF_AND: b3 = 0xE0; break;
365                         case BPF_OR: b3 = 0xC8; break;
366                         case BPF_XOR: b3 = 0xF0; break;
367                         }
368
369                         if (is_imm8(K))
370                                 EMIT3(0x83, add_1reg(b3, a_reg), K);
371                         else
372                                 EMIT2_off32(0x81, add_1reg(b3, a_reg), K);
373                         break;
374
375                 case BPF_ALU64 | BPF_MOV | BPF_K:
376                         /* optimization: if imm32 is positive,
377                          * use 'mov eax, imm32' (which zero-extends imm32)
378                          * to save 2 bytes
379                          */
380                         if (K < 0) {
381                                 /* 'mov rax, imm32' sign extends imm32 */
382                                 b1 = add_1mod(0x48, a_reg);
383                                 b2 = 0xC7;
384                                 b3 = 0xC0;
385                                 EMIT3_off32(b1, b2, add_1reg(b3, a_reg), K);
386                                 break;
387                         }
388
389                 case BPF_ALU | BPF_MOV | BPF_K:
390                         /* mov %eax, imm32 */
391                         if (is_ereg(a_reg))
392                                 EMIT1(add_1mod(0x40, a_reg));
393                         EMIT1_off32(add_1reg(0xB8, a_reg), K);
394                         break;
395
396                         /* A %= X, A /= X, A %= K, A /= K */
397                 case BPF_ALU | BPF_MOD | BPF_X:
398                 case BPF_ALU | BPF_DIV | BPF_X:
399                 case BPF_ALU | BPF_MOD | BPF_K:
400                 case BPF_ALU | BPF_DIV | BPF_K:
401                 case BPF_ALU64 | BPF_MOD | BPF_X:
402                 case BPF_ALU64 | BPF_DIV | BPF_X:
403                 case BPF_ALU64 | BPF_MOD | BPF_K:
404                 case BPF_ALU64 | BPF_DIV | BPF_K:
405                         EMIT1(0x50); /* push rax */
406                         EMIT1(0x52); /* push rdx */
407
408                         if (BPF_SRC(insn->code) == BPF_X)
409                                 /* mov r11, X */
410                                 EMIT_mov(AUX_REG, x_reg);
411                         else
412                                 /* mov r11, K */
413                                 EMIT3_off32(0x49, 0xC7, 0xC3, K);
414
415                         /* mov rax, A */
416                         EMIT_mov(BPF_REG_0, a_reg);
417
418                         /* xor edx, edx
419                          * equivalent to 'xor rdx, rdx', but one byte less
420                          */
421                         EMIT2(0x31, 0xd2);
422
423                         if (BPF_SRC(insn->code) == BPF_X) {
424                                 /* if (X == 0) return 0 */
425
426                                 /* cmp r11, 0 */
427                                 EMIT4(0x49, 0x83, 0xFB, 0x00);
428
429                                 /* jne .+9 (skip over pop, pop, xor and jmp) */
430                                 EMIT2(X86_JNE, 1 + 1 + 2 + 5);
431                                 EMIT1(0x5A); /* pop rdx */
432                                 EMIT1(0x58); /* pop rax */
433                                 EMIT2(0x31, 0xc0); /* xor eax, eax */
434
435                                 /* jmp cleanup_addr
436                                  * addrs[i] - 11, because there are 11 bytes
437                                  * after this insn: div, mov, pop, pop, mov
438                                  */
439                                 jmp_offset = ctx->cleanup_addr - (addrs[i] - 11);
440                                 EMIT1_off32(0xE9, jmp_offset);
441                         }
442
443                         if (BPF_CLASS(insn->code) == BPF_ALU64)
444                                 /* div r11 */
445                                 EMIT3(0x49, 0xF7, 0xF3);
446                         else
447                                 /* div r11d */
448                                 EMIT3(0x41, 0xF7, 0xF3);
449
450                         if (BPF_OP(insn->code) == BPF_MOD)
451                                 /* mov r11, rdx */
452                                 EMIT3(0x49, 0x89, 0xD3);
453                         else
454                                 /* mov r11, rax */
455                                 EMIT3(0x49, 0x89, 0xC3);
456
457                         EMIT1(0x5A); /* pop rdx */
458                         EMIT1(0x58); /* pop rax */
459
460                         /* mov A, r11 */
461                         EMIT_mov(a_reg, AUX_REG);
462                         break;
463
464                 case BPF_ALU | BPF_MUL | BPF_K:
465                 case BPF_ALU | BPF_MUL | BPF_X:
466                 case BPF_ALU64 | BPF_MUL | BPF_K:
467                 case BPF_ALU64 | BPF_MUL | BPF_X:
468                         EMIT1(0x50); /* push rax */
469                         EMIT1(0x52); /* push rdx */
470
471                         /* mov r11, A */
472                         EMIT_mov(AUX_REG, a_reg);
473
474                         if (BPF_SRC(insn->code) == BPF_X)
475                                 /* mov rax, X */
476                                 EMIT_mov(BPF_REG_0, x_reg);
477                         else
478                                 /* mov rax, K */
479                                 EMIT3_off32(0x48, 0xC7, 0xC0, K);
480
481                         if (BPF_CLASS(insn->code) == BPF_ALU64)
482                                 EMIT1(add_1mod(0x48, AUX_REG));
483                         else if (is_ereg(AUX_REG))
484                                 EMIT1(add_1mod(0x40, AUX_REG));
485                         /* mul(q) r11 */
486                         EMIT2(0xF7, add_1reg(0xE0, AUX_REG));
487
488                         /* mov r11, rax */
489                         EMIT_mov(AUX_REG, BPF_REG_0);
490
491                         EMIT1(0x5A); /* pop rdx */
492                         EMIT1(0x58); /* pop rax */
493
494                         /* mov A, r11 */
495                         EMIT_mov(a_reg, AUX_REG);
496                         break;
497
498                         /* shifts */
499                 case BPF_ALU | BPF_LSH | BPF_K:
500                 case BPF_ALU | BPF_RSH | BPF_K:
501                 case BPF_ALU | BPF_ARSH | BPF_K:
502                 case BPF_ALU64 | BPF_LSH | BPF_K:
503                 case BPF_ALU64 | BPF_RSH | BPF_K:
504                 case BPF_ALU64 | BPF_ARSH | BPF_K:
505                         if (BPF_CLASS(insn->code) == BPF_ALU64)
506                                 EMIT1(add_1mod(0x48, a_reg));
507                         else if (is_ereg(a_reg))
508                                 EMIT1(add_1mod(0x40, a_reg));
509
510                         switch (BPF_OP(insn->code)) {
511                         case BPF_LSH: b3 = 0xE0; break;
512                         case BPF_RSH: b3 = 0xE8; break;
513                         case BPF_ARSH: b3 = 0xF8; break;
514                         }
515                         EMIT3(0xC1, add_1reg(b3, a_reg), K);
516                         break;
517
518                 case BPF_ALU | BPF_END | BPF_FROM_BE:
519                         switch (K) {
520                         case 16:
521                                 /* emit 'ror %ax, 8' to swap lower 2 bytes */
522                                 EMIT1(0x66);
523                                 if (is_ereg(a_reg))
524                                         EMIT1(0x41);
525                                 EMIT3(0xC1, add_1reg(0xC8, a_reg), 8);
526                                 break;
527                         case 32:
528                                 /* emit 'bswap eax' to swap lower 4 bytes */
529                                 if (is_ereg(a_reg))
530                                         EMIT2(0x41, 0x0F);
531                                 else
532                                         EMIT1(0x0F);
533                                 EMIT1(add_1reg(0xC8, a_reg));
534                                 break;
535                         case 64:
536                                 /* emit 'bswap rax' to swap 8 bytes */
537                                 EMIT3(add_1mod(0x48, a_reg), 0x0F,
538                                       add_1reg(0xC8, a_reg));
539                                 break;
540                         }
541                         break;
542
543                 case BPF_ALU | BPF_END | BPF_FROM_LE:
544                         break;
545
546                         /* ST: *(u8*)(a_reg + off) = imm */
547                 case BPF_ST | BPF_MEM | BPF_B:
548                         if (is_ereg(a_reg))
549                                 EMIT2(0x41, 0xC6);
550                         else
551                                 EMIT1(0xC6);
552                         goto st;
553                 case BPF_ST | BPF_MEM | BPF_H:
554                         if (is_ereg(a_reg))
555                                 EMIT3(0x66, 0x41, 0xC7);
556                         else
557                                 EMIT2(0x66, 0xC7);
558                         goto st;
559                 case BPF_ST | BPF_MEM | BPF_W:
560                         if (is_ereg(a_reg))
561                                 EMIT2(0x41, 0xC7);
562                         else
563                                 EMIT1(0xC7);
564                         goto st;
565                 case BPF_ST | BPF_MEM | BPF_DW:
566                         EMIT2(add_1mod(0x48, a_reg), 0xC7);
567
568 st:                     if (is_imm8(insn->off))
569                                 EMIT2(add_1reg(0x40, a_reg), insn->off);
570                         else
571                                 EMIT1_off32(add_1reg(0x80, a_reg), insn->off);
572
573                         EMIT(K, bpf_size_to_x86_bytes(BPF_SIZE(insn->code)));
574                         break;
575
576                         /* STX: *(u8*)(a_reg + off) = x_reg */
577                 case BPF_STX | BPF_MEM | BPF_B:
578                         /* emit 'mov byte ptr [rax + off], al' */
579                         if (is_ereg(a_reg) || is_ereg(x_reg) ||
580                             /* have to add extra byte for x86 SIL, DIL regs */
581                             x_reg == BPF_REG_1 || x_reg == BPF_REG_2)
582                                 EMIT2(add_2mod(0x40, a_reg, x_reg), 0x88);
583                         else
584                                 EMIT1(0x88);
585                         goto stx;
586                 case BPF_STX | BPF_MEM | BPF_H:
587                         if (is_ereg(a_reg) || is_ereg(x_reg))
588                                 EMIT3(0x66, add_2mod(0x40, a_reg, x_reg), 0x89);
589                         else
590                                 EMIT2(0x66, 0x89);
591                         goto stx;
592                 case BPF_STX | BPF_MEM | BPF_W:
593                         if (is_ereg(a_reg) || is_ereg(x_reg))
594                                 EMIT2(add_2mod(0x40, a_reg, x_reg), 0x89);
595                         else
596                                 EMIT1(0x89);
597                         goto stx;
598                 case BPF_STX | BPF_MEM | BPF_DW:
599                         EMIT2(add_2mod(0x48, a_reg, x_reg), 0x89);
600 stx:                    if (is_imm8(insn->off))
601                                 EMIT2(add_2reg(0x40, a_reg, x_reg), insn->off);
602                         else
603                                 EMIT1_off32(add_2reg(0x80, a_reg, x_reg),
604                                             insn->off);
605                         break;
606
607                         /* LDX: a_reg = *(u8*)(x_reg + off) */
608                 case BPF_LDX | BPF_MEM | BPF_B:
609                         /* emit 'movzx rax, byte ptr [rax + off]' */
610                         EMIT3(add_2mod(0x48, x_reg, a_reg), 0x0F, 0xB6);
611                         goto ldx;
612                 case BPF_LDX | BPF_MEM | BPF_H:
613                         /* emit 'movzx rax, word ptr [rax + off]' */
614                         EMIT3(add_2mod(0x48, x_reg, a_reg), 0x0F, 0xB7);
615                         goto ldx;
616                 case BPF_LDX | BPF_MEM | BPF_W:
617                         /* emit 'mov eax, dword ptr [rax+0x14]' */
618                         if (is_ereg(a_reg) || is_ereg(x_reg))
619                                 EMIT2(add_2mod(0x40, x_reg, a_reg), 0x8B);
620                         else
621                                 EMIT1(0x8B);
622                         goto ldx;
623                 case BPF_LDX | BPF_MEM | BPF_DW:
624                         /* emit 'mov rax, qword ptr [rax+0x14]' */
625                         EMIT2(add_2mod(0x48, x_reg, a_reg), 0x8B);
626 ldx:                    /* if insn->off == 0 we can save one extra byte, but
627                          * special case of x86 r13 which always needs an offset
628                          * is not worth the hassle
629                          */
630                         if (is_imm8(insn->off))
631                                 EMIT2(add_2reg(0x40, x_reg, a_reg), insn->off);
632                         else
633                                 EMIT1_off32(add_2reg(0x80, x_reg, a_reg),
634                                             insn->off);
635                         break;
636
637                         /* STX XADD: lock *(u32*)(a_reg + off) += x_reg */
638                 case BPF_STX | BPF_XADD | BPF_W:
639                         /* emit 'lock add dword ptr [rax + off], eax' */
640                         if (is_ereg(a_reg) || is_ereg(x_reg))
641                                 EMIT3(0xF0, add_2mod(0x40, a_reg, x_reg), 0x01);
642                         else
643                                 EMIT2(0xF0, 0x01);
644                         goto xadd;
645                 case BPF_STX | BPF_XADD | BPF_DW:
646                         EMIT3(0xF0, add_2mod(0x48, a_reg, x_reg), 0x01);
647 xadd:                   if (is_imm8(insn->off))
648                                 EMIT2(add_2reg(0x40, a_reg, x_reg), insn->off);
649                         else
650                                 EMIT1_off32(add_2reg(0x80, a_reg, x_reg),
651                                             insn->off);
652                         break;
653
654                         /* call */
655                 case BPF_JMP | BPF_CALL:
656                         func = (u8 *) __bpf_call_base + K;
657                         jmp_offset = func - (image + addrs[i]);
658                         if (ctx->seen_ld_abs) {
659                                 EMIT2(0x41, 0x52); /* push %r10 */
660                                 EMIT2(0x41, 0x51); /* push %r9 */
661                                 /* need to adjust jmp offset, since
662                                  * pop %r9, pop %r10 take 4 bytes after call insn
663                                  */
664                                 jmp_offset += 4;
665                         }
666                         if (!K || !is_simm32(jmp_offset)) {
667                                 pr_err("unsupported bpf func %d addr %p image %p\n",
668                                        K, func, image);
669                                 return -EINVAL;
670                         }
671                         EMIT1_off32(0xE8, jmp_offset);
672                         if (ctx->seen_ld_abs) {
673                                 EMIT2(0x41, 0x59); /* pop %r9 */
674                                 EMIT2(0x41, 0x5A); /* pop %r10 */
675                         }
676                         break;
677
678                         /* cond jump */
679                 case BPF_JMP | BPF_JEQ | BPF_X:
680                 case BPF_JMP | BPF_JNE | BPF_X:
681                 case BPF_JMP | BPF_JGT | BPF_X:
682                 case BPF_JMP | BPF_JGE | BPF_X:
683                 case BPF_JMP | BPF_JSGT | BPF_X:
684                 case BPF_JMP | BPF_JSGE | BPF_X:
685                         /* cmp a_reg, x_reg */
686                         EMIT3(add_2mod(0x48, a_reg, x_reg), 0x39,
687                               add_2reg(0xC0, a_reg, x_reg));
688                         goto emit_cond_jmp;
689
690                 case BPF_JMP | BPF_JSET | BPF_X:
691                         /* test a_reg, x_reg */
692                         EMIT3(add_2mod(0x48, a_reg, x_reg), 0x85,
693                               add_2reg(0xC0, a_reg, x_reg));
694                         goto emit_cond_jmp;
695
696                 case BPF_JMP | BPF_JSET | BPF_K:
697                         /* test a_reg, imm32 */
698                         EMIT1(add_1mod(0x48, a_reg));
699                         EMIT2_off32(0xF7, add_1reg(0xC0, a_reg), K);
700                         goto emit_cond_jmp;
701
702                 case BPF_JMP | BPF_JEQ | BPF_K:
703                 case BPF_JMP | BPF_JNE | BPF_K:
704                 case BPF_JMP | BPF_JGT | BPF_K:
705                 case BPF_JMP | BPF_JGE | BPF_K:
706                 case BPF_JMP | BPF_JSGT | BPF_K:
707                 case BPF_JMP | BPF_JSGE | BPF_K:
708                         /* cmp a_reg, imm8/32 */
709                         EMIT1(add_1mod(0x48, a_reg));
710
711                         if (is_imm8(K))
712                                 EMIT3(0x83, add_1reg(0xF8, a_reg), K);
713                         else
714                                 EMIT2_off32(0x81, add_1reg(0xF8, a_reg), K);
715
716 emit_cond_jmp:          /* convert BPF opcode to x86 */
717                         switch (BPF_OP(insn->code)) {
718                         case BPF_JEQ:
719                                 jmp_cond = X86_JE;
720                                 break;
721                         case BPF_JSET:
722                         case BPF_JNE:
723                                 jmp_cond = X86_JNE;
724                                 break;
725                         case BPF_JGT:
726                                 /* GT is unsigned '>', JA in x86 */
727                                 jmp_cond = X86_JA;
728                                 break;
729                         case BPF_JGE:
730                                 /* GE is unsigned '>=', JAE in x86 */
731                                 jmp_cond = X86_JAE;
732                                 break;
733                         case BPF_JSGT:
734                                 /* signed '>', GT in x86 */
735                                 jmp_cond = X86_JG;
736                                 break;
737                         case BPF_JSGE:
738                                 /* signed '>=', GE in x86 */
739                                 jmp_cond = X86_JGE;
740                                 break;
741                         default: /* to silence gcc warning */
742                                 return -EFAULT;
743                         }
744                         jmp_offset = addrs[i + insn->off] - addrs[i];
745                         if (is_imm8(jmp_offset)) {
746                                 EMIT2(jmp_cond, jmp_offset);
747                         } else if (is_simm32(jmp_offset)) {
748                                 EMIT2_off32(0x0F, jmp_cond + 0x10, jmp_offset);
749                         } else {
750                                 pr_err("cond_jmp gen bug %llx\n", jmp_offset);
751                                 return -EFAULT;
752                         }
753
754                         break;
755
756                 case BPF_JMP | BPF_JA:
757                         jmp_offset = addrs[i + insn->off] - addrs[i];
758                         if (!jmp_offset)
759                                 /* optimize out nop jumps */
760                                 break;
761 emit_jmp:
762                         if (is_imm8(jmp_offset)) {
763                                 EMIT2(0xEB, jmp_offset);
764                         } else if (is_simm32(jmp_offset)) {
765                                 EMIT1_off32(0xE9, jmp_offset);
766                         } else {
767                                 pr_err("jmp gen bug %llx\n", jmp_offset);
768                                 return -EFAULT;
769                         }
770                         break;
771
772                 case BPF_LD | BPF_IND | BPF_W:
773                         func = sk_load_word;
774                         goto common_load;
775                 case BPF_LD | BPF_ABS | BPF_W:
776                         func = CHOOSE_LOAD_FUNC(K, sk_load_word);
777 common_load:            ctx->seen_ld_abs = true;
778                         jmp_offset = func - (image + addrs[i]);
779                         if (!func || !is_simm32(jmp_offset)) {
780                                 pr_err("unsupported bpf func %d addr %p image %p\n",
781                                        K, func, image);
782                                 return -EINVAL;
783                         }
784                         if (BPF_MODE(insn->code) == BPF_ABS) {
785                                 /* mov %esi, imm32 */
786                                 EMIT1_off32(0xBE, K);
787                         } else {
788                                 /* mov %rsi, x_reg */
789                                 EMIT_mov(BPF_REG_2, x_reg);
790                                 if (K) {
791                                         if (is_imm8(K))
792                                                 /* add %esi, imm8 */
793                                                 EMIT3(0x83, 0xC6, K);
794                                         else
795                                                 /* add %esi, imm32 */
796                                                 EMIT2_off32(0x81, 0xC6, K);
797                                 }
798                         }
799                         /* skb pointer is in R6 (%rbx), it will be copied into
800                          * %rdi if skb_copy_bits() call is necessary.
801                          * sk_load_* helpers also use %r10 and %r9d.
802                          * See bpf_jit.S
803                          */
804                         EMIT1_off32(0xE8, jmp_offset); /* call */
805                         break;
806
807                 case BPF_LD | BPF_IND | BPF_H:
808                         func = sk_load_half;
809                         goto common_load;
810                 case BPF_LD | BPF_ABS | BPF_H:
811                         func = CHOOSE_LOAD_FUNC(K, sk_load_half);
812                         goto common_load;
813                 case BPF_LD | BPF_IND | BPF_B:
814                         func = sk_load_byte;
815                         goto common_load;
816                 case BPF_LD | BPF_ABS | BPF_B:
817                         func = CHOOSE_LOAD_FUNC(K, sk_load_byte);
818                         goto common_load;
819
820                 case BPF_JMP | BPF_EXIT:
821                         if (i != insn_cnt - 1) {
822                                 jmp_offset = ctx->cleanup_addr - addrs[i];
823                                 goto emit_jmp;
824                         }
825                         /* update cleanup_addr */
826                         ctx->cleanup_addr = proglen;
827                         /* mov rbx, qword ptr [rbp-X] */
828                         EMIT3_off32(0x48, 0x8B, 0x9D, -stacksize);
829                         /* mov r13, qword ptr [rbp-X] */
830                         EMIT3_off32(0x4C, 0x8B, 0xAD, -stacksize + 8);
831                         /* mov r14, qword ptr [rbp-X] */
832                         EMIT3_off32(0x4C, 0x8B, 0xB5, -stacksize + 16);
833                         /* mov r15, qword ptr [rbp-X] */
834                         EMIT3_off32(0x4C, 0x8B, 0xBD, -stacksize + 24);
835
836                         EMIT1(0xC9); /* leave */
837                         EMIT1(0xC3); /* ret */
838                         break;
839
840                 default:
841                         /* By design x64 JIT should support all BPF instructions
842                          * This error will be seen if new instruction was added
843                          * to interpreter, but not to JIT
844                          * or if there is junk in sk_filter
845                          */
846                         pr_err("bpf_jit: unknown opcode %02x\n", insn->code);
847                         return -EINVAL;
848                 }
849
850                 ilen = prog - temp;
851                 if (image) {
852                         if (unlikely(proglen + ilen > oldproglen)) {
853                                 pr_err("bpf_jit_compile fatal error\n");
854                                 return -EFAULT;
855                         }
856                         memcpy(image + proglen, temp, ilen);
857                 }
858                 proglen += ilen;
859                 addrs[i] = proglen;
860                 prog = temp;
861         }
862         return proglen;
863 }
864
865 void bpf_jit_compile(struct sk_filter *prog)
866 {
867 }
868
869 void bpf_int_jit_compile(struct sk_filter *prog)
870 {
871         struct bpf_binary_header *header = NULL;
872         int proglen, oldproglen = 0;
873         struct jit_context ctx = {};
874         u8 *image = NULL;
875         int *addrs;
876         int pass;
877         int i;
878
879         if (!bpf_jit_enable)
880                 return;
881
882         if (!prog || !prog->len)
883                 return;
884
885         addrs = kmalloc(prog->len * sizeof(*addrs), GFP_KERNEL);
886         if (!addrs)
887                 return;
888
889         /* Before first pass, make a rough estimation of addrs[]
890          * each bpf instruction is translated to less than 64 bytes
891          */
892         for (proglen = 0, i = 0; i < prog->len; i++) {
893                 proglen += 64;
894                 addrs[i] = proglen;
895         }
896         ctx.cleanup_addr = proglen;
897
898         for (pass = 0; pass < 10; pass++) {
899                 proglen = do_jit(prog, addrs, image, oldproglen, &ctx);
900                 if (proglen <= 0) {
901                         image = NULL;
902                         if (header)
903                                 module_free(NULL, header);
904                         goto out;
905                 }
906                 if (image) {
907                         if (proglen != oldproglen)
908                                 pr_err("bpf_jit: proglen=%d != oldproglen=%d\n",
909                                        proglen, oldproglen);
910                         break;
911                 }
912                 if (proglen == oldproglen) {
913                         header = bpf_alloc_binary(proglen, &image);
914                         if (!header)
915                                 goto out;
916                 }
917                 oldproglen = proglen;
918         }
919
920         if (bpf_jit_enable > 1)
921                 bpf_jit_dump(prog->len, proglen, 0, image);
922
923         if (image) {
924                 bpf_flush_icache(header, image + proglen);
925                 set_memory_ro((unsigned long)header, header->pages);
926                 prog->bpf_func = (void *)image;
927                 prog->jited = 1;
928         }
929 out:
930         kfree(addrs);
931 }
932
933 static void bpf_jit_free_deferred(struct work_struct *work)
934 {
935         struct sk_filter *fp = container_of(work, struct sk_filter, work);
936         unsigned long addr = (unsigned long)fp->bpf_func & PAGE_MASK;
937         struct bpf_binary_header *header = (void *)addr;
938
939         set_memory_rw(addr, header->pages);
940         module_free(NULL, header);
941         kfree(fp);
942 }
943
944 void bpf_jit_free(struct sk_filter *fp)
945 {
946         if (fp->jited) {
947                 INIT_WORK(&fp->work, bpf_jit_free_deferred);
948                 schedule_work(&fp->work);
949         } else {
950                 kfree(fp);
951         }
952 }