linux/arch/arm64/crypto/aes-glue.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
   4 *
   5 * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
   6 */
   7
   8#include <asm/neon.h>
   9#include <asm/hwcap.h>
  10#include <asm/simd.h>
  11#include <crypto/aes.h>
  12#include <crypto/ctr.h>
  13#include <crypto/sha2.h>
  14#include <crypto/internal/hash.h>
  15#include <crypto/internal/simd.h>
  16#include <crypto/internal/skcipher.h>
  17#include <crypto/scatterwalk.h>
  18#include <linux/module.h>
  19#include <linux/cpufeature.h>
  20#include <crypto/xts.h>
  21
  22#include "aes-ce-setkey.h"
  23
  24#ifdef USE_V8_CRYPTO_EXTENSIONS
  25#define MODE                    "ce"
  26#define PRIO                    300
  27#define STRIDE                  5
  28#define aes_expandkey           ce_aes_expandkey
  29#define aes_ecb_encrypt         ce_aes_ecb_encrypt
  30#define aes_ecb_decrypt         ce_aes_ecb_decrypt
  31#define aes_cbc_encrypt         ce_aes_cbc_encrypt
  32#define aes_cbc_decrypt         ce_aes_cbc_decrypt
  33#define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
  34#define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
  35#define aes_essiv_cbc_encrypt   ce_aes_essiv_cbc_encrypt
  36#define aes_essiv_cbc_decrypt   ce_aes_essiv_cbc_decrypt
  37#define aes_ctr_encrypt         ce_aes_ctr_encrypt
  38#define aes_xts_encrypt         ce_aes_xts_encrypt
  39#define aes_xts_decrypt         ce_aes_xts_decrypt
  40#define aes_mac_update          ce_aes_mac_update
  41MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
  42#else
  43#define MODE                    "neon"
  44#define PRIO                    200
  45#define STRIDE                  4
  46#define aes_ecb_encrypt         neon_aes_ecb_encrypt
  47#define aes_ecb_decrypt         neon_aes_ecb_decrypt
  48#define aes_cbc_encrypt         neon_aes_cbc_encrypt
  49#define aes_cbc_decrypt         neon_aes_cbc_decrypt
  50#define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
  51#define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
  52#define aes_essiv_cbc_encrypt   neon_aes_essiv_cbc_encrypt
  53#define aes_essiv_cbc_decrypt   neon_aes_essiv_cbc_decrypt
  54#define aes_ctr_encrypt         neon_aes_ctr_encrypt
  55#define aes_xts_encrypt         neon_aes_xts_encrypt
  56#define aes_xts_decrypt         neon_aes_xts_decrypt
  57#define aes_mac_update          neon_aes_mac_update
  58MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
  59#endif
  60#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
  61MODULE_ALIAS_CRYPTO("ecb(aes)");
  62MODULE_ALIAS_CRYPTO("cbc(aes)");
  63MODULE_ALIAS_CRYPTO("ctr(aes)");
  64MODULE_ALIAS_CRYPTO("xts(aes)");
  65#endif
  66MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
  67MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
  68MODULE_ALIAS_CRYPTO("cmac(aes)");
  69MODULE_ALIAS_CRYPTO("xcbc(aes)");
  70MODULE_ALIAS_CRYPTO("cbcmac(aes)");
  71
  72MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  73MODULE_LICENSE("GPL v2");
  74
  75/* defined in aes-modes.S */
  76asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
  77                                int rounds, int blocks);
  78asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
  79                                int rounds, int blocks);
  80
  81asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
  82                                int rounds, int blocks, u8 iv[]);
  83asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
  84                                int rounds, int blocks, u8 iv[]);
  85
  86asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
  87                                int rounds, int bytes, u8 const iv[]);
  88asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
  89                                int rounds, int bytes, u8 const iv[]);
  90
  91asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
  92                                int rounds, int bytes, u8 ctr[], u8 finalbuf[]);
  93
  94asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
  95                                int rounds, int bytes, u32 const rk2[], u8 iv[],
  96                                int first);
  97asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
  98                                int rounds, int bytes, u32 const rk2[], u8 iv[],
  99                                int first);
 100
 101asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
 102                                      int rounds, int blocks, u8 iv[],
 103                                      u32 const rk2[]);
 104asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
 105                                      int rounds, int blocks, u8 iv[],
 106                                      u32 const rk2[]);
 107
 108asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
 109                              int blocks, u8 dg[], int enc_before,
 110                              int enc_after);
 111
 112struct crypto_aes_xts_ctx {
 113        struct crypto_aes_ctx key1;
 114        struct crypto_aes_ctx __aligned(8) key2;
 115};
 116
 117struct crypto_aes_essiv_cbc_ctx {
 118        struct crypto_aes_ctx key1;
 119        struct crypto_aes_ctx __aligned(8) key2;
 120        struct crypto_shash *hash;
 121};
 122
 123struct mac_tfm_ctx {
 124        struct crypto_aes_ctx key;
 125        u8 __aligned(8) consts[];
 126};
 127
 128struct mac_desc_ctx {
 129        unsigned int len;
 130        u8 dg[AES_BLOCK_SIZE];
 131};
 132
 133static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 134                               unsigned int key_len)
 135{
 136        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 137
 138        return aes_expandkey(ctx, in_key, key_len);
 139}
 140
 141static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
 142                                      const u8 *in_key, unsigned int key_len)
 143{
 144        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 145        int ret;
 146
 147        ret = xts_verify_key(tfm, in_key, key_len);
 148        if (ret)
 149                return ret;
 150
 151        ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
 152        if (!ret)
 153                ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
 154                                    key_len / 2);
 155        return ret;
 156}
 157
 158static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
 159                                            const u8 *in_key,
 160                                            unsigned int key_len)
 161{
 162        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 163        u8 digest[SHA256_DIGEST_SIZE];
 164        int ret;
 165
 166        ret = aes_expandkey(&ctx->key1, in_key, key_len);
 167        if (ret)
 168                return ret;
 169
 170        crypto_shash_tfm_digest(ctx->hash, in_key, key_len, digest);
 171
 172        return aes_expandkey(&ctx->key2, digest, sizeof(digest));
 173}
 174
 175static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
 176{
 177        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 178        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 179        int err, rounds = 6 + ctx->key_length / 4;
 180        struct skcipher_walk walk;
 181        unsigned int blocks;
 182
 183        err = skcipher_walk_virt(&walk, req, false);
 184
 185        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 186                kernel_neon_begin();
 187                aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 188                                ctx->key_enc, rounds, blocks);
 189                kernel_neon_end();
 190                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 191        }
 192        return err;
 193}
 194
 195static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
 196{
 197        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 198        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 199        int err, rounds = 6 + ctx->key_length / 4;
 200        struct skcipher_walk walk;
 201        unsigned int blocks;
 202
 203        err = skcipher_walk_virt(&walk, req, false);
 204
 205        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 206                kernel_neon_begin();
 207                aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 208                                ctx->key_dec, rounds, blocks);
 209                kernel_neon_end();
 210                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 211        }
 212        return err;
 213}
 214
 215static int cbc_encrypt_walk(struct skcipher_request *req,
 216                            struct skcipher_walk *walk)
 217{
 218        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 219        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 220        int err = 0, rounds = 6 + ctx->key_length / 4;
 221        unsigned int blocks;
 222
 223        while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
 224                kernel_neon_begin();
 225                aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
 226                                ctx->key_enc, rounds, blocks, walk->iv);
 227                kernel_neon_end();
 228                err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
 229        }
 230        return err;
 231}
 232
 233static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
 234{
 235        struct skcipher_walk walk;
 236        int err;
 237
 238        err = skcipher_walk_virt(&walk, req, false);
 239        if (err)
 240                return err;
 241        return cbc_encrypt_walk(req, &walk);
 242}
 243
 244static int cbc_decrypt_walk(struct skcipher_request *req,
 245                            struct skcipher_walk *walk)
 246{
 247        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 248        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 249        int err = 0, rounds = 6 + ctx->key_length / 4;
 250        unsigned int blocks;
 251
 252        while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
 253                kernel_neon_begin();
 254                aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
 255                                ctx->key_dec, rounds, blocks, walk->iv);
 256                kernel_neon_end();
 257                err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
 258        }
 259        return err;
 260}
 261
 262static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
 263{
 264        struct skcipher_walk walk;
 265        int err;
 266
 267        err = skcipher_walk_virt(&walk, req, false);
 268        if (err)
 269                return err;
 270        return cbc_decrypt_walk(req, &walk);
 271}
 272
 273static int cts_cbc_encrypt(struct skcipher_request *req)
 274{
 275        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 276        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 277        int err, rounds = 6 + ctx->key_length / 4;
 278        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 279        struct scatterlist *src = req->src, *dst = req->dst;
 280        struct scatterlist sg_src[2], sg_dst[2];
 281        struct skcipher_request subreq;
 282        struct skcipher_walk walk;
 283
 284        skcipher_request_set_tfm(&subreq, tfm);
 285        skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
 286                                      NULL, NULL);
 287
 288        if (req->cryptlen <= AES_BLOCK_SIZE) {
 289                if (req->cryptlen < AES_BLOCK_SIZE)
 290                        return -EINVAL;
 291                cbc_blocks = 1;
 292        }
 293
 294        if (cbc_blocks > 0) {
 295                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 296                                           cbc_blocks * AES_BLOCK_SIZE,
 297                                           req->iv);
 298
 299                err = skcipher_walk_virt(&walk, &subreq, false) ?:
 300                      cbc_encrypt_walk(&subreq, &walk);
 301                if (err)
 302                        return err;
 303
 304                if (req->cryptlen == AES_BLOCK_SIZE)
 305                        return 0;
 306
 307                dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
 308                if (req->dst != req->src)
 309                        dst = scatterwalk_ffwd(sg_dst, req->dst,
 310                                               subreq.cryptlen);
 311        }
 312
 313        /* handle ciphertext stealing */
 314        skcipher_request_set_crypt(&subreq, src, dst,
 315                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 316                                   req->iv);
 317
 318        err = skcipher_walk_virt(&walk, &subreq, false);
 319        if (err)
 320                return err;
 321
 322        kernel_neon_begin();
 323        aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 324                            ctx->key_enc, rounds, walk.nbytes, walk.iv);
 325        kernel_neon_end();
 326
 327        return skcipher_walk_done(&walk, 0);
 328}
 329
 330static int cts_cbc_decrypt(struct skcipher_request *req)
 331{
 332        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 333        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 334        int err, rounds = 6 + ctx->key_length / 4;
 335        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 336        struct scatterlist *src = req->src, *dst = req->dst;
 337        struct scatterlist sg_src[2], sg_dst[2];
 338        struct skcipher_request subreq;
 339        struct skcipher_walk walk;
 340
 341        skcipher_request_set_tfm(&subreq, tfm);
 342        skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
 343                                      NULL, NULL);
 344
 345        if (req->cryptlen <= AES_BLOCK_SIZE) {
 346                if (req->cryptlen < AES_BLOCK_SIZE)
 347                        return -EINVAL;
 348                cbc_blocks = 1;
 349        }
 350
 351        if (cbc_blocks > 0) {
 352                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 353                                           cbc_blocks * AES_BLOCK_SIZE,
 354                                           req->iv);
 355
 356                err = skcipher_walk_virt(&walk, &subreq, false) ?:
 357                      cbc_decrypt_walk(&subreq, &walk);
 358                if (err)
 359                        return err;
 360
 361                if (req->cryptlen == AES_BLOCK_SIZE)
 362                        return 0;
 363
 364                dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
 365                if (req->dst != req->src)
 366                        dst = scatterwalk_ffwd(sg_dst, req->dst,
 367                                               subreq.cryptlen);
 368        }
 369
 370        /* handle ciphertext stealing */
 371        skcipher_request_set_crypt(&subreq, src, dst,
 372                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 373                                   req->iv);
 374
 375        err = skcipher_walk_virt(&walk, &subreq, false);
 376        if (err)
 377                return err;
 378
 379        kernel_neon_begin();
 380        aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 381                            ctx->key_dec, rounds, walk.nbytes, walk.iv);
 382        kernel_neon_end();
 383
 384        return skcipher_walk_done(&walk, 0);
 385}
 386
 387static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
 388{
 389        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 390
 391        ctx->hash = crypto_alloc_shash("sha256", 0, 0);
 392
 393        return PTR_ERR_OR_ZERO(ctx->hash);
 394}
 395
 396static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
 397{
 398        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 399
 400        crypto_free_shash(ctx->hash);
 401}
 402
 403static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
 404{
 405        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 406        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 407        int err, rounds = 6 + ctx->key1.key_length / 4;
 408        struct skcipher_walk walk;
 409        unsigned int blocks;
 410
 411        err = skcipher_walk_virt(&walk, req, false);
 412
 413        blocks = walk.nbytes / AES_BLOCK_SIZE;
 414        if (blocks) {
 415                kernel_neon_begin();
 416                aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 417                                      ctx->key1.key_enc, rounds, blocks,
 418                                      req->iv, ctx->key2.key_enc);
 419                kernel_neon_end();
 420                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 421        }
 422        return err ?: cbc_encrypt_walk(req, &walk);
 423}
 424
 425static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
 426{
 427        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 428        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 429        int err, rounds = 6 + ctx->key1.key_length / 4;
 430        struct skcipher_walk walk;
 431        unsigned int blocks;
 432
 433        err = skcipher_walk_virt(&walk, req, false);
 434
 435        blocks = walk.nbytes / AES_BLOCK_SIZE;
 436        if (blocks) {
 437                kernel_neon_begin();
 438                aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 439                                      ctx->key1.key_dec, rounds, blocks,
 440                                      req->iv, ctx->key2.key_enc);
 441                kernel_neon_end();
 442                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 443        }
 444        return err ?: cbc_decrypt_walk(req, &walk);
 445}
 446
 447static int ctr_encrypt(struct skcipher_request *req)
 448{
 449        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 450        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 451        int err, rounds = 6 + ctx->key_length / 4;
 452        struct skcipher_walk walk;
 453
 454        err = skcipher_walk_virt(&walk, req, false);
 455
 456        while (walk.nbytes > 0) {
 457                const u8 *src = walk.src.virt.addr;
 458                unsigned int nbytes = walk.nbytes;
 459                u8 *dst = walk.dst.virt.addr;
 460                u8 buf[AES_BLOCK_SIZE];
 461                unsigned int tail;
 462
 463                if (unlikely(nbytes < AES_BLOCK_SIZE))
 464                        src = memcpy(buf, src, nbytes);
 465                else if (nbytes < walk.total)
 466                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 467
 468                kernel_neon_begin();
 469                aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
 470                                walk.iv, buf);
 471                kernel_neon_end();
 472
 473                tail = nbytes % (STRIDE * AES_BLOCK_SIZE);
 474                if (tail > 0 && tail < AES_BLOCK_SIZE)
 475                        /*
 476                         * The final partial block could not be returned using
 477                         * an overlapping store, so it was passed via buf[]
 478                         * instead.
 479                         */
 480                        memcpy(dst + nbytes - tail, buf, tail);
 481
 482                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 483        }
 484
 485        return err;
 486}
 487
 488static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst)
 489{
 490        const struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 491        unsigned long flags;
 492
 493        /*
 494         * Temporarily disable interrupts to avoid races where
 495         * cachelines are evicted when the CPU is interrupted
 496         * to do something else.
 497         */
 498        local_irq_save(flags);
 499        aes_encrypt(ctx, dst, src);
 500        local_irq_restore(flags);
 501}
 502
 503static int __maybe_unused ctr_encrypt_sync(struct skcipher_request *req)
 504{
 505        if (!crypto_simd_usable())
 506                return crypto_ctr_encrypt_walk(req, ctr_encrypt_one);
 507
 508        return ctr_encrypt(req);
 509}
 510
 511static int __maybe_unused xts_encrypt(struct skcipher_request *req)
 512{
 513        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 514        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 515        int err, first, rounds = 6 + ctx->key1.key_length / 4;
 516        int tail = req->cryptlen % AES_BLOCK_SIZE;
 517        struct scatterlist sg_src[2], sg_dst[2];
 518        struct skcipher_request subreq;
 519        struct scatterlist *src, *dst;
 520        struct skcipher_walk walk;
 521
 522        if (req->cryptlen < AES_BLOCK_SIZE)
 523                return -EINVAL;
 524
 525        err = skcipher_walk_virt(&walk, req, false);
 526
 527        if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
 528                int xts_blocks = DIV_ROUND_UP(req->cryptlen,
 529                                              AES_BLOCK_SIZE) - 2;
 530
 531                skcipher_walk_abort(&walk);
 532
 533                skcipher_request_set_tfm(&subreq, tfm);
 534                skcipher_request_set_callback(&subreq,
 535                                              skcipher_request_flags(req),
 536                                              NULL, NULL);
 537                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 538                                           xts_blocks * AES_BLOCK_SIZE,
 539                                           req->iv);
 540                req = &subreq;
 541                err = skcipher_walk_virt(&walk, req, false);
 542        } else {
 543                tail = 0;
 544        }
 545
 546        for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
 547                int nbytes = walk.nbytes;
 548
 549                if (walk.nbytes < walk.total)
 550                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 551
 552                kernel_neon_begin();
 553                aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 554                                ctx->key1.key_enc, rounds, nbytes,
 555                                ctx->key2.key_enc, walk.iv, first);
 556                kernel_neon_end();
 557                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 558        }
 559
 560        if (err || likely(!tail))
 561                return err;
 562
 563        dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
 564        if (req->dst != req->src)
 565                dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 566
 567        skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
 568                                   req->iv);
 569
 570        err = skcipher_walk_virt(&walk, &subreq, false);
 571        if (err)
 572                return err;
 573
 574        kernel_neon_begin();
 575        aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 576                        ctx->key1.key_enc, rounds, walk.nbytes,
 577                        ctx->key2.key_enc, walk.iv, first);
 578        kernel_neon_end();
 579
 580        return skcipher_walk_done(&walk, 0);
 581}
 582
 583static int __maybe_unused xts_decrypt(struct skcipher_request *req)
 584{
 585        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 586        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 587        int err, first, rounds = 6 + ctx->key1.key_length / 4;
 588        int tail = req->cryptlen % AES_BLOCK_SIZE;
 589        struct scatterlist sg_src[2], sg_dst[2];
 590        struct skcipher_request subreq;
 591        struct scatterlist *src, *dst;
 592        struct skcipher_walk walk;
 593
 594        if (req->cryptlen < AES_BLOCK_SIZE)
 595                return -EINVAL;
 596
 597        err = skcipher_walk_virt(&walk, req, false);
 598
 599        if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
 600                int xts_blocks = DIV_ROUND_UP(req->cryptlen,
 601                                              AES_BLOCK_SIZE) - 2;
 602
 603                skcipher_walk_abort(&walk);
 604
 605                skcipher_request_set_tfm(&subreq, tfm);
 606                skcipher_request_set_callback(&subreq,
 607                                              skcipher_request_flags(req),
 608                                              NULL, NULL);
 609                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 610                                           xts_blocks * AES_BLOCK_SIZE,
 611                                           req->iv);
 612                req = &subreq;
 613                err = skcipher_walk_virt(&walk, req, false);
 614        } else {
 615                tail = 0;
 616        }
 617
 618        for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
 619                int nbytes = walk.nbytes;
 620
 621                if (walk.nbytes < walk.total)
 622                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 623
 624                kernel_neon_begin();
 625                aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 626                                ctx->key1.key_dec, rounds, nbytes,
 627                                ctx->key2.key_enc, walk.iv, first);
 628                kernel_neon_end();
 629                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 630        }
 631
 632        if (err || likely(!tail))
 633                return err;
 634
 635        dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
 636        if (req->dst != req->src)
 637                dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 638
 639        skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
 640                                   req->iv);
 641
 642        err = skcipher_walk_virt(&walk, &subreq, false);
 643        if (err)
 644                return err;
 645
 646
 647        kernel_neon_begin();
 648        aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 649                        ctx->key1.key_dec, rounds, walk.nbytes,
 650                        ctx->key2.key_enc, walk.iv, first);
 651        kernel_neon_end();
 652
 653        return skcipher_walk_done(&walk, 0);
 654}
 655
 656static struct skcipher_alg aes_algs[] = { {
 657#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
 658        .base = {
 659                .cra_name               = "__ecb(aes)",
 660                .cra_driver_name        = "__ecb-aes-" MODE,
 661                .cra_priority           = PRIO,
 662                .cra_flags              = CRYPTO_ALG_INTERNAL,
 663                .cra_blocksize          = AES_BLOCK_SIZE,
 664                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 665                .cra_module             = THIS_MODULE,
 666        },
 667        .min_keysize    = AES_MIN_KEY_SIZE,
 668        .max_keysize    = AES_MAX_KEY_SIZE,
 669        .setkey         = skcipher_aes_setkey,
 670        .encrypt        = ecb_encrypt,
 671        .decrypt        = ecb_decrypt,
 672}, {
 673        .base = {
 674                .cra_name               = "__cbc(aes)",
 675                .cra_driver_name        = "__cbc-aes-" MODE,
 676                .cra_priority           = PRIO,
 677                .cra_flags              = CRYPTO_ALG_INTERNAL,
 678                .cra_blocksize          = AES_BLOCK_SIZE,
 679                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 680                .cra_module             = THIS_MODULE,
 681        },
 682        .min_keysize    = AES_MIN_KEY_SIZE,
 683        .max_keysize    = AES_MAX_KEY_SIZE,
 684        .ivsize         = AES_BLOCK_SIZE,
 685        .setkey         = skcipher_aes_setkey,
 686        .encrypt        = cbc_encrypt,
 687        .decrypt        = cbc_decrypt,
 688}, {
 689        .base = {
 690                .cra_name               = "__ctr(aes)",
 691                .cra_driver_name        = "__ctr-aes-" MODE,
 692                .cra_priority           = PRIO,
 693                .cra_flags              = CRYPTO_ALG_INTERNAL,
 694                .cra_blocksize          = 1,
 695                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 696                .cra_module             = THIS_MODULE,
 697        },
 698        .min_keysize    = AES_MIN_KEY_SIZE,
 699        .max_keysize    = AES_MAX_KEY_SIZE,
 700        .ivsize         = AES_BLOCK_SIZE,
 701        .chunksize      = AES_BLOCK_SIZE,
 702        .setkey         = skcipher_aes_setkey,
 703        .encrypt        = ctr_encrypt,
 704        .decrypt        = ctr_encrypt,
 705}, {
 706        .base = {
 707                .cra_name               = "ctr(aes)",
 708                .cra_driver_name        = "ctr-aes-" MODE,
 709                .cra_priority           = PRIO - 1,
 710                .cra_blocksize          = 1,
 711                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 712                .cra_module             = THIS_MODULE,
 713        },
 714        .min_keysize    = AES_MIN_KEY_SIZE,
 715        .max_keysize    = AES_MAX_KEY_SIZE,
 716        .ivsize         = AES_BLOCK_SIZE,
 717        .chunksize      = AES_BLOCK_SIZE,
 718        .setkey         = skcipher_aes_setkey,
 719        .encrypt        = ctr_encrypt_sync,
 720        .decrypt        = ctr_encrypt_sync,
 721}, {
 722        .base = {
 723                .cra_name               = "__xts(aes)",
 724                .cra_driver_name        = "__xts-aes-" MODE,
 725                .cra_priority           = PRIO,
 726                .cra_flags              = CRYPTO_ALG_INTERNAL,
 727                .cra_blocksize          = AES_BLOCK_SIZE,
 728                .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
 729                .cra_module             = THIS_MODULE,
 730        },
 731        .min_keysize    = 2 * AES_MIN_KEY_SIZE,
 732        .max_keysize    = 2 * AES_MAX_KEY_SIZE,
 733        .ivsize         = AES_BLOCK_SIZE,
 734        .walksize       = 2 * AES_BLOCK_SIZE,
 735        .setkey         = xts_set_key,
 736        .encrypt        = xts_encrypt,
 737        .decrypt        = xts_decrypt,
 738}, {
 739#endif
 740        .base = {
 741                .cra_name               = "__cts(cbc(aes))",
 742                .cra_driver_name        = "__cts-cbc-aes-" MODE,
 743                .cra_priority           = PRIO,
 744                .cra_flags              = CRYPTO_ALG_INTERNAL,
 745                .cra_blocksize          = AES_BLOCK_SIZE,
 746                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 747                .cra_module             = THIS_MODULE,
 748        },
 749        .min_keysize    = AES_MIN_KEY_SIZE,
 750        .max_keysize    = AES_MAX_KEY_SIZE,
 751        .ivsize         = AES_BLOCK_SIZE,
 752        .walksize       = 2 * AES_BLOCK_SIZE,
 753        .setkey         = skcipher_aes_setkey,
 754        .encrypt        = cts_cbc_encrypt,
 755        .decrypt        = cts_cbc_decrypt,
 756}, {
 757        .base = {
 758                .cra_name               = "__essiv(cbc(aes),sha256)",
 759                .cra_driver_name        = "__essiv-cbc-aes-sha256-" MODE,
 760                .cra_priority           = PRIO + 1,
 761                .cra_flags              = CRYPTO_ALG_INTERNAL,
 762                .cra_blocksize          = AES_BLOCK_SIZE,
 763                .cra_ctxsize            = sizeof(struct crypto_aes_essiv_cbc_ctx),
 764                .cra_module             = THIS_MODULE,
 765        },
 766        .min_keysize    = AES_MIN_KEY_SIZE,
 767        .max_keysize    = AES_MAX_KEY_SIZE,
 768        .ivsize         = AES_BLOCK_SIZE,
 769        .setkey         = essiv_cbc_set_key,
 770        .encrypt        = essiv_cbc_encrypt,
 771        .decrypt        = essiv_cbc_decrypt,
 772        .init           = essiv_cbc_init_tfm,
 773        .exit           = essiv_cbc_exit_tfm,
 774} };
 775
 776static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 777                         unsigned int key_len)
 778{
 779        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 780
 781        return aes_expandkey(&ctx->key, in_key, key_len);
 782}
 783
 784static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
 785{
 786        u64 a = be64_to_cpu(x->a);
 787        u64 b = be64_to_cpu(x->b);
 788
 789        y->a = cpu_to_be64((a << 1) | (b >> 63));
 790        y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
 791}
 792
 793static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 794                       unsigned int key_len)
 795{
 796        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 797        be128 *consts = (be128 *)ctx->consts;
 798        int rounds = 6 + key_len / 4;
 799        int err;
 800
 801        err = cbcmac_setkey(tfm, in_key, key_len);
 802        if (err)
 803                return err;
 804
 805        /* encrypt the zero vector */
 806        kernel_neon_begin();
 807        aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
 808                        rounds, 1);
 809        kernel_neon_end();
 810
 811        cmac_gf128_mul_by_x(consts, consts);
 812        cmac_gf128_mul_by_x(consts + 1, consts);
 813
 814        return 0;
 815}
 816
 817static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
 818                       unsigned int key_len)
 819{
 820        static u8 const ks[3][AES_BLOCK_SIZE] = {
 821                { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
 822                { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
 823                { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
 824        };
 825
 826        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 827        int rounds = 6 + key_len / 4;
 828        u8 key[AES_BLOCK_SIZE];
 829        int err;
 830
 831        err = cbcmac_setkey(tfm, in_key, key_len);
 832        if (err)
 833                return err;
 834
 835        kernel_neon_begin();
 836        aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
 837        aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
 838        kernel_neon_end();
 839
 840        return cbcmac_setkey(tfm, key, sizeof(key));
 841}
 842
 843static int mac_init(struct shash_desc *desc)
 844{
 845        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 846
 847        memset(ctx->dg, 0, AES_BLOCK_SIZE);
 848        ctx->len = 0;
 849
 850        return 0;
 851}
 852
 853static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
 854                          u8 dg[], int enc_before, int enc_after)
 855{
 856        int rounds = 6 + ctx->key_length / 4;
 857
 858        if (crypto_simd_usable()) {
 859                int rem;
 860
 861                do {
 862                        kernel_neon_begin();
 863                        rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
 864                                             dg, enc_before, enc_after);
 865                        kernel_neon_end();
 866                        in += (blocks - rem) * AES_BLOCK_SIZE;
 867                        blocks = rem;
 868                        enc_before = 0;
 869                } while (blocks);
 870        } else {
 871                if (enc_before)
 872                        aes_encrypt(ctx, dg, dg);
 873
 874                while (blocks--) {
 875                        crypto_xor(dg, in, AES_BLOCK_SIZE);
 876                        in += AES_BLOCK_SIZE;
 877
 878                        if (blocks || enc_after)
 879                                aes_encrypt(ctx, dg, dg);
 880                }
 881        }
 882}
 883
 884static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
 885{
 886        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 887        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 888
 889        while (len > 0) {
 890                unsigned int l;
 891
 892                if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
 893                    (ctx->len + len) > AES_BLOCK_SIZE) {
 894
 895                        int blocks = len / AES_BLOCK_SIZE;
 896
 897                        len %= AES_BLOCK_SIZE;
 898
 899                        mac_do_update(&tctx->key, p, blocks, ctx->dg,
 900                                      (ctx->len != 0), (len != 0));
 901
 902                        p += blocks * AES_BLOCK_SIZE;
 903
 904                        if (!len) {
 905                                ctx->len = AES_BLOCK_SIZE;
 906                                break;
 907                        }
 908                        ctx->len = 0;
 909                }
 910
 911                l = min(len, AES_BLOCK_SIZE - ctx->len);
 912
 913                if (l <= AES_BLOCK_SIZE) {
 914                        crypto_xor(ctx->dg + ctx->len, p, l);
 915                        ctx->len += l;
 916                        len -= l;
 917                        p += l;
 918                }
 919        }
 920
 921        return 0;
 922}
 923
 924static int cbcmac_final(struct shash_desc *desc, u8 *out)
 925{
 926        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 927        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 928
 929        mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
 930
 931        memcpy(out, ctx->dg, AES_BLOCK_SIZE);
 932
 933        return 0;
 934}
 935
 936static int cmac_final(struct shash_desc *desc, u8 *out)
 937{
 938        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 939        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 940        u8 *consts = tctx->consts;
 941
 942        if (ctx->len != AES_BLOCK_SIZE) {
 943                ctx->dg[ctx->len] ^= 0x80;
 944                consts += AES_BLOCK_SIZE;
 945        }
 946
 947        mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
 948
 949        memcpy(out, ctx->dg, AES_BLOCK_SIZE);
 950
 951        return 0;
 952}
 953
 954static struct shash_alg mac_algs[] = { {
 955        .base.cra_name          = "cmac(aes)",
 956        .base.cra_driver_name   = "cmac-aes-" MODE,
 957        .base.cra_priority      = PRIO,
 958        .base.cra_blocksize     = AES_BLOCK_SIZE,
 959        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
 960                                  2 * AES_BLOCK_SIZE,
 961        .base.cra_module        = THIS_MODULE,
 962
 963        .digestsize             = AES_BLOCK_SIZE,
 964        .init                   = mac_init,
 965        .update                 = mac_update,
 966        .final                  = cmac_final,
 967        .setkey                 = cmac_setkey,
 968        .descsize               = sizeof(struct mac_desc_ctx),
 969}, {
 970        .base.cra_name          = "xcbc(aes)",
 971        .base.cra_driver_name   = "xcbc-aes-" MODE,
 972        .base.cra_priority      = PRIO,
 973        .base.cra_blocksize     = AES_BLOCK_SIZE,
 974        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
 975                                  2 * AES_BLOCK_SIZE,
 976        .base.cra_module        = THIS_MODULE,
 977
 978        .digestsize             = AES_BLOCK_SIZE,
 979        .init                   = mac_init,
 980        .update                 = mac_update,
 981        .final                  = cmac_final,
 982        .setkey                 = xcbc_setkey,
 983        .descsize               = sizeof(struct mac_desc_ctx),
 984}, {
 985        .base.cra_name          = "cbcmac(aes)",
 986        .base.cra_driver_name   = "cbcmac-aes-" MODE,
 987        .base.cra_priority      = PRIO,
 988        .base.cra_blocksize     = 1,
 989        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
 990        .base.cra_module        = THIS_MODULE,
 991
 992        .digestsize             = AES_BLOCK_SIZE,
 993        .init                   = mac_init,
 994        .update                 = mac_update,
 995        .final                  = cbcmac_final,
 996        .setkey                 = cbcmac_setkey,
 997        .descsize               = sizeof(struct mac_desc_ctx),
 998} };
 999
