linux/arch/arm/crypto/aes-neonbs-glue.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Bit sliced AES using NEON instructions
   4 *
   5 * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
   6 */
   7
   8#include <asm/neon.h>
   9#include <asm/simd.h>
  10#include <crypto/aes.h>
  11#include <crypto/ctr.h>
  12#include <crypto/internal/cipher.h>
  13#include <crypto/internal/simd.h>
  14#include <crypto/internal/skcipher.h>
  15#include <crypto/scatterwalk.h>
  16#include <crypto/xts.h>
  17#include <linux/module.h>
  18
  19MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  20MODULE_LICENSE("GPL v2");
  21
  22MODULE_ALIAS_CRYPTO("ecb(aes)");
  23MODULE_ALIAS_CRYPTO("cbc(aes)-all");
  24MODULE_ALIAS_CRYPTO("ctr(aes)");
  25MODULE_ALIAS_CRYPTO("xts(aes)");
  26
  27MODULE_IMPORT_NS(CRYPTO_INTERNAL);
  28
  29asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
  30
  31asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
  32                                  int rounds, int blocks);
  33asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
  34                                  int rounds, int blocks);
  35
  36asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
  37                                  int rounds, int blocks, u8 iv[]);
  38
  39asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
  40                                  int rounds, int blocks, u8 ctr[], u8 final[]);
  41
  42asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
  43                                  int rounds, int blocks, u8 iv[], int);
  44asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
  45                                  int rounds, int blocks, u8 iv[], int);
  46
  47struct aesbs_ctx {
  48        int     rounds;
  49        u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
  50};
  51
  52struct aesbs_cbc_ctx {
  53        struct aesbs_ctx        key;
  54        struct crypto_skcipher  *enc_tfm;
  55};
  56
  57struct aesbs_xts_ctx {
  58        struct aesbs_ctx        key;
  59        struct crypto_cipher    *cts_tfm;
  60        struct crypto_cipher    *tweak_tfm;
  61};
  62
  63struct aesbs_ctr_ctx {
  64        struct aesbs_ctx        key;            /* must be first member */
  65        struct crypto_aes_ctx   fallback;
  66};
  67
  68static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
  69                        unsigned int key_len)
  70{
  71        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  72        struct crypto_aes_ctx rk;
  73        int err;
  74
  75        err = aes_expandkey(&rk, in_key, key_len);
  76        if (err)
  77                return err;
  78
  79        ctx->rounds = 6 + key_len / 4;
  80
  81        kernel_neon_begin();
  82        aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
  83        kernel_neon_end();
  84
  85        return 0;
  86}
  87
  88static int __ecb_crypt(struct skcipher_request *req,
  89                       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
  90                                  int rounds, int blocks))
  91{
  92        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  93        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  94        struct skcipher_walk walk;
  95        int err;
  96
  97        err = skcipher_walk_virt(&walk, req, false);
  98
  99        while (walk.nbytes >= AES_BLOCK_SIZE) {
 100                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 101
 102                if (walk.nbytes < walk.total)
 103                        blocks = round_down(blocks,
 104                                            walk.stride / AES_BLOCK_SIZE);
 105
 106                kernel_neon_begin();
 107                fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
 108                   ctx->rounds, blocks);
 109                kernel_neon_end();
 110                err = skcipher_walk_done(&walk,
 111                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 112        }
 113
 114        return err;
 115}
 116
 117static int ecb_encrypt(struct skcipher_request *req)
 118{
 119        return __ecb_crypt(req, aesbs_ecb_encrypt);
 120}
 121
 122static int ecb_decrypt(struct skcipher_request *req)
 123{
 124        return __ecb_crypt(req, aesbs_ecb_decrypt);
 125}
 126
 127static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 128                            unsigned int key_len)
 129{
 130        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 131        struct crypto_aes_ctx rk;
 132        int err;
 133
 134        err = aes_expandkey(&rk, in_key, key_len);
 135        if (err)
 136                return err;
 137
 138        ctx->key.rounds = 6 + key_len / 4;
 139
 140        kernel_neon_begin();
 141        aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
 142        kernel_neon_end();
 143        memzero_explicit(&rk, sizeof(rk));
 144
 145        return crypto_skcipher_setkey(ctx->enc_tfm, in_key, key_len);
 146}
 147
 148static int cbc_encrypt(struct skcipher_request *req)
 149{
 150        struct skcipher_request *subreq = skcipher_request_ctx(req);
 151        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 152        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 153
 154        skcipher_request_set_tfm(subreq, ctx->enc_tfm);
 155        skcipher_request_set_callback(subreq,
 156                                      skcipher_request_flags(req),
 157                                      NULL, NULL);
 158        skcipher_request_set_crypt(subreq, req->src, req->dst,
 159                                   req->cryptlen, req->iv);
 160
 161        return crypto_skcipher_encrypt(subreq);
 162}
 163
 164static int cbc_decrypt(struct skcipher_request *req)
 165{
 166        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 167        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 168        struct skcipher_walk walk;
 169        int err;
 170
 171        err = skcipher_walk_virt(&walk, req, false);
 172
 173        while (walk.nbytes >= AES_BLOCK_SIZE) {
 174                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 175
 176                if (walk.nbytes < walk.total)
 177                        blocks = round_down(blocks,
 178                                            walk.stride / AES_BLOCK_SIZE);
 179
 180                kernel_neon_begin();
 181                aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 182                                  ctx->key.rk, ctx->key.rounds, blocks,
 183                                  walk.iv);
 184                kernel_neon_end();
 185                err = skcipher_walk_done(&walk,
 186                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 187        }
 188
 189        return err;
 190}
 191
 192static int cbc_init(struct crypto_skcipher *tfm)
 193{
 194        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 195        unsigned int reqsize;
 196
 197        ctx->enc_tfm = crypto_alloc_skcipher("cbc(aes)", 0, CRYPTO_ALG_ASYNC |
 198                                             CRYPTO_ALG_NEED_FALLBACK);
 199        if (IS_ERR(ctx->enc_tfm))
 200                return PTR_ERR(ctx->enc_tfm);
 201
 202        reqsize = sizeof(struct skcipher_request);
 203        reqsize += crypto_skcipher_reqsize(ctx->enc_tfm);
 204        crypto_skcipher_set_reqsize(tfm, reqsize);
 205
 206        return 0;
 207}
 208
 209static void cbc_exit(struct crypto_skcipher *tfm)
 210{
 211        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 212
 213        crypto_free_skcipher(ctx->enc_tfm);
 214}
 215
 216static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
 217                                 unsigned int key_len)
 218{
 219        struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
 220        int err;
 221
 222        err = aes_expandkey(&ctx->fallback, in_key, key_len);
 223        if (err)
 224                return err;
 225
 226        ctx->key.rounds = 6 + key_len / 4;
 227
 228        kernel_neon_begin();
 229        aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
 230        kernel_neon_end();
 231
 232        return 0;
 233}
 234
 235static int ctr_encrypt(struct skcipher_request *req)
 236{
 237        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 238        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
 239        struct skcipher_walk walk;
 240        u8 buf[AES_BLOCK_SIZE];
 241        int err;
 242
 243        err = skcipher_walk_virt(&walk, req, false);
 244
 245        while (walk.nbytes > 0) {
 246                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 247                u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
 248
 249                if (walk.nbytes < walk.total) {
 250                        blocks = round_down(blocks,
 251                                            walk.stride / AES_BLOCK_SIZE);
 252                        final = NULL;
 253                }
 254
 255                kernel_neon_begin();
 256                aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 257                                  ctx->rk, ctx->rounds, blocks, walk.iv, final);
 258                kernel_neon_end();
 259
 260                if (final) {
 261                        u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
 262                        u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
 263
 264                        crypto_xor_cpy(dst, src, final,
 265                                       walk.total % AES_BLOCK_SIZE);
 266
 267                        err = skcipher_walk_done(&walk, 0);
 268                        break;
 269                }
 270                err = skcipher_walk_done(&walk,
 271                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 272        }
 273
 274        return err;
 275}
 276
 277static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
 278{
 279        struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
 280        unsigned long flags;
 281
 282        /*
 283         * Temporarily disable interrupts to avoid races where
 284         * cachelines are evicted when the CPU is interrupted
 285         * to do something else.
 286         */
 287        local_irq_save(flags);
 288        aes_encrypt(&ctx->fallback, dst, src);
 289        local_irq_restore(flags);
 290}
 291
 292static int ctr_encrypt_sync(struct skcipher_request *req)
 293{
 294        if (!crypto_simd_usable())
 295                return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
 296
 297        return ctr_encrypt(req);
 298}
 299
 300static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 301                            unsigned int key_len)
 302{
 303        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 304        int err;
 305
 306        err = xts_verify_key(tfm, in_key, key_len);
 307        if (err)
 308                return err;
 309
 310        key_len /= 2;
 311        err = crypto_cipher_setkey(ctx->cts_tfm, in_key, key_len);
 312        if (err)
 313                return err;
 314        err = crypto_cipher_setkey(ctx->tweak_tfm, in_key + key_len, key_len);
 315        if (err)
 316                return err;
 317
 318        return aesbs_setkey(tfm, in_key, key_len);
 319}
 320
 321static int xts_init(struct crypto_skcipher *tfm)
 322{
 323        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 324
 325        ctx->cts_tfm = crypto_alloc_cipher("aes", 0, 0);
 326        if (IS_ERR(ctx->cts_tfm))
 327                return PTR_ERR(ctx->cts_tfm);
 328
 329        ctx->tweak_tfm = crypto_alloc_cipher("aes", 0, 0);
 330        if (IS_ERR(ctx->tweak_tfm))
 331                crypto_free_cipher(ctx->cts_tfm);
 332
 333        return PTR_ERR_OR_ZERO(ctx->tweak_tfm);
 334}
 335
 336static void xts_exit(struct crypto_skcipher *tfm)
 337{
 338        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 339
 340        crypto_free_cipher(ctx->tweak_tfm);
 341        crypto_free_cipher(ctx->cts_tfm);
 342}
 343
 344static int __xts_crypt(struct skcipher_request *req, bool encrypt,
 345                       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
 346                                  int rounds, int blocks, u8 iv[], int))
 347{
 348        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 349        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 350        int tail = req->cryptlen % AES_BLOCK_SIZE;
 351        struct skcipher_request subreq;
 352        u8 buf[2 * AES_BLOCK_SIZE];
 353        struct skcipher_walk walk;
 354        int err;
 355
 356        if (req->cryptlen < AES_BLOCK_SIZE)
 357                return -EINVAL;
 358
 359        if (unlikely(tail)) {
 360                skcipher_request_set_tfm(&subreq, tfm);
 361                skcipher_request_set_callback(&subreq,
 362                                              skcipher_request_flags(req),
 363                                              NULL, NULL);
 364                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 365                                           req->cryptlen - tail, req->iv);
 366                req = &subreq;
 367        }
 368
 369        err = skcipher_walk_virt(&walk, req, true);
 370        if (err)
 371                return err;
 372
 373        crypto_cipher_encrypt_one(ctx->tweak_tfm, walk.iv, walk.iv);
 374
 375        while (walk.nbytes >= AES_BLOCK_SIZE) {
 376                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 377                int reorder_last_tweak = !encrypt && tail > 0;
 378
 379                if (walk.nbytes < walk.total) {
 380                        blocks = round_down(blocks,
 381                                            walk.stride / AES_BLOCK_SIZE);
 382                        reorder_last_tweak = 0;
 383                }
 384
 385                kernel_neon_begin();
 386                fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
 387                   ctx->key.rounds, blocks, walk.iv, reorder_last_tweak);
 388                kernel_neon_end();
 389                err = skcipher_walk_done(&walk,
 390                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 391        }
 392
 393        if (err || likely(!tail))
 394                return err;
 395
 396        /* handle ciphertext stealing */
 397        scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
 398                                 AES_BLOCK_SIZE, 0);
 399        memcpy(buf + AES_BLOCK_SIZE, buf, tail);
 400        scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
 401
 402        crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
 403
 404        if (encrypt)
 405                crypto_cipher_encrypt_one(ctx->cts_tfm, buf, buf);
 406        else
 407                crypto_cipher_decrypt_one(ctx->cts_tfm, buf, buf);
 408
 409        crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
 410
 411        scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
 412                                 AES_BLOCK_SIZE + tail, 1);
 413        return 0;
 414}
 415
 416static int xts_encrypt(struct skcipher_request *req)
 417{
 418        return __xts_crypt(req, true, aesbs_xts_encrypt);
 419}
 420
 421static int xts_decrypt(struct skcipher_request *req)
 422{
 423        return __xts_crypt(req, false, aesbs_xts_decrypt);
 424}
 425
 426static struct skcipher_alg aes_algs[] = { {
 427        .base.cra_name          = "__ecb(aes)",
 428        .base.cra_driver_name   = "__ecb-aes-neonbs",
 429        .base.cra_priority      = 250,
 430        .base.cra_blocksize     = AES_BLOCK_SIZE,
 431        .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
 432        .base.cra_module        = THIS_MODULE,
 433        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 434
 435        .min_keysize            = AES_MIN_KEY_SIZE,
 436        .max_keysize            = AES_MAX_KEY_SIZE,
 437        .walksize               = 8 * AES_BLOCK_SIZE,
 438        .setkey                 = aesbs_setkey,
 439        .encrypt                = ecb_encrypt,
 440        .decrypt                = ecb_decrypt,
 441}, {
 442        .base.cra_name          = "__cbc(aes)",
 443        .base.cra_driver_name   = "__cbc-aes-neonbs",
 444        .base.cra_priority      = 250,
 445        .base.cra_blocksize     = AES_BLOCK_SIZE,
 446        .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
 447        .base.cra_module        = THIS_MODULE,
 448        .base.cra_flags         = CRYPTO_ALG_INTERNAL |
 449                                  CRYPTO_ALG_NEED_FALLBACK,
 450
 451        .min_keysize            = AES_MIN_KEY_SIZE,
 452        .max_keysize            = AES_MAX_KEY_SIZE,
 453        .walksize               = 8 * AES_BLOCK_SIZE,
 454        .ivsize                 = AES_BLOCK_SIZE,
 455        .setkey                 = aesbs_cbc_setkey,
 456        .encrypt                = cbc_encrypt,
 457        .decrypt                = cbc_decrypt,
 458        .init                   = cbc_init,
 459        .exit                   = cbc_exit,
 460}, {
 461        .base.cra_name          = "__ctr(aes)",
 462        .base.cra_driver_name   = "__ctr-aes-neonbs",
 463        .base.cra_priority      = 250,
 464        .base.cra_blocksize     = 1,
 465        .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
 466        .base.cra_module        = THIS_MODULE,
 467        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 468
 469        .min_keysize            = AES_MIN_KEY_SIZE,
 470        .max_keysize            = AES_MAX_KEY_SIZE,
 471        .chunksize              = AES_BLOCK_SIZE,
 472        .walksize               = 8 * AES_BLOCK_SIZE,
 473        .ivsize                 = AES_BLOCK_SIZE,
 474        .setkey                 = aesbs_setkey,
 475        .encrypt                = ctr_encrypt,
 476        .decrypt                = ctr_encrypt,
 477}, {
 478        .base.cra_name          = "ctr(aes)",
 479        .base.cra_driver_name   = "ctr-aes-neonbs-sync",
 480        .base.cra_priority      = 250 - 1,
 481        .base.cra_blocksize     = 1,
 482        .base.cra_ctxsize       = sizeof(struct aesbs_ctr_ctx),
 483        .base.cra_module        = THIS_MODULE,
 484
 485        .min_keysize            = AES_MIN_KEY_SIZE,
 486        .max_keysize            = AES_MAX_KEY_SIZE,
 487        .chunksize              = AES_BLOCK_SIZE,
 488        .walksize               = 8 * AES_BLOCK_SIZE,
 489        .ivsize                 = AES_BLOCK_SIZE,
 490        .setkey                 = aesbs_ctr_setkey_sync,
 491        .encrypt                = ctr_encrypt_sync,
 492        .decrypt                = ctr_encrypt_sync,
 493}, {
 494        .base.cra_name          = "__xts(aes)",
 495        .base.cra_driver_name   = "__xts-aes-neonbs",
 496        .base.cra_priority      = 250,
 497        .base.cra_blocksize     = AES_BLOCK_SIZE,
 498        .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
 499        .base.cra_module        = THIS_MODULE,
 500        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 501
 502        .min_keysize            = 2 * AES_MIN_KEY_SIZE,
 503        .max_keysize            = 2 * AES_MAX_KEY_SIZE,
 504        .walksize               = 8 * AES_BLOCK_SIZE,
 505        .ivsize                 = AES_BLOCK_SIZE,
 506        .setkey                 = aesbs_xts_setkey,
 507        .encrypt                = xts_encrypt,
 508        .decrypt                = xts_decrypt,
 509        .init                   = xts_init,
 510        .exit                   = xts_exit,
 511} };
 512
 513static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
 514
 515static void aes_exit(void)
 516{
 517        int i;
 518
 519        for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
 520                if (aes_simd_algs[i])
 521                        simd_skcipher_free(aes_simd_algs[i]);
 522
 523        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 524}
 525
 526static int __init aes_init(void)
 527{
 528        struct simd_skcipher_alg *simd;
 529        const char *basename;
 530        const char *algname;
 531        const char *drvname;
 532        int err;
 533        int i;
 534
 535        if (!(elf_hwcap & HWCAP_NEON))
 536                return -ENODEV;
 537
 538        err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 539        if (err)
 540                return err;
 541
 542        for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
 543                if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
 544                        continue;
 545
 546                algname = aes_algs[i].base.cra_name + 2;
 547                drvname = aes_algs[i].base.cra_driver_name + 2;
 548                basename = aes_algs[i].base.cra_driver_name;
 549                simd = simd_skcipher_create_compat(algname, drvname, basename);
 550                err = PTR_ERR(simd);
 551                if (IS_ERR(simd))
 552                        goto unregister_simds;
 553
 554                aes_simd_algs[i] = simd;
 555        }
 556        return 0;
 557
 558unregister_simds:
 559        aes_exit();
 560        return err;
 561}
 562
 563late_initcall(aes_init);
 564module_exit(aes_exit);
 565