crypto: riscv/chacha - implement library instead of skcipher
authorEric Biggers <ebiggers@google.com>
Sat, 5 Apr 2025 18:26:01 +0000 (11:26 -0700)
committerHerbert Xu <herbert@gondor.apana.org.au>
Mon, 7 Apr 2025 05:22:28 +0000 (13:22 +0800)
Currently the RISC-V optimized ChaCha20 is only wired up to the
crypto_skcipher API, which makes it unavailable to users of the library
API.  The crypto_skcipher API for ChaCha20 is going to change to be
implemented on top of the library API, so the library API needs to be
supported.  And of course it's needed anyway to serve the library users.

Therefore, change the RISC-V ChaCha20 code to implement the library API
instead of the crypto_skcipher API.

The library functions take the ChaCha state matrix directly (instead of
key and IV) and support both ChaCha20 and ChaCha12.  To make the RISC-V
code work properly for that, change the assembly code to take the state
matrix directly and add a nrounds parameter.

Signed-off-by: Eric Biggers <ebiggers@google.com>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
arch/riscv/crypto/Kconfig
arch/riscv/crypto/chacha-riscv64-glue.c
arch/riscv/crypto/chacha-riscv64-zvkb.S

index c67095a3d66907ce53730463930fc8ba0c023d9b..6392e1e11bc9626dfdef12caa4e9a32919bf1aa8 100644 (file)
@@ -19,14 +19,11 @@ config CRYPTO_AES_RISCV64
          - Zvkg vector crypto extension (XTS)
 
 config CRYPTO_CHACHA_RISCV64
-       tristate "Ciphers: ChaCha"
+       tristate
        depends on 64BIT && RISCV_ISA_V && TOOLCHAIN_HAS_VECTOR_CRYPTO
-       select CRYPTO_SKCIPHER
-       help
-         Length-preserving ciphers: ChaCha20 stream cipher algorithm
-
-         Architecture: riscv64 using:
-         - Zvkb vector crypto extension
+       select CRYPTO_ARCH_HAVE_LIB_CHACHA
+       select CRYPTO_LIB_CHACHA_GENERIC
+       default CRYPTO_LIB_CHACHA_INTERNAL
 
 config CRYPTO_GHASH_RISCV64
        tristate "Hash functions: GHASH"
