bpf: x64: add JIT support for multi-function programs
[linux-block.git] / arch / x86 / net / bpf_jit_comp.c
index 68859b58ab8415e6d13d66dc25a33dcdd8bf08a7..87f214fbe66ec163d24b12b6defc7edab612ecc9 100644 (file)
@@ -1109,13 +1109,23 @@ common_load:
        return proglen;
 }
 
+struct x64_jit_data {
+       struct bpf_binary_header *header;
+       int *addrs;
+       u8 *image;
+       int proglen;
+       struct jit_context ctx;
+};
+
 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 {
        struct bpf_binary_header *header = NULL;
        struct bpf_prog *tmp, *orig_prog = prog;
+       struct x64_jit_data *jit_data;
        int proglen, oldproglen = 0;
        struct jit_context ctx = {};
        bool tmp_blinded = false;
+       bool extra_pass = false;
        u8 *image = NULL;
        int *addrs;
        int pass;
@@ -1135,10 +1145,28 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
                prog = tmp;
        }
 
+       jit_data = prog->aux->jit_data;
+       if (!jit_data) {
+               jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
+               if (!jit_data) {
+                       prog = orig_prog;
+                       goto out;
+               }
+               prog->aux->jit_data = jit_data;
+       }
+       addrs = jit_data->addrs;
+       if (addrs) {
+               ctx = jit_data->ctx;
+               oldproglen = jit_data->proglen;
+               image = jit_data->image;
+               header = jit_data->header;
+               extra_pass = true;
+               goto skip_init_addrs;
+       }
        addrs = kmalloc(prog->len * sizeof(*addrs), GFP_KERNEL);
        if (!addrs) {
                prog = orig_prog;
-               goto out;
+               goto out_addrs;
        }
 
        /* Before first pass, make a rough estimation of addrs[]
@@ -1149,6 +1177,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
                addrs[i] = proglen;
        }
        ctx.cleanup_addr = proglen;
+skip_init_addrs:
 
        /* JITed image shrinks with every pass and the loop iterates
         * until the image stops shrinking. Very large bpf programs
@@ -1189,7 +1218,15 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 
        if (image) {
                bpf_flush_icache(header, image + proglen);
-               bpf_jit_binary_lock_ro(header);
+               if (!prog->is_func || extra_pass) {
+                       bpf_jit_binary_lock_ro(header);
+               } else {
+                       jit_data->addrs = addrs;
+                       jit_data->ctx = ctx;
+                       jit_data->proglen = proglen;
+                       jit_data->image = image;
+                       jit_data->header = header;
+               }
                prog->bpf_func = (void *)image;
                prog->jited = 1;
                prog->jited_len = proglen;
@@ -1197,8 +1234,12 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
                prog = orig_prog;
        }
 
+       if (!prog->is_func || extra_pass) {
 out_addrs:
-       kfree(addrs);
+               kfree(addrs);
+               kfree(jit_data);
+               prog->aux->jit_data = NULL;
+       }
 out:
        if (tmp_blinded)
                bpf_jit_prog_release_other(prog, prog == orig_prog ?