Merge branch 'address-masking'
[linux-block.git] / crypto / lskcipher.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Linear symmetric key cipher operations.
4  *
5  * Generic encrypt/decrypt wrapper for ciphers.
6  *
7  * Copyright (c) 2023 Herbert Xu <herbert@gondor.apana.org.au>
8  */
9
10 #include <linux/cryptouser.h>
11 #include <linux/err.h>
12 #include <linux/export.h>
13 #include <linux/kernel.h>
14 #include <linux/seq_file.h>
15 #include <linux/slab.h>
16 #include <linux/string.h>
17 #include <net/netlink.h>
18 #include "skcipher.h"
19
20 static inline struct crypto_lskcipher *__crypto_lskcipher_cast(
21         struct crypto_tfm *tfm)
22 {
23         return container_of(tfm, struct crypto_lskcipher, base);
24 }
25
26 static inline struct lskcipher_alg *__crypto_lskcipher_alg(
27         struct crypto_alg *alg)
28 {
29         return container_of(alg, struct lskcipher_alg, co.base);
30 }
31
32 static int lskcipher_setkey_unaligned(struct crypto_lskcipher *tfm,
33                                       const u8 *key, unsigned int keylen)
34 {
35         unsigned long alignmask = crypto_lskcipher_alignmask(tfm);
36         struct lskcipher_alg *cipher = crypto_lskcipher_alg(tfm);
37         u8 *buffer, *alignbuffer;
38         unsigned long absize;
39         int ret;
40
41         absize = keylen + alignmask;
42         buffer = kmalloc(absize, GFP_ATOMIC);
43         if (!buffer)
44                 return -ENOMEM;
45
46         alignbuffer = (u8 *)ALIGN((unsigned long)buffer, alignmask + 1);
47         memcpy(alignbuffer, key, keylen);
48         ret = cipher->setkey(tfm, alignbuffer, keylen);
49         kfree_sensitive(buffer);
50         return ret;
51 }
52
53 int crypto_lskcipher_setkey(struct crypto_lskcipher *tfm, const u8 *key,
54                             unsigned int keylen)
55 {
56         unsigned long alignmask = crypto_lskcipher_alignmask(tfm);
57         struct lskcipher_alg *cipher = crypto_lskcipher_alg(tfm);
58
59         if (keylen < cipher->co.min_keysize || keylen > cipher->co.max_keysize)
60                 return -EINVAL;
61
62         if ((unsigned long)key & alignmask)
63                 return lskcipher_setkey_unaligned(tfm, key, keylen);
64         else
65                 return cipher->setkey(tfm, key, keylen);
66 }
67 EXPORT_SYMBOL_GPL(crypto_lskcipher_setkey);
68
69 static int crypto_lskcipher_crypt_unaligned(
70         struct crypto_lskcipher *tfm, const u8 *src, u8 *dst, unsigned len,
71         u8 *iv, int (*crypt)(struct crypto_lskcipher *tfm, const u8 *src,
72                              u8 *dst, unsigned len, u8 *iv, u32 flags))
73 {
74         unsigned statesize = crypto_lskcipher_statesize(tfm);
75         unsigned ivsize = crypto_lskcipher_ivsize(tfm);
76         unsigned bs = crypto_lskcipher_blocksize(tfm);
77         unsigned cs = crypto_lskcipher_chunksize(tfm);
78         int err;
79         u8 *tiv;
80         u8 *p;
81
82         BUILD_BUG_ON(MAX_CIPHER_BLOCKSIZE > PAGE_SIZE ||
83                      MAX_CIPHER_ALIGNMASK >= PAGE_SIZE);
84
85         tiv = kmalloc(PAGE_SIZE, GFP_ATOMIC);
86         if (!tiv)
87                 return -ENOMEM;
88
89         memcpy(tiv, iv, ivsize + statesize);
90
91         p = kmalloc(PAGE_SIZE, GFP_ATOMIC);
92         err = -ENOMEM;
93         if (!p)
94                 goto out;
95
96         while (len >= bs) {
97                 unsigned chunk = min((unsigned)PAGE_SIZE, len);
98                 int err;
99
100                 if (chunk > cs)
101                         chunk &= ~(cs - 1);
102
103                 memcpy(p, src, chunk);
104                 err = crypt(tfm, p, p, chunk, tiv, CRYPTO_LSKCIPHER_FLAG_FINAL);
105                 if (err)
106                         goto out;
107
108                 memcpy(dst, p, chunk);
109                 src += chunk;
110                 dst += chunk;
111                 len -= chunk;
112         }
113
114         err = len ? -EINVAL : 0;
115
116 out:
117         memcpy(iv, tiv, ivsize + statesize);
118         kfree_sensitive(p);
119         kfree_sensitive(tiv);
120         return err;
121 }
122
123 static int crypto_lskcipher_crypt(struct crypto_lskcipher *tfm, const u8 *src,
124                                   u8 *dst, unsigned len, u8 *iv,
125                                   int (*crypt)(struct crypto_lskcipher *tfm,
126                                                const u8 *src, u8 *dst,
127                                                unsigned len, u8 *iv,
128                                                u32 flags))
129 {
130         unsigned long alignmask = crypto_lskcipher_alignmask(tfm);
131
132         if (((unsigned long)src | (unsigned long)dst | (unsigned long)iv) &
133             alignmask)
134                 return crypto_lskcipher_crypt_unaligned(tfm, src, dst, len, iv,
135                                                         crypt);
136
137         return crypt(tfm, src, dst, len, iv, CRYPTO_LSKCIPHER_FLAG_FINAL);
138 }
139
140 int crypto_lskcipher_encrypt(struct crypto_lskcipher *tfm, const u8 *src,
141                              u8 *dst, unsigned len, u8 *iv)
142 {
143         struct lskcipher_alg *alg = crypto_lskcipher_alg(tfm);
144
145         return crypto_lskcipher_crypt(tfm, src, dst, len, iv, alg->encrypt);
146 }
147 EXPORT_SYMBOL_GPL(crypto_lskcipher_encrypt);
148
149 int crypto_lskcipher_decrypt(struct crypto_lskcipher *tfm, const u8 *src,
150                              u8 *dst, unsigned len, u8 *iv)
151 {
152         struct lskcipher_alg *alg = crypto_lskcipher_alg(tfm);
153
154         return crypto_lskcipher_crypt(tfm, src, dst, len, iv, alg->decrypt);
155 }
156 EXPORT_SYMBOL_GPL(crypto_lskcipher_decrypt);
157
158 static int crypto_lskcipher_crypt_sg(struct skcipher_request *req,
159                                      int (*crypt)(struct crypto_lskcipher *tfm,
160                                                   const u8 *src, u8 *dst,
161                                                   unsigned len, u8 *ivs,
162                                                   u32 flags))
163 {
164         struct crypto_skcipher *skcipher = crypto_skcipher_reqtfm(req);
165         struct crypto_lskcipher **ctx = crypto_skcipher_ctx(skcipher);
166         u8 *ivs = skcipher_request_ctx(req);
167         struct crypto_lskcipher *tfm = *ctx;
168         struct skcipher_walk walk;
169         unsigned ivsize;
170         u32 flags;
171         int err;
172
173         ivsize = crypto_lskcipher_ivsize(tfm);
174         ivs = PTR_ALIGN(ivs, crypto_skcipher_alignmask(skcipher) + 1);
175         memcpy(ivs, req->iv, ivsize);
176
177         flags = req->base.flags & CRYPTO_TFM_REQ_MAY_SLEEP;
178
179         if (req->base.flags & CRYPTO_SKCIPHER_REQ_CONT)
180                 flags |= CRYPTO_LSKCIPHER_FLAG_CONT;
181
182         if (!(req->base.flags & CRYPTO_SKCIPHER_REQ_NOTFINAL))
183                 flags |= CRYPTO_LSKCIPHER_FLAG_FINAL;
184
185         err = skcipher_walk_virt(&walk, req, false);
186
187         while (walk.nbytes) {
188                 err = crypt(tfm, walk.src.virt.addr, walk.dst.virt.addr,
189                             walk.nbytes, ivs,
190                             flags & ~(walk.nbytes == walk.total ?
191                             0 : CRYPTO_LSKCIPHER_FLAG_FINAL));
192                 err = skcipher_walk_done(&walk, err);
193                 flags |= CRYPTO_LSKCIPHER_FLAG_CONT;
194         }
195
196         memcpy(req->iv, ivs, ivsize);
197
198         return err;
199 }
200
201 int crypto_lskcipher_encrypt_sg(struct skcipher_request *req)
202 {
203         struct crypto_skcipher *skcipher = crypto_skcipher_reqtfm(req);
204         struct crypto_lskcipher **ctx = crypto_skcipher_ctx(skcipher);
205         struct lskcipher_alg *alg = crypto_lskcipher_alg(*ctx);
206
207         return crypto_lskcipher_crypt_sg(req, alg->encrypt);
208 }
209
210 int crypto_lskcipher_decrypt_sg(struct skcipher_request *req)
211 {
212         struct crypto_skcipher *skcipher = crypto_skcipher_reqtfm(req);
213         struct crypto_lskcipher **ctx = crypto_skcipher_ctx(skcipher);
214         struct lskcipher_alg *alg = crypto_lskcipher_alg(*ctx);
215
216         return crypto_lskcipher_crypt_sg(req, alg->decrypt);
217 }
218
219 static void crypto_lskcipher_exit_tfm(struct crypto_tfm *tfm)
220 {
221         struct crypto_lskcipher *skcipher = __crypto_lskcipher_cast(tfm);
222         struct lskcipher_alg *alg = crypto_lskcipher_alg(skcipher);
223
224         alg->exit(skcipher);
225 }
226
227 static int crypto_lskcipher_init_tfm(struct crypto_tfm *tfm)
228 {
229         struct crypto_lskcipher *skcipher = __crypto_lskcipher_cast(tfm);
230         struct lskcipher_alg *alg = crypto_lskcipher_alg(skcipher);
231
232         if (alg->exit)
233                 skcipher->base.exit = crypto_lskcipher_exit_tfm;
234
235         if (alg->init)
236                 return alg->init(skcipher);
237
238         return 0;
239 }
240
241 static void crypto_lskcipher_free_instance(struct crypto_instance *inst)
242 {
243         struct lskcipher_instance *skcipher =
244                 container_of(inst, struct lskcipher_instance, s.base);
245
246         skcipher->free(skcipher);
247 }
248
249 static void __maybe_unused crypto_lskcipher_show(
250         struct seq_file *m, struct crypto_alg *alg)
251 {
252         struct lskcipher_alg *skcipher = __crypto_lskcipher_alg(alg);
253
254         seq_printf(m, "type         : lskcipher\n");
255         seq_printf(m, "blocksize    : %u\n", alg->cra_blocksize);
256         seq_printf(m, "min keysize  : %u\n", skcipher->co.min_keysize);
257         seq_printf(m, "max keysize  : %u\n", skcipher->co.max_keysize);
258         seq_printf(m, "ivsize       : %u\n", skcipher->co.ivsize);
259         seq_printf(m, "chunksize    : %u\n", skcipher->co.chunksize);
260         seq_printf(m, "statesize    : %u\n", skcipher->co.statesize);
261 }
262
263 static int __maybe_unused crypto_lskcipher_report(
264         struct sk_buff *skb, struct crypto_alg *alg)
265 {
266         struct lskcipher_alg *skcipher = __crypto_lskcipher_alg(alg);
267         struct crypto_report_blkcipher rblkcipher;
268
269         memset(&rblkcipher, 0, sizeof(rblkcipher));
270
271         strscpy(rblkcipher.type, "lskcipher", sizeof(rblkcipher.type));
272         strscpy(rblkcipher.geniv, "<none>", sizeof(rblkcipher.geniv));
273
274         rblkcipher.blocksize = alg->cra_blocksize;
275         rblkcipher.min_keysize = skcipher->co.min_keysize;
276         rblkcipher.max_keysize = skcipher->co.max_keysize;
277         rblkcipher.ivsize = skcipher->co.ivsize;
278
279         return nla_put(skb, CRYPTOCFGA_REPORT_BLKCIPHER,
280                        sizeof(rblkcipher), &rblkcipher);
281 }
282
283 static const struct crypto_type crypto_lskcipher_type = {
284         .extsize = crypto_alg_extsize,
285         .init_tfm = crypto_lskcipher_init_tfm,
286         .free = crypto_lskcipher_free_instance,
287 #ifdef CONFIG_PROC_FS
288         .show = crypto_lskcipher_show,
289 #endif
290 #if IS_ENABLED(CONFIG_CRYPTO_USER)
291         .report = crypto_lskcipher_report,
292 #endif
293         .maskclear = ~CRYPTO_ALG_TYPE_MASK,
294         .maskset = CRYPTO_ALG_TYPE_MASK,
295         .type = CRYPTO_ALG_TYPE_LSKCIPHER,
296         .tfmsize = offsetof(struct crypto_lskcipher, base),
297 };
298
299 static void crypto_lskcipher_exit_tfm_sg(struct crypto_tfm *tfm)
300 {
301         struct crypto_lskcipher **ctx = crypto_tfm_ctx(tfm);
302
303         crypto_free_lskcipher(*ctx);
304 }
305
306 int crypto_init_lskcipher_ops_sg(struct crypto_tfm *tfm)
307 {
308         struct crypto_lskcipher **ctx = crypto_tfm_ctx(tfm);
309         struct crypto_alg *calg = tfm->__crt_alg;
310         struct crypto_lskcipher *skcipher;
311
312         if (!crypto_mod_get(calg))
313                 return -EAGAIN;
314
315         skcipher = crypto_create_tfm(calg, &crypto_lskcipher_type);
316         if (IS_ERR(skcipher)) {
317                 crypto_mod_put(calg);
318                 return PTR_ERR(skcipher);
319         }
320
321         *ctx = skcipher;
322         tfm->exit = crypto_lskcipher_exit_tfm_sg;
323
324         return 0;
325 }
326
327 int crypto_grab_lskcipher(struct crypto_lskcipher_spawn *spawn,
328                           struct crypto_instance *inst,
329                           const char *name, u32 type, u32 mask)
330 {
331         spawn->base.frontend = &crypto_lskcipher_type;
332         return crypto_grab_spawn(&spawn->base, inst, name, type, mask);
333 }
334 EXPORT_SYMBOL_GPL(crypto_grab_lskcipher);
335
336 struct crypto_lskcipher *crypto_alloc_lskcipher(const char *alg_name,
337                                                 u32 type, u32 mask)
338 {
339         return crypto_alloc_tfm(alg_name, &crypto_lskcipher_type, type, mask);
340 }
341 EXPORT_SYMBOL_GPL(crypto_alloc_lskcipher);
342
343 static int lskcipher_prepare_alg(struct lskcipher_alg *alg)
344 {
345         struct crypto_alg *base = &alg->co.base;
346         int err;
347
348         err = skcipher_prepare_alg_common(&alg->co);
349         if (err)
350                 return err;
351
352         if (alg->co.chunksize & (alg->co.chunksize - 1))
353                 return -EINVAL;
354
355         base->cra_type = &crypto_lskcipher_type;
356         base->cra_flags |= CRYPTO_ALG_TYPE_LSKCIPHER;
357
358         return 0;
359 }
360
361 int crypto_register_lskcipher(struct lskcipher_alg *alg)
362 {
363         struct crypto_alg *base = &alg->co.base;
364         int err;
365
366         err = lskcipher_prepare_alg(alg);
367         if (err)
368                 return err;
369
370         return crypto_register_alg(base);
371 }
372 EXPORT_SYMBOL_GPL(crypto_register_lskcipher);
373
374 void crypto_unregister_lskcipher(struct lskcipher_alg *alg)
375 {
376         crypto_unregister_alg(&alg->co.base);
377 }
378 EXPORT_SYMBOL_GPL(crypto_unregister_lskcipher);
379
380 int crypto_register_lskciphers(struct lskcipher_alg *algs, int count)
381 {
382         int i, ret;
383
384         for (i = 0; i < count; i++) {
385                 ret = crypto_register_lskcipher(&algs[i]);
386                 if (ret)
387                         goto err;
388         }
389
390         return 0;
391
392 err:
393         for (--i; i >= 0; --i)
394                 crypto_unregister_lskcipher(&algs[i]);
395
396         return ret;
397 }
398 EXPORT_SYMBOL_GPL(crypto_register_lskciphers);
399
400 void crypto_unregister_lskciphers(struct lskcipher_alg *algs, int count)
401 {
402         int i;
403
404         for (i = count - 1; i >= 0; --i)
405                 crypto_unregister_lskcipher(&algs[i]);
406 }
407 EXPORT_SYMBOL_GPL(crypto_unregister_lskciphers);
408
409 int lskcipher_register_instance(struct crypto_template *tmpl,
410                                 struct lskcipher_instance *inst)
411 {
412         int err;
413
414         if (WARN_ON(!inst->free))
415                 return -EINVAL;
416
417         err = lskcipher_prepare_alg(&inst->alg);
418         if (err)
419                 return err;
420
421         return crypto_register_instance(tmpl, lskcipher_crypto_instance(inst));
422 }
423 EXPORT_SYMBOL_GPL(lskcipher_register_instance);
424
425 static int lskcipher_setkey_simple(struct crypto_lskcipher *tfm, const u8 *key,
426                                    unsigned int keylen)
427 {
428         struct crypto_lskcipher *cipher = lskcipher_cipher_simple(tfm);
429
430         crypto_lskcipher_clear_flags(cipher, CRYPTO_TFM_REQ_MASK);
431         crypto_lskcipher_set_flags(cipher, crypto_lskcipher_get_flags(tfm) &
432                                    CRYPTO_TFM_REQ_MASK);
433         return crypto_lskcipher_setkey(cipher, key, keylen);
434 }
435
436 static int lskcipher_init_tfm_simple(struct crypto_lskcipher *tfm)
437 {
438         struct lskcipher_instance *inst = lskcipher_alg_instance(tfm);
439         struct crypto_lskcipher **ctx = crypto_lskcipher_ctx(tfm);
440         struct crypto_lskcipher_spawn *spawn;
441         struct crypto_lskcipher *cipher;
442
443         spawn = lskcipher_instance_ctx(inst);
444         cipher = crypto_spawn_lskcipher(spawn);
445         if (IS_ERR(cipher))
446                 return PTR_ERR(cipher);
447
448         *ctx = cipher;
449         return 0;
450 }
451
452 static void lskcipher_exit_tfm_simple(struct crypto_lskcipher *tfm)
453 {
454         struct crypto_lskcipher **ctx = crypto_lskcipher_ctx(tfm);
455
456         crypto_free_lskcipher(*ctx);
457 }
458
459 static void lskcipher_free_instance_simple(struct lskcipher_instance *inst)
460 {
461         crypto_drop_lskcipher(lskcipher_instance_ctx(inst));
462         kfree(inst);
463 }
464
465 /**
466  * lskcipher_alloc_instance_simple - allocate instance of simple block cipher
467  *
468  * Allocate an lskcipher_instance for a simple block cipher mode of operation,
469  * e.g. cbc or ecb.  The instance context will have just a single crypto_spawn,
470  * that for the underlying cipher.  The {min,max}_keysize, ivsize, blocksize,
471  * alignmask, and priority are set from the underlying cipher but can be
472  * overridden if needed.  The tfm context defaults to
473  * struct crypto_lskcipher *, and default ->setkey(), ->init(), and
474  * ->exit() methods are installed.
475  *
476  * @tmpl: the template being instantiated
477  * @tb: the template parameters
478  *
479  * Return: a pointer to the new instance, or an ERR_PTR().  The caller still
480  *         needs to register the instance.
481  */
482 struct lskcipher_instance *lskcipher_alloc_instance_simple(
483         struct crypto_template *tmpl, struct rtattr **tb)
484 {
485         u32 mask;
486         struct lskcipher_instance *inst;
487         struct crypto_lskcipher_spawn *spawn;
488         char ecb_name[CRYPTO_MAX_ALG_NAME];
489         struct lskcipher_alg *cipher_alg;
490         const char *cipher_name;
491         int err;
492
493         err = crypto_check_attr_type(tb, CRYPTO_ALG_TYPE_LSKCIPHER, &mask);
494         if (err)
495                 return ERR_PTR(err);
496
497         cipher_name = crypto_attr_alg_name(tb[1]);
498         if (IS_ERR(cipher_name))
499                 return ERR_CAST(cipher_name);
500
501         inst = kzalloc(sizeof(*inst) + sizeof(*spawn), GFP_KERNEL);
502         if (!inst)
503                 return ERR_PTR(-ENOMEM);
504
505         spawn = lskcipher_instance_ctx(inst);
506         err = crypto_grab_lskcipher(spawn,
507                                     lskcipher_crypto_instance(inst),
508                                     cipher_name, 0, mask);
509
510         ecb_name[0] = 0;
511         if (err == -ENOENT && !!memcmp(tmpl->name, "ecb", 4)) {
512                 err = -ENAMETOOLONG;
513                 if (snprintf(ecb_name, CRYPTO_MAX_ALG_NAME, "ecb(%s)",
514                              cipher_name) >= CRYPTO_MAX_ALG_NAME)
515                         goto err_free_inst;
516
517                 err = crypto_grab_lskcipher(spawn,
518                                             lskcipher_crypto_instance(inst),
519                                             ecb_name, 0, mask);
520         }
521
522         if (err)
523                 goto err_free_inst;
524
525         cipher_alg = crypto_lskcipher_spawn_alg(spawn);
526
527         err = crypto_inst_setname(lskcipher_crypto_instance(inst), tmpl->name,
528                                   &cipher_alg->co.base);
529         if (err)
530                 goto err_free_inst;
531
532         if (ecb_name[0]) {
533                 int len;
534
535                 err = -EINVAL;
536                 len = strscpy(ecb_name, &cipher_alg->co.base.cra_name[4],
537                               sizeof(ecb_name));
538                 if (len < 2)
539                         goto err_free_inst;
540
541                 if (ecb_name[len - 1] != ')')
542                         goto err_free_inst;
543
544                 ecb_name[len - 1] = 0;
545
546                 err = -ENAMETOOLONG;
547                 if (snprintf(inst->alg.co.base.cra_name, CRYPTO_MAX_ALG_NAME,
548                              "%s(%s)", tmpl->name, ecb_name) >=
549                     CRYPTO_MAX_ALG_NAME)
550                         goto err_free_inst;
551
552                 if (strcmp(ecb_name, cipher_name) &&
553                     snprintf(inst->alg.co.base.cra_driver_name,
554                              CRYPTO_MAX_ALG_NAME,
555                              "%s(%s)", tmpl->name, cipher_name) >=
556                     CRYPTO_MAX_ALG_NAME)
557                         goto err_free_inst;
558         } else {
559                 /* Don't allow nesting. */
560                 err = -ELOOP;
561                 if ((cipher_alg->co.base.cra_flags & CRYPTO_ALG_INSTANCE))
562                         goto err_free_inst;
563         }
564
565         err = -EINVAL;
566         if (cipher_alg->co.ivsize)
567                 goto err_free_inst;
568
569         inst->free = lskcipher_free_instance_simple;
570
571         /* Default algorithm properties, can be overridden */
572         inst->alg.co.base.cra_blocksize = cipher_alg->co.base.cra_blocksize;
573         inst->alg.co.base.cra_alignmask = cipher_alg->co.base.cra_alignmask;
574         inst->alg.co.base.cra_priority = cipher_alg->co.base.cra_priority;
575         inst->alg.co.min_keysize = cipher_alg->co.min_keysize;
576         inst->alg.co.max_keysize = cipher_alg->co.max_keysize;
577         inst->alg.co.ivsize = cipher_alg->co.base.cra_blocksize;
578         inst->alg.co.statesize = cipher_alg->co.statesize;
579
580         /* Use struct crypto_lskcipher * by default, can be overridden */
581         inst->alg.co.base.cra_ctxsize = sizeof(struct crypto_lskcipher *);
582         inst->alg.setkey = lskcipher_setkey_simple;
583         inst->alg.init = lskcipher_init_tfm_simple;
584         inst->alg.exit = lskcipher_exit_tfm_simple;
585
586         return inst;
587
588 err_free_inst:
589         lskcipher_free_instance_simple(inst);
590         return ERR_PTR(err);
591 }
592 EXPORT_SYMBOL_GPL(lskcipher_alloc_instance_simple);