crypto: arm64/sha3-ce - Use API partial block handling
authorHerbert Xu <herbert@gondor.apana.org.au>
Fri, 18 Apr 2025 03:00:11 +0000 (11:00 +0800)
committerHerbert Xu <herbert@gondor.apana.org.au>
Wed, 23 Apr 2025 07:52:46 +0000 (15:52 +0800)
Use the Crypto API partial block handling.

Also remove the unnecessary SIMD fallback path.

Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
arch/arm64/crypto/sha3-ce-glue.c
arch/s390/crypto/sha.h
include/crypto/sha3.h

index 5662c3ac49e91ccdbba250bc7a0835a84621c36e..b4f1001046c9a19dcc7498b81670157af1d85f70 100644 (file)
 #include <asm/hwcap.h>
 #include <asm/neon.h>
 #include <asm/simd.h>
-#include <linux/unaligned.h>
 #include <crypto/internal/hash.h>
-#include <crypto/internal/simd.h>
 #include <crypto/sha3.h>
 #include <linux/cpufeature.h>
-#include <linux/crypto.h>
+#include <linux/kernel.h>
 #include <linux/module.h>
+#include <linux/string.h>
+#include <linux/unaligned.h>
 
 MODULE_DESCRIPTION("SHA3 secure hash using ARMv8 Crypto Extensions");
 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
@@ -35,74 +35,55 @@ static int sha3_update(struct shash_desc *desc, const u8 *data,
                       unsigned int len)
 {
        struct sha3_state *sctx = shash_desc_ctx(desc);
-       unsigned int digest_size = crypto_shash_digestsize(desc->tfm);
-
-       if (!crypto_simd_usable())
-               return crypto_sha3_update(desc, data, len);
-
-       if ((sctx->partial + len) >= sctx->rsiz) {
-               int blocks;
-
-               if (sctx->partial) {
-                       int p = sctx->rsiz - sctx->partial;
-
-                       memcpy(sctx->buf + sctx->partial, data, p);
-                       kernel_neon_begin();
-                       sha3_ce_transform(sctx->st, sctx->buf, 1, digest_size);
-                       kernel_neon_end();
-
-                       data += p;
-                       len -= p;
-                       sctx->partial = 0;
-               }
-
-               blocks = len / sctx->rsiz;
-               len %= sctx->rsiz;
-
-               while (blocks) {
-                       int rem;
-
-                       kernel_neon_begin();
-                       rem = sha3_ce_transform(sctx->st, data, blocks,
-                                               digest_size);
-                       kernel_neon_end();
-                       data += (blocks - rem) * sctx->rsiz;
-                       blocks = rem;
-               }
-       }
-
-       if (len) {
-               memcpy(sctx->buf + sctx->partial, data, len);
-               sctx->partial += len;
-       }
-       return 0;
+       struct crypto_shash *tfm = desc->tfm;
+       unsigned int bs, ds;
+       int blocks;
+
+       ds = crypto_shash_digestsize(tfm);
+       bs = crypto_shash_blocksize(tfm);
+       blocks = len / bs;
+       len -= blocks * bs;
+       do {
+               int rem;
+
+               kernel_neon_begin();
+               rem = sha3_ce_transform(sctx->st, data, blocks, ds);
+               kernel_neon_end();
+               data += (blocks - rem) * bs;
+               blocks = rem;
+       } while (blocks);
+       return len;
 }
 
-static int sha3_final(struct shash_desc *desc, u8 *out)
+static int sha3_finup(struct shash_desc *desc, const u8 *src, unsigned int len,
+                     u8 *out)
 {
        struct sha3_state *sctx = shash_desc_ctx(desc);
-       unsigned int digest_size = crypto_shash_digestsize(desc->tfm);
+       struct crypto_shash *tfm = desc->tfm;
        __le64 *digest = (__le64 *)out;
+       u8 block[SHA3_224_BLOCK_SIZE];
+       unsigned int bs, ds;
        int i;
 
-       if (!crypto_simd_usable())
-               return crypto_sha3_final(desc, out);
+       ds = crypto_shash_digestsize(tfm);
+       bs = crypto_shash_blocksize(tfm);
+       memcpy(block, src, len);
 
-       sctx->buf[sctx->partial++] = 0x06;
-       memset(sctx->buf + sctx->partial, 0, sctx->rsiz - sctx->partial);
-       sctx->buf[sctx->rsiz - 1] |= 0x80;
+       block[len++] = 0x06;
+       memset(block + len, 0, bs - len);
+       block[bs - 1] |= 0x80;
 
        kernel_neon_begin();
-       sha3_ce_transform(sctx->st, sctx->buf, 1, digest_size);
+       sha3_ce_transform(sctx->st, block, 1, ds);
        kernel_neon_end();
+       memzero_explicit(block , sizeof(block));
 
-       for (i = 0; i < digest_size / 8; i++)
+       for (i = 0; i < ds / 8; i++)
                put_unaligned_le64(sctx->st[i], digest++);
 
-       if (digest_size & 4)
+       if (ds & 4)
                put_unaligned_le32(sctx->st[i], (__le32 *)digest);
 
-       memzero_explicit(sctx, sizeof(*sctx));
        return 0;
 }
 
