linux/arch/arm/crypto/aes-ce-glue.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * aes-ce-glue.c - wrapper code for ARMv8 AES
   4 *
   5 * Copyright (C) 2015 Linaro Ltd <ard.biesheuvel@linaro.org>
   6 */
   7
   8#include <asm/hwcap.h>
   9#include <asm/neon.h>
  10#include <asm/simd.h>
  11#include <asm/unaligned.h>
  12#include <crypto/aes.h>
  13#include <crypto/ctr.h>
  14#include <crypto/internal/simd.h>
  15#include <crypto/internal/skcipher.h>
  16#include <crypto/scatterwalk.h>
  17#include <linux/cpufeature.h>
  18#include <linux/module.h>
  19#include <crypto/xts.h>
  20
  21MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
  22MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  23MODULE_LICENSE("GPL v2");
  24
  25/* defined in aes-ce-core.S */
  26asmlinkage u32 ce_aes_sub(u32 input);
  27asmlinkage void ce_aes_invert(void *dst, void *src);
  28
  29asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
  30                                   int rounds, int blocks);
  31asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
  32                                   int rounds, int blocks);
  33
  34asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
  35                                   int rounds, int blocks, u8 iv[]);
  36asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
  37                                   int rounds, int blocks, u8 iv[]);
  38asmlinkage void ce_aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
  39                                   int rounds, int bytes, u8 const iv[]);
  40asmlinkage void ce_aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
  41                                   int rounds, int bytes, u8 const iv[]);
  42
  43asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
  44                                   int rounds, int blocks, u8 ctr[]);
  45
  46asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
  47                                   int rounds, int bytes, u8 iv[],
  48                                   u32 const rk2[], int first);
  49asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
  50                                   int rounds, int bytes, u8 iv[],
  51                                   u32 const rk2[], int first);
  52
  53struct aes_block {
  54        u8 b[AES_BLOCK_SIZE];
  55};
  56
  57static int num_rounds(struct crypto_aes_ctx *ctx)
  58{
  59        /*
  60         * # of rounds specified by AES:
  61         * 128 bit key          10 rounds
  62         * 192 bit key          12 rounds
  63         * 256 bit key          14 rounds
  64         * => n byte key        => 6 + (n/4) rounds
  65         */
  66        return 6 + ctx->key_length / 4;
  67}
  68
  69static int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
  70                            unsigned int key_len)
  71{
  72        /*
  73         * The AES key schedule round constants
  74         */
  75        static u8 const rcon[] = {
  76                0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
  77        };
  78
  79        u32 kwords = key_len / sizeof(u32);
  80        struct aes_block *key_enc, *key_dec;
  81        int i, j;
  82
  83        if (key_len != AES_KEYSIZE_128 &&
  84            key_len != AES_KEYSIZE_192 &&
  85            key_len != AES_KEYSIZE_256)
  86                return -EINVAL;
  87
  88        ctx->key_length = key_len;
  89        for (i = 0; i < kwords; i++)
  90                ctx->key_enc[i] = get_unaligned_le32(in_key + i * sizeof(u32));
  91
  92        kernel_neon_begin();
  93        for (i = 0; i < sizeof(rcon); i++) {
  94                u32 *rki = ctx->key_enc + (i * kwords);
  95                u32 *rko = rki + kwords;
  96
  97                rko[0] = ror32(ce_aes_sub(rki[kwords - 1]), 8);
  98                rko[0] = rko[0] ^ rki[0] ^ rcon[i];
  99                rko[1] = rko[0] ^ rki[1];
 100                rko[2] = rko[1] ^ rki[2];
 101                rko[3] = rko[2] ^ rki[3];
 102
 103                if (key_len == AES_KEYSIZE_192) {
 104                        if (i >= 7)
 105                                break;
 106                        rko[4] = rko[3] ^ rki[4];
 107                        rko[5] = rko[4] ^ rki[5];
 108                } else if (key_len == AES_KEYSIZE_256) {
 109                        if (i >= 6)
 110                                break;
 111                        rko[4] = ce_aes_sub(rko[3]) ^ rki[4];
 112                        rko[5] = rko[4] ^ rki[5];
 113                        rko[6] = rko[5] ^ rki[6];
 114                        rko[7] = rko[6] ^ rki[7];
 115                }
 116        }
 117
 118        /*
 119         * Generate the decryption keys for the Equivalent Inverse Cipher.
 120         * This involves reversing the order of the round keys, and applying
 121         * the Inverse Mix Columns transformation on all but the first and
 122         * the last one.
 123         */
 124        key_enc = (struct aes_block *)ctx->key_enc;
 125        key_dec = (struct aes_block *)ctx->key_dec;
 126        j = num_rounds(ctx);
 127
 128        key_dec[0] = key_enc[j];
 129        for (i = 1, j--; j > 0; i++, j--)
 130                ce_aes_invert(key_dec + i, key_enc + j);
 131        key_dec[i] = key_enc[0];
 132
 133        kernel_neon_end();
 134        return 0;
 135}
 136
 137static int ce_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 138                         unsigned int key_len)
 139{
 140        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 141
 142        return ce_aes_expandkey(ctx, in_key, key_len);
 143}
 144
 145struct crypto_aes_xts_ctx {
 146        struct crypto_aes_ctx key1;
 147        struct crypto_aes_ctx __aligned(8) key2;
 148};
 149
 150static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
 151                       unsigned int key_len)
 152{
 153        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 154        int ret;
 155
 156        ret = xts_verify_key(tfm, in_key, key_len);
 157        if (ret)
 158                return ret;
 159
 160        ret = ce_aes_expandkey(&ctx->key1, in_key, key_len / 2);
 161        if (!ret)
 162                ret = ce_aes_expandkey(&ctx->key2, &in_key[key_len / 2],
 163                                       key_len / 2);
 164        return ret;
 165}
 166
 167static int ecb_encrypt(struct skcipher_request *req)
 168{
 169        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 170        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 171        struct skcipher_walk walk;
 172        unsigned int blocks;
 173        int err;
 174
 175        err = skcipher_walk_virt(&walk, req, false);
 176
 177        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 178                kernel_neon_begin();
 179                ce_aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 180                                   ctx->key_enc, num_rounds(ctx), blocks);
 181                kernel_neon_end();
 182                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 183        }
 184        return err;
 185}
 186
 187static int ecb_decrypt(struct skcipher_request *req)
 188{
 189        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 190        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 191        struct skcipher_walk walk;
 192        unsigned int blocks;
 193        int err;
 194
 195        err = skcipher_walk_virt(&walk, req, false);
 196
 197        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 198                kernel_neon_begin();
 199                ce_aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 200                                   ctx->key_dec, num_rounds(ctx), blocks);
 201                kernel_neon_end();
 202                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 203        }
 204        return err;
 205}
 206
 207static int cbc_encrypt_walk(struct skcipher_request *req,
 208                            struct skcipher_walk *walk)
 209{
 210        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 211        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 212        unsigned int blocks;
 213        int err = 0;
 214
 215        while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
 216                kernel_neon_begin();
 217                ce_aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
 218                                   ctx->key_enc, num_rounds(ctx), blocks,
 219                                   walk->iv);
 220                kernel_neon_end();
 221                err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
 222        }
 223        return err;
 224}
 225
 226static int cbc_encrypt(struct skcipher_request *req)
 227{
 228        struct skcipher_walk walk;
 229        int err;
 230
 231        err = skcipher_walk_virt(&walk, req, false);
 232        if (err)
 233                return err;
 234        return cbc_encrypt_walk(req, &walk);
 235}
 236
 237static int cbc_decrypt_walk(struct skcipher_request *req,
 238                            struct skcipher_walk *walk)
 239{
 240        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 241        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 242        unsigned int blocks;
 243        int err = 0;
 244
 245        while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
 246                kernel_neon_begin();
 247                ce_aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
 248                                   ctx->key_dec, num_rounds(ctx), blocks,
 249                                   walk->iv);
 250                kernel_neon_end();
 251                err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
 252        }
 253        return err;
 254}
 255
 256static int cbc_decrypt(struct skcipher_request *req)
 257{
 258        struct skcipher_walk walk;
 259        int err;
 260
 261        err = skcipher_walk_virt(&walk, req, false);
 262        if (err)
 263                return err;
 264        return cbc_decrypt_walk(req, &walk);
 265}
 266
 267static int cts_cbc_encrypt(struct skcipher_request *req)
 268{
 269        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 270        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 271        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 272        struct scatterlist *src = req->src, *dst = req->dst;
 273        struct scatterlist sg_src[2], sg_dst[2];
 274        struct skcipher_request subreq;
 275        struct skcipher_walk walk;
 276        int err;
 277
 278        skcipher_request_set_tfm(&subreq, tfm);
 279        skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
 280                                      NULL, NULL);
 281
 282        if (req->cryptlen <= AES_BLOCK_SIZE) {
 283                if (req->cryptlen < AES_BLOCK_SIZE)
 284                        return -EINVAL;
 285                cbc_blocks = 1;
 286        }
 287
 288        if (cbc_blocks > 0) {
 289                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 290                                           cbc_blocks * AES_BLOCK_SIZE,
 291                                           req->iv);
 292
 293                err = skcipher_walk_virt(&walk, &subreq, false) ?:
 294                      cbc_encrypt_walk(&subreq, &walk);
 295                if (err)
 296                        return err;
 297
 298                if (req->cryptlen == AES_BLOCK_SIZE)
 299                        return 0;
 300
 301                dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
 302                if (req->dst != req->src)
 303                        dst = scatterwalk_ffwd(sg_dst, req->dst,
 304                                               subreq.cryptlen);
 305        }
 306
 307        /* handle ciphertext stealing */
 308        skcipher_request_set_crypt(&subreq, src, dst,
 309                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 310                                   req->iv);
 311
 312        err = skcipher_walk_virt(&walk, &subreq, false);
 313        if (err)
 314                return err;
 315
 316        kernel_neon_begin();
 317        ce_aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 318                               ctx->key_enc, num_rounds(ctx), walk.nbytes,
 319                               walk.iv);
 320        kernel_neon_end();
 321
 322        return skcipher_walk_done(&walk, 0);
 323}
 324
 325static int cts_cbc_decrypt(struct skcipher_request *req)
 326{
 327        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 328        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 329        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 330        struct scatterlist *src = req->src, *dst = req->dst;
 331        struct scatterlist sg_src[2], sg_dst[2];
 332        struct skcipher_request subreq;
 333        struct skcipher_walk walk;
 334        int err;
 335
 336        skcipher_request_set_tfm(&subreq, tfm);
 337        skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
 338                                      NULL, NULL);
 339
 340        if (req->cryptlen <= AES_BLOCK_SIZE) {
 341                if (req->cryptlen < AES_BLOCK_SIZE)
 342                        return -EINVAL;
 343                cbc_blocks = 1;
 344        }
 345
 346        if (cbc_blocks > 0) {
 347                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 348                                           cbc_blocks * AES_BLOCK_SIZE,
 349                                           req->iv);
 350
 351                err = skcipher_walk_virt(&walk, &subreq, false) ?:
 352                      cbc_decrypt_walk(&subreq, &walk);
 353                if (err)
 354                        return err;
 355
 356                if (req->cryptlen == AES_BLOCK_SIZE)
 357                        return 0;
 358
 359                dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
 360                if (req->dst != req->src)
 361                        dst = scatterwalk_ffwd(sg_dst, req->dst,
 362                                               subreq.cryptlen);
 363        }
 364
 365        /* handle ciphertext stealing */
 366        skcipher_request_set_crypt(&subreq, src, dst,
 367                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 368                                   req->iv);
 369
 370        err = skcipher_walk_virt(&walk, &subreq, false);
 371        if (err)
 372                return err;
 373
 374        kernel_neon_begin();
 375        ce_aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 376                               ctx->key_dec, num_rounds(ctx), walk.nbytes,
 377                               walk.iv);
 378        kernel_neon_end();
 379
 380        return skcipher_walk_done(&walk, 0);
 381}
 382
 383static int ctr_encrypt(struct skcipher_request *req)
 384{
 385        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 386        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 387        struct skcipher_walk walk;
 388        int err, blocks;
 389
 390        err = skcipher_walk_virt(&walk, req, false);
 391
 392        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 393                kernel_neon_begin();
 394                ce_aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 395                                   ctx->key_enc, num_rounds(ctx), blocks,
 396                                   walk.iv);
 397                kernel_neon_end();
 398                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 399        }
 400        if (walk.nbytes) {
 401                u8 __aligned(8) tail[AES_BLOCK_SIZE];
 402                unsigned int nbytes = walk.nbytes;
 403                u8 *tdst = walk.dst.virt.addr;
 404                u8 *tsrc = walk.src.virt.addr;
 405
 406                /*
 407                 * Tell aes_ctr_encrypt() to process a tail block.
 408                 */
 409                blocks = -1;
 410
 411                kernel_neon_begin();
 412                ce_aes_ctr_encrypt(tail, NULL, ctx->key_enc, num_rounds(ctx),
 413                                   blocks, walk.iv);
 414                kernel_neon_end();
 415                crypto_xor_cpy(tdst, tsrc, tail, nbytes);
 416                err = skcipher_walk_done(&walk, 0);
 417        }
 418        return err;
 419}
 420
 421static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
 422{
 423        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 424        unsigned long flags;
 425
 426        /*
 427         * Temporarily disable interrupts to avoid races where
 428         * cachelines are evicted when the CPU is interrupted
 429         * to do something else.
 430         */
 431        local_irq_save(flags);
 432        aes_encrypt(ctx, dst, src);
 433        local_irq_restore(flags);
 434}
 435
 436static int ctr_encrypt_sync(struct skcipher_request *req)
 437{
 438        if (!crypto_simd_usable())
 439                return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
 440
 441        return ctr_encrypt(req);
 442}
 443
 444static int xts_encrypt(struct skcipher_request *req)
 445{
 446        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 447        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 448        int err, first, rounds = num_rounds(&ctx->key1);
 449        int tail = req->cryptlen % AES_BLOCK_SIZE;
 450        struct scatterlist sg_src[2], sg_dst[2];
 451        struct skcipher_request subreq;
 452        struct scatterlist *src, *dst;
 453        struct skcipher_walk walk;
 454
 455        if (req->cryptlen < AES_BLOCK_SIZE)
 456                return -EINVAL;
 457
 458        err = skcipher_walk_virt(&walk, req, false);
 459
 460        if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
 461                int xts_blocks = DIV_ROUND_UP(req->cryptlen,
 462                                              AES_BLOCK_SIZE) - 2;
 463
 464                skcipher_walk_abort(&walk);
 465
 466                skcipher_request_set_tfm(&subreq, tfm);
 467                skcipher_request_set_callback(&subreq,
 468                                              skcipher_request_flags(req),
 469                                              NULL, NULL);
 470                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 471                                           xts_blocks * AES_BLOCK_SIZE,
 472                                           req->iv);
 473                req = &subreq;
 474                err = skcipher_walk_virt(&walk, req, false);
 475        } else {
 476                tail = 0;
 477        }
 478
 479        for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
 480                int nbytes = walk.nbytes;
 481
 482                if (walk.nbytes < walk.total)
 483                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 484
 485                kernel_neon_begin();
 486                ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 487                                   ctx->key1.key_enc, rounds, nbytes, walk.iv,
 488                                   ctx->key2.key_enc, first);
 489                kernel_neon_end();
 490                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 491        }
 492
 493        if (err || likely(!tail))
 494                return err;
 495
 496        dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
 497        if (req->dst != req->src)
 498                dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 499
 500        skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
 501                                   req->iv);
 502
 503        err = skcipher_walk_virt(&walk, req, false);
 504        if (err)
 505                return err;
 506
 507        kernel_neon_begin();
 508        ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 509                           ctx->key1.key_enc, rounds, walk.nbytes, walk.iv,
 510                           ctx->key2.key_enc, first);
 511        kernel_neon_end();
 512
 513        return skcipher_walk_done(&walk, 0);
 514}
 515
 516static int xts_decrypt(struct skcipher_request *req)
 517{
 518        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 519        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 520        int err, first, rounds = num_rounds(&ctx->key1);
 521        int tail = req->cryptlen % AES_BLOCK_SIZE;
 522        struct scatterlist sg_src[2], sg_dst[2];
 523        struct skcipher_request subreq;
 524        struct scatterlist *src, *dst;
 525        struct skcipher_walk walk;
 526
 527        if (req->cryptlen < AES_BLOCK_SIZE)
 528                return -EINVAL;
 529
 530        err = skcipher_walk_virt(&walk, req, false);
 531
 532        if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
 533                int xts_blocks = DIV_ROUND_UP(req->cryptlen,
 534                                              AES_BLOCK_SIZE) - 2;
 535
 536                skcipher_walk_abort(&walk);
 537
 538                skcipher_request_set_tfm(&subreq, tfm);
 539                skcipher_request_set_callback(&subreq,
 540                                              skcipher_request_flags(req),
 541                                              NULL, NULL);
 542                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 543                                           xts_blocks * AES_BLOCK_SIZE,
 544                                           req->iv);
 545                req = &subreq;
 546                err = skcipher_walk_virt(&walk, req, false);
 547        } else {
 548                tail = 0;
 549        }
 550
 551        for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
 552                int nbytes = walk.nbytes;
 553
 554                if (walk.nbytes < walk.total)
 555                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 556
 557                kernel_neon_begin();
 558                ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 559                                   ctx->key1.key_dec, rounds, nbytes, walk.iv,
 560                                   ctx->key2.key_enc, first);
 561                kernel_neon_end();
 562                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 563        }
 564
 565        if (err || likely(!tail))
 566                return err;
 567
 568        dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
 569        if (req->dst != req->src)
 570                dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 571
 572        skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
 573                                   req->iv);
 574
 575        err = skcipher_walk_virt(&walk, req, false);
 576        if (err)
 577                return err;
 578
 579        kernel_neon_begin();
 580        ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 581                           ctx->key1.key_dec, rounds, walk.nbytes, walk.iv,
 582                           ctx->key2.key_enc, first);
 583        kernel_neon_end();
 584
 585        return skcipher_walk_done(&walk, 0);
 586}
 587
 588static struct skcipher_alg aes_algs[] = { {
 589        .base.cra_name          = "__ecb(aes)",
 590        .base.cra_driver_name   = "__ecb-aes-ce",
 591        .base.cra_priority      = 300,
 592        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 593        .base.cra_blocksize     = AES_BLOCK_SIZE,
 594        .base.cra_ctxsize       = sizeof(struct crypto_aes_ctx),
 595        .base.cra_module        = THIS_MODULE,
 596
 597        .min_keysize            = AES_MIN_KEY_SIZE,
 598        .max_keysize            = AES_MAX_KEY_SIZE,
 599        .setkey                 = ce_aes_setkey,
 600        .encrypt                = ecb_encrypt,
 601        .decrypt                = ecb_decrypt,
 602}, {
 603        .base.cra_name          = "__cbc(aes)",
 604        .base.cra_driver_name   = "__cbc-aes-ce",
 605        .base.cra_priority      = 300,
 606        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 607        .base.cra_blocksize     = AES_BLOCK_SIZE,
 608        .base.cra_ctxsize       = sizeof(struct crypto_aes_ctx),
 609        .base.cra_module        = THIS_MODULE,
 610
 611        .min_keysize            = AES_MIN_KEY_SIZE,
 612        .max_keysize            = AES_MAX_KEY_SIZE,
 613        .ivsize                 = AES_BLOCK_SIZE,
 614        .setkey                 = ce_aes_setkey,
 615        .encrypt                = cbc_encrypt,
 616        .decrypt                = cbc_decrypt,
 617}, {
 618        .base.cra_name          = "__cts(cbc(aes))",
 619        .base.cra_driver_name   = "__cts-cbc-aes-ce",
 620        .base.cra_priority      = 300,
 621        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 622        .base.cra_blocksize     = AES_BLOCK_SIZE,
 623        .base.cra_ctxsize       = sizeof(struct crypto_aes_ctx),
 624        .base.cra_module        = THIS_MODULE,
 625
 626        .min_keysize            = AES_MIN_KEY_SIZE,
 627        .max_keysize            = AES_MAX_KEY_SIZE,
 628        .ivsize                 = AES_BLOCK_SIZE,
 629        .walksize               = 2 * AES_BLOCK_SIZE,
 630        .setkey                 = ce_aes_setkey,
 631        .encrypt                = cts_cbc_encrypt,
 632        .decrypt                = cts_cbc_decrypt,
 633}, {
 634        .base.cra_name          = "__ctr(aes)",
 635        .base.cra_driver_name   = "__ctr-aes-ce",
 636        .base.cra_priority      = 300,
 637        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 638        .base.cra_blocksize     = 1,
 639        .base.cra_ctxsize       = sizeof(struct crypto_aes_ctx),
 640        .base.cra_module        = THIS_MODULE,
 641
 642        .min_keysize            = AES_MIN_KEY_SIZE,
 643        .max_keysize            = AES_MAX_KEY_SIZE,
 644        .ivsize                 = AES_BLOCK_SIZE,
 645        .chunksize              = AES_BLOCK_SIZE,
 646        .setkey                 = ce_aes_setkey,
 647        .encrypt                = ctr_encrypt,
 648        .decrypt                = ctr_encrypt,
 649}, {
 650        .base.cra_name          = "ctr(aes)",
 651        .base.cra_driver_name   = "ctr-aes-ce-sync",
 652        .base.cra_priority      = 300 - 1,
 653        .base.cra_blocksize     = 1,
 654        .base.cra_ctxsize       = sizeof(struct crypto_aes_ctx),
 655        .base.cra_module        = THIS_MODULE,
 656
 657        .min_keysize            = AES_MIN_KEY_SIZE,
 658        .max_keysize            = AES_MAX_KEY_SIZE,
 659        .ivsize                 = AES_BLOCK_SIZE,
 660        .chunksize              = AES_BLOCK_SIZE,
 661        .setkey                 = ce_aes_setkey,
 662        .encrypt                = ctr_encrypt_sync,
 663        .decrypt                = ctr_encrypt_sync,
 664}, {
 665        .base.cra_name          = "__xts(aes)",
 666        .base.cra_driver_name   = "__xts-aes-ce",
 667        .base.cra_priority      = 300,
 668        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 669        .base.cra_blocksize     = AES_BLOCK_SIZE,
 670        .base.cra_ctxsize       = sizeof(struct crypto_aes_xts_ctx),
 671        .base.cra_module        = THIS_MODULE,
 672
 673        .min_keysize            = 2 * AES_MIN_KEY_SIZE,
 674        .max_keysize            = 2 * AES_MAX_KEY_SIZE,
 675        .ivsize                 = AES_BLOCK_SIZE,
 676        .walksize               = 2 * AES_BLOCK_SIZE,
 677        .setkey                 = xts_set_key,
 678        .encrypt                = xts_encrypt,
 679        .decrypt                = xts_decrypt,
 680} };
 681
 682static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
 683
 684static void aes_exit(void)
 685{
 686        int i;
 687
 688        for (i = 0; i < ARRAY_SIZE(aes_simd_algs) && aes_simd_algs[i]; i++)
 689                simd_skcipher_free(aes_simd_algs[i]);
 690
 691        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 692}
 693
 694static int __init aes_init(void)
 695{
 696        struct simd_skcipher_alg *simd;
 697        const char *basename;
 698        const char *algname;
 699        const char *drvname;
 700        int err;
 701        int i;
 702
 703        err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 704        if (err)
 705                return err;
 706
 707        for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
 708                if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
 709                        continue;
 710
 711                algname = aes_algs[i].base.cra_name + 2;
 712                drvname = aes_algs[i].base.cra_driver_name + 2;
 713                basename = aes_algs[i].base.cra_driver_name;
 714                simd = simd_skcipher_create_compat(algname, drvname, basename);
 715                err = PTR_ERR(simd);
 716                if (IS_ERR(simd))
 717                        goto unregister_simds;
 718
 719                aes_simd_algs[i] = simd;
 720        }
 721
 722        return 0;
 723
 724unregister_simds:
 725        aes_exit();
 726        return err;
 727}
 728
 729module_cpu_feature_match(AES, aes_init);
 730module_exit(aes_exit);
 731