Linux 6.12-rc1
[linux-2.6-block.git] / arch / arm / crypto / aes-neonbs-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/simd.h>
13 #include <crypto/internal/skcipher.h>
14 #include <crypto/scatterwalk.h>
15 #include <crypto/xts.h>
16 #include <linux/module.h>
17 #include "aes-cipher.h"
18
19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20 MODULE_DESCRIPTION("Bit sliced AES using NEON instructions");
21 MODULE_LICENSE("GPL v2");
22
23 MODULE_ALIAS_CRYPTO("ecb(aes)");
24 MODULE_ALIAS_CRYPTO("cbc(aes)");
25 MODULE_ALIAS_CRYPTO("ctr(aes)");
26 MODULE_ALIAS_CRYPTO("xts(aes)");
27
28 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
29
30 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
31                                   int rounds, int blocks);
32 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
33                                   int rounds, int blocks);
34
35 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
36                                   int rounds, int blocks, u8 iv[]);
37
38 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
39                                   int rounds, int blocks, u8 ctr[]);
40
41 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
42                                   int rounds, int blocks, u8 iv[], int);
43 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
44                                   int rounds, int blocks, u8 iv[], int);
45
46 struct aesbs_ctx {
47         int     rounds;
48         u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
49 };
50
51 struct aesbs_cbc_ctx {
52         struct aesbs_ctx        key;
53         struct crypto_aes_ctx   fallback;
54 };
55
56 struct aesbs_xts_ctx {
57         struct aesbs_ctx        key;
58         struct crypto_aes_ctx   fallback;
59         struct crypto_aes_ctx   tweak_key;
60 };
61
62 struct aesbs_ctr_ctx {
63         struct aesbs_ctx        key;            /* must be first member */
64         struct crypto_aes_ctx   fallback;
65 };
66
67 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
68                         unsigned int key_len)
69 {
70         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
71         struct crypto_aes_ctx rk;
72         int err;
73
74         err = aes_expandkey(&rk, in_key, key_len);
75         if (err)
76                 return err;
77
78         ctx->rounds = 6 + key_len / 4;
79
80         kernel_neon_begin();
81         aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
82         kernel_neon_end();
83
84         return 0;
85 }
86
87 static int __ecb_crypt(struct skcipher_request *req,
88                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
89                                   int rounds, int blocks))
90 {
91         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
92         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
93         struct skcipher_walk walk;
94         int err;
95
96         err = skcipher_walk_virt(&walk, req, false);
97
98         while (walk.nbytes >= AES_BLOCK_SIZE) {
99                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
100
101                 if (walk.nbytes < walk.total)
102                         blocks = round_down(blocks,
103                                             walk.stride / AES_BLOCK_SIZE);
104
105                 kernel_neon_begin();
106                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
107                    ctx->rounds, blocks);
108                 kernel_neon_end();
109                 err = skcipher_walk_done(&walk,
110                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
111         }
112
113         return err;
114 }
115
116 static int ecb_encrypt(struct skcipher_request *req)
117 {
118         return __ecb_crypt(req, aesbs_ecb_encrypt);
119 }
120
121 static int ecb_decrypt(struct skcipher_request *req)
122 {
123         return __ecb_crypt(req, aesbs_ecb_decrypt);
124 }
125
126 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
127                             unsigned int key_len)
128 {
129         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
130         int err;
131
132         err = aes_expandkey(&ctx->fallback, in_key, key_len);
133         if (err)
134                 return err;
135
136         ctx->key.rounds = 6 + key_len / 4;
137
138         kernel_neon_begin();
139         aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
140         kernel_neon_end();
141
142         return 0;
143 }
144
145 static int cbc_encrypt(struct skcipher_request *req)
146 {
147         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
148         const struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
149         struct skcipher_walk walk;
150         unsigned int nbytes;
151         int err;
152
153         err = skcipher_walk_virt(&walk, req, false);
154
155         while ((nbytes = walk.nbytes) >= AES_BLOCK_SIZE) {
156                 const u8 *src = walk.src.virt.addr;
157                 u8 *dst = walk.dst.virt.addr;
158                 u8 *prev = walk.iv;
159
160                 do {
161                         crypto_xor_cpy(dst, src, prev, AES_BLOCK_SIZE);
162                         __aes_arm_encrypt(ctx->fallback.key_enc,
163                                           ctx->key.rounds, dst, dst);
164                         prev = dst;
165                         src += AES_BLOCK_SIZE;
166                         dst += AES_BLOCK_SIZE;
167                         nbytes -= AES_BLOCK_SIZE;
168                 } while (nbytes >= AES_BLOCK_SIZE);
169                 memcpy(walk.iv, prev, AES_BLOCK_SIZE);
170                 err = skcipher_walk_done(&walk, nbytes);
171         }
172         return err;
173 }
174
175 static int cbc_decrypt(struct skcipher_request *req)
176 {
177         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
178         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
179         struct skcipher_walk walk;
180         int err;
181
182         err = skcipher_walk_virt(&walk, req, false);
183
184         while (walk.nbytes >= AES_BLOCK_SIZE) {
185                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
186
187                 if (walk.nbytes < walk.total)
188                         blocks = round_down(blocks,
189                                             walk.stride / AES_BLOCK_SIZE);
190
191                 kernel_neon_begin();
192                 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
193                                   ctx->key.rk, ctx->key.rounds, blocks,
194                                   walk.iv);
195                 kernel_neon_end();
196                 err = skcipher_walk_done(&walk,
197                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
198         }
199
200         return err;
201 }
202
203 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
204                                  unsigned int key_len)
205 {
206         struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
207         int err;
208
209         err = aes_expandkey(&ctx->fallback, in_key, key_len);
210         if (err)
211                 return err;
212
213         ctx->key.rounds = 6 + key_len / 4;
214
215         kernel_neon_begin();
216         aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
217         kernel_neon_end();
218
219         return 0;
220 }
221
222 static int ctr_encrypt(struct skcipher_request *req)
223 {
224         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
225         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
226         struct skcipher_walk walk;
227         u8 buf[AES_BLOCK_SIZE];
228         int err;
229
230         err = skcipher_walk_virt(&walk, req, false);
231
232         while (walk.nbytes > 0) {
233                 const u8 *src = walk.src.virt.addr;
234                 u8 *dst = walk.dst.virt.addr;
235                 int bytes = walk.nbytes;
236
237                 if (unlikely(bytes < AES_BLOCK_SIZE))
238                         src = dst = memcpy(buf + sizeof(buf) - bytes,
239                                            src, bytes);
240                 else if (walk.nbytes < walk.total)
241                         bytes &= ~(8 * AES_BLOCK_SIZE - 1);
242
243                 kernel_neon_begin();
244                 aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv);
245                 kernel_neon_end();
246
247                 if (unlikely(bytes < AES_BLOCK_SIZE))
248                         memcpy(walk.dst.virt.addr,
249                                buf + sizeof(buf) - bytes, bytes);
250
251                 err = skcipher_walk_done(&walk, walk.nbytes - bytes);
252         }
253
254         return err;
255 }
256
257 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
258 {
259         struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
260
261         __aes_arm_encrypt(ctx->fallback.key_enc, ctx->key.rounds, src, dst);
262 }
263
264 static int ctr_encrypt_sync(struct skcipher_request *req)
265 {
266         if (!crypto_simd_usable())
267                 return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
268
269         return ctr_encrypt(req);
270 }
271
272 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
273                             unsigned int key_len)
274 {
275         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
276         int err;
277
278         err = xts_verify_key(tfm, in_key, key_len);
279         if (err)
280                 return err;
281
282         key_len /= 2;
283         err = aes_expandkey(&ctx->fallback, in_key, key_len);
284         if (err)
285                 return err;
286         err = aes_expandkey(&ctx->tweak_key, in_key + key_len, key_len);
287         if (err)
288                 return err;
289
290         return aesbs_setkey(tfm, in_key, key_len);
291 }
292
293 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
294                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
295                                   int rounds, int blocks, u8 iv[], int))
296 {
297         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
298         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
299         const int rounds = ctx->key.rounds;
300         int tail = req->cryptlen % AES_BLOCK_SIZE;
301         struct skcipher_request subreq;
302         u8 buf[2 * AES_BLOCK_SIZE];
303         struct skcipher_walk walk;
304         int err;
305
306         if (req->cryptlen < AES_BLOCK_SIZE)
307                 return -EINVAL;
308
309         if (unlikely(tail)) {
310                 skcipher_request_set_tfm(&subreq, tfm);
311                 skcipher_request_set_callback(&subreq,
312                                               skcipher_request_flags(req),
313                                               NULL, NULL);
314                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
315                                            req->cryptlen - tail, req->iv);
316                 req = &subreq;
317         }
318
319         err = skcipher_walk_virt(&walk, req, true);
320         if (err)
321                 return err;
322
323         __aes_arm_encrypt(ctx->tweak_key.key_enc, rounds, walk.iv, walk.iv);
324
325         while (walk.nbytes >= AES_BLOCK_SIZE) {
326                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
327                 int reorder_last_tweak = !encrypt && tail > 0;
328
329                 if (walk.nbytes < walk.total) {
330                         blocks = round_down(blocks,
331                                             walk.stride / AES_BLOCK_SIZE);
332                         reorder_last_tweak = 0;
333                 }
334
335                 kernel_neon_begin();
336                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
337                    rounds, blocks, walk.iv, reorder_last_tweak);
338                 kernel_neon_end();
339                 err = skcipher_walk_done(&walk,
340                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
341         }
342
343         if (err || likely(!tail))
344                 return err;
345
346         /* handle ciphertext stealing */
347         scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
348                                  AES_BLOCK_SIZE, 0);
349         memcpy(buf + AES_BLOCK_SIZE, buf, tail);
350         scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
351
352         crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
353
354         if (encrypt)
355                 __aes_arm_encrypt(ctx->fallback.key_enc, rounds, buf, buf);
356         else
357                 __aes_arm_decrypt(ctx->fallback.key_dec, rounds, buf, buf);
358
359         crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
360
361         scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
362                                  AES_BLOCK_SIZE + tail, 1);
363         return 0;
364 }
365
366 static int xts_encrypt(struct skcipher_request *req)
367 {
368         return __xts_crypt(req, true, aesbs_xts_encrypt);
369 }
370
371 static int xts_decrypt(struct skcipher_request *req)
372 {
373         return __xts_crypt(req, false, aesbs_xts_decrypt);
374 }
375
376 static struct skcipher_alg aes_algs[] = { {
377         .base.cra_name          = "__ecb(aes)",
378         .base.cra_driver_name   = "__ecb-aes-neonbs",
379         .base.cra_priority      = 250,
380         .base.cra_blocksize     = AES_BLOCK_SIZE,
381         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
382         .base.cra_module        = THIS_MODULE,
383         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
384
385         .min_keysize            = AES_MIN_KEY_SIZE,
386         .max_keysize            = AES_MAX_KEY_SIZE,
387         .walksize               = 8 * AES_BLOCK_SIZE,
388         .setkey                 = aesbs_setkey,
389         .encrypt                = ecb_encrypt,
390         .decrypt                = ecb_decrypt,
391 }, {
392         .base.cra_name          = "__cbc(aes)",
393         .base.cra_driver_name   = "__cbc-aes-neonbs",
394         .base.cra_priority      = 250,
395         .base.cra_blocksize     = AES_BLOCK_SIZE,
396         .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
397         .base.cra_module        = THIS_MODULE,
398         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
399
400         .min_keysize            = AES_MIN_KEY_SIZE,
401         .max_keysize            = AES_MAX_KEY_SIZE,
402         .walksize               = 8 * AES_BLOCK_SIZE,
403         .ivsize                 = AES_BLOCK_SIZE,
404         .setkey                 = aesbs_cbc_setkey,
405         .encrypt                = cbc_encrypt,
406         .decrypt                = cbc_decrypt,
407 }, {
408         .base.cra_name          = "__ctr(aes)",
409         .base.cra_driver_name   = "__ctr-aes-neonbs",
410         .base.cra_priority      = 250,
411         .base.cra_blocksize     = 1,
412         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
413         .base.cra_module        = THIS_MODULE,
414         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
415
416         .min_keysize            = AES_MIN_KEY_SIZE,
417         .max_keysize            = AES_MAX_KEY_SIZE,
418         .chunksize              = AES_BLOCK_SIZE,
419         .walksize               = 8 * AES_BLOCK_SIZE,
420         .ivsize                 = AES_BLOCK_SIZE,
421         .setkey                 = aesbs_setkey,
422         .encrypt                = ctr_encrypt,
423         .decrypt                = ctr_encrypt,
424 }, {
425         .base.cra_name          = "ctr(aes)",
426         .base.cra_driver_name   = "ctr-aes-neonbs-sync",
427         .base.cra_priority      = 250 - 1,
428         .base.cra_blocksize     = 1,
429         .base.cra_ctxsize       = sizeof(struct aesbs_ctr_ctx),
430         .base.cra_module        = THIS_MODULE,
431
432         .min_keysize            = AES_MIN_KEY_SIZE,
433         .max_keysize            = AES_MAX_KEY_SIZE,
434         .chunksize              = AES_BLOCK_SIZE,
435         .walksize               = 8 * AES_BLOCK_SIZE,
436         .ivsize                 = AES_BLOCK_SIZE,
437         .setkey                 = aesbs_ctr_setkey_sync,
438         .encrypt                = ctr_encrypt_sync,
439         .decrypt                = ctr_encrypt_sync,
440 }, {
441         .base.cra_name          = "__xts(aes)",
442         .base.cra_driver_name   = "__xts-aes-neonbs",
443         .base.cra_priority      = 250,
444         .base.cra_blocksize     = AES_BLOCK_SIZE,
445         .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
446         .base.cra_module        = THIS_MODULE,
447         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
448
449         .min_keysize            = 2 * AES_MIN_KEY_SIZE,
450         .max_keysize            = 2 * AES_MAX_KEY_SIZE,
451         .walksize               = 8 * AES_BLOCK_SIZE,
452         .ivsize                 = AES_BLOCK_SIZE,
453         .setkey                 = aesbs_xts_setkey,
454         .encrypt                = xts_encrypt,
455         .decrypt                = xts_decrypt,
456 } };
457
458 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
459
460 static void aes_exit(void)
461 {
462         int i;
463
464         for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
465                 if (aes_simd_algs[i])
466                         simd_skcipher_free(aes_simd_algs[i]);
467
468         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
469 }
470
471 static int __init aes_init(void)
472 {
473         struct simd_skcipher_alg *simd;
474         const char *basename;
475         const char *algname;
476         const char *drvname;
477         int err;
478         int i;
479
480         if (!(elf_hwcap & HWCAP_NEON))
481                 return -ENODEV;
482
483         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
484         if (err)
485                 return err;
486
487         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
488                 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
489                         continue;
490
491                 algname = aes_algs[i].base.cra_name + 2;
492                 drvname = aes_algs[i].base.cra_driver_name + 2;
493                 basename = aes_algs[i].base.cra_driver_name;
494                 simd = simd_skcipher_create_compat(aes_algs + i, algname, drvname, basename);
495                 err = PTR_ERR(simd);
496                 if (IS_ERR(simd))
497                         goto unregister_simds;
498
499                 aes_simd_algs[i] = simd;
500         }
501         return 0;
502
503 unregister_simds:
504         aes_exit();
505         return err;
506 }
507
508 late_initcall(aes_init);
509 module_exit(aes_exit);