@@ -110,10 +91,11 @@ static struct shash_alg algs[] = { {
        .digestsize             = SHA3_224_DIGEST_SIZE,
        .init                   = crypto_sha3_init,
        .update                 = sha3_update,
-       .final                  = sha3_final,
-       .descsize               = sizeof(struct sha3_state),
+       .finup                  = sha3_finup,
+       .descsize               = SHA3_STATE_SIZE,
        .base.cra_name          = "sha3-224",
        .base.cra_driver_name   = "sha3-224-ce",
+       .base.cra_flags         = CRYPTO_AHASH_ALG_BLOCK_ONLY,
        .base.cra_blocksize     = SHA3_224_BLOCK_SIZE,
        .base.cra_module        = THIS_MODULE,
        .base.cra_priority      = 200,
@@ -121,10 +103,11 @@ static struct shash_alg algs[] = { {
        .digestsize             = SHA3_256_DIGEST_SIZE,
        .init                   = crypto_sha3_init,
        .update                 = sha3_update,
-       .final                  = sha3_final,
-       .descsize               = sizeof(struct sha3_state),
+       .finup                  = sha3_finup,
+       .descsize               = SHA3_STATE_SIZE,
        .base.cra_name          = "sha3-256",
        .base.cra_driver_name   = "sha3-256-ce",
+       .base.cra_flags         = CRYPTO_AHASH_ALG_BLOCK_ONLY,
        .base.cra_blocksize     = SHA3_256_BLOCK_SIZE,
        .base.cra_module        = THIS_MODULE,
        .base.cra_priority      = 200,
@@ -132,10 +115,11 @@ static struct shash_alg algs[] = { {
        .digestsize             = SHA3_384_DIGEST_SIZE,
        .init                   = crypto_sha3_init,
        .update                 = sha3_update,
-       .final                  = sha3_final,
-       .descsize               = sizeof(struct sha3_state),
+       .finup                  = sha3_finup,
+       .descsize               = SHA3_STATE_SIZE,
        .base.cra_name          = "sha3-384",
        .base.cra_driver_name   = "sha3-384-ce",
+       .base.cra_flags         = CRYPTO_AHASH_ALG_BLOCK_ONLY,
        .base.cra_blocksize     = SHA3_384_BLOCK_SIZE,
        .base.cra_module        = THIS_MODULE,
        .base.cra_priority      = 200,
@@ -143,10 +127,11 @@ static struct shash_alg algs[] = { {
        .digestsize             = SHA3_512_DIGEST_SIZE,
        .init                   = crypto_sha3_init,
        .update                 = sha3_update,
-       .final                  = sha3_final,
-       .descsize               = sizeof(struct sha3_state),
+       .finup                  = sha3_finup,
+       .descsize               = SHA3_STATE_SIZE,
        .base.cra_name          = "sha3-512",
        .base.cra_driver_name   = "sha3-512-ce",
+       .base.cra_flags         = CRYPTO_AHASH_ALG_BLOCK_ONLY,
        .base.cra_blocksize     = SHA3_512_BLOCK_SIZE,
        .base.cra_module        = THIS_MODULE,
        .base.cra_priority      = 200,
index b8aeb51b2f3d1de229fe9e5b1c04245965f4feaf..d95437ebe1caef43b3a9d37b711952cc8c8402c1 100644 (file)
@@ -14,7 +14,6 @@
 #include <linux/types.h>
 
 /* must be big enough for the largest SHA variant */
-#define SHA3_STATE_SIZE                        200
 #define CPACF_MAX_PARMBLOCK_SIZE       SHA3_STATE_SIZE
 #define SHA_MAX_BLOCK_SIZE             SHA3_224_BLOCK_SIZE
 #define S390_SHA_CTX_SIZE              offsetof(struct s390_sha_ctx, buf)
index 080f60c2e6b16012bd47d931b5179e263859609c..661f196193cf7490a6b0285c78f8f2d7d46d432d 100644 (file)
 #define SHA3_512_DIGEST_SIZE   (512 / 8)
 #define SHA3_512_BLOCK_SIZE    (200 - 2 * SHA3_512_DIGEST_SIZE)
 
+#define SHA3_STATE_SIZE                200
+
 struct sha3_state {
-       u64             st[25];
+       u64             st[SHA3_STATE_SIZE / 8];
        unsigned int    rsiz;
        unsigned int    rsizw;