linux/arch/arm64/crypto/aes-glue.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
   4 *
   5 * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
   6 */
   7
   8#include <asm/neon.h>
   9#include <asm/hwcap.h>
  10#include <asm/simd.h>
  11#include <crypto/aes.h>
  12#include <crypto/ctr.h>
  13#include <crypto/sha.h>
  14#include <crypto/internal/hash.h>
  15#include <crypto/internal/simd.h>
  16#include <crypto/internal/skcipher.h>
  17#include <crypto/scatterwalk.h>
  18#include <linux/module.h>
  19#include <linux/cpufeature.h>
  20#include <crypto/xts.h>
  21
  22#include "aes-ce-setkey.h"
  23
  24#ifdef USE_V8_CRYPTO_EXTENSIONS
  25#define MODE                    "ce"
  26#define PRIO                    300
  27#define aes_expandkey           ce_aes_expandkey
  28#define aes_ecb_encrypt         ce_aes_ecb_encrypt
  29#define aes_ecb_decrypt         ce_aes_ecb_decrypt
  30#define aes_cbc_encrypt         ce_aes_cbc_encrypt
  31#define aes_cbc_decrypt         ce_aes_cbc_decrypt
  32#define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
  33#define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
  34#define aes_essiv_cbc_encrypt   ce_aes_essiv_cbc_encrypt
  35#define aes_essiv_cbc_decrypt   ce_aes_essiv_cbc_decrypt
  36#define aes_ctr_encrypt         ce_aes_ctr_encrypt
  37#define aes_xts_encrypt         ce_aes_xts_encrypt
  38#define aes_xts_decrypt         ce_aes_xts_decrypt
  39#define aes_mac_update          ce_aes_mac_update
  40MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
  41#else
  42#define MODE                    "neon"
  43#define PRIO                    200
  44#define aes_ecb_encrypt         neon_aes_ecb_encrypt
  45#define aes_ecb_decrypt         neon_aes_ecb_decrypt
  46#define aes_cbc_encrypt         neon_aes_cbc_encrypt
  47#define aes_cbc_decrypt         neon_aes_cbc_decrypt
  48#define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
  49#define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
  50#define aes_essiv_cbc_encrypt   neon_aes_essiv_cbc_encrypt
  51#define aes_essiv_cbc_decrypt   neon_aes_essiv_cbc_decrypt
  52#define aes_ctr_encrypt         neon_aes_ctr_encrypt
  53#define aes_xts_encrypt         neon_aes_xts_encrypt
  54#define aes_xts_decrypt         neon_aes_xts_decrypt
  55#define aes_mac_update          neon_aes_mac_update
  56MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
  57#endif
  58#if defined(USE_V8_CRYPTO_EXTENSIONS) || !defined(CONFIG_CRYPTO_AES_ARM64_BS)
  59MODULE_ALIAS_CRYPTO("ecb(aes)");
  60MODULE_ALIAS_CRYPTO("cbc(aes)");
  61MODULE_ALIAS_CRYPTO("ctr(aes)");
  62MODULE_ALIAS_CRYPTO("xts(aes)");
  63#endif
  64MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
  65MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
  66MODULE_ALIAS_CRYPTO("cmac(aes)");
  67MODULE_ALIAS_CRYPTO("xcbc(aes)");
  68MODULE_ALIAS_CRYPTO("cbcmac(aes)");
  69
  70MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  71MODULE_LICENSE("GPL v2");
  72
  73/* defined in aes-modes.S */
  74asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
  75                                int rounds, int blocks);
  76asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
  77                                int rounds, int blocks);
  78
  79asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
  80                                int rounds, int blocks, u8 iv[]);
  81asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
  82                                int rounds, int blocks, u8 iv[]);
  83
  84asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
  85                                int rounds, int bytes, u8 const iv[]);
  86asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
  87                                int rounds, int bytes, u8 const iv[]);
  88
  89asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
  90                                int rounds, int blocks, u8 ctr[]);
  91
  92asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
  93                                int rounds, int bytes, u32 const rk2[], u8 iv[],
  94                                int first);
  95asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
  96                                int rounds, int bytes, u32 const rk2[], u8 iv[],
  97                                int first);
  98
  99asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
 100                                      int rounds, int blocks, u8 iv[],
 101                                      u32 const rk2[]);
 102asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
 103                                      int rounds, int blocks, u8 iv[],
 104                                      u32 const rk2[]);
 105
 106asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
 107                               int blocks, u8 dg[], int enc_before,
 108                               int enc_after);
 109
 110struct crypto_aes_xts_ctx {
 111        struct crypto_aes_ctx key1;
 112        struct crypto_aes_ctx __aligned(8) key2;
 113};
 114
 115struct crypto_aes_essiv_cbc_ctx {
 116        struct crypto_aes_ctx key1;
 117        struct crypto_aes_ctx __aligned(8) key2;
 118        struct crypto_shash *hash;
 119};
 120
 121struct mac_tfm_ctx {
 122        struct crypto_aes_ctx key;
 123        u8 __aligned(8) consts[];
 124};
 125
 126struct mac_desc_ctx {
 127        unsigned int len;
 128        u8 dg[AES_BLOCK_SIZE];
 129};
 130
 131static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 132                               unsigned int key_len)
 133{
 134        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 135        int ret;
 136
 137        ret = aes_expandkey(ctx, in_key, key_len);
 138        if (ret)
 139                crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
 140
 141        return ret;
 142}
 143
 144static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
 145                                      const u8 *in_key, unsigned int key_len)
 146{
 147        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 148        int ret;
 149
 150        ret = xts_verify_key(tfm, in_key, key_len);
 151        if (ret)
 152                return ret;
 153
 154        ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
 155        if (!ret)
 156                ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
 157                                    key_len / 2);
 158        if (!ret)
 159                return 0;
 160
 161        crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
 162        return -EINVAL;
 163}
 164
 165static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
 166                                            const u8 *in_key,
 167                                            unsigned int key_len)
 168{
 169        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 170        SHASH_DESC_ON_STACK(desc, ctx->hash);
 171        u8 digest[SHA256_DIGEST_SIZE];
 172        int ret;
 173
 174        ret = aes_expandkey(&ctx->key1, in_key, key_len);
 175        if (ret)
 176                goto out;
 177
 178        desc->tfm = ctx->hash;
 179        crypto_shash_digest(desc, in_key, key_len, digest);
 180
 181        ret = aes_expandkey(&ctx->key2, digest, sizeof(digest));
 182        if (ret)
 183                goto out;
 184
 185        return 0;
 186out:
 187        crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
 188        return -EINVAL;
 189}
 190
 191static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
 192{
 193        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 194        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 195        int err, rounds = 6 + ctx->key_length / 4;
 196        struct skcipher_walk walk;
 197        unsigned int blocks;
 198
 199        err = skcipher_walk_virt(&walk, req, false);
 200
 201        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 202                kernel_neon_begin();
 203                aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 204                                ctx->key_enc, rounds, blocks);
 205                kernel_neon_end();
 206                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 207        }
 208        return err;
 209}
 210
 211static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
 212{
 213        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 214        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 215        int err, rounds = 6 + ctx->key_length / 4;
 216        struct skcipher_walk walk;
 217        unsigned int blocks;
 218
 219        err = skcipher_walk_virt(&walk, req, false);
 220
 221        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 222                kernel_neon_begin();
 223                aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 224                                ctx->key_dec, rounds, blocks);
 225                kernel_neon_end();
 226                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 227        }
 228        return err;
 229}
 230
 231static int cbc_encrypt_walk(struct skcipher_request *req,
 232                            struct skcipher_walk *walk)
 233{
 234        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 235        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 236        int err = 0, rounds = 6 + ctx->key_length / 4;
 237        unsigned int blocks;
 238
 239        while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
 240                kernel_neon_begin();
 241                aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
 242                                ctx->key_enc, rounds, blocks, walk->iv);
 243                kernel_neon_end();
 244                err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
 245        }
 246        return err;
 247}
 248
 249static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
 250{
 251        struct skcipher_walk walk;
 252        int err;
 253
 254        err = skcipher_walk_virt(&walk, req, false);
 255        if (err)
 256                return err;
 257        return cbc_encrypt_walk(req, &walk);
 258}
 259
 260static int cbc_decrypt_walk(struct skcipher_request *req,
 261                            struct skcipher_walk *walk)
 262{
 263        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 264        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 265        int err = 0, rounds = 6 + ctx->key_length / 4;
 266        unsigned int blocks;
 267
 268        while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
 269                kernel_neon_begin();
 270                aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
 271                                ctx->key_dec, rounds, blocks, walk->iv);
 272                kernel_neon_end();
 273                err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
 274        }
 275        return err;
 276}
 277
 278static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
 279{
 280        struct skcipher_walk walk;
 281        int err;
 282
 283        err = skcipher_walk_virt(&walk, req, false);
 284        if (err)
 285                return err;
 286        return cbc_decrypt_walk(req, &walk);
 287}
 288
 289static int cts_cbc_encrypt(struct skcipher_request *req)
 290{
 291        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 292        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 293        int err, rounds = 6 + ctx->key_length / 4;
 294        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 295        struct scatterlist *src = req->src, *dst = req->dst;
 296        struct scatterlist sg_src[2], sg_dst[2];
 297        struct skcipher_request subreq;
 298        struct skcipher_walk walk;
 299
 300        skcipher_request_set_tfm(&subreq, tfm);
 301        skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
 302                                      NULL, NULL);
 303
 304        if (req->cryptlen <= AES_BLOCK_SIZE) {
 305                if (req->cryptlen < AES_BLOCK_SIZE)
 306                        return -EINVAL;
 307                cbc_blocks = 1;
 308        }
 309
 310        if (cbc_blocks > 0) {
 311                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 312                                           cbc_blocks * AES_BLOCK_SIZE,
 313                                           req->iv);
 314
 315                err = skcipher_walk_virt(&walk, &subreq, false) ?:
 316                      cbc_encrypt_walk(&subreq, &walk);
 317                if (err)
 318                        return err;
 319
 320                if (req->cryptlen == AES_BLOCK_SIZE)
 321                        return 0;
 322
 323                dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
 324                if (req->dst != req->src)
 325                        dst = scatterwalk_ffwd(sg_dst, req->dst,
 326                                               subreq.cryptlen);
 327        }
 328
 329        /* handle ciphertext stealing */
 330        skcipher_request_set_crypt(&subreq, src, dst,
 331                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 332                                   req->iv);
 333
 334        err = skcipher_walk_virt(&walk, &subreq, false);
 335        if (err)
 336                return err;
 337
 338        kernel_neon_begin();
 339        aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 340                            ctx->key_enc, rounds, walk.nbytes, walk.iv);
 341        kernel_neon_end();
 342
 343        return skcipher_walk_done(&walk, 0);
 344}
 345
 346static int cts_cbc_decrypt(struct skcipher_request *req)
 347{
 348        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 349        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 350        int err, rounds = 6 + ctx->key_length / 4;
 351        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 352        struct scatterlist *src = req->src, *dst = req->dst;
 353        struct scatterlist sg_src[2], sg_dst[2];
 354        struct skcipher_request subreq;
 355        struct skcipher_walk walk;
 356
 357        skcipher_request_set_tfm(&subreq, tfm);
 358        skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
 359                                      NULL, NULL);
 360
 361        if (req->cryptlen <= AES_BLOCK_SIZE) {
 362                if (req->cryptlen < AES_BLOCK_SIZE)
 363                        return -EINVAL;
 364                cbc_blocks = 1;
 365        }
 366
 367        if (cbc_blocks > 0) {
 368                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 369                                           cbc_blocks * AES_BLOCK_SIZE,
 370                                           req->iv);
 371
 372                err = skcipher_walk_virt(&walk, &subreq, false) ?:
 373                      cbc_decrypt_walk(&subreq, &walk);
 374                if (err)
 375                        return err;
 376
 377                if (req->cryptlen == AES_BLOCK_SIZE)
 378                        return 0;
 379
 380                dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
 381                if (req->dst != req->src)
 382                        dst = scatterwalk_ffwd(sg_dst, req->dst,
 383                                               subreq.cryptlen);
 384        }
 385
 386        /* handle ciphertext stealing */
 387        skcipher_request_set_crypt(&subreq, src, dst,
 388                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 389                                   req->iv);
 390
 391        err = skcipher_walk_virt(&walk, &subreq, false);
 392        if (err)
 393                return err;
 394
 395        kernel_neon_begin();
 396        aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 397                            ctx->key_dec, rounds, walk.nbytes, walk.iv);
 398        kernel_neon_end();
 399
 400        return skcipher_walk_done(&walk, 0);
 401}
 402
 403static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
 404{
 405        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 406
 407        ctx->hash = crypto_alloc_shash("sha256", 0, 0);
 408
 409        return PTR_ERR_OR_ZERO(ctx->hash);
 410}
 411
 412static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
 413{
 414        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 415
 416        crypto_free_shash(ctx->hash);
 417}
 418
 419static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
 420{
 421        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 422        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 423        int err, rounds = 6 + ctx->key1.key_length / 4;
 424        struct skcipher_walk walk;
 425        unsigned int blocks;
 426
 427        err = skcipher_walk_virt(&walk, req, false);
 428
 429        blocks = walk.nbytes / AES_BLOCK_SIZE;
 430        if (blocks) {
 431                kernel_neon_begin();
 432                aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 433                                      ctx->key1.key_enc, rounds, blocks,
 434                                      req->iv, ctx->key2.key_enc);
 435                kernel_neon_end();
 436                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 437        }
 438        return err ?: cbc_encrypt_walk(req, &walk);
 439}
 440
 441static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
 442{
 443        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 444        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 445        int err, rounds = 6 + ctx->key1.key_length / 4;
 446        struct skcipher_walk walk;
 447        unsigned int blocks;
 448
 449        err = skcipher_walk_virt(&walk, req, false);
 450
 451        blocks = walk.nbytes / AES_BLOCK_SIZE;
 452        if (blocks) {
 453                kernel_neon_begin();
 454                aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 455                                      ctx->key1.key_dec, rounds, blocks,
 456                                      req->iv, ctx->key2.key_enc);
 457                kernel_neon_end();
 458                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 459        }
 460        return err ?: cbc_decrypt_walk(req, &walk);
 461}
 462
 463static int ctr_encrypt(struct skcipher_request *req)
 464{
 465        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 466        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 467        int err, rounds = 6 + ctx->key_length / 4;
 468        struct skcipher_walk walk;
 469        int blocks;
 470
 471        err = skcipher_walk_virt(&walk, req, false);
 472
 473        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 474                kernel_neon_begin();
 475                aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 476                                ctx->key_enc, rounds, blocks, walk.iv);
 477                kernel_neon_end();
 478                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 479        }
 480        if (walk.nbytes) {
 481                u8 __aligned(8) tail[AES_BLOCK_SIZE];
 482                unsigned int nbytes = walk.nbytes;
 483                u8 *tdst = walk.dst.virt.addr;
 484                u8 *tsrc = walk.src.virt.addr;
 485
 486                /*
 487                 * Tell aes_ctr_encrypt() to process a tail block.
 488                 */
 489                blocks = -1;
 490
 491                kernel_neon_begin();
 492                aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
 493                                blocks, walk.iv);
 494                kernel_neon_end();
 495                crypto_xor_cpy(tdst, tsrc, tail, nbytes);
 496                err = skcipher_walk_done(&walk, 0);
 497        }
 498
 499        return err;
 500}
 501
 502static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
 503{
 504        const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 505        unsigned long flags;
 506
 507        /*
 508         * Temporarily disable interrupts to avoid races where
 509         * cachelines are evicted when the CPU is interrupted
 510         * to do something else.
 511         */
 512        local_irq_save(flags);
 513        aes_encrypt(ctx, dst, src);
 514        local_irq_restore(flags);
 515}
 516
 517static int __maybe_unused ctr_encrypt_sync(struct skcipher_request *req)
 518{
 519        if (!crypto_simd_usable())
 520                return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
 521
 522        return ctr_encrypt(req);
 523}
 524
 525static int __maybe_unused xts_encrypt(struct skcipher_request *req)
 526{
 527        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 528        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 529        int err, first, rounds = 6 + ctx->key1.key_length / 4;
 530        int tail = req->cryptlen % AES_BLOCK_SIZE;
 531        struct scatterlist sg_src[2], sg_dst[2];
 532        struct skcipher_request subreq;
 533        struct scatterlist *src, *dst;
 534        struct skcipher_walk walk;
 535
 536        if (req->cryptlen < AES_BLOCK_SIZE)
 537                return -EINVAL;
 538
 539        err = skcipher_walk_virt(&walk, req, false);
 540
 541        if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
 542                int xts_blocks = DIV_ROUND_UP(req->cryptlen,
 543                                              AES_BLOCK_SIZE) - 2;
 544
 545                skcipher_walk_abort(&walk);
 546
 547                skcipher_request_set_tfm(&subreq, tfm);
 548                skcipher_request_set_callback(&subreq,
 549                                              skcipher_request_flags(req),
 550                                              NULL, NULL);
 551                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 552                                           xts_blocks * AES_BLOCK_SIZE,
 553                                           req->iv);
 554                req = &subreq;
 555                err = skcipher_walk_virt(&walk, req, false);
 556        } else {
 557                tail = 0;
 558        }
 559
 560        for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
 561                int nbytes = walk.nbytes;
 562
 563                if (walk.nbytes < walk.total)
 564                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 565
 566                kernel_neon_begin();
 567                aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 568                                ctx->key1.key_enc, rounds, nbytes,
 569                                ctx->key2.key_enc, walk.iv, first);
 570                kernel_neon_end();
 571                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 572        }
 573
 574        if (err || likely(!tail))
 575                return err;
 576
 577        dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
 578        if (req->dst != req->src)
 579                dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 580
 581        skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
 582                                   req->iv);
 583
 584        err = skcipher_walk_virt(&walk, &subreq, false);
 585        if (err)
 586                return err;
 587
 588        kernel_neon_begin();
 589        aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 590                        ctx->key1.key_enc, rounds, walk.nbytes,
 591                        ctx->key2.key_enc, walk.iv, first);
 592        kernel_neon_end();
 593
 594        return skcipher_walk_done(&walk, 0);
 595}
 596
 597static int __maybe_unused xts_decrypt(struct skcipher_request *req)
 598{
 599        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 600        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 601        int err, first, rounds = 6 + ctx->key1.key_length / 4;
 602        int tail = req->cryptlen % AES_BLOCK_SIZE;
 603        struct scatterlist sg_src[2], sg_dst[2];
 604        struct skcipher_request subreq;
 605        struct scatterlist *src, *dst;
 606        struct skcipher_walk walk;
 607
 608        if (req->cryptlen < AES_BLOCK_SIZE)
 609                return -EINVAL;
 610
 611        err = skcipher_walk_virt(&walk, req, false);
 612
 613        if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
 614                int xts_blocks = DIV_ROUND_UP(req->cryptlen,
 615                                              AES_BLOCK_SIZE) - 2;
 616
 617                skcipher_walk_abort(&walk);
 618
 619                skcipher_request_set_tfm(&subreq, tfm);
 620                skcipher_request_set_callback(&subreq,
 621                                              skcipher_request_flags(req),
 622                                              NULL, NULL);
 623                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 624                                           xts_blocks * AES_BLOCK_SIZE,
 625                                           req->iv);
 626                req = &subreq;
 627                err = skcipher_walk_virt(&walk, req, false);
 628        } else {
 629                tail = 0;
 630        }
 631
 632        for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
 633                int nbytes = walk.nbytes;
 634
 635                if (walk.nbytes < walk.total)
 636                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 637
 638                kernel_neon_begin();
 639                aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 640                                ctx->key1.key_dec, rounds, nbytes,
 641                                ctx->key2.key_enc, walk.iv, first);
 642                kernel_neon_end();
 643                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 644        }
 645
 646        if (err || likely(!tail))
 647                return err;
 648
 649        dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
 650        if (req->dst != req->src)
 651                dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 652
 653        skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
 654                                   req->iv);
 655
 656        err = skcipher_walk_virt(&walk, &subreq, false);
 657        if (err)
 658                return err;
 659
 660
 661        kernel_neon_begin();
 662        aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 663                        ctx->key1.key_dec, rounds, walk.nbytes,
 664                        ctx->key2.key_enc, walk.iv, first);
 665        kernel_neon_end();
 666
 667        return skcipher_walk_done(&walk, 0);
 668}
 669
 670static struct skcipher_alg aes_algs[] = { {
 671#if defined(USE_V8_CRYPTO_EXTENSIONS) || !defined(CONFIG_CRYPTO_AES_ARM64_BS)
 672        .base = {
 673                .cra_name               = "__ecb(aes)",
 674                .cra_driver_name        = "__ecb-aes-" MODE,
 675                .cra_priority           = PRIO,
 676                .cra_flags              = CRYPTO_ALG_INTERNAL,
 677                .cra_blocksize          = AES_BLOCK_SIZE,
 678                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 679                .cra_module             = THIS_MODULE,
 680        },
 681        .min_keysize    = AES_MIN_KEY_SIZE,
 682        .max_keysize    = AES_MAX_KEY_SIZE,
 683        .setkey         = skcipher_aes_setkey,
 684        .encrypt        = ecb_encrypt,
 685        .decrypt        = ecb_decrypt,
 686}, {
 687        .base = {
 688                .cra_name               = "__cbc(aes)",
 689                .cra_driver_name        = "__cbc-aes-" MODE,
 690                .cra_priority           = PRIO,
 691                .cra_flags              = CRYPTO_ALG_INTERNAL,
 692                .cra_blocksize          = AES_BLOCK_SIZE,
 693                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 694                .cra_module             = THIS_MODULE,
 695        },
 696        .min_keysize    = AES_MIN_KEY_SIZE,
 697        .max_keysize    = AES_MAX_KEY_SIZE,
 698        .ivsize         = AES_BLOCK_SIZE,
 699        .setkey         = skcipher_aes_setkey,
 700        .encrypt        = cbc_encrypt,
 701        .decrypt        = cbc_decrypt,
 702}, {
 703        .base = {
 704                .cra_name               = "__ctr(aes)",
 705                .cra_driver_name        = "__ctr-aes-" MODE,
 706                .cra_priority           = PRIO,
 707                .cra_flags              = CRYPTO_ALG_INTERNAL,
 708                .cra_blocksize          = 1,
 709                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 710                .cra_module             = THIS_MODULE,
 711        },
 712        .min_keysize    = AES_MIN_KEY_SIZE,
 713        .max_keysize    = AES_MAX_KEY_SIZE,
 714        .ivsize         = AES_BLOCK_SIZE,
 715        .chunksize      = AES_BLOCK_SIZE,
 716        .setkey         = skcipher_aes_setkey,
 717        .encrypt        = ctr_encrypt,
 718        .decrypt        = ctr_encrypt,
 719}, {
 720        .base = {
 721                .cra_name               = "ctr(aes)",
 722                .cra_driver_name        = "ctr-aes-" MODE,
 723                .cra_priority           = PRIO - 1,
 724                .cra_blocksize          = 1,
 725                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 726                .cra_module             = THIS_MODULE,
 727        },
 728        .min_keysize    = AES_MIN_KEY_SIZE,
 729        .max_keysize    = AES_MAX_KEY_SIZE,
 730        .ivsize         = AES_BLOCK_SIZE,
 731        .chunksize      = AES_BLOCK_SIZE,
 732        .setkey         = skcipher_aes_setkey,
 733        .encrypt        = ctr_encrypt_sync,
 734        .decrypt        = ctr_encrypt_sync,
 735}, {
 736        .base = {
 737                .cra_name               = "__xts(aes)",
 738                .cra_driver_name        = "__xts-aes-" MODE,
 739                .cra_priority           = PRIO,
 740                .cra_flags              = CRYPTO_ALG_INTERNAL,
 741                .cra_blocksize          = AES_BLOCK_SIZE,
 742                .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
 743                .cra_module             = THIS_MODULE,
 744        },
 745        .min_keysize    = 2 * AES_MIN_KEY_SIZE,
 746        .max_keysize    = 2 * AES_MAX_KEY_SIZE,
 747        .ivsize         = AES_BLOCK_SIZE,
 748        .walksize       = 2 * AES_BLOCK_SIZE,
 749        .setkey         = xts_set_key,
 750        .encrypt        = xts_encrypt,
 751        .decrypt        = xts_decrypt,
 752}, {
 753#endif
 754        .base = {
 755                .cra_name               = "__cts(cbc(aes))",
 756                .cra_driver_name        = "__cts-cbc-aes-" MODE,
 757                .cra_priority           = PRIO,
 758                .cra_flags              = CRYPTO_ALG_INTERNAL,
 759                .cra_blocksize          = AES_BLOCK_SIZE,
 760                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 761                .cra_module             = THIS_MODULE,
 762        },
 763        .min_keysize    = AES_MIN_KEY_SIZE,
 764        .max_keysize    = AES_MAX_KEY_SIZE,
 765        .ivsize         = AES_BLOCK_SIZE,
 766        .walksize       = 2 * AES_BLOCK_SIZE,
 767        .setkey         = skcipher_aes_setkey,
 768        .encrypt        = cts_cbc_encrypt,
 769        .decrypt        = cts_cbc_decrypt,
 770}, {
 771        .base = {
 772                .cra_name               = "__essiv(cbc(aes),sha256)",
 773                .cra_driver_name        = "__essiv-cbc-aes-sha256-" MODE,
 774                .cra_priority           = PRIO + 1,
 775                .cra_flags              = CRYPTO_ALG_INTERNAL,
 776                .cra_blocksize          = AES_BLOCK_SIZE,
 777                .cra_ctxsize            = sizeof(struct crypto_aes_essiv_cbc_ctx),
 778                .cra_module             = THIS_MODULE,
 779        },
 780        .min_keysize    = AES_MIN_KEY_SIZE,
 781        .max_keysize    = AES_MAX_KEY_SIZE,
 782        .ivsize         = AES_BLOCK_SIZE,
 783        .setkey         = essiv_cbc_set_key,
 784        .encrypt        = essiv_cbc_encrypt,
 785        .decrypt        = essiv_cbc_decrypt,
 786        .init           = essiv_cbc_init_tfm,
 787        .exit           = essiv_cbc_exit_tfm,
 788} };
 789
 790static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 791                         unsigned int key_len)
 792{
 793        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 794        int err;
 795
 796        err = aes_expandkey(&ctx->key, in_key, key_len);
 797        if (err)
 798                crypto_shash_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
 799
 800        return err;
 801}
 802
 803static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
 804{
 805        u64 a = be64_to_cpu(x->a);
 806        u64 b = be64_to_cpu(x->b);
 807
 808        y->a = cpu_to_be64((a << 1) | (b >> 63));
 809        y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
 810}
 811
 812static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 813                       unsigned int key_len)
 814{
 815        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 816        be128 *consts = (be128 *)ctx->consts;
 817        int rounds = 6 + key_len / 4;
 818        int err;
 819
 820        err = cbcmac_setkey(tfm, in_key, key_len);
 821        if (err)
 822                return err;
 823
 824        /* encrypt the zero vector */
 825        kernel_neon_begin();
 826        aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
 827                        rounds, 1);
 828        kernel_neon_end();
 829
 830        cmac_gf128_mul_by_x(consts, consts);
 831        cmac_gf128_mul_by_x(consts + 1, consts);
 832
 833        return 0;
 834}
 835
 836static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
 837                       unsigned int key_len)
 838{
 839        static u8 const ks[3][AES_BLOCK_SIZE] = {
 840                { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
 841                { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
 842                { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
 843        };
 844
 845        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 846        int rounds = 6 + key_len / 4;
 847        u8 key[AES_BLOCK_SIZE];
 848        int err;
 849
 850        err = cbcmac_setkey(tfm, in_key, key_len);
 851        if (err)
 852                return err;
 853
 854        kernel_neon_begin();
 855        aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
 856        aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
 857        kernel_neon_end();
 858
 859        return cbcmac_setkey(tfm, key, sizeof(key));
 860}
 861
 862static int mac_init(struct shash_desc *desc)
 863{
 864        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 865
 866        memset(ctx->dg, 0, AES_BLOCK_SIZE);
 867        ctx->len = 0;
 868
 869        return 0;
 870}
 871
 872static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
 873                          u8 dg[], int enc_before, int enc_after)
 874{
 875        int rounds = 6 + ctx->key_length / 4;
 876
 877        if (crypto_simd_usable()) {
 878                kernel_neon_begin();
 879                aes_mac_update(in, ctx->key_enc, rounds, blocks, dg, enc_before,
 880                               enc_after);
 881                kernel_neon_end();
 882        } else {
 883                if (enc_before)
 884                        aes_encrypt(ctx, dg, dg);
 885
 886                while (blocks--) {
 887                        crypto_xor(dg, in, AES_BLOCK_SIZE);
 888                        in += AES_BLOCK_SIZE;
 889
 890                        if (blocks || enc_after)
 891                                aes_encrypt(ctx, dg, dg);
 892                }
 893        }
 894}
 895
 896static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
 897{
 898        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 899        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 900
 901        while (len > 0) {
 902                unsigned int l;
 903
 904                if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
 905                    (ctx->len + len) > AES_BLOCK_SIZE) {
 906
 907                        int blocks = len / AES_BLOCK_SIZE;
 908
 909                        len %= AES_BLOCK_SIZE;
 910
 911                        mac_do_update(&tctx->key, p, blocks, ctx->dg,
 912                                      (ctx->len != 0), (len != 0));
 913
 914                        p += blocks * AES_BLOCK_SIZE;
 915
 916                        if (!len) {
 917                                ctx->len = AES_BLOCK_SIZE;
 918                                break;
 919                        }
 920                        ctx->len = 0;
 921                }
 922
 923                l = min(len, AES_BLOCK_SIZE - ctx->len);
 924
 925                if (l <= AES_BLOCK_SIZE) {
 926                        crypto_xor(ctx->dg + ctx->len, p, l);
 927                        ctx->len += l;
 928                        len -= l;
 929                        p += l;
 930                }
 931        }
 932
 933        return 0;
 934}
 935
 936static int cbcmac_final(struct shash_desc *desc, u8 *out)
 937{
 938        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 939        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 940
 941        mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
 942
 943        memcpy(out, ctx->dg, AES_BLOCK_SIZE);
 944
 945        return 0;
 946}
 947
 948static int cmac_final(struct shash_desc *desc, u8 *out)
 949{
 950        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 951        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 952        u8 *consts = tctx->consts;
 953
 954        if (ctx->len != AES_BLOCK_SIZE) {
 955                ctx->dg[ctx->len] ^= 0x80;
 956                consts += AES_BLOCK_SIZE;
 957        }
 958
 959        mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
 960
 961        memcpy(out, ctx->dg, AES_BLOCK_SIZE);
 962
 963        return 0;
 964}
 965
 966static struct shash_alg mac_algs[] = { {
 967        .base.cra_name          = "cmac(aes)",
 968        .base.cra_driver_name   = "cmac-aes-" MODE,
 969        .base.cra_priority      = PRIO,
 970        .base.cra_blocksize     = AES_BLOCK_SIZE,
 971        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
 972                                  2 * AES_BLOCK_SIZE,
 973        .base.cra_module        = THIS_MODULE,
 974
 975        .digestsize             = AES_BLOCK_SIZE,
 976        .init                   = mac_init,
 977        .update                 = mac_update,
 978        .final                  = cmac_final,
 979        .setkey                 = cmac_setkey,
 980        .descsize               = sizeof(struct mac_desc_ctx),
 981}, {
 982        .base.cra_name          = "xcbc(aes)",
 983        .base.cra_driver_name   = "xcbc-aes-" MODE,
 984        .base.cra_priority      = PRIO,
 985        .base.cra_blocksize     = AES_BLOCK_SIZE,
 986        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
 987                                  2 * AES_BLOCK_SIZE,
 988        .base.cra_module        = THIS_MODULE,
 989
 990        .digestsize             = AES_BLOCK_SIZE,
 991        .init                   = mac_init,
 992        .update                 = mac_update,
 993        .final                  = cmac_final,
 994        .setkey                 = xcbc_setkey,
 995        .descsize               = sizeof(struct mac_desc_ctx),
 996}, {
 997        .base.cra_name          = "cbcmac(aes)",
 998        .base.cra_driver_name   = "cbcmac-aes-" MODE,
 999        .base.cra_priority      = PRIO,
1000        .base.cra_blocksize     = 1,
1001        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
1002        .base.cra_module        = THIS_MODULE,
1003
1004        .digestsize             = AES_BLOCK_SIZE,
1005        .init                   = mac_init,
1006        .update                 = mac_update,
1007        .final                  = cbcmac_final,
1008        .setkey                 = cbcmac_setkey,
1009        .descsize               = sizeof(struct mac_desc_ctx),
1010} };
1011
1012static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
1013
1014static void aes_exit(void)
1015{
1016        int i;
1017
1018        for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
1019                if (aes_simd_algs[i])
1020                        simd_skcipher_free(aes_simd_algs[i]);
1021
1022        crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1023        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1024}
1025
1026static int __init aes_init(void)
1027{
1028        struct simd_skcipher_alg *simd;
1029        const char *basename;
1030        const char *algname;
1031        const char *drvname;
1032        int err;
1033        int i;
1034
1035        err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1036        if (err)
1037                return err;
1038
1039        err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1040        if (err)
1041                goto unregister_ciphers;
1042
1043        for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
1044                if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
1045                        continue;
1046
1047                algname = aes_algs[i].base.cra_name + 2;
1048                drvname = aes_algs[i].base.cra_driver_name + 2;
1049                basename = aes_algs[i].base.cra_driver_name;
1050                simd = simd_skcipher_create_compat(algname, drvname, basename);
1051                err = PTR_ERR(simd);
1052                if (IS_ERR(simd))
1053                        goto unregister_simds;
1054
1055                aes_simd_algs[i] = simd;
1056        }
1057
1058        return 0;
1059
1060unregister_simds:
1061        aes_exit();
1062        return err;
1063unregister_ciphers:
1064        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1065        return err;
1066}
1067
1068#ifdef USE_V8_CRYPTO_EXTENSIONS
1069module_cpu_feature_match(AES, aes_init);
1070#else
1071module_init(aes_init);
1072EXPORT_SYMBOL(neon_aes_ecb_encrypt);
1073EXPORT_SYMBOL(neon_aes_cbc_encrypt);
1074EXPORT_SYMBOL(neon_aes_xts_encrypt);
1075EXPORT_SYMBOL(neon_aes_xts_decrypt);
1076#endif
1077module_exit(aes_exit);
1078