crypto: x86/sha256 - Use API partial block handling
authorHerbert Xu <herbert@gondor.apana.org.au>
Fri, 18 Apr 2025 02:59:41 +0000 (10:59 +0800)
committerHerbert Xu <herbert@gondor.apana.org.au>
Wed, 23 Apr 2025 07:52:36 +0000 (15:52 +0800)
Use the Crypto API partial block handling.

Also remove the unnecessary SIMD fallback path.

Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
arch/x86/crypto/sha256_ssse3_glue.c
include/crypto/sha2.h
include/crypto/sha256_base.h

index 429a3cefbab4e23c53c8b79e3ef5718728896782..367ce4830fa4567b1997233500e7b0345613ff95 100644 (file)
 
 #define pr_fmt(fmt)    KBUILD_MODNAME ": " fmt
 
+#include <asm/cpu_device_id.h>
+#include <asm/fpu/api.h>
 #include <crypto/internal/hash.h>
-#include <crypto/internal/simd.h>
-#include <linux/init.h>
-#include <linux/module.h>
-#include <linux/mm.h>
-#include <linux/types.h>
 #include <crypto/sha2.h>
 #include <crypto/sha256_base.h>
-#include <linux/string.h>
-#include <asm/cpu_device_id.h>
-#include <asm/simd.h>
+#include <linux/kernel.h>
+#include <linux/module.h>
 
