crypto: remove cipher routines from public crypto API
[linux-block.git] / arch / arm / crypto / aes-neonbs-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/cipher.h>
13 #include <crypto/internal/simd.h>
14 #include <crypto/internal/skcipher.h>
15 #include <crypto/scatterwalk.h>
16 #include <crypto/xts.h>
17 #include <linux/module.h>
18
19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20 MODULE_LICENSE("GPL v2");
21
22 MODULE_ALIAS_CRYPTO("ecb(aes)");
23 MODULE_ALIAS_CRYPTO("cbc(aes)-all");
24 MODULE_ALIAS_CRYPTO("ctr(aes)");
25 MODULE_ALIAS_CRYPTO("xts(aes)");
26
27 MODULE_IMPORT_NS(CRYPTO_INTERNAL);
28
29 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
30
31 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
32                                   int rounds, int blocks);
33 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
34                                   int rounds, int blocks);
35
36 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
37                                   int rounds, int blocks, u8 iv[]);
38
39 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
40                                   int rounds, int blocks, u8 ctr[], u8 final[]);
41
42 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
43                                   int rounds, int blocks, u8 iv[], int);
44 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
45                                   int rounds, int blocks, u8 iv[], int);
46
47 struct aesbs_ctx {
48         int     rounds;
49         u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
50 };
51
52 struct aesbs_cbc_ctx {
53         struct aesbs_ctx        key;
54         struct crypto_skcipher  *enc_tfm;
55 };
56
57 struct aesbs_xts_ctx {
58         struct aesbs_ctx        key;
59         struct crypto_cipher    *cts_tfm;
60         struct crypto_cipher    *tweak_tfm;
61 };
62
63 struct aesbs_ctr_ctx {
64         struct aesbs_ctx        key;            /* must be first member */
65         struct crypto_aes_ctx   fallback;
66 };
67
68 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
69                         unsigned int key_len)
70 {
71         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
72         struct crypto_aes_ctx rk;
73         int err;
74
75         err = aes_expandkey(&rk, in_key, key_len);
76         if (err)
77                 return err;
78
79         ctx->rounds = 6 + key_len / 4;
80
81         kernel_neon_begin();
82         aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
83         kernel_neon_end();
84
85         return 0;
86 }
87
88 static int __ecb_crypt(struct skcipher_request *req,
89                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
90                                   int rounds, int blocks))
91 {
92         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
93         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
94         struct skcipher_walk walk;
95         int err;
96
97         err = skcipher_walk_virt(&walk, req, false);
98
99         while (walk.nbytes >= AES_BLOCK_SIZE) {
100                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
101
102                 if (walk.nbytes < walk.total)
103                         blocks = round_down(blocks,
104                                             walk.stride / AES_BLOCK_SIZE);
105
106                 kernel_neon_begin();
107                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
108                    ctx->rounds, blocks);
109                 kernel_neon_end();
110                 err = skcipher_walk_done(&walk,
111                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
112         }
113
114         return err;
115 }
116
117 static int ecb_encrypt(struct skcipher_request *req)
118 {
119         return __ecb_crypt(req, aesbs_ecb_encrypt);
120 }
121
122 static int ecb_decrypt(struct skcipher_request *req)
123 {
124         return __ecb_crypt(req, aesbs_ecb_decrypt);
125 }
126
127 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
128                             unsigned int key_len)
129 {
130         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
131         struct crypto_aes_ctx rk;
132         int err;
133
134         err = aes_expandkey(&rk, in_key, key_len);
135         if (err)
136                 return err;
137
138         ctx->key.rounds = 6 + key_len / 4;
139
140         kernel_neon_begin();
141         aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
142         kernel_neon_end();
143         memzero_explicit(&rk, sizeof(rk));
144
145         return crypto_skcipher_setkey(ctx->enc_tfm, in_key, key_len);
146 }
147
148 static int cbc_encrypt(struct skcipher_request *req)
149 {
150         struct skcipher_request *subreq = skcipher_request_ctx(req);
151         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
152         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
153
154         skcipher_request_set_tfm(subreq, ctx->enc_tfm);
155         skcipher_request_set_callback(subreq,
156                                       skcipher_request_flags(req),
157                                       NULL, NULL);
158         skcipher_request_set_crypt(subreq, req->src, req->dst,
159                                    req->cryptlen, req->iv);
160
161         return crypto_skcipher_encrypt(subreq);
162 }
163
164 static int cbc_decrypt(struct skcipher_request *req)
165 {
166         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
167         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
168         struct skcipher_walk walk;
169         int err;
170
171         err = skcipher_walk_virt(&walk, req, false);
172
173         while (walk.nbytes >= AES_BLOCK_SIZE) {
174                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
175
176                 if (walk.nbytes < walk.total)
177                         blocks = round_down(blocks,
178                                             walk.stride / AES_BLOCK_SIZE);
179
180                 kernel_neon_begin();
181                 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
182                                   ctx->key.rk, ctx->key.rounds, blocks,
183                                   walk.iv);
184                 kernel_neon_end();
185                 err = skcipher_walk_done(&walk,
186                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
187         }
188
189         return err;
190 }
191
192 static int cbc_init(struct crypto_skcipher *tfm)
193 {
194         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
195         unsigned int reqsize;
196
197         ctx->enc_tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC |
198                                              CRYPTO_ALG_NEED_FALLBACK);
199         if (IS_ERR(ctx->enc_tfm))
200                 return PTR_ERR(ctx->enc_tfm);
201
202         reqsize = sizeof(struct skcipher_request);
203         reqsize += crypto_skcipher_reqsize(ctx->enc_tfm);
204         crypto_skcipher_set_reqsize(tfm, reqsize);
205
206         return 0;
207 }
208
209 static void cbc_exit(struct crypto_skcipher *tfm)
210 {
211         struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
212
213         crypto_free_skcipher(ctx->enc_tfm);
214 }
215
216 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
217                                  unsigned int key_len)
218 {
219         struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
220         int err;
221
222         err = aes_expandkey(&ctx->fallback, in_key, key_len);
223         if (err)
224                 return err;
225
226         ctx->key.rounds = 6 + key_len / 4;
227
228         kernel_neon_begin();
229         aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
230         kernel_neon_end();
231
232         return 0;
233 }
234
235 static int ctr_encrypt(struct skcipher_request *req)
236 {
237         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
238         struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
239         struct skcipher_walk walk;
240         u8 buf[AES_BLOCK_SIZE];
241         int err;
242
243         err = skcipher_walk_virt(&walk, req, false);
244
245         while (walk.nbytes > 0) {
246                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
247                 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
248
249                 if (walk.nbytes < walk.total) {
250                         blocks = round_down(blocks,
251                                             walk.stride / AES_BLOCK_SIZE);
252                         final = NULL;
253                 }
254
255                 kernel_neon_begin();
256                 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
257                                   ctx->rk, ctx->rounds, blocks, walk.iv, final);
258                 kernel_neon_end();
259
260                 if (final) {
261                         u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
262                         u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
263
264                         crypto_xor_cpy(dst, src, final,
265                                        walk.total % AES_BLOCK_SIZE);
266
267                         err = skcipher_walk_done(&walk, 0);
268                         break;
269                 }
270                 err = skcipher_walk_done(&walk,
271                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
272         }
273
274         return err;
275 }
276
277 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
278 {
279         struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
280         unsigned long flags;
281
282         /*
283          * Temporarily disable interrupts to avoid races where
284          * cachelines are evicted when the CPU is interrupted
285          * to do something else.
286          */
287         local_irq_save(flags);
288         aes_encrypt(&ctx->fallback, dst, src);
289         local_irq_restore(flags);
290 }
291
292 static int ctr_encrypt_sync(struct skcipher_request *req)
293 {
294         if (!crypto_simd_usable())
295                 return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
296
297         return ctr_encrypt(req);
298 }
299
300 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
301                             unsigned int key_len)
302 {
303         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
304         int err;
305
306         err = xts_verify_key(tfm, in_key, key_len);
307         if (err)
308                 return err;
309
310         key_len /= 2;
311         err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
312         if (err)
313                 return err;
314         err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
315         if (err)
316                 return err;
317
318         return aesbs_setkey(tfm, in_key, key_len);
319 }
320
321 static int xts_init(struct crypto_skcipher *tfm)
322 {
323         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
324
325         ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
326         if (IS_ERR(ctx->cts_tfm))
327                 return PTR_ERR(ctx->cts_tfm);
328
329         ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
330         if (IS_ERR(ctx->tweak_tfm))
331                 crypto_free_cipher(ctx->cts_tfm);
332
333         return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
334 }
335
336 static void xts_exit(struct crypto_skcipher *tfm)
337 {
338         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
339
340         crypto_free_cipher(ctx->tweak_tfm);
341         crypto_free_cipher(ctx->cts_tfm);
342 }
343
344 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
345                        void (*fn)(u8 out[], u8 const in[], u8 const rk[],
346                                   int rounds, int blocks, u8 iv[], int))
347 {
348         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
349         struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
350         int tail = req->cryptlen % AES_BLOCK_SIZE;
351         struct skcipher_request subreq;
352         u8 buf[2 * AES_BLOCK_SIZE];
353         struct skcipher_walk walk;
354         int err;
355
356         if (req->cryptlen < AES_BLOCK_SIZE)
357                 return -EINVAL;
358
359         if (unlikely(tail)) {
360                 skcipher_request_set_tfm(&subreq, tfm);
361                 skcipher_request_set_callback(&subreq,
362                                               skcipher_request_flags(req),
363                                               NULL, NULL);
364                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
365                                            req->cryptlen - tail, req->iv);
366                 req = &subreq;
367         }
368
369         err = skcipher_walk_virt(&walk, req, true);
370         if (err)
371                 return err;
372
373         crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
374
375         while (walk.nbytes >= AES_BLOCK_SIZE) {
376                 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
377                 int reorder_last_tweak = !encrypt && tail > 0;
378
379                 if (walk.nbytes < walk.total) {
380                         blocks = round_down(blocks,
381                                             walk.stride / AES_BLOCK_SIZE);
382                         reorder_last_tweak = 0;
383                 }
384
385                 kernel_neon_begin();
386                 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
387                    ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
388                 kernel_neon_end();
389                 err = skcipher_walk_done(&walk,
390                                          walk.nbytes - blocks * AES_BLOCK_SIZE);
391         }
392
393         if (err || likely(!tail))
394                 return err;
395
396         /* handle ciphertext stealing */
397         scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
398                                  AES_BLOCK_SIZE, 0);
399         memcpy(buf + AES_BLOCK_SIZE, buf, tail);
400         scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
401
402         crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
403
404         if (encrypt)
405                 crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
406         else
407                 crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
408
409         crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
410
411         scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
412                                  AES_BLOCK_SIZE + tail, 1);
413         return 0;
414 }
415
416 static int xts_encrypt(struct skcipher_request *req)
417 {
418         return __xts_crypt(req, true, aesbs_xts_encrypt);
419 }
420
421 static int xts_decrypt(struct skcipher_request *req)
422 {
423         return __xts_crypt(req, false, aesbs_xts_decrypt);
424 }
425
426 static struct skcipher_alg aes_algs[] = { {
427         .base.cra_name          = "__ecb(aes)",
428         .base.cra_driver_name   = "__ecb-aes-neonbs",
429         .base.cra_priority      = 250,
430         .base.cra_blocksize     = AES_BLOCK_SIZE,
431         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
432         .base.cra_module        = THIS_MODULE,
433         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
434
435         .min_keysize            = AES_MIN_KEY_SIZE,
436         .max_keysize            = AES_MAX_KEY_SIZE,
437         .walksize               = 8 * AES_BLOCK_SIZE,
438         .setkey                 = aesbs_setkey,
439         .encrypt                = ecb_encrypt,
440         .decrypt                = ecb_decrypt,
441 }, {
442         .base.cra_name          = "__cbc(aes)",
443         .base.cra_driver_name   = "__cbc-aes-neonbs",
444         .base.cra_priority      = 250,
445         .base.cra_blocksize     = AES_BLOCK_SIZE,
446         .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
447         .base.cra_module        = THIS_MODULE,
448         .base.cra_flags         = CRYPTO_ALG_INTERNAL |
449                                   CRYPTO_ALG_NEED_FALLBACK,
450
451         .min_keysize            = AES_MIN_KEY_SIZE,
452         .max_keysize            = AES_MAX_KEY_SIZE,
453         .walksize               = 8 * AES_BLOCK_SIZE,
454         .ivsize                 = AES_BLOCK_SIZE,
455         .setkey                 = aesbs_cbc_setkey,
456         .encrypt                = cbc_encrypt,
457         .decrypt                = cbc_decrypt,
458         .init                   = cbc_init,
459         .exit                   = cbc_exit,
460 }, {
461         .base.cra_name          = "__ctr(aes)",
462         .base.cra_driver_name   = "__ctr-aes-neonbs",
463         .base.cra_priority      = 250,
464         .base.cra_blocksize     = 1,
465         .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
466         .base.cra_module        = THIS_MODULE,
467         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
468
469         .min_keysize            = AES_MIN_KEY_SIZE,
470         .max_keysize            = AES_MAX_KEY_SIZE,
471         .chunksize              = AES_BLOCK_SIZE,
472         .walksize               = 8 * AES_BLOCK_SIZE,
473         .ivsize                 = AES_BLOCK_SIZE,
474         .setkey                 = aesbs_setkey,
475         .encrypt                = ctr_encrypt,
476         .decrypt                = ctr_encrypt,
477 }, {
478         .base.cra_name          = "ctr(aes)",
479         .base.cra_driver_name   = "ctr-aes-neonbs-sync",
480         .base.cra_priority      = 250 - 1,
481         .base.cra_blocksize     = 1,
482         .base.cra_ctxsize       = sizeof(struct aesbs_ctr_ctx),
483         .base.cra_module        = THIS_MODULE,
484
485         .min_keysize            = AES_MIN_KEY_SIZE,
486         .max_keysize            = AES_MAX_KEY_SIZE,
487         .chunksize              = AES_BLOCK_SIZE,
488         .walksize               = 8 * AES_BLOCK_SIZE,
489         .ivsize                 = AES_BLOCK_SIZE,
490         .setkey                 = aesbs_ctr_setkey_sync,
491         .encrypt                = ctr_encrypt_sync,
492         .decrypt                = ctr_encrypt_sync,
493 }, {
494         .base.cra_name          = "__xts(aes)",
495         .base.cra_driver_name   = "__xts-aes-neonbs",
496         .base.cra_priority      = 250,
497         .base.cra_blocksize     = AES_BLOCK_SIZE,
498         .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
499         .base.cra_module        = THIS_MODULE,
500         .base.cra_flags         = CRYPTO_ALG_INTERNAL,
501
502         .min_keysize            = 2 * AES_MIN_KEY_SIZE,
503         .max_keysize            = 2 * AES_MAX_KEY_SIZE,
504         .walksize               = 8 * AES_BLOCK_SIZE,
505         .ivsize                 = AES_BLOCK_SIZE,
506         .setkey                 = aesbs_xts_setkey,
507         .encrypt                = xts_encrypt,
508         .decrypt                = xts_decrypt,
509         .init                   = xts_init,
510         .exit                   = xts_exit,
511 } };
512
513 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
514
515 static void aes_exit(void)
516 {
517         int i;
518
519         for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
520                 if (aes_simd_algs[i])
521                         simd_skcipher_free(aes_simd_algs[i]);
522
523         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
524 }
525
526 static int __init aes_init(void)
527 {
528         struct simd_skcipher_alg *simd;
529         const char *basename;
530         const char *algname;
531         const char *drvname;
532         int err;
533         int i;
534
535         if (!(elf_hwcap & HWCAP_NEON))
536                 return -ENODEV;
537
538         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
539         if (err)
540                 return err;
541
542         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
543                 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
544                         continue;
545
546                 algname = aes_algs[i].base.cra_name + 2;
547                 drvname = aes_algs[i].base.cra_driver_name + 2;
548                 basename = aes_algs[i].base.cra_driver_name;
549                 simd = simd_skcipher_create_compat(algname, drvname, basename);
550                 err = PTR_ERR(simd);
551                 if (IS_ERR(simd))
552                         goto unregister_simds;
553
554                 aes_simd_algs[i] = simd;
555         }
556         return 0;
557
558 unregister_simds:
559         aes_exit();
560         return err;
561 }
562
563 late_initcall(aes_init);
564 module_exit(aes_exit);