crypto: x86/aes-xts - make non-AVX implementation use new glue code
authorEric Biggers <ebiggers@google.com>
Sun, 7 Apr 2024 21:22:31 +0000 (17:22 -0400)
committerHerbert Xu <herbert@gondor.apana.org.au>
Fri, 12 Apr 2024 07:07:53 +0000 (15:07 +0800)
Make the non-AVX implementation of AES-XTS (xts-aes-aesni) use the new
glue code that was introduced for the AVX implementations of AES-XTS.
This reduces code size, and it improves the performance of xts-aes-aesni
due to the optimization for messages that don't span page boundaries.

This required moving the new glue functions higher up in the file and
allowing the IV encryption function to be specified by the caller.

Signed-off-by: Eric Biggers <ebiggers@google.com>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
arch/x86/crypto/aes-xts-avx-x86_64.S
arch/x86/crypto/aesni-intel_asm.S
arch/x86/crypto/aesni-intel_glue.c

index b8005d0205f89ff36ae7ea5385affc249170cc8f..fcaf64a2f8c6222716e82d3fb6f37a87d6bae72b 100644 (file)
 
 // void aes_xts_encrypt_iv(const struct crypto_aes_ctx *tweak_key,
 //                        u8 iv[AES_BLOCK_SIZE]);
-SYM_FUNC_START(aes_xts_encrypt_iv)
+SYM_TYPED_FUNC_START(aes_xts_encrypt_iv)
        vmovdqu         (%rsi), %xmm0
        vpxor           0*16(%rdi), %xmm0, %xmm0
        vaesenc         1*16(%rdi), %xmm0, %xmm0
index 7ecb55cae3d61c00f515eb886028dcbb58f88bc7..1cb55eea2efaeb8c37ced4578803309965093fa9 100644 (file)
@@ -2843,10 +2843,10 @@ SYM_FUNC_END(aesni_ctr_enc)
        pxor KEY, IV;
 
 /*
- * void aesni_xts_encrypt(const struct crypto_aes_ctx *ctx, u8 *dst,
- *                       const u8 *src, unsigned int len, le128 *iv)
+ * void aesni_xts_enc(const struct crypto_aes_ctx *ctx, u8 *dst,
+ *                   const u8 *src, unsigned int len, le128 *iv)
  */
-SYM_FUNC_START(aesni_xts_encrypt)
+SYM_FUNC_START(aesni_xts_enc)
        FRAME_BEGIN
 #ifndef __x86_64__
        pushl IVP
@@ -2995,13 +2995,13 @@ SYM_FUNC_START(aesni_xts_encrypt)
 
        movups STATE, (OUTP)
        jmp .Lxts_enc_ret
-SYM_FUNC_END(aesni_xts_encrypt)
+SYM_FUNC_END(aesni_xts_enc)
 
 /*
- * void aesni_xts_decrypt(const struct crypto_aes_ctx *ctx, u8 *dst,
- *                       const u8 *src, unsigned int len, le128 *iv)
+ * void aesni_xts_dec(const struct crypto_aes_ctx *ctx, u8 *dst,
+ *                   const u8 *src, unsigned int len, le128 *iv)
  */
-SYM_FUNC_START(aesni_xts_decrypt)
+SYM_FUNC_START(aesni_xts_dec)
        FRAME_BEGIN
 #ifndef __x86_64__
        pushl IVP
@@ -3157,4 +3157,4 @@ SYM_FUNC_START(aesni_xts_decrypt)
 
        movups STATE, (OUTP)
        jmp .Lxts_dec_ret
-SYM_FUNC_END(aesni_xts_decrypt)
+SYM_FUNC_END(aesni_xts_dec)
index 0b37a470325b5c2700faf27ca2a6e2604e35387c..e7d21000cb05dfe373019cb339553e88b5695d38 100644 (file)
@@ -107,11 +107,11 @@ asmlinkage void aesni_cts_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out,
 #define AVX_GEN2_OPTSIZE 640
 #define AVX_GEN4_OPTSIZE 4096
 