-asmlinkage void sha256_transform_ssse3(struct sha256_state *state,
+asmlinkage void sha256_transform_ssse3(struct crypto_sha256_state *state,
                                       const u8 *data, int blocks);
 
 static const struct x86_cpu_id module_cpu_ids[] = {
@@ -54,37 +50,29 @@ static const struct x86_cpu_id module_cpu_ids[] = {
 MODULE_DEVICE_TABLE(x86cpu, module_cpu_ids);
 
 static int _sha256_update(struct shash_desc *desc, const u8 *data,
-                         unsigned int len, sha256_block_fn *sha256_xform)
+                         unsigned int len,
+                         crypto_sha256_block_fn *sha256_xform)
 {
-       struct sha256_state *sctx = shash_desc_ctx(desc);
-
-       if (!crypto_simd_usable() ||
-           (sctx->count % SHA256_BLOCK_SIZE) + len < SHA256_BLOCK_SIZE)
-               return crypto_sha256_update(desc, data, len);
+       int remain;
 
        /*
-        * Make sure struct sha256_state begins directly with the SHA256
+        * Make sure struct crypto_sha256_state begins directly with the SHA256
         * 256-bit internal state, as this is what the asm functions expect.
         */
-       BUILD_BUG_ON(offsetof(struct sha256_state, state) != 0);
+       BUILD_BUG_ON(offsetof(struct crypto_sha256_state, state) != 0);
 
        kernel_fpu_begin();
-       sha256_base_do_update(desc, data, len, sha256_xform);
+       remain = sha256_base_do_update_blocks(desc, data, len, sha256_xform);
        kernel_fpu_end();
 
-       return 0;
+       return remain;
 }
 
 static int sha256_finup(struct shash_desc *desc, const u8 *data,
-             unsigned int len, u8 *out, sha256_block_fn *sha256_xform)
+             unsigned int len, u8 *out, crypto_sha256_block_fn *sha256_xform)
 {
-       if (!crypto_simd_usable())
-               return crypto_sha256_finup(desc, data, len, out);
-
        kernel_fpu_begin();
-       if (len)
-               sha256_base_do_update(desc, data, len, sha256_xform);
-       sha256_base_do_finalize(desc, sha256_xform);
+       sha256_base_do_finup(desc, data, len, sha256_xform);
        kernel_fpu_end();
 
        return sha256_base_finish(desc, out);
@@ -102,12 +90,6 @@ static int sha256_ssse3_finup(struct shash_desc *desc, const u8 *data,
        return sha256_finup(desc, data, len, out, sha256_transform_ssse3);
 }
 
-/* Add padding and return the message digest. */
-static int sha256_ssse3_final(struct shash_desc *desc, u8 *out)
-{
-       return sha256_ssse3_finup(desc, NULL, 0, out);
-}
-
 static int sha256_ssse3_digest(struct shash_desc *desc, const u8 *data,
              unsigned int len, u8 *out)
 {
@@ -119,14 +101,15 @@ static struct shash_alg sha256_ssse3_algs[] = { {
        .digestsize     =       SHA256_DIGEST_SIZE,
        .init           =       sha256_base_init,
        .update         =       sha256_ssse3_update,
-       .final          =       sha256_ssse3_final,
        .finup          =       sha256_ssse3_finup,
        .digest         =       sha256_ssse3_digest,
-       .descsize       =       sizeof(struct sha256_state),
+       .descsize       =       sizeof(struct crypto_sha256_state),
        .base           =       {
                .cra_name       =       "sha256",
                .cra_driver_name =      "sha256-ssse3",
                .cra_priority   =       150,
+               .cra_flags      =       CRYPTO_AHASH_ALG_BLOCK_ONLY |
+                                       CRYPTO_AHASH_ALG_FINUP_MAX,
                .cra_blocksize  =       SHA256_BLOCK_SIZE,
                .cra_module     =       THIS_MODULE,
        }
@@ -134,13 +117,14 @@ static struct shash_alg sha256_ssse3_algs[] = { {
        .digestsize     =       SHA224_DIGEST_SIZE,
        .init           =       sha224_base_init,
        .update         =       sha256_ssse3_update,
-       .final          =       sha256_ssse3_final,
        .finup          =       sha256_ssse3_finup,
-       .descsize       =       sizeof(struct sha256_state),
+       .descsize       =       sizeof(struct crypto_sha256_state),
        .base           =       {
                .cra_name       =       "sha224",
                .cra_driver_name =      "sha224-ssse3",
                .cra_priority   =       150,
+               .cra_flags      =       CRYPTO_AHASH_ALG_BLOCK_ONLY |
+                                       CRYPTO_AHASH_ALG_FINUP_MAX,
                .cra_blocksize  =       SHA224_BLOCK_SIZE,
                .cra_module     =       THIS_MODULE,
        }
@@ -161,7 +145,7 @@ static void unregister_sha256_ssse3(void)
                                ARRAY_SIZE(sha256_ssse3_algs));
 }
 
-asmlinkage void sha256_transform_avx(struct sha256_state *state,
+asmlinkage void sha256_transform_avx(struct crypto_sha256_state *state,
                                     const u8 *data, int blocks);
 
 static int sha256_avx_update(struct shash_desc *desc, const u8 *data,
@@ -176,11 +160,6 @@ static int sha256_avx_finup(struct shash_desc *desc, const u8 *data,
        return sha256_finup(desc, data, len, out, sha256_transform_avx);
 }
 
-static int sha256_avx_final(struct shash_desc *desc, u8 *out)
-{
-       return sha256_avx_finup(desc, NULL, 0, out);
-}
-
 static int sha256_avx_digest(struct shash_desc *desc, const u8 *data,
                      unsigned int len, u8 *out)
 {
@@ -192,14 +171,15 @@ static struct shash_alg sha256_avx_algs[] = { {
        .digestsize     =       SHA256_DIGEST_SIZE,
        .init           =       sha256_base_init,
        .update         =       sha256_avx_update,
-       .final          =       sha256_avx_final,
        .finup          =       sha256_avx_finup,
        .digest         =       sha256_avx_digest,
-       .descsize       =       sizeof(struct sha256_state),
+       .descsize       =       sizeof(struct crypto_sha256_state),
        .base           =       {
                .cra_name       =       "sha256",
                .cra_driver_name =      "sha256-avx",
                .cra_priority   =       160,
+               .cra_flags      =       CRYPTO_AHASH_ALG_BLOCK_ONLY |
+                                       CRYPTO_AHASH_ALG_FINUP_MAX,
                .cra_blocksize  =       SHA256_BLOCK_SIZE,
                .cra_module     =       THIS_MODULE,
        }
@@ -207,13 +187,14 @@ static struct shash_alg sha256_avx_algs[] = { {
        .digestsize     =       SHA224_DIGEST_SIZE,
        .init           =       sha224_base_init,
        .update         =       sha256_avx_update,
-       .final          =       sha256_avx_final,
        .finup          =       sha256_avx_finup,
-       .descsize       =       sizeof(struct sha256_state),
+       .descsize       =       sizeof(struct crypto_sha256_state),
        .base           =       {
                .cra_name       =       "sha224",
                .cra_driver_name =      "sha224-avx",
                .cra_priority   =       160,
+               .cra_flags      =       CRYPTO_AHASH_ALG_BLOCK_ONLY |
+                                       CRYPTO_AHASH_ALG_FINUP_MAX,
                .cra_blocksize  =       SHA224_BLOCK_SIZE,
                .cra_module     =       THIS_MODULE,
        }
@@ -245,7 +226,7 @@ static void unregister_sha256_avx(void)
                                ARRAY_SIZE(sha256_avx_algs));
 }
 
-asmlinkage void sha256_transform_rorx(struct sha256_state *state,
+asmlinkage void sha256_transform_rorx(struct crypto_sha256_state *state,
                                      const u8 *data, int blocks);
 
 static int sha256_avx2_update(struct shash_desc *desc, const u8 *data,
@@ -260,11 +241,6 @@ static int sha256_avx2_finup(struct shash_desc *desc, const u8 *data,
        return sha256_finup(desc, data, len, out, sha256_transform_rorx);
 }
 
-static int sha256_avx2_final(struct shash_desc *desc, u8 *out)
-{
-       return sha256_avx2_finup(desc, NULL, 0, out);
-}
-
 static int sha256_avx2_digest(struct shash_desc *desc, const u8 *data,
                      unsigned int len, u8 *out)
 {
@@ -276,14 +252,15 @@ static struct shash_alg sha256_avx2_algs[] = { {
        .digestsize     =       SHA256_DIGEST_SIZE,
        .init           =       sha256_base_init,
        .update         =       sha256_avx2_update,
-       .final          =       sha256_avx2_final,
        .finup          =       sha256_avx2_finup,
        .digest         =       sha256_avx2_digest,
-       .descsize       =       sizeof(struct sha256_state),
+       .descsize       =       sizeof(struct crypto_sha256_state),
        .base           =       {
                .cra_name       =       "sha256",
                .cra_driver_name =      "sha256-avx2",
                .cra_priority   =       170,
+               .cra_flags      =       CRYPTO_AHASH_ALG_BLOCK_ONLY |
+                                       CRYPTO_AHASH_ALG_FINUP_MAX,
                .cra_blocksize  =       SHA256_BLOCK_SIZE,
                .cra_module     =       THIS_MODULE,
        }
@@ -291,13 +268,14 @@ static struct shash_alg sha256_avx2_algs[] = { {
        .digestsize     =       SHA224_DIGEST_SIZE,
        .init           =       sha224_base_init,
        .update         =       sha256_avx2_update,
-       .final          =       sha256_avx2_final,
        .finup          =       sha256_avx2_finup,
-       .descsize       =       sizeof(struct sha256_state),
+       .descsize       =       sizeof(struct crypto_sha256_state),
        .base           =       {
                .cra_name       =       "sha224",
                .cra_driver_name =      "sha224-avx2",
                .cra_priority   =       170,
+               .cra_flags      =       CRYPTO_AHASH_ALG_BLOCK_ONLY |
+                                       CRYPTO_AHASH_ALG_FINUP_MAX,
                .cra_blocksize  =       SHA224_BLOCK_SIZE,
                .cra_module     =       THIS_MODULE,
        }
@@ -327,7 +305,7 @@ static void unregister_sha256_avx2(void)
                                ARRAY_SIZE(sha256_avx2_algs));
 }
 
-asmlinkage void sha256_ni_transform(struct sha256_state *digest,
+asmlinkage void sha256_ni_transform(struct crypto_sha256_state *digest,
                                    const u8 *data, int rounds);
 
 static int sha256_ni_update(struct shash_desc *desc, const u8 *data,
@@ -342,11 +320,6 @@ static int sha256_ni_finup(struct shash_desc *desc, const u8 *data,
        return sha256_finup(desc, data, len, out, sha256_ni_transform);
 }
 
-static int sha256_ni_final(struct shash_desc *desc, u8 *out)
-{
-       return sha256_ni_finup(desc, NULL, 0, out);
-}
-
 static int sha256_ni_digest(struct shash_desc *desc, const u8 *data,
                      unsigned int len, u8 *out)
 {
@@ -358,14 +331,15 @@ static struct shash_alg sha256_ni_algs[] = { {
        .digestsize     =       SHA256_DIGEST_SIZE,
        .init           =       sha256_base_init,
        .update         =       sha256_ni_update,
-       .final          =       sha256_ni_final,
        .finup          =       sha256_ni_finup,
        .digest         =       sha256_ni_digest,
-       .descsize       =       sizeof(struct sha256_state),
+       .descsize       =       sizeof(struct crypto_sha256_state),
        .base           =       {
                .cra_name       =       "sha256",
                .cra_driver_name =      "sha256-ni",
                .cra_priority   =       250,
+               .cra_flags      =       CRYPTO_AHASH_ALG_BLOCK_ONLY |
+                                       CRYPTO_AHASH_ALG_FINUP_MAX,
                .cra_blocksize  =       SHA256_BLOCK_SIZE,
                .cra_module     =       THIS_MODULE,
        }
@@ -373,13 +347,14 @@ static struct shash_alg sha256_ni_algs[] = { {
        .digestsize     =       SHA224_DIGEST_SIZE,
        .init           =       sha224_base_init,
        .update         =       sha256_ni_update,
-       .final          =       sha256_ni_final,
        .finup          =       sha256_ni_finup,
-       .descsize       =       sizeof(struct sha256_state),
+       .descsize       =       sizeof(struct crypto_sha256_state),
        .base           =       {
                .cra_name       =       "sha224",
                .cra_driver_name =      "sha224-ni",
                .cra_priority   =       250,
+               .cra_flags      =       CRYPTO_AHASH_ALG_BLOCK_ONLY |
+                                       CRYPTO_AHASH_ALG_FINUP_MAX,
                .cra_blocksize  =       SHA224_BLOCK_SIZE,
                .cra_module     =       THIS_MODULE,
        }
index b9e9281d76c94bd07120eca3c13c179197b01cd1..d9b1b99323938519ce17aa7390acaedcdc40d180 100644 (file)
@@ -64,6 +64,11 @@ extern const u8 sha384_zero_message_hash[SHA384_DIGEST_SIZE];
 
 extern const u8 sha512_zero_message_hash[SHA512_DIGEST_SIZE];
 
+struct crypto_sha256_state {
+       u32 state[SHA256_DIGEST_SIZE / 4];
+       u64 count;
+};
+
 struct sha256_state {
        u32 state[SHA256_DIGEST_SIZE / 4];
        u64 count;
index e0418818d63c84bcb6deb3347c1f63dde058963a..727a1b63e1e976c0ada22a4b90df398da1541a66 100644 (file)
@@ -8,15 +8,17 @@
 #ifndef _CRYPTO_SHA256_BASE_H
 #define _CRYPTO_SHA256_BASE_H
 
-#include <asm/byteorder.h>
-#include <linux/unaligned.h>
 #include <crypto/internal/hash.h>
 #include <crypto/sha2.h>
+#include <linux/math.h>
 #include <linux/string.h>
 #include <linux/types.h>
+#include <linux/unaligned.h>
 
 typedef void (sha256_block_fn)(struct sha256_state *sst, u8 const *src,
                               int blocks);
+typedef void (crypto_sha256_block_fn)(struct crypto_sha256_state *sst,
+                                     u8 const *src, int blocks);
 
 static inline int sha224_base_init(struct shash_desc *desc)
 {
@@ -81,6 +83,64 @@ static inline int sha256_base_do_update(struct shash_desc *desc,
        return lib_sha256_base_do_update(sctx, data, len, block_fn);
 }
 
+static inline int lib_sha256_base_do_update_blocks(
+       struct crypto_sha256_state *sctx, const u8 *data, unsigned int len,
+       crypto_sha256_block_fn *block_fn)
+{
+       unsigned int remain = len - round_down(len, SHA256_BLOCK_SIZE);
+
+       sctx->count += len - remain;
+       block_fn(sctx, data, len / SHA256_BLOCK_SIZE);
+       return remain;
+}
+
+static inline int sha256_base_do_update_blocks(
+       struct shash_desc *desc, const u8 *data, unsigned int len,
+       crypto_sha256_block_fn *block_fn)
+{
+       return lib_sha256_base_do_update_blocks(shash_desc_ctx(desc), data,
+                                               len, block_fn);
+}
+
+static inline int lib_sha256_base_do_finup(struct crypto_sha256_state *sctx,
+                                          const u8 *src, unsigned int len,
+                                          crypto_sha256_block_fn *block_fn)
+{
+       unsigned int bit_offset = SHA256_BLOCK_SIZE / 8 - 1;
+       union {
+               __be64 b64[SHA256_BLOCK_SIZE / 4];
+               u8 u8[SHA256_BLOCK_SIZE * 2];
+       } block = {};
+
+       if (len >= bit_offset * 8)
+               bit_offset += SHA256_BLOCK_SIZE / 8;
+       memcpy(&block, src, len);
+       block.u8[len] = 0x80;
+       sctx->count += len;
+       block.b64[bit_offset] = cpu_to_be64(sctx->count << 3);
+       block_fn(sctx, block.u8, (bit_offset + 1) * 8 / SHA256_BLOCK_SIZE);
+       memzero_explicit(&block, sizeof(block));
+
+       return 0;
+}
+
+static inline int sha256_base_do_finup(struct shash_desc *desc,
+                                      const u8 *src, unsigned int len,
+                                      crypto_sha256_block_fn *block_fn)
+{
+       struct crypto_sha256_state *sctx = shash_desc_ctx(desc);
+
+       if (len >= SHA256_BLOCK_SIZE) {
+               int remain;
+
+               remain = lib_sha256_base_do_update_blocks(sctx, src, len,
+                                                         block_fn);
+               src += len - remain;
+               len = remain;
+       }
+       return lib_sha256_base_do_finup(sctx, src, len, block_fn);
+}
+
 static inline int lib_sha256_base_do_finalize(struct sha256_state *sctx,
                                              sha256_block_fn *block_fn)
 {
@@ -111,15 +171,21 @@ static inline int sha256_base_do_finalize(struct shash_desc *desc,
        return lib_sha256_base_do_finalize(sctx, block_fn);
 }
 
-static inline int lib_sha256_base_finish(struct sha256_state *sctx, u8 *out,
-                                        unsigned int digest_size)
+static inline int __sha256_base_finish(u32 state[SHA256_DIGEST_SIZE / 4],
+                                      u8 *out, unsigned int digest_size)
 {
        __be32 *digest = (__be32 *)out;
        int i;
 
        for (i = 0; digest_size > 0; i++, digest_size -= sizeof(__be32))
-               put_unaligned_be32(sctx->state[i], digest++);
+               put_unaligned_be32(state[i], digest++);
+       return 0;
+}
 
+static inline int lib_sha256_base_finish(struct sha256_state *sctx, u8 *out,
+                                        unsigned int digest_size)
+{
+       __sha256_base_finish(sctx->state, out, digest_size);
        memzero_explicit(sctx, sizeof(*sctx));
        return 0;
 }
@@ -127,9 +193,9 @@ static inline int lib_sha256_base_finish(struct sha256_state *sctx, u8 *out,
 static inline int sha256_base_finish(struct shash_desc *desc, u8 *out)
 {
        unsigned int digest_size = crypto_shash_digestsize(desc->tfm);
-       struct sha256_state *sctx = shash_desc_ctx(desc);
+       struct crypto_sha256_state *sctx = shash_desc_ctx(desc);
 
-       return lib_sha256_base_finish(sctx, out, digest_size);
+       return __sha256_base_finish(sctx->state, out, digest_size);
 }
 
 #endif /* _CRYPTO_SHA256_BASE_H */