fscrypt: derive dirhash key for casefolded directories
[linux-block.git] / fs / crypto / keysetup.c
index f577bb6613f93f35c644ab500dd3872f4314ff07..74d61d827d91399149317530f7945454c58113b6 100644 (file)
@@ -13,7 +13,7 @@
 
 #include "fscrypt_private.h"
 
-static struct fscrypt_mode available_modes[] = {
+struct fscrypt_mode fscrypt_modes[] = {
        [FSCRYPT_MODE_AES_256_XTS] = {
                .friendly_name = "AES-256-XTS",
                .cipher_str = "xts(aes)",
@@ -51,10 +51,10 @@ select_encryption_mode(const union fscrypt_policy *policy,
                       const struct inode *inode)
 {
        if (S_ISREG(inode->i_mode))
-               return &available_modes[fscrypt_policy_contents_mode(policy)];
+               return &fscrypt_modes[fscrypt_policy_contents_mode(policy)];
 
        if (S_ISDIR(inode->i_mode) || S_ISLNK(inode->i_mode))
-               return &available_modes[fscrypt_policy_fnames_mode(policy)];
+               return &fscrypt_modes[fscrypt_policy_fnames_mode(policy)];
 
        WARN_ONCE(1, "fscrypt: filesystem tried to load encryption info for inode %lu, which is not encryptable (file type %d)\n",
                  inode->i_ino, (inode->i_mode & S_IFMT));
@@ -89,8 +89,11 @@ struct crypto_skcipher *fscrypt_allocate_skcipher(struct fscrypt_mode *mode,
                 * first time a mode is used.
                 */
                pr_info("fscrypt: %s using implementation \"%s\"\n",
-                       mode->friendly_name,
-                       crypto_skcipher_alg(tfm)->base.cra_driver_name);
+                       mode->friendly_name, crypto_skcipher_driver_name(tfm));
+       }
+       if (WARN_ON(crypto_skcipher_ivsize(tfm) != mode->ivsize)) {
+               err = -EINVAL;
+               goto err_free_tfm;
        }
        crypto_skcipher_set_flags(tfm, CRYPTO_TFM_REQ_FORBID_WEAK_KEYS);
        err = crypto_skcipher_setkey(tfm, raw_key, mode->keysize);
@@ -126,7 +129,7 @@ static int setup_per_mode_key(struct fscrypt_info *ci,
        const struct inode *inode = ci->ci_inode;
        const struct super_block *sb = inode->i_sb;
        struct fscrypt_mode *mode = ci->ci_mode;
-       u8 mode_num = mode - available_modes;
+       const u8 mode_num = mode - fscrypt_modes;
        struct crypto_skcipher *tfm, *prev_tfm;
        u8 mode_key[FSCRYPT_MAX_KEY_SIZE];
        u8 hkdf_info[sizeof(mode_num) + sizeof(sb->s_uuid)];
@@ -171,10 +174,24 @@ done:
        return 0;
 }
 
+int fscrypt_derive_dirhash_key(struct fscrypt_info *ci,
+                              const struct fscrypt_master_key *mk)
+{
+       int err;
+
+       err = fscrypt_hkdf_expand(&mk->mk_secret.hkdf, HKDF_CONTEXT_DIRHASH_KEY,
+                                 ci->ci_nonce, FS_KEY_DERIVATION_NONCE_SIZE,
+                                 (u8 *)&ci->ci_dirhash_key,
+                                 sizeof(ci->ci_dirhash_key));
+       if (err)
+               return err;
+       ci->ci_dirhash_key_initialized = true;
+       return 0;
+}
+
 static int fscrypt_setup_v2_file_key(struct fscrypt_info *ci,
                                     struct fscrypt_master_key *mk)
 {
-       u8 derived_key[FSCRYPT_MAX_KEY_SIZE];
        int err;
 
        if (ci->ci_policy.v2.flags & FSCRYPT_POLICY_FLAG_DIRECT_KEY) {
@@ -186,14 +203,8 @@ static int fscrypt_setup_v2_file_key(struct fscrypt_info *ci,
                 * This ensures that the master key is consistently used only
                 * for HKDF, avoiding key reuse issues.
                 */
-               if (!fscrypt_mode_supports_direct_key(ci->ci_mode)) {
-                       fscrypt_warn(ci->ci_inode,
-                                    "Direct key flag not allowed with %s",
-                                    ci->ci_mode->friendly_name);
-                       return -EINVAL;
-               }
-               return setup_per_mode_key(ci, mk, mk->mk_direct_tfms,
-                                         HKDF_CONTEXT_DIRECT_KEY, false);
+               err = setup_per_mode_key(ci, mk, mk->mk_direct_tfms,
+                                        HKDF_CONTEXT_DIRECT_KEY, false);
        } else if (ci->ci_policy.v2.flags &
                   FSCRYPT_POLICY_FLAG_IV_INO_LBLK_64) {
                /*
@@ -202,21 +213,33 @@ static int fscrypt_setup_v2_file_key(struct fscrypt_info *ci,
                 * the IVs.  This format is optimized for use with inline
                 * encryption hardware compliant with the UFS or eMMC standards.
                 */
-               return setup_per_mode_key(ci, mk, mk->mk_iv_ino_lblk_64_tfms,
-                                         HKDF_CONTEXT_IV_INO_LBLK_64_KEY,
-                                         true);
+               err = setup_per_mode_key(ci, mk, mk->mk_iv_ino_lblk_64_tfms,
+                                        HKDF_CONTEXT_IV_INO_LBLK_64_KEY, true);
+       } else {
+               u8 derived_key[FSCRYPT_MAX_KEY_SIZE];
+
+               err = fscrypt_hkdf_expand(&mk->mk_secret.hkdf,
+                                         HKDF_CONTEXT_PER_FILE_KEY,
+                                         ci->ci_nonce,
+                                         FS_KEY_DERIVATION_NONCE_SIZE,
+                                         derived_key, ci->ci_mode->keysize);
+               if (err)
+                       return err;
+
+               err = fscrypt_set_derived_key(ci, derived_key);
+               memzero_explicit(derived_key, ci->ci_mode->keysize);
        }
-
-       err = fscrypt_hkdf_expand(&mk->mk_secret.hkdf,
-                                 HKDF_CONTEXT_PER_FILE_KEY,
-                                 ci->ci_nonce, FS_KEY_DERIVATION_NONCE_SIZE,
-                                 derived_key, ci->ci_mode->keysize);
        if (err)
                return err;
 
-       err = fscrypt_set_derived_key(ci, derived_key);
-       memzero_explicit(derived_key, ci->ci_mode->keysize);
-       return err;
+       /* Derive a secret dirhash key for directories that need it. */
+       if (S_ISDIR(ci->ci_inode->i_mode) && IS_CASEFOLDED(ci->ci_inode)) {
+               err = fscrypt_derive_dirhash_key(ci, mk);
+               if (err)
+                       return err;
+       }
+
+       return 0;
 }
 
 /*