-asmlinkage void aesni_xts_encrypt(const struct crypto_aes_ctx *ctx, u8 *out,
-                                 const u8 *in, unsigned int len, u8 *iv);
+asmlinkage void aesni_xts_enc(const struct crypto_aes_ctx *ctx, u8 *out,
+                             const u8 *in, unsigned int len, u8 *iv);
 
-asmlinkage void aesni_xts_decrypt(const struct crypto_aes_ctx *ctx, u8 *out,
-                                 const u8 *in, unsigned int len, u8 *iv);
+asmlinkage void aesni_xts_dec(const struct crypto_aes_ctx *ctx, u8 *out,
+                             const u8 *in, unsigned int len, u8 *iv);
 
 #ifdef CONFIG_X86_64
 
@@ -875,7 +875,7 @@ static int helper_rfc4106_decrypt(struct aead_request *req)
 }
 #endif
 
-static int xts_aesni_setkey(struct crypto_skcipher *tfm, const u8 *key,
+static int xts_setkey_aesni(struct crypto_skcipher *tfm, const u8 *key,
                            unsigned int keylen)
 {
        struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
@@ -896,108 +896,152 @@ static int xts_aesni_setkey(struct crypto_skcipher *tfm, const u8 *key,
        return aes_set_key_common(&ctx->tweak_ctx, key + keylen, keylen);
 }
 
-static int xts_crypt(struct skcipher_request *req, bool encrypt)
+typedef void (*xts_encrypt_iv_func)(const struct crypto_aes_ctx *tweak_key,
+                                   u8 iv[AES_BLOCK_SIZE]);
+typedef void (*xts_crypt_func)(const struct crypto_aes_ctx *key,
+                              const u8 *src, u8 *dst, size_t len,
+                              u8 tweak[AES_BLOCK_SIZE]);
+
+/* This handles cases where the source and/or destination span pages. */
+static noinline int
+xts_crypt_slowpath(struct skcipher_request *req, xts_crypt_func crypt_func)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
-       struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
+       const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
        int tail = req->cryptlen % AES_BLOCK_SIZE;
+       struct scatterlist sg_src[2], sg_dst[2];
        struct skcipher_request subreq;
        struct skcipher_walk walk;
+       struct scatterlist *src, *dst;
        int err;
 
-       if (req->cryptlen < AES_BLOCK_SIZE)
-               return -EINVAL;
-
-       err = skcipher_walk_virt(&walk, req, false);
-       if (!walk.nbytes)
-               return err;
-
-       if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
-               int blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
-
-               skcipher_walk_abort(&walk);
-
+       /*
+        * If the message length isn't divisible by the AES block size, then
+        * separate off the last full block and the partial block.  This ensures
+        * that they are processed in the same call to the assembly function,
+        * which is required for ciphertext stealing.
+        */
+       if (tail) {
                skcipher_request_set_tfm(&subreq, tfm);
                skcipher_request_set_callback(&subreq,
                                              skcipher_request_flags(req),
                                              NULL, NULL);
                skcipher_request_set_crypt(&subreq, req->src, req->dst,
-                                          blocks * AES_BLOCK_SIZE, req->iv);
+                                          req->cryptlen - tail - AES_BLOCK_SIZE,
+                                          req->iv);
                req = &subreq;
-
-               err = skcipher_walk_virt(&walk, req, false);
-               if (!walk.nbytes)
-                       return err;
-       } else {
-               tail = 0;
        }
 
-       kernel_fpu_begin();
-
-       /* calculate first value of T */
-       aesni_enc(&ctx->tweak_ctx, walk.iv, walk.iv);
+       err = skcipher_walk_virt(&walk, req, false);
 
-       while (walk.nbytes > 0) {
-               int nbytes = walk.nbytes;
+       while (walk.nbytes) {
+               unsigned int nbytes = walk.nbytes;
 
                if (nbytes < walk.total)
-                       nbytes &= ~(AES_BLOCK_SIZE - 1);
-
-               if (encrypt)
-                       aesni_xts_encrypt(&ctx->crypt_ctx,
-                                         walk.dst.virt.addr, walk.src.virt.addr,
-                                         nbytes, walk.iv);
-               else
-                       aesni_xts_decrypt(&ctx->crypt_ctx,
-                                         walk.dst.virt.addr, walk.src.virt.addr,
-                                         nbytes, walk.iv);
-               kernel_fpu_end();
+                       nbytes = round_down(nbytes, AES_BLOCK_SIZE);
 
+               kernel_fpu_begin();
+               (*crypt_func)(&ctx->crypt_ctx, walk.src.virt.addr,
+                             walk.dst.virt.addr, nbytes, req->iv);
+               kernel_fpu_end();
                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
-
-               if (walk.nbytes > 0)
-                       kernel_fpu_begin();
        }
 
-       if (unlikely(tail > 0 && !err)) {
-               struct scatterlist sg_src[2], sg_dst[2];
-               struct scatterlist *src, *dst;
+       if (err || !tail)
+               return err;
 
-               dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
-               if (req->dst != req->src)
-                       dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+       /* Do ciphertext stealing with the last full block and partial block. */
 
-               skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
-                                          req->iv);
+       dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+       if (req->dst != req->src)
+               dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 
-               err = skcipher_walk_virt(&walk, &subreq, false);
-               if (err)
-                       return err;
+       skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+                                  req->iv);
 
-               kernel_fpu_begin();
-               if (encrypt)
-                       aesni_xts_encrypt(&ctx->crypt_ctx,
-                                         walk.dst.virt.addr, walk.src.virt.addr,
-                                         walk.nbytes, walk.iv);
-               else
-                       aesni_xts_decrypt(&ctx->crypt_ctx,
-                                         walk.dst.virt.addr, walk.src.virt.addr,
-                                         walk.nbytes, walk.iv);
-               kernel_fpu_end();
+       err = skcipher_walk_virt(&walk, req, false);
+       if (err)
+               return err;
 
-               err = skcipher_walk_done(&walk, 0);
+       kernel_fpu_begin();
+       (*crypt_func)(&ctx->crypt_ctx, walk.src.virt.addr, walk.dst.virt.addr,
+                     walk.nbytes, req->iv);
+       kernel_fpu_end();
+
+       return skcipher_walk_done(&walk, 0);
+}
+
+/* __always_inline to avoid indirect call in fastpath */
+static __always_inline int
+xts_crypt(struct skcipher_request *req, xts_encrypt_iv_func encrypt_iv,
+         xts_crypt_func crypt_func)
+{
+       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+       const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
+       const unsigned int cryptlen = req->cryptlen;
+       struct scatterlist *src = req->src;
+       struct scatterlist *dst = req->dst;
+
+       if (unlikely(cryptlen < AES_BLOCK_SIZE))
+               return -EINVAL;
+
+       kernel_fpu_begin();
+       (*encrypt_iv)(&ctx->tweak_ctx, req->iv);
+
+       /*
+        * In practice, virtually all XTS plaintexts and ciphertexts are either
+        * 512 or 4096 bytes, aligned such that they don't span page boundaries.
+        * To optimize the performance of these cases, and also any other case
+        * where no page boundary is spanned, the below fast-path handles
+        * single-page sources and destinations as efficiently as possible.
+        */
+       if (likely(src->length >= cryptlen && dst->length >= cryptlen &&
+                  src->offset + cryptlen <= PAGE_SIZE &&
+                  dst->offset + cryptlen <= PAGE_SIZE)) {
+               struct page *src_page = sg_page(src);
+               struct page *dst_page = sg_page(dst);
+               void *src_virt = kmap_local_page(src_page) + src->offset;
+               void *dst_virt = kmap_local_page(dst_page) + dst->offset;
+
+               (*crypt_func)(&ctx->crypt_ctx, src_virt, dst_virt, cryptlen,
+                             req->iv);
+               kunmap_local(dst_virt);
+               kunmap_local(src_virt);
+               kernel_fpu_end();
+               return 0;
        }
-       return err;
+       kernel_fpu_end();
+       return xts_crypt_slowpath(req, crypt_func);
+}
+
+static void aesni_xts_encrypt_iv(const struct crypto_aes_ctx *tweak_key,
+                                u8 iv[AES_BLOCK_SIZE])
+{
+       aesni_enc(tweak_key, iv, iv);
+}
+
+static void aesni_xts_encrypt(const struct crypto_aes_ctx *key,
+                             const u8 *src, u8 *dst, size_t len,
+                             u8 tweak[AES_BLOCK_SIZE])
+{
+       aesni_xts_enc(key, dst, src, len, tweak);
 }
 
