linux/arch/arm/crypto/aes-neonbs-glue.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Bit sliced AES using NEON instructions
   4 *
   5 * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
   6 */
   7
   8#include <asm/neon.h>
   9#include <crypto/aes.h>
  10#include <crypto/cbc.h>
  11#include <crypto/internal/simd.h>
  12#include <crypto/internal/skcipher.h>
  13#include <crypto/xts.h>
  14#include <linux/module.h>
  15
  16MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  17MODULE_LICENSE("GPL v2");
  18
  19MODULE_ALIAS_CRYPTO("ecb(aes)");
  20MODULE_ALIAS_CRYPTO("cbc(aes)");
  21MODULE_ALIAS_CRYPTO("ctr(aes)");
  22MODULE_ALIAS_CRYPTO("xts(aes)");
  23
  24asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
  25
  26asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
  27                                  int rounds, int blocks);
  28asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
  29                                  int rounds, int blocks);
  30
  31asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
  32                                  int rounds, int blocks, u8 iv[]);
  33
  34asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
  35                                  int rounds, int blocks, u8 ctr[], u8 final[]);
  36
  37asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
  38                                  int rounds, int blocks, u8 iv[]);
  39asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
  40                                  int rounds, int blocks, u8 iv[]);
  41
  42struct aesbs_ctx {
  43        int     rounds;
  44        u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
  45};
  46
  47struct aesbs_cbc_ctx {
  48        struct aesbs_ctx        key;
  49        struct crypto_cipher    *enc_tfm;
  50};
  51
  52struct aesbs_xts_ctx {
  53        struct aesbs_ctx        key;
  54        struct crypto_cipher    *tweak_tfm;
  55};
  56
  57static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
  58                        unsigned int key_len)
  59{
  60        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  61        struct crypto_aes_ctx rk;
  62        int err;
  63
  64        err = crypto_aes_expand_key(&rk, in_key, key_len);
  65        if (err)
  66                return err;
  67
  68        ctx->rounds = 6 + key_len / 4;
  69
  70        kernel_neon_begin();
  71        aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
  72        kernel_neon_end();
  73
  74        return 0;
  75}
  76
  77static int __ecb_crypt(struct skcipher_request *req,
  78                       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
  79                                  int rounds, int blocks))
  80{
  81        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  82        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  83        struct skcipher_walk walk;
  84        int err;
  85
  86        err = skcipher_walk_virt(&walk, req, true);
  87
  88        kernel_neon_begin();
  89        while (walk.nbytes >= AES_BLOCK_SIZE) {
  90                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
  91
  92                if (walk.nbytes < walk.total)
  93                        blocks = round_down(blocks,
  94                                            walk.stride / AES_BLOCK_SIZE);
  95
  96                fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
  97                   ctx->rounds, blocks);
  98                err = skcipher_walk_done(&walk,
  99                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 100        }
 101        kernel_neon_end();
 102
 103        return err;
 104}
 105
 106static int ecb_encrypt(struct skcipher_request *req)
 107{
 108        return __ecb_crypt(req, aesbs_ecb_encrypt);
 109}
 110
 111static int ecb_decrypt(struct skcipher_request *req)
 112{
 113        return __ecb_crypt(req, aesbs_ecb_decrypt);
 114}
 115
 116static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 117                            unsigned int key_len)
 118{
 119        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 120        struct crypto_aes_ctx rk;
 121        int err;
 122
 123        err = crypto_aes_expand_key(&rk, in_key, key_len);
 124        if (err)
 125                return err;
 126
 127        ctx->key.rounds = 6 + key_len / 4;
 128
 129        kernel_neon_begin();
 130        aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
 131        kernel_neon_end();
 132
 133        return crypto_cipher_setkey(ctx->enc_tfm, in_key, key_len);
 134}
 135
 136static void cbc_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
 137{
 138        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 139
 140        crypto_cipher_encrypt_one(ctx->enc_tfm, dst, src);
 141}
 142
 143static int cbc_encrypt(struct skcipher_request *req)
 144{
 145        return crypto_cbc_encrypt_walk(req, cbc_encrypt_one);
 146}
 147
 148static int cbc_decrypt(struct skcipher_request *req)
 149{
 150        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 151        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 152        struct skcipher_walk walk;
 153        int err;
 154
 155        err = skcipher_walk_virt(&walk, req, true);
 156
 157        kernel_neon_begin();
 158        while (walk.nbytes >= AES_BLOCK_SIZE) {
 159                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 160
 161                if (walk.nbytes < walk.total)
 162                        blocks = round_down(blocks,
 163                                            walk.stride / AES_BLOCK_SIZE);
 164
 165                aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 166                                  ctx->key.rk, ctx->key.rounds, blocks,
 167                                  walk.iv);
 168                err = skcipher_walk_done(&walk,
 169                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 170        }
 171        kernel_neon_end();
 172
 173        return err;
 174}
 175
 176static int cbc_init(struct crypto_tfm *tfm)
 177{
 178        struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
 179
 180        ctx->enc_tfm = crypto_alloc_cipher("aes", 0, 0);
 181
 182        return PTR_ERR_OR_ZERO(ctx->enc_tfm);
 183}
 184
 185static void cbc_exit(struct crypto_tfm *tfm)
 186{
 187        struct aesbs_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
 188
 189        crypto_free_cipher(ctx->enc_tfm);
 190}
 191
 192static int ctr_encrypt(struct skcipher_request *req)
 193{
 194        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 195        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
 196        struct skcipher_walk walk;
 197        u8 buf[AES_BLOCK_SIZE];
 198        int err;
 199
 200        err = skcipher_walk_virt(&walk, req, true);
 201
 202        kernel_neon_begin();
 203        while (walk.nbytes > 0) {
 204                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 205                u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
 206
 207                if (walk.nbytes < walk.total) {
 208                        blocks = round_down(blocks,
 209                                            walk.stride / AES_BLOCK_SIZE);
 210                        final = NULL;
 211                }
 212
 213                aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 214                                  ctx->rk, ctx->rounds, blocks, walk.iv, final);
 215
 216                if (final) {
 217                        u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
 218                        u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
 219
 220                        crypto_xor_cpy(dst, src, final,
 221                                       walk.total % AES_BLOCK_SIZE);
 222
 223                        err = skcipher_walk_done(&walk, 0);
 224                        break;
 225                }
 226                err = skcipher_walk_done(&walk,
 227                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 228        }
 229        kernel_neon_end();
 230
 231        return err;
 232}
 233
 234static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 235                            unsigned int key_len)
 236{
 237        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 238        int err;
 239
 240        err = xts_verify_key(tfm, in_key, key_len);
 241        if (err)
 242                return err;
 243
 244        key_len /= 2;
 245        err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
 246        if (err)
 247                return err;
 248
 249        return aesbs_setkey(tfm, in_key, key_len);
 250}
 251
 252static int xts_init(struct crypto_tfm *tfm)
 253{
 254        struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
 255
 256        ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
 257
 258        return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
 259}
 260
 261static void xts_exit(struct crypto_tfm *tfm)
 262{
 263        struct aesbs_xts_ctx *ctx = crypto_tfm_ctx(tfm);
 264
 265        crypto_free_cipher(ctx->tweak_tfm);
 266}
 267
 268static int __xts_crypt(struct skcipher_request *req,
 269                       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
 270                                  int rounds, int blocks, u8 iv[]))
 271{
 272        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 273        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 274        struct skcipher_walk walk;
 275        int err;
 276
 277        err = skcipher_walk_virt(&walk, req, true);
 278        if (err)
 279                return err;
 280
 281        crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
 282
 283        kernel_neon_begin();
 284        while (walk.nbytes >= AES_BLOCK_SIZE) {
 285                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 286
 287                if (walk.nbytes < walk.total)
 288                        blocks = round_down(blocks,
 289                                            walk.stride / AES_BLOCK_SIZE);
 290
 291                fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
 292                   ctx->key.rounds, blocks, walk.iv);
 293                err = skcipher_walk_done(&walk,
 294                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 295        }
 296        kernel_neon_end();
 297
 298        return err;
 299}
 300
 301static int xts_encrypt(struct skcipher_request *req)
 302{
 303        return __xts_crypt(req, aesbs_xts_encrypt);
 304}
 305
 306static int xts_decrypt(struct skcipher_request *req)
 307{
 308        return __xts_crypt(req, aesbs_xts_decrypt);
 309}
 310
 311static struct skcipher_alg aes_algs[] = { {
 312        .base.cra_name          = "__ecb(aes)",
 313        .base.cra_driver_name   = "__ecb-aes-neonbs",
 314        .base.cra_priority      = 250,
 315        .base.cra_blocksize     = AES_BLOCK_SIZE,
 316        .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
 317        .base.cra_module        = THIS_MODULE,
 318        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 319
 320        .min_keysize            = AES_MIN_KEY_SIZE,
 321        .max_keysize            = AES_MAX_KEY_SIZE,
 322        .walksize               = 8 * AES_BLOCK_SIZE,
 323        .setkey                 = aesbs_setkey,
 324        .encrypt                = ecb_encrypt,
 325        .decrypt                = ecb_decrypt,
 326}, {
 327        .base.cra_name          = "__cbc(aes)",
 328        .base.cra_driver_name   = "__cbc-aes-neonbs",
 329        .base.cra_priority      = 250,
 330        .base.cra_blocksize     = AES_BLOCK_SIZE,
 331        .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
 332        .base.cra_module        = THIS_MODULE,
 333        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 334        .base.cra_init          = cbc_init,
 335        .base.cra_exit          = cbc_exit,
 336
 337        .min_keysize            = AES_MIN_KEY_SIZE,
 338        .max_keysize            = AES_MAX_KEY_SIZE,
 339        .walksize               = 8 * AES_BLOCK_SIZE,
 340        .ivsize                 = AES_BLOCK_SIZE,
 341        .setkey                 = aesbs_cbc_setkey,
 342        .encrypt                = cbc_encrypt,
 343        .decrypt                = cbc_decrypt,
 344}, {
 345        .base.cra_name          = "__ctr(aes)",
 346        .base.cra_driver_name   = "__ctr-aes-neonbs",
 347        .base.cra_priority      = 250,
 348        .base.cra_blocksize     = 1,
 349        .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
 350        .base.cra_module        = THIS_MODULE,
 351        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 352
 353        .min_keysize            = AES_MIN_KEY_SIZE,
 354        .max_keysize            = AES_MAX_KEY_SIZE,
 355        .chunksize              = AES_BLOCK_SIZE,
 356        .walksize               = 8 * AES_BLOCK_SIZE,
 357        .ivsize                 = AES_BLOCK_SIZE,
 358        .setkey                 = aesbs_setkey,
 359        .encrypt                = ctr_encrypt,
 360        .decrypt                = ctr_encrypt,
 361}, {
 362        .base.cra_name          = "__xts(aes)",
 363        .base.cra_driver_name   = "__xts-aes-neonbs",
 364        .base.cra_priority      = 250,
 365        .base.cra_blocksize     = AES_BLOCK_SIZE,
 366        .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
 367        .base.cra_module        = THIS_MODULE,
 368        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 369        .base.cra_init          = xts_init,
 370        .base.cra_exit          = xts_exit,
 371
 372        .min_keysize            = 2 * AES_MIN_KEY_SIZE,
 373        .max_keysize            = 2 * AES_MAX_KEY_SIZE,
 374        .walksize               = 8 * AES_BLOCK_SIZE,
 375        .ivsize                 = AES_BLOCK_SIZE,
 376        .setkey                 = aesbs_xts_setkey,
 377        .encrypt                = xts_encrypt,
 378        .decrypt                = xts_decrypt,
 379} };
 380
 381static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
 382
 383static void aes_exit(void)
 384{
 385        int i;
 386
 387        for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
 388                if (aes_simd_algs[i])
 389                        simd_skcipher_free(aes_simd_algs[i]);
 390
 391        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 392}
 393
 394static int __init aes_init(void)
 395{
 396        struct simd_skcipher_alg *simd;
 397        const char *basename;
 398        const char *algname;
 399        const char *drvname;
 400        int err;
 401        int i;
 402
 403        if (!(elf_hwcap & HWCAP_NEON))
 404                return -ENODEV;
 405
 406        err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 407        if (err)
 408                return err;
 409
 410        for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
 411                if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
 412                        continue;
 413
 414                algname = aes_algs[i].base.cra_name + 2;
 415                drvname = aes_algs[i].base.cra_driver_name + 2;
 416                basename = aes_algs[i].base.cra_driver_name;
 417                simd = simd_skcipher_create_compat(algname, drvname, basename);
 418                err = PTR_ERR(simd);
 419                if (IS_ERR(simd))
 420                        goto unregister_simds;
 421
 422                aes_simd_algs[i] = simd;
 423        }
 424        return 0;
 425
 426unregister_simds:
 427        aes_exit();
 428        return err;
 429}
 430
 431late_initcall(aes_init);
 432module_exit(aes_exit);
 433