linux/arch/arm64/crypto/aes-glue.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
   4 *
   5 * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
   6 */
   7
   8#include <asm/neon.h>
   9#include <asm/hwcap.h>
  10#include <asm/simd.h>
  11#include <crypto/aes.h>
  12#include <crypto/internal/hash.h>
  13#include <crypto/internal/simd.h>
  14#include <crypto/internal/skcipher.h>
  15#include <crypto/scatterwalk.h>
  16#include <linux/module.h>
  17#include <linux/cpufeature.h>
  18#include <crypto/xts.h>
  19
  20#include "aes-ce-setkey.h"
  21#include "aes-ctr-fallback.h"
  22
  23#ifdef USE_V8_CRYPTO_EXTENSIONS
  24#define MODE                    "ce"
  25#define PRIO                    300
  26#define aes_setkey              ce_aes_setkey
  27#define aes_expandkey           ce_aes_expandkey
  28#define aes_ecb_encrypt         ce_aes_ecb_encrypt
  29#define aes_ecb_decrypt         ce_aes_ecb_decrypt
  30#define aes_cbc_encrypt         ce_aes_cbc_encrypt
  31#define aes_cbc_decrypt         ce_aes_cbc_decrypt
  32#define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
  33#define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
  34#define aes_ctr_encrypt         ce_aes_ctr_encrypt
  35#define aes_xts_encrypt         ce_aes_xts_encrypt
  36#define aes_xts_decrypt         ce_aes_xts_decrypt
  37#define aes_mac_update          ce_aes_mac_update
  38MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
  39#else
  40#define MODE                    "neon"
  41#define PRIO                    200
  42#define aes_setkey              crypto_aes_set_key
  43#define aes_expandkey           crypto_aes_expand_key
  44#define aes_ecb_encrypt         neon_aes_ecb_encrypt
  45#define aes_ecb_decrypt         neon_aes_ecb_decrypt
  46#define aes_cbc_encrypt         neon_aes_cbc_encrypt
  47#define aes_cbc_decrypt         neon_aes_cbc_decrypt
  48#define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
  49#define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
  50#define aes_ctr_encrypt         neon_aes_ctr_encrypt
  51#define aes_xts_encrypt         neon_aes_xts_encrypt
  52#define aes_xts_decrypt         neon_aes_xts_decrypt
  53#define aes_mac_update          neon_aes_mac_update
  54MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
  55MODULE_ALIAS_CRYPTO("ecb(aes)");
  56MODULE_ALIAS_CRYPTO("cbc(aes)");
  57MODULE_ALIAS_CRYPTO("ctr(aes)");
  58MODULE_ALIAS_CRYPTO("xts(aes)");
  59MODULE_ALIAS_CRYPTO("cmac(aes)");
  60MODULE_ALIAS_CRYPTO("xcbc(aes)");
  61MODULE_ALIAS_CRYPTO("cbcmac(aes)");
  62#endif
  63
  64MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  65MODULE_LICENSE("GPL v2");
  66
  67/* defined in aes-modes.S */
  68asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
  69                                int rounds, int blocks);
  70asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
  71                                int rounds, int blocks);
  72
  73asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
  74                                int rounds, int blocks, u8 iv[]);
  75asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
  76                                int rounds, int blocks, u8 iv[]);
  77
  78asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
  79                                int rounds, int bytes, u8 const iv[]);
  80asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
  81                                int rounds, int bytes, u8 const iv[]);
  82
  83asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
  84                                int rounds, int blocks, u8 ctr[]);
  85
  86asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
  87                                int rounds, int blocks, u32 const rk2[], u8 iv[],
  88                                int first);
  89asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
  90                                int rounds, int blocks, u32 const rk2[], u8 iv[],
  91                                int first);
  92
  93asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
  94                               int blocks, u8 dg[], int enc_before,
  95                               int enc_after);
  96
  97struct cts_cbc_req_ctx {
  98        struct scatterlist sg_src[2];
  99        struct scatterlist sg_dst[2];
 100        struct skcipher_request subreq;
 101};
 102
 103struct crypto_aes_xts_ctx {
 104        struct crypto_aes_ctx key1;
 105        struct crypto_aes_ctx __aligned(8) key2;
 106};
 107
 108struct mac_tfm_ctx {
 109        struct crypto_aes_ctx key;
 110        u8 __aligned(8) consts[];
 111};
 112
 113struct mac_desc_ctx {
 114        unsigned int len;
 115        u8 dg[AES_BLOCK_SIZE];
 116};
 117
 118static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 119                               unsigned int key_len)
 120{
 121        return aes_setkey(crypto_skcipher_tfm(tfm), in_key, key_len);
 122}
 123
 124static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
 125                       unsigned int key_len)
 126{
 127        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 128        int ret;
 129
 130        ret = xts_verify_key(tfm, in_key, key_len);
 131        if (ret)
 132                return ret;
 133
 134        ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
 135        if (!ret)
 136                ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
 137                                    key_len / 2);
 138        if (!ret)
 139                return 0;
 140
 141        crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
 142        return -EINVAL;
 143}
 144
 145static int ecb_encrypt(struct skcipher_request *req)
 146{
 147        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 148        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 149        int err, rounds = 6 + ctx->key_length / 4;
 150        struct skcipher_walk walk;
 151        unsigned int blocks;
 152
 153        err = skcipher_walk_virt(&walk, req, false);
 154
 155        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 156                kernel_neon_begin();
 157                aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 158                                ctx->key_enc, rounds, blocks);
 159                kernel_neon_end();
 160                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 161        }
 162        return err;
 163}
 164
 165static int ecb_decrypt(struct skcipher_request *req)
 166{
 167        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 168        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 169        int err, rounds = 6 + ctx->key_length / 4;
 170        struct skcipher_walk walk;
 171        unsigned int blocks;
 172
 173        err = skcipher_walk_virt(&walk, req, false);
 174
 175        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 176                kernel_neon_begin();
 177                aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 178                                ctx->key_dec, rounds, blocks);
 179                kernel_neon_end();
 180                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 181        }
 182        return err;
 183}
 184
 185static int cbc_encrypt(struct skcipher_request *req)
 186{
 187        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 188        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 189        int err, rounds = 6 + ctx->key_length / 4;
 190        struct skcipher_walk walk;
 191        unsigned int blocks;
 192
 193        err = skcipher_walk_virt(&walk, req, false);
 194
 195        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 196                kernel_neon_begin();
 197                aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 198                                ctx->key_enc, rounds, blocks, walk.iv);
 199                kernel_neon_end();
 200                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 201        }
 202        return err;
 203}
 204
 205static int cbc_decrypt(struct skcipher_request *req)
 206{
 207        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 208        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 209        int err, rounds = 6 + ctx->key_length / 4;
 210        struct skcipher_walk walk;
 211        unsigned int blocks;
 212
 213        err = skcipher_walk_virt(&walk, req, false);
 214
 215        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 216                kernel_neon_begin();
 217                aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 218                                ctx->key_dec, rounds, blocks, walk.iv);
 219                kernel_neon_end();
 220                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 221        }
 222        return err;
 223}
 224
 225static int cts_cbc_init_tfm(struct crypto_skcipher *tfm)
 226{
 227        crypto_skcipher_set_reqsize(tfm, sizeof(struct cts_cbc_req_ctx));
 228        return 0;
 229}
 230
 231static int cts_cbc_encrypt(struct skcipher_request *req)
 232{
 233        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 234        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 235        struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
 236        int err, rounds = 6 + ctx->key_length / 4;
 237        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 238        struct scatterlist *src = req->src, *dst = req->dst;
 239        struct skcipher_walk walk;
 240
 241        skcipher_request_set_tfm(&rctx->subreq, tfm);
 242
 243        if (req->cryptlen <= AES_BLOCK_SIZE) {
 244                if (req->cryptlen < AES_BLOCK_SIZE)
 245                        return -EINVAL;
 246                cbc_blocks = 1;
 247        }
 248
 249        if (cbc_blocks > 0) {
 250                unsigned int blocks;
 251
 252                skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
 253                                           cbc_blocks * AES_BLOCK_SIZE,
 254                                           req->iv);
 255
 256                err = skcipher_walk_virt(&walk, &rctx->subreq, false);
 257
 258                while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 259                        kernel_neon_begin();
 260                        aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 261                                        ctx->key_enc, rounds, blocks, walk.iv);
 262                        kernel_neon_end();
 263                        err = skcipher_walk_done(&walk,
 264                                                 walk.nbytes % AES_BLOCK_SIZE);
 265                }
 266                if (err)
 267                        return err;
 268
 269                if (req->cryptlen == AES_BLOCK_SIZE)
 270                        return 0;
 271
 272                dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
 273                                             rctx->subreq.cryptlen);
 274                if (req->dst != req->src)
 275                        dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
 276                                               rctx->subreq.cryptlen);
 277        }
 278
 279        /* handle ciphertext stealing */
 280        skcipher_request_set_crypt(&rctx->subreq, src, dst,
 281                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 282                                   req->iv);
 283
 284        err = skcipher_walk_virt(&walk, &rctx->subreq, false);
 285        if (err)
 286                return err;
 287
 288        kernel_neon_begin();
 289        aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 290                            ctx->key_enc, rounds, walk.nbytes, walk.iv);
 291        kernel_neon_end();
 292
 293        return skcipher_walk_done(&walk, 0);
 294}
 295
 296static int cts_cbc_decrypt(struct skcipher_request *req)
 297{
 298        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 299        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 300        struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
 301        int err, rounds = 6 + ctx->key_length / 4;
 302        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 303        struct scatterlist *src = req->src, *dst = req->dst;
 304        struct skcipher_walk walk;
 305
 306        skcipher_request_set_tfm(&rctx->subreq, tfm);
 307
 308        if (req->cryptlen <= AES_BLOCK_SIZE) {
 309                if (req->cryptlen < AES_BLOCK_SIZE)
 310                        return -EINVAL;
 311                cbc_blocks = 1;
 312        }
 313
 314        if (cbc_blocks > 0) {
 315                unsigned int blocks;
 316
 317                skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
 318                                           cbc_blocks * AES_BLOCK_SIZE,
 319                                           req->iv);
 320
 321                err = skcipher_walk_virt(&walk, &rctx->subreq, false);
 322
 323                while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 324                        kernel_neon_begin();
 325                        aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 326                                        ctx->key_dec, rounds, blocks, walk.iv);
 327                        kernel_neon_end();
 328                        err = skcipher_walk_done(&walk,
 329                                                 walk.nbytes % AES_BLOCK_SIZE);
 330                }
 331                if (err)
 332                        return err;
 333
 334                if (req->cryptlen == AES_BLOCK_SIZE)
 335                        return 0;
 336
 337                dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
 338                                             rctx->subreq.cryptlen);
 339                if (req->dst != req->src)
 340                        dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
 341                                               rctx->subreq.cryptlen);
 342        }
 343
 344        /* handle ciphertext stealing */
 345        skcipher_request_set_crypt(&rctx->subreq, src, dst,
 346                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 347                                   req->iv);
 348
 349        err = skcipher_walk_virt(&walk, &rctx->subreq, false);
 350        if (err)
 351                return err;
 352
 353        kernel_neon_begin();
 354        aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 355                            ctx->key_dec, rounds, walk.nbytes, walk.iv);
 356        kernel_neon_end();
 357
 358        return skcipher_walk_done(&walk, 0);
 359}
 360
 361static int ctr_encrypt(struct skcipher_request *req)
 362{
 363        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 364        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 365        int err, rounds = 6 + ctx->key_length / 4;
 366        struct skcipher_walk walk;
 367        int blocks;
 368
 369        err = skcipher_walk_virt(&walk, req, false);
 370
 371        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 372                kernel_neon_begin();
 373                aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 374                                ctx->key_enc, rounds, blocks, walk.iv);
 375                kernel_neon_end();
 376                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 377        }
 378        if (walk.nbytes) {
 379                u8 __aligned(8) tail[AES_BLOCK_SIZE];
 380                unsigned int nbytes = walk.nbytes;
 381                u8 *tdst = walk.dst.virt.addr;
 382                u8 *tsrc = walk.src.virt.addr;
 383
 384                /*
 385                 * Tell aes_ctr_encrypt() to process a tail block.
 386                 */
 387                blocks = -1;
 388
 389                kernel_neon_begin();
 390                aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
 391                                blocks, walk.iv);
 392                kernel_neon_end();
 393                crypto_xor_cpy(tdst, tsrc, tail, nbytes);
 394                err = skcipher_walk_done(&walk, 0);
 395        }
 396
 397        return err;
 398}
 399
 400static int ctr_encrypt_sync(struct skcipher_request *req)
 401{
 402        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 403        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 404
 405        if (!crypto_simd_usable())
 406                return aes_ctr_encrypt_fallback(ctx, req);
 407
 408        return ctr_encrypt(req);
 409}
 410
 411static int xts_encrypt(struct skcipher_request *req)
 412{
 413        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 414        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 415        int err, first, rounds = 6 + ctx->key1.key_length / 4;
 416        struct skcipher_walk walk;
 417        unsigned int blocks;
 418
 419        err = skcipher_walk_virt(&walk, req, false);
 420
 421        for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
 422                kernel_neon_begin();
 423                aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 424                                ctx->key1.key_enc, rounds, blocks,
 425                                ctx->key2.key_enc, walk.iv, first);
 426                kernel_neon_end();
 427                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 428        }
 429
 430        return err;
 431}
 432
 433static int xts_decrypt(struct skcipher_request *req)
 434{
 435        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 436        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 437        int err, first, rounds = 6 + ctx->key1.key_length / 4;
 438        struct skcipher_walk walk;
 439        unsigned int blocks;
 440
 441        err = skcipher_walk_virt(&walk, req, false);
 442
 443        for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
 444                kernel_neon_begin();
 445                aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 446                                ctx->key1.key_dec, rounds, blocks,
 447                                ctx->key2.key_enc, walk.iv, first);
 448                kernel_neon_end();
 449                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 450        }
 451
 452        return err;
 453}
 454
 455static struct skcipher_alg aes_algs[] = { {
 456        .base = {
 457                .cra_name               = "__ecb(aes)",
 458                .cra_driver_name        = "__ecb-aes-" MODE,
 459                .cra_priority           = PRIO,
 460                .cra_flags              = CRYPTO_ALG_INTERNAL,
 461                .cra_blocksize          = AES_BLOCK_SIZE,
 462                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 463                .cra_module             = THIS_MODULE,
 464        },
 465        .min_keysize    = AES_MIN_KEY_SIZE,
 466        .max_keysize    = AES_MAX_KEY_SIZE,
 467        .setkey         = skcipher_aes_setkey,
 468        .encrypt        = ecb_encrypt,
 469        .decrypt        = ecb_decrypt,
 470}, {
 471        .base = {
 472                .cra_name               = "__cbc(aes)",
 473                .cra_driver_name        = "__cbc-aes-" MODE,
 474                .cra_priority           = PRIO,
 475                .cra_flags              = CRYPTO_ALG_INTERNAL,
 476                .cra_blocksize          = AES_BLOCK_SIZE,
 477                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 478                .cra_module             = THIS_MODULE,
 479        },
 480        .min_keysize    = AES_MIN_KEY_SIZE,
 481        .max_keysize    = AES_MAX_KEY_SIZE,
 482        .ivsize         = AES_BLOCK_SIZE,
 483        .setkey         = skcipher_aes_setkey,
 484        .encrypt        = cbc_encrypt,
 485        .decrypt        = cbc_decrypt,
 486}, {
 487        .base = {
 488                .cra_name               = "__cts(cbc(aes))",
 489                .cra_driver_name        = "__cts-cbc-aes-" MODE,
 490                .cra_priority           = PRIO,
 491                .cra_flags              = CRYPTO_ALG_INTERNAL,
 492                .cra_blocksize          = AES_BLOCK_SIZE,
 493                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 494                .cra_module             = THIS_MODULE,
 495        },
 496        .min_keysize    = AES_MIN_KEY_SIZE,
 497        .max_keysize    = AES_MAX_KEY_SIZE,
 498        .ivsize         = AES_BLOCK_SIZE,
 499        .walksize       = 2 * AES_BLOCK_SIZE,
 500        .setkey         = skcipher_aes_setkey,
 501        .encrypt        = cts_cbc_encrypt,
 502        .decrypt        = cts_cbc_decrypt,
 503        .init           = cts_cbc_init_tfm,
 504}, {
 505        .base = {
 506                .cra_name               = "__ctr(aes)",
 507                .cra_driver_name        = "__ctr-aes-" MODE,
 508                .cra_priority           = PRIO,
 509                .cra_flags              = CRYPTO_ALG_INTERNAL,
 510                .cra_blocksize          = 1,
 511                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 512                .cra_module             = THIS_MODULE,
 513        },
 514        .min_keysize    = AES_MIN_KEY_SIZE,
 515        .max_keysize    = AES_MAX_KEY_SIZE,
 516        .ivsize         = AES_BLOCK_SIZE,
 517        .chunksize      = AES_BLOCK_SIZE,
 518        .setkey         = skcipher_aes_setkey,
 519        .encrypt        = ctr_encrypt,
 520        .decrypt        = ctr_encrypt,
 521}, {
 522        .base = {
 523                .cra_name               = "ctr(aes)",
 524                .cra_driver_name        = "ctr-aes-" MODE,
 525                .cra_priority           = PRIO - 1,
 526                .cra_blocksize          = 1,
 527                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 528                .cra_module             = THIS_MODULE,
 529        },
 530        .min_keysize    = AES_MIN_KEY_SIZE,
 531        .max_keysize    = AES_MAX_KEY_SIZE,
 532        .ivsize         = AES_BLOCK_SIZE,
 533        .chunksize      = AES_BLOCK_SIZE,
 534        .setkey         = skcipher_aes_setkey,
 535        .encrypt        = ctr_encrypt_sync,
 536        .decrypt        = ctr_encrypt_sync,
 537}, {
 538        .base = {
 539                .cra_name               = "__xts(aes)",
 540                .cra_driver_name        = "__xts-aes-" MODE,
 541                .cra_priority           = PRIO,
 542                .cra_flags              = CRYPTO_ALG_INTERNAL,
 543                .cra_blocksize          = AES_BLOCK_SIZE,
 544                .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
 545                .cra_module             = THIS_MODULE,
 546        },
 547        .min_keysize    = 2 * AES_MIN_KEY_SIZE,
 548        .max_keysize    = 2 * AES_MAX_KEY_SIZE,
 549        .ivsize         = AES_BLOCK_SIZE,
 550        .setkey         = xts_set_key,
 551        .encrypt        = xts_encrypt,
 552        .decrypt        = xts_decrypt,
 553} };
 554
 555static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 556                         unsigned int key_len)
 557{
 558        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 559        int err;
 560
 561        err = aes_expandkey(&ctx->key, in_key, key_len);
 562        if (err)
 563                crypto_shash_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
 564
 565        return err;
 566}
 567
 568static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
 569{
 570        u64 a = be64_to_cpu(x->a);
 571        u64 b = be64_to_cpu(x->b);
 572
 573        y->a = cpu_to_be64((a << 1) | (b >> 63));
 574        y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
 575}
 576
 577static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 578                       unsigned int key_len)
 579{
 580        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 581        be128 *consts = (be128 *)ctx->consts;
 582        int rounds = 6 + key_len / 4;
 583        int err;
 584
 585        err = cbcmac_setkey(tfm, in_key, key_len);
 586        if (err)
 587                return err;
 588
 589        /* encrypt the zero vector */
 590        kernel_neon_begin();
 591        aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
 592                        rounds, 1);
 593        kernel_neon_end();
 594
 595        cmac_gf128_mul_by_x(consts, consts);
 596        cmac_gf128_mul_by_x(consts + 1, consts);
 597
 598        return 0;
 599}
 600
 601static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
 602                       unsigned int key_len)
 603{
 604        static u8 const ks[3][AES_BLOCK_SIZE] = {
 605                { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
 606                { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
 607                { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
 608        };
 609
 610        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 611        int rounds = 6 + key_len / 4;
 612        u8 key[AES_BLOCK_SIZE];
 613        int err;
 614
 615        err = cbcmac_setkey(tfm, in_key, key_len);
 616        if (err)
 617                return err;
 618
 619        kernel_neon_begin();
 620        aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
 621        aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
 622        kernel_neon_end();
 623
 624        return cbcmac_setkey(tfm, key, sizeof(key));
 625}
 626
 627static int mac_init(struct shash_desc *desc)
 628{
 629        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 630
 631        memset(ctx->dg, 0, AES_BLOCK_SIZE);
 632        ctx->len = 0;
 633
 634        return 0;
 635}
 636
 637static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
 638                          u8 dg[], int enc_before, int enc_after)
 639{
 640        int rounds = 6 + ctx->key_length / 4;
 641
 642        if (crypto_simd_usable()) {
 643                kernel_neon_begin();
 644                aes_mac_update(in, ctx->key_enc, rounds, blocks, dg, enc_before,
 645                               enc_after);
 646                kernel_neon_end();
 647        } else {
 648                if (enc_before)
 649                        __aes_arm64_encrypt(ctx->key_enc, dg, dg, rounds);
 650
 651                while (blocks--) {
 652                        crypto_xor(dg, in, AES_BLOCK_SIZE);
 653                        in += AES_BLOCK_SIZE;
 654
 655                        if (blocks || enc_after)
 656                                __aes_arm64_encrypt(ctx->key_enc, dg, dg,
 657                                                    rounds);
 658                }
 659        }
 660}
 661
 662static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
 663{
 664        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 665        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 666
 667        while (len > 0) {
 668                unsigned int l;
 669
 670                if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
 671                    (ctx->len + len) > AES_BLOCK_SIZE) {
 672
 673                        int blocks = len / AES_BLOCK_SIZE;
 674
 675                        len %= AES_BLOCK_SIZE;
 676
 677                        mac_do_update(&tctx->key, p, blocks, ctx->dg,
 678                                      (ctx->len != 0), (len != 0));
 679
 680                        p += blocks * AES_BLOCK_SIZE;
 681
 682                        if (!len) {
 683                                ctx->len = AES_BLOCK_SIZE;
 684                                break;
 685                        }
 686                        ctx->len = 0;
 687                }
 688
 689                l = min(len, AES_BLOCK_SIZE - ctx->len);
 690
 691                if (l <= AES_BLOCK_SIZE) {
 692                        crypto_xor(ctx->dg + ctx->len, p, l);
 693                        ctx->len += l;
 694                        len -= l;
 695                        p += l;
 696                }
 697        }
 698
 699        return 0;
 700}
 701
 702static int cbcmac_final(struct shash_desc *desc, u8 *out)
 703{
 704        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 705        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 706
 707        mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
 708
 709        memcpy(out, ctx->dg, AES_BLOCK_SIZE);
 710
 711        return 0;
 712}
 713
 714static int cmac_final(struct shash_desc *desc, u8 *out)
 715{
 716        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 717        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 718        u8 *consts = tctx->consts;
 719
 720        if (ctx->len != AES_BLOCK_SIZE) {
 721                ctx->dg[ctx->len] ^= 0x80;
 722                consts += AES_BLOCK_SIZE;
 723        }
 724
 725        mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
 726
 727        memcpy(out, ctx->dg, AES_BLOCK_SIZE);
 728
 729        return 0;
 730}
 731
 732static struct shash_alg mac_algs[] = { {
 733        .base.cra_name          = "cmac(aes)",
 734        .base.cra_driver_name   = "cmac-aes-" MODE,
 735        .base.cra_priority      = PRIO,
 736        .base.cra_blocksize     = AES_BLOCK_SIZE,
 737        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
 738                                  2 * AES_BLOCK_SIZE,
 739        .base.cra_module        = THIS_MODULE,
 740
 741        .digestsize             = AES_BLOCK_SIZE,
 742        .init                   = mac_init,
 743        .update                 = mac_update,
 744        .final                  = cmac_final,
 745        .setkey                 = cmac_setkey,
 746        .descsize               = sizeof(struct mac_desc_ctx),
 747}, {
 748        .base.cra_name          = "xcbc(aes)",
 749        .base.cra_driver_name   = "xcbc-aes-" MODE,
 750        .base.cra_priority      = PRIO,
 751        .base.cra_blocksize     = AES_BLOCK_SIZE,
 752        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
 753                                  2 * AES_BLOCK_SIZE,
 754        .base.cra_module        = THIS_MODULE,
 755
 756        .digestsize             = AES_BLOCK_SIZE,
 757        .init                   = mac_init,
 758        .update                 = mac_update,
 759        .final                  = cmac_final,
 760        .setkey                 = xcbc_setkey,
 761        .descsize               = sizeof(struct mac_desc_ctx),
 762}, {
 763        .base.cra_name          = "cbcmac(aes)",
 764        .base.cra_driver_name   = "cbcmac-aes-" MODE,
 765        .base.cra_priority      = PRIO,
 766        .base.cra_blocksize     = 1,
 767        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
 768        .base.cra_module        = THIS_MODULE,
 769
 770        .digestsize             = AES_BLOCK_SIZE,
 771        .init                   = mac_init,
 772        .update                 = mac_update,
 773        .final                  = cbcmac_final,
 774        .setkey                 = cbcmac_setkey,
 775        .descsize               = sizeof(struct mac_desc_ctx),
 776} };
 777
 778static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
 779
 780static void aes_exit(void)
 781{
 782        int i;
 783
 784        for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
 785                if (aes_simd_algs[i])
 786                        simd_skcipher_free(aes_simd_algs[i]);
 787
 788        crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
 789        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 790}
 791
 792static int __init aes_init(void)
 793{
 794        struct simd_skcipher_alg *simd;
 795        const char *basename;
 796        const char *algname;
 797        const char *drvname;
 798        int err;
 799        int i;
 800
 801        err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 802        if (err)
 803                return err;
 804
 805        err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
 806        if (err)
 807                goto unregister_ciphers;
 808
 809        for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
 810                if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
 811                        continue;
 812
 813                algname = aes_algs[i].base.cra_name + 2;
 814                drvname = aes_algs[i].base.cra_driver_name + 2;
 815                basename = aes_algs[i].base.cra_driver_name;
 816                simd = simd_skcipher_create_compat(algname, drvname, basename);
 817                err = PTR_ERR(simd);
 818                if (IS_ERR(simd))
 819                        goto unregister_simds;
 820
 821                aes_simd_algs[i] = simd;
 822        }
 823
 824        return 0;
 825
 826unregister_simds:
 827        aes_exit();
 828        return err;
 829unregister_ciphers:
 830        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 831        return err;
 832}
 833
 834#ifdef USE_V8_CRYPTO_EXTENSIONS
 835module_cpu_feature_match(AES, aes_init);
 836#else
 837module_init(aes_init);
 838EXPORT_SYMBOL(neon_aes_ecb_encrypt);
 839EXPORT_SYMBOL(neon_aes_cbc_encrypt);
 840#endif
 841module_exit(aes_exit);
 842