-static int xts_encrypt(struct skcipher_request *req)
+static void aesni_xts_decrypt(const struct crypto_aes_ctx *key,
+                             const u8 *src, u8 *dst, size_t len,
+                             u8 tweak[AES_BLOCK_SIZE])
 {
-       return xts_crypt(req, true);
+       aesni_xts_dec(key, dst, src, len, tweak);
 }
 
-static int xts_decrypt(struct skcipher_request *req)
+static int xts_encrypt_aesni(struct skcipher_request *req)
 {
-       return xts_crypt(req, false);
+       return xts_crypt(req, aesni_xts_encrypt_iv, aesni_xts_encrypt);
+}
+
+static int xts_decrypt_aesni(struct skcipher_request *req)
+{
+       return xts_crypt(req, aesni_xts_encrypt_iv, aesni_xts_decrypt);
 }
 
 static struct crypto_alg aesni_cipher_alg = {
@@ -1101,9 +1145,9 @@ static struct skcipher_alg aesni_skciphers[] = {
                .max_keysize    = 2 * AES_MAX_KEY_SIZE,
                .ivsize         = AES_BLOCK_SIZE,
                .walksize       = 2 * AES_BLOCK_SIZE,
-               .setkey         = xts_aesni_setkey,
-               .encrypt        = xts_encrypt,
-               .decrypt        = xts_decrypt,
+               .setkey         = xts_setkey_aesni,
+               .encrypt        = xts_encrypt_aesni,
+               .decrypt        = xts_decrypt_aesni,
        }
 };
 