1000static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
1001
1002static void aes_exit(void)
1003{
1004        int i;
1005
1006        for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
1007                if (aes_simd_algs[i])
1008                        simd_skcipher_free(aes_simd_algs[i]);
1009
1010        crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1011        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1012}
1013
1014static int __init aes_init(void)
1015{
1016        struct simd_skcipher_alg *simd;
1017        const char *basename;
1018        const char *algname;
1019        const char *drvname;
1020        int err;
1021        int i;
1022
1023        err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1024        if (err)
1025                return err;
1026
1027        err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
1028        if (err)
1029                goto unregister_ciphers;
1030
1031        for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
1032                if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
1033                        continue;
1034
1035                algname = aes_algs[i].base.cra_name + 2;
1036                drvname = aes_algs[i].base.cra_driver_name + 2;
1037                basename = aes_algs[i].base.cra_driver_name;
1038                simd = simd_skcipher_create_compat(algname, drvname, basename);
1039                err = PTR_ERR(simd);
1040                if (IS_ERR(simd))
1041                        goto unregister_simds;
1042
1043                aes_simd_algs[i] = simd;
1044        }
1045
1046        return 0;
1047
1048unregister_simds:
1049        aes_exit();
1050        return err;
1051unregister_ciphers:
1052        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
1053        return err;
1054}
1055
1056#ifdef USE_V8_CRYPTO_EXTENSIONS
1057module_cpu_feature_match(AES, aes_init);
1058#else
1059module_init(aes_init);
1060EXPORT_SYMBOL(neon_aes_ecb_encrypt);
1061EXPORT_SYMBOL(neon_aes_cbc_encrypt);
1062EXPORT_SYMBOL(neon_aes_xts_encrypt);
1063EXPORT_SYMBOL(neon_aes_xts_decrypt);
1064#endif
1065module_exit(aes_exit);
1066