index 10b46f36375affaaf4336d2144fcd352b53f4cc7..68caef7a3d50b43924421aefedea5be945f33781 100644 (file)
@@ -1,6 +1,6 @@
 // SPDX-License-Identifier: GPL-2.0-only
 /*
- * ChaCha20 using the RISC-V vector crypto extensions
+ * ChaCha stream cipher (RISC-V optimized)
  *
  * Copyright (C) 2023 SiFive, Inc.
  * Author: Jerry Shih <jerry.shih@sifive.com>
@@ -8,94 +8,56 @@
 
 #include <asm/simd.h>
 #include <asm/vector.h>
-#include <crypto/internal/chacha.h>
-#include <crypto/internal/skcipher.h>
+#include <crypto/chacha.h>
+#include <crypto/internal/simd.h>
 #include <linux/linkage.h>
 #include <linux/module.h>
 
-asmlinkage void chacha20_zvkb(const u32 key[8], const u8 *in, u8 *out,
-                             size_t len, const u32 iv[4]);
+static __ro_after_init DEFINE_STATIC_KEY_FALSE(use_zvkb);
 
-static int riscv64_chacha20_crypt(struct skcipher_request *req)
+asmlinkage void chacha_zvkb(u32 state[16], const u8 *in, u8 *out,
+                           size_t nblocks, int nrounds);
+
+void hchacha_block_arch(const u32 *state, u32 *out, int nrounds)
 {
-       u32 iv[CHACHA_IV_SIZE / sizeof(u32)];
-       u8 block_buffer[CHACHA_BLOCK_SIZE];
-       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
-       const struct chacha_ctx *ctx = crypto_skcipher_ctx(tfm);
-       struct skcipher_walk walk;
-       unsigned int nbytes;
-       unsigned int tail_bytes;
-       int err;
+       hchacha_block_generic(state, out, nrounds);
+}
+EXPORT_SYMBOL(hchacha_block_arch);
 
-       iv[0] = get_unaligned_le32(req->iv);
-       iv[1] = get_unaligned_le32(req->iv + 4);
-       iv[2] = get_unaligned_le32(req->iv + 8);
-       iv[3] = get_unaligned_le32(req->iv + 12);
+void chacha_crypt_arch(u32 *state, u8 *dst, const u8 *src, unsigned int bytes,
+                      int nrounds)
+{
+       u8 block_buffer[CHACHA_BLOCK_SIZE];
+       unsigned int full_blocks = bytes / CHACHA_BLOCK_SIZE;
+       unsigned int tail_bytes = bytes % CHACHA_BLOCK_SIZE;
 
-       err = skcipher_walk_virt(&walk, req, false);
-       while (walk.nbytes) {
-               nbytes = walk.nbytes & ~(CHACHA_BLOCK_SIZE - 1);
-               tail_bytes = walk.nbytes & (CHACHA_BLOCK_SIZE - 1);
-               kernel_vector_begin();
-               if (nbytes) {
-                       chacha20_zvkb(ctx->key, walk.src.virt.addr,
-                                     walk.dst.virt.addr, nbytes, iv);
-                       iv[0] += nbytes / CHACHA_BLOCK_SIZE;
-               }
-               if (walk.nbytes == walk.total && tail_bytes > 0) {
-                       memcpy(block_buffer, walk.src.virt.addr + nbytes,
-                              tail_bytes);
-                       chacha20_zvkb(ctx->key, block_buffer, block_buffer,
-                                     CHACHA_BLOCK_SIZE, iv);
-                       memcpy(walk.dst.virt.addr + nbytes, block_buffer,
-                              tail_bytes);
-                       tail_bytes = 0;
-               }
-               kernel_vector_end();
+       if (!static_branch_likely(&use_zvkb) || !crypto_simd_usable())
+               return chacha_crypt_generic(state, dst, src, bytes, nrounds);
 
-               err = skcipher_walk_done(&walk, tail_bytes);
+       kernel_vector_begin();
+       if (full_blocks) {
+               chacha_zvkb(state, src, dst, full_blocks, nrounds);
+               src += full_blocks * CHACHA_BLOCK_SIZE;
+               dst += full_blocks * CHACHA_BLOCK_SIZE;
        }
-
-       return err;
+       if (tail_bytes) {
+               memcpy(block_buffer, src, tail_bytes);
+               chacha_zvkb(state, block_buffer, block_buffer, 1, nrounds);
+               memcpy(dst, block_buffer, tail_bytes);
+       }
+       kernel_vector_end();
 }
-
-static struct skcipher_alg riscv64_chacha_alg = {
-       .setkey = chacha20_setkey,
-       .encrypt = riscv64_chacha20_crypt,
-       .decrypt = riscv64_chacha20_crypt,
-       .min_keysize = CHACHA_KEY_SIZE,
-       .max_keysize = CHACHA_KEY_SIZE,
-       .ivsize = CHACHA_IV_SIZE,
-       .chunksize = CHACHA_BLOCK_SIZE,
-       .walksize = 4 * CHACHA_BLOCK_SIZE,
-       .base = {
-               .cra_blocksize = 1,
-               .cra_ctxsize = sizeof(struct chacha_ctx),
-               .cra_priority = 300,
-               .cra_name = "chacha20",
-               .cra_driver_name = "chacha20-riscv64-zvkb",
-               .cra_module = THIS_MODULE,
-       },
-};
+EXPORT_SYMBOL(chacha_crypt_arch);
 
 static int __init riscv64_chacha_mod_init(void)
 {
        if (riscv_isa_extension_available(NULL, ZVKB) &&
            riscv_vector_vlen() >= 128)
-               return crypto_register_skcipher(&riscv64_chacha_alg);
-
-       return -ENODEV;
-}
-
-static void __exit riscv64_chacha_mod_exit(void)
-{
-       crypto_unregister_skcipher(&riscv64_chacha_alg);
+               static_branch_enable(&use_zvkb);
+       return 0;
 }
-
 module_init(riscv64_chacha_mod_init);
-module_exit(riscv64_chacha_mod_exit);
 
-MODULE_DESCRIPTION("ChaCha20 (RISC-V accelerated)");
+MODULE_DESCRIPTION("ChaCha stream cipher (RISC-V optimized)");
 MODULE_AUTHOR("Jerry Shih <jerry.shih@sifive.com>");
 MODULE_LICENSE("GPL");
-MODULE_ALIAS_CRYPTO("chacha20");
index bf057737ac6935587a878d057fa823b545bc0529..ab4423b3880eaf970d222288be8920f065fbba67 100644 (file)
 .text
 .option arch, +zvkb
 
-#define KEYP           a0
+#define STATEP         a0
 #define INP            a1
 #define OUTP           a2
-#define LEN            a3
-#define IVP            a4
+#define NBLOCKS                a3
+#define NROUNDS                a4
 
 #define CONSTS0                a5
 #define CONSTS1                a6
@@ -59,7 +59,7 @@
 #define TMP            t1
 #define VL             t2
 #define STRIDE         t3
-#define NROUNDS                t4
+#define ROUND_CTR      t4
 #define KEY0           s0
 #define KEY1           s1
 #define KEY2           s2
        vror.vi         \b3, \b3, 32 - 7
 .endm
 
-// void chacha20_zvkb(const u32 key[8], const u8 *in, u8 *out, size_t len,
-//                   const u32 iv[4]);
+// void chacha_zvkb(u32 state[16], const u8 *in, u8 *out, size_t nblocks,
+//                 int nrounds);
 //
-// |len| must be nonzero and a multiple of 64 (CHACHA_BLOCK_SIZE).
-// The counter is treated as 32-bit, following the RFC7539 convention.
-SYM_FUNC_START(chacha20_zvkb)
-       srli            LEN, LEN, 6     // Bytes to blocks
-
+// |nblocks| is the number of 64-byte blocks to process, and must be nonzero.
+//
+// |state| gives the ChaCha state matrix, including the 32-bit counter in
+// state[12] following the RFC7539 convention; note that this differs from the
+// original Salsa20 paper which uses a 64-bit counter in state[12..13].  The
+// updated 32-bit counter is written back to state[12] before returning.
+SYM_FUNC_START(chacha_zvkb)
        addi            sp, sp, -96
        sd              s0, 0(sp)
        sd              s1, 8(sp)
@@ -157,26 +159,26 @@ SYM_FUNC_START(chacha20_zvkb)
        li              STRIDE, 64
 
        // Set up the initial state matrix in scalar registers.
-       li              CONSTS0, 0x61707865     // "expa" little endian
-       li              CONSTS1, 0x3320646e     // "nd 3" little endian
-       li              CONSTS2, 0x79622d32     // "2-by" little endian
-       li              CONSTS3, 0x6b206574     // "te k" little endian
-       lw              KEY0, 0(KEYP)
-       lw              KEY1, 4(KEYP)
-       lw              KEY2, 8(KEYP)
-       lw              KEY3, 12(KEYP)
-       lw              KEY4, 16(KEYP)
-       lw              KEY5, 20(KEYP)
-       lw              KEY6, 24(KEYP)
-       lw              KEY7, 28(KEYP)
-       lw              COUNTER, 0(IVP)
-       lw              NONCE0, 4(IVP)
-       lw              NONCE1, 8(IVP)
-       lw              NONCE2, 12(IVP)
+       lw              CONSTS0, 0(STATEP)
+       lw              CONSTS1, 4(STATEP)
+       lw              CONSTS2, 8(STATEP)
+       lw              CONSTS3, 12(STATEP)
+       lw              KEY0, 16(STATEP)
+       lw              KEY1, 20(STATEP)
+       lw              KEY2, 24(STATEP)
+       lw              KEY3, 28(STATEP)
+       lw              KEY4, 32(STATEP)
+       lw              KEY5, 36(STATEP)
+       lw              KEY6, 40(STATEP)
+       lw              KEY7, 44(STATEP)
+       lw              COUNTER, 48(STATEP)
+       lw              NONCE0, 52(STATEP)
+       lw              NONCE1, 56(STATEP)
+       lw              NONCE2, 60(STATEP)
 
 .Lblock_loop:
        // Set vl to the number of blocks to process in this iteration.
-       vsetvli         VL, LEN, e32, m1, ta, ma
+       vsetvli         VL, NBLOCKS, e32, m1, ta, ma
 
        // Set up the initial state matrix for the next VL blocks in v0-v15.
        // v{i} holds the i'th 32-bit word of the state matrix for all blocks.
@@ -203,16 +205,16 @@ SYM_FUNC_START(chacha20_zvkb)
        // v{16+i} holds the i'th 32-bit word for all blocks.
        vlsseg8e32.v    v16, (INP), STRIDE
 
-       li              NROUNDS, 20
+       mv              ROUND_CTR, NROUNDS
 .Lnext_doubleround:
-       addi            NROUNDS, NROUNDS, -2
+       addi            ROUND_CTR, ROUND_CTR, -2
        // column round
        chacha_round    v0, v4, v8, v12, v1, v5, v9, v13, \
                        v2, v6, v10, v14, v3, v7, v11, v15
        // diagonal round
        chacha_round    v0, v5, v10, v15, v1, v6, v11, v12, \
                        v2, v7, v8, v13, v3, v4, v9, v14
-       bnez            NROUNDS, .Lnext_doubleround
+       bnez            ROUND_CTR, .Lnext_doubleround
 
        // Load the second half of the input data for each block into v24-v31.
        // v{24+i} holds the {8+i}'th 32-bit word for all blocks.
@@ -271,12 +273,13 @@ SYM_FUNC_START(chacha20_zvkb)
        // Update the counter, the remaining number of blocks, and the input and
        // output pointers according to the number of blocks processed (VL).
        add             COUNTER, COUNTER, VL
-       sub             LEN, LEN, VL
+       sub             NBLOCKS, NBLOCKS, VL
        slli            TMP, VL, 6
        add             OUTP, OUTP, TMP
        add             INP, INP, TMP
-       bnez            LEN, .Lblock_loop
+       bnez            NBLOCKS, .Lblock_loop
 
+       sw              COUNTER, 48(STATEP)
        ld              s0, 0(sp)
        ld              s1, 8(sp)
        ld              s2, 16(sp)
@@ -291,4 +294,4 @@ SYM_FUNC_START(chacha20_zvkb)
        ld              s11, 88(sp)
        addi            sp, sp, 96
        ret
-SYM_FUNC_END(chacha20_zvkb)
+SYM_FUNC_END(chacha_zvkb)