@@ -1139,121 +1183,6 @@ static struct simd_skcipher_alg *aesni_simd_xctr;
 asmlinkage void aes_xts_encrypt_iv(const struct crypto_aes_ctx *tweak_key,
                                   u8 iv[AES_BLOCK_SIZE]);
 
-typedef void (*xts_asm_func)(const struct crypto_aes_ctx *key,
-                            const u8 *src, u8 *dst, size_t len,
-                            u8 tweak[AES_BLOCK_SIZE]);
-
-/* This handles cases where the source and/or destination span pages. */
-static noinline int
-xts_crypt_slowpath(struct skcipher_request *req, xts_asm_func asm_func)
-{
-       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
-       const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
-       int tail = req->cryptlen % AES_BLOCK_SIZE;
-       struct scatterlist sg_src[2], sg_dst[2];
-       struct skcipher_request subreq;
-       struct skcipher_walk walk;
-       struct scatterlist *src, *dst;
-       int err;
-
-       /*
-        * If the message length isn't divisible by the AES block size, then
-        * separate off the last full block and the partial block.  This ensures
-        * that they are processed in the same call to the assembly function,
-        * which is required for ciphertext stealing.
-        */
-       if (tail) {
-               skcipher_request_set_tfm(&subreq, tfm);
-               skcipher_request_set_callback(&subreq,
-                                             skcipher_request_flags(req),
-                                             NULL, NULL);
-               skcipher_request_set_crypt(&subreq, req->src, req->dst,
-                                          req->cryptlen - tail - AES_BLOCK_SIZE,
-                                          req->iv);
-               req = &subreq;
-       }
-
-       err = skcipher_walk_virt(&walk, req, false);
-
-       while (walk.nbytes) {
-               unsigned int nbytes = walk.nbytes;
-
-               if (nbytes < walk.total)
-                       nbytes = round_down(nbytes, AES_BLOCK_SIZE);
-
-               kernel_fpu_begin();
-               (*asm_func)(&ctx->crypt_ctx, walk.src.virt.addr,
-                           walk.dst.virt.addr, nbytes, req->iv);
-               kernel_fpu_end();
-               err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
-       }
-
-       if (err || !tail)
-               return err;
-
-       /* Do ciphertext stealing with the last full block and partial block. */
-
-       dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
-       if (req->dst != req->src)
-               dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
-
-       skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
-                                  req->iv);
-
-       err = skcipher_walk_virt(&walk, req, false);
-       if (err)
-               return err;
-
-       kernel_fpu_begin();
-       (*asm_func)(&ctx->crypt_ctx, walk.src.virt.addr, walk.dst.virt.addr,
-                   walk.nbytes, req->iv);
-       kernel_fpu_end();
-
-       return skcipher_walk_done(&walk, 0);
-}
-
-/* __always_inline to avoid indirect call in fastpath */
-static __always_inline int
-xts_crypt2(struct skcipher_request *req, xts_asm_func asm_func)
-{
-       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
-       const struct aesni_xts_ctx *ctx = aes_xts_ctx(tfm);
-       const unsigned int cryptlen = req->cryptlen;
-       struct scatterlist *src = req->src;
-       struct scatterlist *dst = req->dst;
-
-       if (unlikely(cryptlen < AES_BLOCK_SIZE))
-               return -EINVAL;
-
-       kernel_fpu_begin();
-       aes_xts_encrypt_iv(&ctx->tweak_ctx, req->iv);
-
-       /*
-        * In practice, virtually all XTS plaintexts and ciphertexts are either
-        * 512 or 4096 bytes, aligned such that they don't span page boundaries.
-        * To optimize the performance of these cases, and also any other case
-        * where no page boundary is spanned, the below fast-path handles
-        * single-page sources and destinations as efficiently as possible.
-        */
-       if (likely(src->length >= cryptlen && dst->length >= cryptlen &&
-                  src->offset + cryptlen <= PAGE_SIZE &&
-                  dst->offset + cryptlen <= PAGE_SIZE)) {
-               struct page *src_page = sg_page(src);
-               struct page *dst_page = sg_page(dst);
-               void *src_virt = kmap_local_page(src_page) + src->offset;
-               void *dst_virt = kmap_local_page(dst_page) + dst->offset;
-
-               (*asm_func)(&ctx->crypt_ctx, src_virt, dst_virt, cryptlen,
-                           req->iv);
-               kunmap_local(dst_virt);
-               kunmap_local(src_virt);
-               kernel_fpu_end();
-               return 0;
-       }
-       kernel_fpu_end();
-       return xts_crypt_slowpath(req, asm_func);
-}
-
 #define DEFINE_XTS_ALG(suffix, driver_name, priority)                         \
                                                                               \
 asmlinkage void aes_xts_encrypt_##suffix(const struct crypto_aes_ctx *key,     \
