linux/arch/arm64/crypto/aes-glue.c
<<
>>
Prefs
   1/*
   2 * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
   3 *
   4 * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
   5 *
   6 * This program is free software; you can redistribute it and/or modify
   7 * it under the terms of the GNU General Public License version 2 as
   8 * published by the Free Software Foundation.
   9 */
  10
  11#include <asm/neon.h>
  12#include <asm/hwcap.h>
  13#include <asm/simd.h>
  14#include <crypto/aes.h>
  15#include <crypto/internal/hash.h>
  16#include <crypto/internal/simd.h>
  17#include <crypto/internal/skcipher.h>
  18#include <crypto/scatterwalk.h>
  19#include <linux/module.h>
  20#include <linux/cpufeature.h>
  21#include <crypto/xts.h>
  22
  23#include "aes-ce-setkey.h"
  24#include "aes-ctr-fallback.h"
  25
  26#ifdef USE_V8_CRYPTO_EXTENSIONS
  27#define MODE                    "ce"
  28#define PRIO                    300
  29#define aes_setkey              ce_aes_setkey
  30#define aes_expandkey           ce_aes_expandkey
  31#define aes_ecb_encrypt         ce_aes_ecb_encrypt
  32#define aes_ecb_decrypt         ce_aes_ecb_decrypt
  33#define aes_cbc_encrypt         ce_aes_cbc_encrypt
  34#define aes_cbc_decrypt         ce_aes_cbc_decrypt
  35#define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
  36#define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
  37#define aes_ctr_encrypt         ce_aes_ctr_encrypt
  38#define aes_xts_encrypt         ce_aes_xts_encrypt
  39#define aes_xts_decrypt         ce_aes_xts_decrypt
  40#define aes_mac_update          ce_aes_mac_update
  41MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
  42#else
  43#define MODE                    "neon"
  44#define PRIO                    200
  45#define aes_setkey              crypto_aes_set_key
  46#define aes_expandkey           crypto_aes_expand_key
  47#define aes_ecb_encrypt         neon_aes_ecb_encrypt
  48#define aes_ecb_decrypt         neon_aes_ecb_decrypt
  49#define aes_cbc_encrypt         neon_aes_cbc_encrypt
  50#define aes_cbc_decrypt         neon_aes_cbc_decrypt
  51#define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
  52#define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
  53#define aes_ctr_encrypt         neon_aes_ctr_encrypt
  54#define aes_xts_encrypt         neon_aes_xts_encrypt
  55#define aes_xts_decrypt         neon_aes_xts_decrypt
  56#define aes_mac_update          neon_aes_mac_update
  57MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
  58MODULE_ALIAS_CRYPTO("ecb(aes)");
  59MODULE_ALIAS_CRYPTO("cbc(aes)");
  60MODULE_ALIAS_CRYPTO("ctr(aes)");
  61MODULE_ALIAS_CRYPTO("xts(aes)");
  62MODULE_ALIAS_CRYPTO("cmac(aes)");
  63MODULE_ALIAS_CRYPTO("xcbc(aes)");
  64MODULE_ALIAS_CRYPTO("cbcmac(aes)");
  65#endif
  66
  67MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  68MODULE_LICENSE("GPL v2");
  69
  70/* defined in aes-modes.S */
  71asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
  72                                int rounds, int blocks);
  73asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
  74                                int rounds, int blocks);
  75
  76asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
  77                                int rounds, int blocks, u8 iv[]);
  78asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
  79                                int rounds, int blocks, u8 iv[]);
  80
  81asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
  82                                int rounds, int bytes, u8 const iv[]);
  83asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
  84                                int rounds, int bytes, u8 const iv[]);
  85
  86asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
  87                                int rounds, int blocks, u8 ctr[]);
  88
  89asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
  90                                int rounds, int blocks, u32 const rk2[], u8 iv[],
  91                                int first);
  92asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
  93                                int rounds, int blocks, u32 const rk2[], u8 iv[],
  94                                int first);
  95
  96asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
  97                               int blocks, u8 dg[], int enc_before,
  98                               int enc_after);
  99
 100struct cts_cbc_req_ctx {
 101        struct scatterlist sg_src[2];
 102        struct scatterlist sg_dst[2];
 103        struct skcipher_request subreq;
 104};
 105
 106struct crypto_aes_xts_ctx {
 107        struct crypto_aes_ctx key1;
 108        struct crypto_aes_ctx __aligned(8) key2;
 109};
 110
 111struct mac_tfm_ctx {
 112        struct crypto_aes_ctx key;
 113        u8 __aligned(8) consts[];
 114};
 115
 116struct mac_desc_ctx {
 117        unsigned int len;
 118        u8 dg[AES_BLOCK_SIZE];
 119};
 120
 121static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 122                               unsigned int key_len)
 123{
 124        return aes_setkey(crypto_skcipher_tfm(tfm), in_key, key_len);
 125}
 126
 127static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
 128                       unsigned int key_len)
 129{
 130        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 131        int ret;
 132
 133        ret = xts_verify_key(tfm, in_key, key_len);
 134        if (ret)
 135                return ret;
 136
 137        ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
 138        if (!ret)
 139                ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
 140                                    key_len / 2);
 141        if (!ret)
 142                return 0;
 143
 144        crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
 145        return -EINVAL;
 146}
 147
 148static int ecb_encrypt(struct skcipher_request *req)
 149{
 150        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 151        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 152        int err, rounds = 6 + ctx->key_length / 4;
 153        struct skcipher_walk walk;
 154        unsigned int blocks;
 155
 156        err = skcipher_walk_virt(&walk, req, false);
 157
 158        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 159                kernel_neon_begin();
 160                aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 161                                ctx->key_enc, rounds, blocks);
 162                kernel_neon_end();
 163                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 164        }
 165        return err;
 166}
 167
 168static int ecb_decrypt(struct skcipher_request *req)
 169{
 170        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 171        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 172        int err, rounds = 6 + ctx->key_length / 4;
 173        struct skcipher_walk walk;
 174        unsigned int blocks;
 175
 176        err = skcipher_walk_virt(&walk, req, false);
 177
 178        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 179                kernel_neon_begin();
 180                aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 181                                ctx->key_dec, rounds, blocks);
 182                kernel_neon_end();
 183                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 184        }
 185        return err;
 186}
 187
 188static int cbc_encrypt(struct skcipher_request *req)
 189{
 190        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 191        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 192        int err, rounds = 6 + ctx->key_length / 4;
 193        struct skcipher_walk walk;
 194        unsigned int blocks;
 195
 196        err = skcipher_walk_virt(&walk, req, false);
 197
 198        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 199                kernel_neon_begin();
 200                aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 201                                ctx->key_enc, rounds, blocks, walk.iv);
 202                kernel_neon_end();
 203                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 204        }
 205        return err;
 206}
 207
 208static int cbc_decrypt(struct skcipher_request *req)
 209{
 210        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 211        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 212        int err, rounds = 6 + ctx->key_length / 4;
 213        struct skcipher_walk walk;
 214        unsigned int blocks;
 215
 216        err = skcipher_walk_virt(&walk, req, false);
 217
 218        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 219                kernel_neon_begin();
 220                aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 221                                ctx->key_dec, rounds, blocks, walk.iv);
 222                kernel_neon_end();
 223                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 224        }
 225        return err;
 226}
 227
 228static int cts_cbc_init_tfm(struct crypto_skcipher *tfm)
 229{
 230        crypto_skcipher_set_reqsize(tfm, sizeof(struct cts_cbc_req_ctx));
 231        return 0;
 232}
 233
 234static int cts_cbc_encrypt(struct skcipher_request *req)
 235{
 236        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 237        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 238        struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
 239        int err, rounds = 6 + ctx->key_length / 4;
 240        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 241        struct scatterlist *src = req->src, *dst = req->dst;
 242        struct skcipher_walk walk;
 243
 244        skcipher_request_set_tfm(&rctx->subreq, tfm);
 245
 246        if (req->cryptlen <= AES_BLOCK_SIZE) {
 247                if (req->cryptlen < AES_BLOCK_SIZE)
 248                        return -EINVAL;
 249                cbc_blocks = 1;
 250        }
 251
 252        if (cbc_blocks > 0) {
 253                unsigned int blocks;
 254
 255                skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
 256                                           cbc_blocks * AES_BLOCK_SIZE,
 257                                           req->iv);
 258
 259                err = skcipher_walk_virt(&walk, &rctx->subreq, false);
 260
 261                while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 262                        kernel_neon_begin();
 263                        aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 264                                        ctx->key_enc, rounds, blocks, walk.iv);
 265                        kernel_neon_end();
 266                        err = skcipher_walk_done(&walk,
 267                                                 walk.nbytes % AES_BLOCK_SIZE);
 268                }
 269                if (err)
 270                        return err;
 271
 272                if (req->cryptlen == AES_BLOCK_SIZE)
 273                        return 0;
 274
 275                dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
 276                                             rctx->subreq.cryptlen);
 277                if (req->dst != req->src)
 278                        dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
 279                                               rctx->subreq.cryptlen);
 280        }
 281
 282        /* handle ciphertext stealing */
 283        skcipher_request_set_crypt(&rctx->subreq, src, dst,
 284                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 285                                   req->iv);
 286
 287        err = skcipher_walk_virt(&walk, &rctx->subreq, false);
 288        if (err)
 289                return err;
 290
 291        kernel_neon_begin();
 292        aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 293                            ctx->key_enc, rounds, walk.nbytes, walk.iv);
 294        kernel_neon_end();
 295
 296        return skcipher_walk_done(&walk, 0);
 297}
 298
 299static int cts_cbc_decrypt(struct skcipher_request *req)
 300{
 301        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 302        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 303        struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
 304        int err, rounds = 6 + ctx->key_length / 4;
 305        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 306        struct scatterlist *src = req->src, *dst = req->dst;
 307        struct skcipher_walk walk;
 308
 309        skcipher_request_set_tfm(&rctx->subreq, tfm);
 310
 311        if (req->cryptlen <= AES_BLOCK_SIZE) {
 312                if (req->cryptlen < AES_BLOCK_SIZE)
 313                        return -EINVAL;
 314                cbc_blocks = 1;
 315        }
 316
 317        if (cbc_blocks > 0) {
 318                unsigned int blocks;
 319
 320                skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
 321                                           cbc_blocks * AES_BLOCK_SIZE,
 322                                           req->iv);
 323
 324                err = skcipher_walk_virt(&walk, &rctx->subreq, false);
 325
 326                while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 327                        kernel_neon_begin();
 328                        aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 329                                        ctx->key_dec, rounds, blocks, walk.iv);
 330                        kernel_neon_end();
 331                        err = skcipher_walk_done(&walk,
 332                                                 walk.nbytes % AES_BLOCK_SIZE);
 333                }
 334                if (err)
 335                        return err;
 336
 337                if (req->cryptlen == AES_BLOCK_SIZE)
 338                        return 0;
 339
 340                dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
 341                                             rctx->subreq.cryptlen);
 342                if (req->dst != req->src)
 343                        dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
 344                                               rctx->subreq.cryptlen);
 345        }
 346
 347        /* handle ciphertext stealing */
 348        skcipher_request_set_crypt(&rctx->subreq, src, dst,
 349                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 350                                   req->iv);
 351
 352        err = skcipher_walk_virt(&walk, &rctx->subreq, false);
 353        if (err)
 354                return err;
 355
 356        kernel_neon_begin();
 357        aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 358                            ctx->key_dec, rounds, walk.nbytes, walk.iv);
 359        kernel_neon_end();
 360
 361        return skcipher_walk_done(&walk, 0);
 362}
 363
 364static int ctr_encrypt(struct skcipher_request *req)
 365{
 366        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 367        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 368        int err, rounds = 6 + ctx->key_length / 4;
 369        struct skcipher_walk walk;
 370        int blocks;
 371
 372        err = skcipher_walk_virt(&walk, req, false);
 373
 374        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 375                kernel_neon_begin();
 376                aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 377                                ctx->key_enc, rounds, blocks, walk.iv);
 378                kernel_neon_end();
 379                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 380        }
 381        if (walk.nbytes) {
 382                u8 __aligned(8) tail[AES_BLOCK_SIZE];
 383                unsigned int nbytes = walk.nbytes;
 384                u8 *tdst = walk.dst.virt.addr;
 385                u8 *tsrc = walk.src.virt.addr;
 386
 387                /*
 388                 * Tell aes_ctr_encrypt() to process a tail block.
 389                 */
 390                blocks = -1;
 391
 392                kernel_neon_begin();
 393                aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
 394                                blocks, walk.iv);
 395                kernel_neon_end();
 396                crypto_xor_cpy(tdst, tsrc, tail, nbytes);
 397                err = skcipher_walk_done(&walk, 0);
 398        }
 399
 400        return err;
 401}
 402
 403static int ctr_encrypt_sync(struct skcipher_request *req)
 404{
 405        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 406        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 407
 408        if (!may_use_simd())
 409                return aes_ctr_encrypt_fallback(ctx, req);
 410
 411        return ctr_encrypt(req);
 412}
 413
 414static int xts_encrypt(struct skcipher_request *req)
 415{
 416        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 417        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 418        int err, first, rounds = 6 + ctx->key1.key_length / 4;
 419        struct skcipher_walk walk;
 420        unsigned int blocks;
 421
 422        err = skcipher_walk_virt(&walk, req, false);
 423
 424        for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
 425                kernel_neon_begin();
 426                aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 427                                ctx->key1.key_enc, rounds, blocks,
 428                                ctx->key2.key_enc, walk.iv, first);
 429                kernel_neon_end();
 430                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 431        }
 432
 433        return err;
 434}
 435
 436static int xts_decrypt(struct skcipher_request *req)
 437{
 438        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 439        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 440        int err, first, rounds = 6 + ctx->key1.key_length / 4;
 441        struct skcipher_walk walk;
 442        unsigned int blocks;
 443
 444        err = skcipher_walk_virt(&walk, req, false);
 445
 446        for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
 447                kernel_neon_begin();
 448                aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 449                                ctx->key1.key_dec, rounds, blocks,
 450                                ctx->key2.key_enc, walk.iv, first);
 451                kernel_neon_end();
 452                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 453        }
 454
 455        return err;
 456}
 457
 458static struct skcipher_alg aes_algs[] = { {
 459        .base = {
 460                .cra_name               = "__ecb(aes)",
 461                .cra_driver_name        = "__ecb-aes-" MODE,
 462                .cra_priority           = PRIO,
 463                .cra_flags              = CRYPTO_ALG_INTERNAL,
 464                .cra_blocksize          = AES_BLOCK_SIZE,
 465                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 466                .cra_module             = THIS_MODULE,
 467        },
 468        .min_keysize    = AES_MIN_KEY_SIZE,
 469        .max_keysize    = AES_MAX_KEY_SIZE,
 470        .setkey         = skcipher_aes_setkey,
 471        .encrypt        = ecb_encrypt,
 472        .decrypt        = ecb_decrypt,
 473}, {
 474        .base = {
 475                .cra_name               = "__cbc(aes)",
 476                .cra_driver_name        = "__cbc-aes-" MODE,
 477                .cra_priority           = PRIO,
 478                .cra_flags              = CRYPTO_ALG_INTERNAL,
 479                .cra_blocksize          = AES_BLOCK_SIZE,
 480                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 481                .cra_module             = THIS_MODULE,
 482        },
 483        .min_keysize    = AES_MIN_KEY_SIZE,
 484        .max_keysize    = AES_MAX_KEY_SIZE,
 485        .ivsize         = AES_BLOCK_SIZE,
 486        .setkey         = skcipher_aes_setkey,
 487        .encrypt        = cbc_encrypt,
 488        .decrypt        = cbc_decrypt,
 489}, {
 490        .base = {
 491                .cra_name               = "__cts(cbc(aes))",
 492                .cra_driver_name        = "__cts-cbc-aes-" MODE,
 493                .cra_priority           = PRIO,
 494                .cra_flags              = CRYPTO_ALG_INTERNAL,
 495                .cra_blocksize          = AES_BLOCK_SIZE,
 496                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 497                .cra_module             = THIS_MODULE,
 498        },
 499        .min_keysize    = AES_MIN_KEY_SIZE,
 500        .max_keysize    = AES_MAX_KEY_SIZE,
 501        .ivsize         = AES_BLOCK_SIZE,
 502        .walksize       = 2 * AES_BLOCK_SIZE,
 503        .setkey         = skcipher_aes_setkey,
 504        .encrypt        = cts_cbc_encrypt,
 505        .decrypt        = cts_cbc_decrypt,
 506        .init           = cts_cbc_init_tfm,
 507}, {
 508        .base = {
 509                .cra_name               = "__ctr(aes)",
 510                .cra_driver_name        = "__ctr-aes-" MODE,
 511                .cra_priority           = PRIO,
 512                .cra_flags              = CRYPTO_ALG_INTERNAL,
 513                .cra_blocksize          = 1,
 514                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 515                .cra_module             = THIS_MODULE,
 516        },
 517        .min_keysize    = AES_MIN_KEY_SIZE,
 518        .max_keysize    = AES_MAX_KEY_SIZE,
 519        .ivsize         = AES_BLOCK_SIZE,
 520        .chunksize      = AES_BLOCK_SIZE,
 521        .setkey         = skcipher_aes_setkey,
 522        .encrypt        = ctr_encrypt,
 523        .decrypt        = ctr_encrypt,
 524}, {
 525        .base = {
 526                .cra_name               = "ctr(aes)",
 527                .cra_driver_name        = "ctr-aes-" MODE,
 528                .cra_priority           = PRIO - 1,
 529                .cra_blocksize          = 1,
 530                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 531                .cra_module             = THIS_MODULE,
 532        },
 533        .min_keysize    = AES_MIN_KEY_SIZE,
 534        .max_keysize    = AES_MAX_KEY_SIZE,
 535        .ivsize         = AES_BLOCK_SIZE,
 536        .chunksize      = AES_BLOCK_SIZE,
 537        .setkey         = skcipher_aes_setkey,
 538        .encrypt        = ctr_encrypt_sync,
 539        .decrypt        = ctr_encrypt_sync,
 540}, {
 541        .base = {
 542                .cra_name               = "__xts(aes)",
 543                .cra_driver_name        = "__xts-aes-" MODE,
 544                .cra_priority           = PRIO,
 545                .cra_flags              = CRYPTO_ALG_INTERNAL,
 546                .cra_blocksize          = AES_BLOCK_SIZE,
 547                .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
 548                .cra_module             = THIS_MODULE,
 549        },
 550        .min_keysize    = 2 * AES_MIN_KEY_SIZE,
 551        .max_keysize    = 2 * AES_MAX_KEY_SIZE,
 552        .ivsize         = AES_BLOCK_SIZE,
 553        .setkey         = xts_set_key,
 554        .encrypt        = xts_encrypt,
 555        .decrypt        = xts_decrypt,
 556} };
 557
 558static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 559                         unsigned int key_len)
 560{
 561        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 562        int err;
 563
 564        err = aes_expandkey(&ctx->key, in_key, key_len);
 565        if (err)
 566                crypto_shash_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
 567
 568        return err;
 569}
 570
 571static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
 572{
 573        u64 a = be64_to_cpu(x->a);
 574        u64 b = be64_to_cpu(x->b);
 575
 576        y->a = cpu_to_be64((a << 1) | (b >> 63));
 577        y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
 578}
 579
 580static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 581                       unsigned int key_len)
 582{
 583        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 584        be128 *consts = (be128 *)ctx->consts;
 585        int rounds = 6 + key_len / 4;
 586        int err;
 587
 588        err = cbcmac_setkey(tfm, in_key, key_len);
 589        if (err)
 590                return err;
 591
 592        /* encrypt the zero vector */
 593        kernel_neon_begin();
 594        aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
 595                        rounds, 1);
 596        kernel_neon_end();
 597
 598        cmac_gf128_mul_by_x(consts, consts);
 599        cmac_gf128_mul_by_x(consts + 1, consts);
 600
 601        return 0;
 602}
 603
 604static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
 605                       unsigned int key_len)
 606{
 607        static u8 const ks[3][AES_BLOCK_SIZE] = {
 608                { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
 609                { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
 610                { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
 611        };
 612
 613        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 614        int rounds = 6 + key_len / 4;
 615        u8 key[AES_BLOCK_SIZE];
 616        int err;
 617
 618        err = cbcmac_setkey(tfm, in_key, key_len);
 619        if (err)
 620                return err;
 621
 622        kernel_neon_begin();
 623        aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
 624        aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
 625        kernel_neon_end();
 626
 627        return cbcmac_setkey(tfm, key, sizeof(key));
 628}
 629
 630static int mac_init(struct shash_desc *desc)
 631{
 632        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 633
 634        memset(ctx->dg, 0, AES_BLOCK_SIZE);
 635        ctx->len = 0;
 636
 637        return 0;
 638}
 639
 640static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
 641                          u8 dg[], int enc_before, int enc_after)
 642{
 643        int rounds = 6 + ctx->key_length / 4;
 644
 645        if (may_use_simd()) {
 646                kernel_neon_begin();
 647                aes_mac_update(in, ctx->key_enc, rounds, blocks, dg, enc_before,
 648                               enc_after);
 649                kernel_neon_end();
 650        } else {
 651                if (enc_before)
 652                        __aes_arm64_encrypt(ctx->key_enc, dg, dg, rounds);
 653
 654                while (blocks--) {
 655                        crypto_xor(dg, in, AES_BLOCK_SIZE);
 656                        in += AES_BLOCK_SIZE;
 657
 658                        if (blocks || enc_after)
 659                                __aes_arm64_encrypt(ctx->key_enc, dg, dg,
 660                                                    rounds);
 661                }
 662        }
 663}
 664
 665static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
 666{
 667        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 668        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 669
 670        while (len > 0) {
 671                unsigned int l;
 672
 673                if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
 674                    (ctx->len + len) > AES_BLOCK_SIZE) {
 675
 676                        int blocks = len / AES_BLOCK_SIZE;
 677
 678                        len %= AES_BLOCK_SIZE;
 679
 680                        mac_do_update(&tctx->key, p, blocks, ctx->dg,
 681                                      (ctx->len != 0), (len != 0));
 682
 683                        p += blocks * AES_BLOCK_SIZE;
 684
 685                        if (!len) {
 686                                ctx->len = AES_BLOCK_SIZE;
 687                                break;
 688                        }
 689                        ctx->len = 0;
 690                }
 691
 692                l = min(len, AES_BLOCK_SIZE - ctx->len);
 693
 694                if (l <= AES_BLOCK_SIZE) {
 695                        crypto_xor(ctx->dg + ctx->len, p, l);
 696                        ctx->len += l;
 697                        len -= l;
 698                        p += l;
 699                }
 700        }
 701
 702        return 0;
 703}
 704
 705static int cbcmac_final(struct shash_desc *desc, u8 *out)
 706{
 707        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 708        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 709
 710        mac_do_update(&tctx->key, NULL, 0, ctx->dg, 1, 0);
 711
 712        memcpy(out, ctx->dg, AES_BLOCK_SIZE);
 713
 714        return 0;
 715}
 716
 717static int cmac_final(struct shash_desc *desc, u8 *out)
 718{
 719        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 720        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 721        u8 *consts = tctx->consts;
 722
 723        if (ctx->len != AES_BLOCK_SIZE) {
 724                ctx->dg[ctx->len] ^= 0x80;
 725                consts += AES_BLOCK_SIZE;
 726        }
 727
 728        mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
 729
 730        memcpy(out, ctx->dg, AES_BLOCK_SIZE);
 731
 732        return 0;
 733}
 734
 735static struct shash_alg mac_algs[] = { {
 736        .base.cra_name          = "cmac(aes)",
 737        .base.cra_driver_name   = "cmac-aes-" MODE,
 738        .base.cra_priority      = PRIO,
 739        .base.cra_blocksize     = AES_BLOCK_SIZE,
 740        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
 741                                  2 * AES_BLOCK_SIZE,
 742        .base.cra_module        = THIS_MODULE,
 743
 744        .digestsize             = AES_BLOCK_SIZE,
 745        .init                   = mac_init,
 746        .update                 = mac_update,
 747        .final                  = cmac_final,
 748        .setkey                 = cmac_setkey,
 749        .descsize               = sizeof(struct mac_desc_ctx),
 750}, {
 751        .base.cra_name          = "xcbc(aes)",
 752        .base.cra_driver_name   = "xcbc-aes-" MODE,
 753        .base.cra_priority      = PRIO,
 754        .base.cra_blocksize     = AES_BLOCK_SIZE,
 755        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
 756                                  2 * AES_BLOCK_SIZE,
 757        .base.cra_module        = THIS_MODULE,
 758
 759        .digestsize             = AES_BLOCK_SIZE,
 760        .init                   = mac_init,
 761        .update                 = mac_update,
 762        .final                  = cmac_final,
 763        .setkey                 = xcbc_setkey,
 764        .descsize               = sizeof(struct mac_desc_ctx),
 765}, {
 766        .base.cra_name          = "cbcmac(aes)",
 767        .base.cra_driver_name   = "cbcmac-aes-" MODE,
 768        .base.cra_priority      = PRIO,
 769        .base.cra_blocksize     = 1,
 770        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
 771        .base.cra_module        = THIS_MODULE,
 772
 773        .digestsize             = AES_BLOCK_SIZE,
 774        .init                   = mac_init,
 775        .update                 = mac_update,
 776        .final                  = cbcmac_final,
 777        .setkey                 = cbcmac_setkey,
 778        .descsize               = sizeof(struct mac_desc_ctx),
 779} };
 780
 781static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
 782
 783static void aes_exit(void)
 784{
 785        int i;
 786
 787        for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
 788                if (aes_simd_algs[i])
 789                        simd_skcipher_free(aes_simd_algs[i]);
 790
 791        crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
 792        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 793}
 794
 795static int __init aes_init(void)
 796{
 797        struct simd_skcipher_alg *simd;
 798        const char *basename;
 799        const char *algname;
 800        const char *drvname;
 801        int err;
 802        int i;
 803
 804        err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 805        if (err)
 806                return err;
 807
 808        err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
 809        if (err)
 810                goto unregister_ciphers;
 811
 812        for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
 813                if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
 814                        continue;
 815
 816                algname = aes_algs[i].base.cra_name + 2;
 817                drvname = aes_algs[i].base.cra_driver_name + 2;
 818                basename = aes_algs[i].base.cra_driver_name;
 819                simd = simd_skcipher_create_compat(algname, drvname, basename);
 820                err = PTR_ERR(simd);
 821                if (IS_ERR(simd))
 822                        goto unregister_simds;
 823
 824                aes_simd_algs[i] = simd;
 825        }
 826
 827        return 0;
 828
 829unregister_simds:
 830        aes_exit();
 831        return err;
 832unregister_ciphers:
 833        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 834        return err;
 835}
 836
 837#ifdef USE_V8_CRYPTO_EXTENSIONS
 838module_cpu_feature_match(AES, aes_init);
 839#else
 840module_init(aes_init);
 841EXPORT_SYMBOL(neon_aes_ecb_encrypt);
 842EXPORT_SYMBOL(neon_aes_cbc_encrypt);
 843#endif
 844module_exit(aes_exit);
 845