@@ -1265,12 +1194,12 @@ asmlinkage void aes_xts_decrypt_##suffix(const struct crypto_aes_ctx *key,     \
                                                                               \
 static int xts_encrypt_##suffix(struct skcipher_request *req)                 \
 {                                                                             \
-       return xts_crypt2(req, aes_xts_encrypt_##suffix);                      \
+       return xts_crypt(req, aes_xts_encrypt_iv, aes_xts_encrypt_##suffix);   \
 }                                                                             \
                                                                               \
 static int xts_decrypt_##suffix(struct skcipher_request *req)                 \
 {                                                                             \
-       return xts_crypt2(req, aes_xts_decrypt_##suffix);                      \
+       return xts_crypt(req, aes_xts_encrypt_iv, aes_xts_decrypt_##suffix);   \
 }                                                                             \
                                                                               \
 static struct skcipher_alg aes_xts_alg_##suffix = {                           \
@@ -1287,7 +1216,7 @@ static struct skcipher_alg aes_xts_alg_##suffix = {                              \
        .max_keysize    = 2 * AES_MAX_KEY_SIZE,                                \
        .ivsize         = AES_BLOCK_SIZE,                                      \
        .walksize       = 2 * AES_BLOCK_SIZE,                                  \
-       .setkey         = xts_aesni_setkey,                                    \
+       .setkey         = xts_setkey_aesni,                                    \
        .encrypt        = xts_encrypt_##suffix,                                \
        .decrypt        = xts_decrypt_##suffix,                                \
 };                                                                            \