linux/arch/arm64/crypto/aes-glue.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
   4 *
   5 * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
   6 */
   7
   8#include <asm/neon.h>
   9#include <asm/hwcap.h>
  10#include <asm/simd.h>
  11#include <crypto/aes.h>
  12#include <crypto/ctr.h>
  13#include <crypto/sha2.h>
  14#include <crypto/internal/hash.h>
  15#include <crypto/internal/simd.h>
  16#include <crypto/internal/skcipher.h>
  17#include <crypto/scatterwalk.h>
  18#include <linux/module.h>
  19#include <linux/cpufeature.h>
  20#include <crypto/xts.h>
  21
  22#include "aes-ce-setkey.h"
  23
  24#ifdef USE_V8_CRYPTO_EXTENSIONS
  25#define MODE                    "ce"
  26#define PRIO                    300
  27#define aes_expandkey           ce_aes_expandkey
  28#define aes_ecb_encrypt         ce_aes_ecb_encrypt
  29#define aes_ecb_decrypt         ce_aes_ecb_decrypt
  30#define aes_cbc_encrypt         ce_aes_cbc_encrypt
  31#define aes_cbc_decrypt         ce_aes_cbc_decrypt
  32#define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
  33#define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
  34#define aes_essiv_cbc_encrypt   ce_aes_essiv_cbc_encrypt
  35#define aes_essiv_cbc_decrypt   ce_aes_essiv_cbc_decrypt
  36#define aes_ctr_encrypt         ce_aes_ctr_encrypt
  37#define aes_xts_encrypt         ce_aes_xts_encrypt
  38#define aes_xts_decrypt         ce_aes_xts_decrypt
  39#define aes_mac_update          ce_aes_mac_update
  40MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
  41#else
  42#define MODE                    "neon"
  43#define PRIO                    200
  44#define aes_ecb_encrypt         neon_aes_ecb_encrypt
  45#define aes_ecb_decrypt         neon_aes_ecb_decrypt
  46#define aes_cbc_encrypt         neon_aes_cbc_encrypt
  47#define aes_cbc_decrypt         neon_aes_cbc_decrypt
  48#define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
  49#define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
  50#define aes_essiv_cbc_encrypt   neon_aes_essiv_cbc_encrypt
  51#define aes_essiv_cbc_decrypt   neon_aes_essiv_cbc_decrypt
  52#define aes_ctr_encrypt         neon_aes_ctr_encrypt
  53#define aes_xts_encrypt         neon_aes_xts_encrypt
  54#define aes_xts_decrypt         neon_aes_xts_decrypt
  55#define aes_mac_update          neon_aes_mac_update
  56MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
  57#endif
  58#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
  59MODULE_ALIAS_CRYPTO("ecb(aes)");
  60MODULE_ALIAS_CRYPTO("cbc(aes)");
  61MODULE_ALIAS_CRYPTO("ctr(aes)");
  62MODULE_ALIAS_CRYPTO("xts(aes)");
  63#endif
  64MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
  65MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
  66MODULE_ALIAS_CRYPTO("cmac(aes)");
  67MODULE_ALIAS_CRYPTO("xcbc(aes)");
  68MODULE_ALIAS_CRYPTO("cbcmac(aes)");
  69
  70MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  71MODULE_LICENSE("GPL v2");
  72
  73/* defined in aes-modes.S */
  74asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
  75                                int rounds, int blocks);
  76asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
  77                                int rounds, int blocks);
  78
  79asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
  80                                int rounds, int blocks, u8 iv[]);
  81asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
  82                                int rounds, int blocks, u8 iv[]);
  83
  84asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
  85                                int rounds, int bytes, u8 const iv[]);
  86asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
  87                                int rounds, int bytes, u8 const iv[]);
  88
  89asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
  90                                int rounds, int bytes, u8 ctr[]);
  91
  92asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
  93                                int rounds, int bytes, u32 const rk2[], u8 iv[],
  94                                int first);
  95asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
  96                                int rounds, int bytes, u32 const rk2[], u8 iv[],
  97                                int first);
  98
  99asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
 100                                      int rounds, int blocks, u8 iv[],
 101                                      u32 const rk2[]);
 102asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
 103                                      int rounds, int blocks, u8 iv[],
 104                                      u32 const rk2[]);
 105
 106asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
 107                              int blocks, u8 dg[], int enc_before,
 108                              int enc_after);
 109
 110struct crypto_aes_xts_ctx {
 111        struct crypto_aes_ctx key1;
 112        struct crypto_aes_ctx __aligned(8) key2;
 113};
 114
 115struct crypto_aes_essiv_cbc_ctx {
 116        struct crypto_aes_ctx key1;
 117        struct crypto_aes_ctx __aligned(8) key2;
 118        struct crypto_shash *hash;
 119};
 120
 121struct mac_tfm_ctx {
 122        struct crypto_aes_ctx key;
 123        u8 __aligned(8) consts[];
 124};
 125
 126struct mac_desc_ctx {
 127        unsigned int len;
 128        u8 dg[AES_BLOCK_SIZE];
 129};
 130
 131static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 132                               unsigned int key_len)
 133{
 134        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 135
 136        return aes_expandkey(ctx, in_key, key_len);
 137}
 138
 139static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
 140                                      const u8 *in_key, unsigned int key_len)
 141{
 142        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 143        int ret;
 144
 145        ret = xts_verify_key(tfm, in_key, key_len);
 146        if (ret)
 147                return ret;
 148
 149        ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
 150        if (!ret)
 151                ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
 152                                    key_len / 2);
 153        return ret;
 154}
 155
 156static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
 157                                            const u8 *in_key,
 158                                            unsigned int key_len)
 159{
 160        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 161        u8 digest[SHA256_DIGEST_SIZE];
 162        int ret;
 163
 164        ret = aes_expandkey(&ctx->key1, in_key, key_len);
 165        if (ret)
 166                return ret;
 167
 168        crypto_shash_tfm_digest(ctx->hash, in_key, key_len, digest);
 169
 170        return aes_expandkey(&ctx->key2, digest, sizeof(digest));
 171}
 172
 173static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
 174{
 175        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 176        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 177        int err, rounds = 6 + ctx->key_length / 4;
 178        struct skcipher_walk walk;
 179        unsigned int blocks;
 180
 181        err = skcipher_walk_virt(&walk, req, false);
 182
 183        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 184                kernel_neon_begin();
 185                aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 186                                ctx->key_enc, rounds, blocks);
 187                kernel_neon_end();
 188                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 189        }
 190        return err;
 191}
 192
 193static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
 194{
 195        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 196        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 197        int err, rounds = 6 + ctx->key_length / 4;
 198        struct skcipher_walk walk;
 199        unsigned int blocks;
 200
 201        err = skcipher_walk_virt(&walk, req, false);
 202
 203        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 204                kernel_neon_begin();
 205                aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 206                                ctx->key_dec, rounds, blocks);
 207                kernel_neon_end();
 208                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 209        }
 210        return err;
 211}
 212
 213static int cbc_encrypt_walk(struct skcipher_request *req,
 214                            struct skcipher_walk *walk)
 215{
 216        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 217        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 218        int err = 0, rounds = 6 + ctx->key_length / 4;
 219        unsigned int blocks;
 220
 221        while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
 222                kernel_neon_begin();
 223                aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
 224                                ctx->key_enc, rounds, blocks, walk->iv);
 225                kernel_neon_end();
 226                err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
 227        }
 228        return err;
 229}
 230
 231static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
 232{
 233        struct skcipher_walk walk;
 234        int err;
 235
 236        err = skcipher_walk_virt(&walk, req, false);
 237        if (err)
 238                return err;
 239        return cbc_encrypt_walk(req, &walk);
 240}
 241
 242static int cbc_decrypt_walk(struct skcipher_request *req,
 243                            struct skcipher_walk *walk)
 244{
 245        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 246        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 247        int err = 0, rounds = 6 + ctx->key_length / 4;
 248        unsigned int blocks;
 249
 250        while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
 251                kernel_neon_begin();
 252                aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
 253                                ctx->key_dec, rounds, blocks, walk->iv);
 254                kernel_neon_end();
 255                err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
 256        }
 257        return err;
 258}
 259
 260static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
 261{
 262        struct skcipher_walk walk;
 263        int err;
 264
 265        err = skcipher_walk_virt(&walk, req, false);
 266        if (err)
 267                return err;
 268        return cbc_decrypt_walk(req, &walk);
 269}
 270
 271static int cts_cbc_encrypt(struct skcipher_request *req)
 272{
 273        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 274        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 275        int err, rounds = 6 + ctx->key_length / 4;
 276        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 277        struct scatterlist *src = req->src, *dst = req->dst;
 278        struct scatterlist sg_src[2], sg_dst[2];
 279        struct skcipher_request subreq;
 280        struct skcipher_walk walk;
 281
 282        skcipher_request_set_tfm(&subreq, tfm);
 283        skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
 284                                      NULL, NULL);
 285
 286        if (req->cryptlen <= AES_BLOCK_SIZE) {
 287                if (req->cryptlen < AES_BLOCK_SIZE)
 288                        return -EINVAL;
 289                cbc_blocks = 1;
 290        }
 291
 292        if (cbc_blocks > 0) {
 293                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 294                                           cbc_blocks * AES_BLOCK_SIZE,
 295                                           req->iv);
 296
 297                err = skcipher_walk_virt(&walk, &subreq, false) ?:
 298                      cbc_encrypt_walk(&subreq, &walk);
 299                if (err)
 300                        return err;
 301
 302                if (req->cryptlen == AES_BLOCK_SIZE)
 303                        return 0;
 304
 305                dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
 306                if (req->dst != req->src)
 307                        dst = scatterwalk_ffwd(sg_dst, req->dst,
 308                                               subreq.cryptlen);
 309        }
 310
 311        /* handle ciphertext stealing */
 312        skcipher_request_set_crypt(&subreq, src, dst,
 313                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 314                                   req->iv);
 315
 316        err = skcipher_walk_virt(&walk, &subreq, false);
 317        if (err)
 318                return err;
 319
 320        kernel_neon_begin();
 321        aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 322                            ctx->key_enc, rounds, walk.nbytes, walk.iv);
 323        kernel_neon_end();
 324
 325        return skcipher_walk_done(&walk, 0);
 326}
 327
 328static int cts_cbc_decrypt(struct skcipher_request *req)
 329{
 330        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 331        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 332        int err, rounds = 6 + ctx->key_length / 4;
 333        int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
 334        struct scatterlist *src = req->src, *dst = req->dst;
 335        struct scatterlist sg_src[2], sg_dst[2];
 336        struct skcipher_request subreq;
 337        struct skcipher_walk walk;
 338
 339        skcipher_request_set_tfm(&subreq, tfm);
 340        skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
 341                                      NULL, NULL);
 342
 343        if (req->cryptlen <= AES_BLOCK_SIZE) {
 344                if (req->cryptlen < AES_BLOCK_SIZE)
 345                        return -EINVAL;
 346                cbc_blocks = 1;
 347        }
 348
 349        if (cbc_blocks > 0) {
 350                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 351                                           cbc_blocks * AES_BLOCK_SIZE,
 352                                           req->iv);
 353
 354                err = skcipher_walk_virt(&walk, &subreq, false) ?:
 355                      cbc_decrypt_walk(&subreq, &walk);
 356                if (err)
 357                        return err;
 358
 359                if (req->cryptlen == AES_BLOCK_SIZE)
 360                        return 0;
 361
 362                dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
 363                if (req->dst != req->src)
 364                        dst = scatterwalk_ffwd(sg_dst, req->dst,
 365                                               subreq.cryptlen);
 366        }
 367
 368        /* handle ciphertext stealing */
 369        skcipher_request_set_crypt(&subreq, src, dst,
 370                                   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
 371                                   req->iv);
 372
 373        err = skcipher_walk_virt(&walk, &subreq, false);
 374        if (err)
 375                return err;
 376
 377        kernel_neon_begin();
 378        aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 379                            ctx->key_dec, rounds, walk.nbytes, walk.iv);
 380        kernel_neon_end();
 381
 382        return skcipher_walk_done(&walk, 0);
 383}
 384
 385static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
 386{
 387        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 388
 389        ctx->hash = crypto_alloc_shash("sha256", 0, 0);
 390
 391        return PTR_ERR_OR_ZERO(ctx->hash);
 392}
 393
 394static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
 395{
 396        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 397
 398        crypto_free_shash(ctx->hash);
 399}
 400
 401static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
 402{
 403        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 404        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 405        int err, rounds = 6 + ctx->key1.key_length / 4;
 406        struct skcipher_walk walk;
 407        unsigned int blocks;
 408
 409        err = skcipher_walk_virt(&walk, req, false);
 410
 411        blocks = walk.nbytes / AES_BLOCK_SIZE;
 412        if (blocks) {
 413                kernel_neon_begin();
 414                aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 415                                      ctx->key1.key_enc, rounds, blocks,
 416                                      req->iv, ctx->key2.key_enc);
 417                kernel_neon_end();
 418                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 419        }
 420        return err ?: cbc_encrypt_walk(req, &walk);
 421}
 422
 423static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
 424{
 425        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 426        struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 427        int err, rounds = 6 + ctx->key1.key_length / 4;
 428        struct skcipher_walk walk;
 429        unsigned int blocks;
 430
 431        err = skcipher_walk_virt(&walk, req, false);
 432
 433        blocks = walk.nbytes / AES_BLOCK_SIZE;
 434        if (blocks) {
 435                kernel_neon_begin();
 436                aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 437                                      ctx->key1.key_dec, rounds, blocks,
 438                                      req->iv, ctx->key2.key_enc);
 439                kernel_neon_end();
 440                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 441        }
 442        return err ?: cbc_decrypt_walk(req, &walk);
 443}
 444
 445static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
 446{
 447        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 448        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 449        int err, rounds = 6 + ctx->key_length / 4;
 450        struct skcipher_walk walk;
 451
 452        err = skcipher_walk_virt(&walk, req, false);
 453
 454        while (walk.nbytes > 0) {
 455                const u8 *src = walk.src.virt.addr;
 456                unsigned int nbytes = walk.nbytes;
 457                u8 *dst = walk.dst.virt.addr;
 458                u8 buf[AES_BLOCK_SIZE];
 459
 460                if (unlikely(nbytes < AES_BLOCK_SIZE))
 461                        src = dst = memcpy(buf + sizeof(buf) - nbytes,
 462                                           src, nbytes);
 463                else if (nbytes < walk.total)
 464                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 465
 466                kernel_neon_begin();
 467                aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
 468                                walk.iv);
 469                kernel_neon_end();
 470
 471                if (unlikely(nbytes < AES_BLOCK_SIZE))
 472                        memcpy(walk.dst.virt.addr,
 473                               buf + sizeof(buf) - nbytes, nbytes);
 474
 475                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 476        }
 477
 478        return err;
 479}
 480
 481static int __maybe_unused xts_encrypt(struct skcipher_request *req)
 482{
 483        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 484        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 485        int err, first, rounds = 6 + ctx->key1.key_length / 4;
 486        int tail = req->cryptlen % AES_BLOCK_SIZE;
 487        struct scatterlist sg_src[2], sg_dst[2];
 488        struct skcipher_request subreq;
 489        struct scatterlist *src, *dst;
 490        struct skcipher_walk walk;
 491
 492        if (req->cryptlen < AES_BLOCK_SIZE)
 493                return -EINVAL;
 494
 495        err = skcipher_walk_virt(&walk, req, false);
 496
 497        if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
 498                int xts_blocks = DIV_ROUND_UP(req->cryptlen,
 499                                              AES_BLOCK_SIZE) - 2;
 500
 501                skcipher_walk_abort(&walk);
 502
 503                skcipher_request_set_tfm(&subreq, tfm);
 504                skcipher_request_set_callback(&subreq,
 505                                              skcipher_request_flags(req),
 506                                              NULL, NULL);
 507                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 508                                           xts_blocks * AES_BLOCK_SIZE,
 509                                           req->iv);
 510                req = &subreq;
 511                err = skcipher_walk_virt(&walk, req, false);
 512        } else {
 513                tail = 0;
 514        }
 515
 516        for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
 517                int nbytes = walk.nbytes;
 518
 519                if (walk.nbytes < walk.total)
 520                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 521
 522                kernel_neon_begin();
 523                aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 524                                ctx->key1.key_enc, rounds, nbytes,
 525                                ctx->key2.key_enc, walk.iv, first);
 526                kernel_neon_end();
 527                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 528        }
 529
 530        if (err || likely(!tail))
 531                return err;
 532
 533        dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
 534        if (req->dst != req->src)
 535                dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 536
 537        skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
 538                                   req->iv);
 539
 540        err = skcipher_walk_virt(&walk, &subreq, false);
 541        if (err)
 542                return err;
 543
 544        kernel_neon_begin();
 545        aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 546                        ctx->key1.key_enc, rounds, walk.nbytes,
 547                        ctx->key2.key_enc, walk.iv, first);
 548        kernel_neon_end();
 549
 550        return skcipher_walk_done(&walk, 0);
 551}
 552
 553static int __maybe_unused xts_decrypt(struct skcipher_request *req)
 554{
 555        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 556        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 557        int err, first, rounds = 6 + ctx->key1.key_length / 4;
 558        int tail = req->cryptlen % AES_BLOCK_SIZE;
 559        struct scatterlist sg_src[2], sg_dst[2];
 560        struct skcipher_request subreq;
 561        struct scatterlist *src, *dst;
 562        struct skcipher_walk walk;
 563
 564        if (req->cryptlen < AES_BLOCK_SIZE)
 565                return -EINVAL;
 566
 567        err = skcipher_walk_virt(&walk, req, false);
 568
 569        if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
 570                int xts_blocks = DIV_ROUND_UP(req->cryptlen,
 571                                              AES_BLOCK_SIZE) - 2;
 572
 573                skcipher_walk_abort(&walk);
 574
 575                skcipher_request_set_tfm(&subreq, tfm);
 576                skcipher_request_set_callback(&subreq,
 577                                              skcipher_request_flags(req),
 578                                              NULL, NULL);
 579                skcipher_request_set_crypt(&subreq, req->src, req->dst,
 580                                           xts_blocks * AES_BLOCK_SIZE,
 581                                           req->iv);
 582                req = &subreq;
 583                err = skcipher_walk_virt(&walk, req, false);
 584        } else {
 585                tail = 0;
 586        }
 587
 588        for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
 589                int nbytes = walk.nbytes;
 590
 591                if (walk.nbytes < walk.total)
 592                        nbytes &= ~(AES_BLOCK_SIZE - 1);
 593
 594                kernel_neon_begin();
 595                aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 596                                ctx->key1.key_dec, rounds, nbytes,
 597                                ctx->key2.key_enc, walk.iv, first);
 598                kernel_neon_end();
 599                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
 600        }
 601
 602        if (err || likely(!tail))
 603                return err;
 604
 605        dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
 606        if (req->dst != req->src)
 607                dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
 608
 609        skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
 610                                   req->iv);
 611
 612        err = skcipher_walk_virt(&walk, &subreq, false);
 613        if (err)
 614                return err;
 615
 616
 617        kernel_neon_begin();
 618        aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 619                        ctx->key1.key_dec, rounds, walk.nbytes,
 620                        ctx->key2.key_enc, walk.iv, first);
 621        kernel_neon_end();
 622
 623        return skcipher_walk_done(&walk, 0);
 624}
 625
 626static struct skcipher_alg aes_algs[] = { {
 627#if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
 628        .base = {
 629                .cra_name               = "ecb(aes)",
 630                .cra_driver_name        = "ecb-aes-" MODE,
 631                .cra_priority           = PRIO,
 632                .cra_blocksize          = AES_BLOCK_SIZE,
 633                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 634                .cra_module             = THIS_MODULE,
 635        },
 636        .min_keysize    = AES_MIN_KEY_SIZE,
 637        .max_keysize    = AES_MAX_KEY_SIZE,
 638        .setkey         = skcipher_aes_setkey,
 639        .encrypt        = ecb_encrypt,
 640        .decrypt        = ecb_decrypt,
 641}, {
 642        .base = {
 643                .cra_name               = "cbc(aes)",
 644                .cra_driver_name        = "cbc-aes-" MODE,
 645                .cra_priority           = PRIO,
 646                .cra_blocksize          = AES_BLOCK_SIZE,
 647                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 648                .cra_module             = THIS_MODULE,
 649        },
 650        .min_keysize    = AES_MIN_KEY_SIZE,
 651        .max_keysize    = AES_MAX_KEY_SIZE,
 652        .ivsize         = AES_BLOCK_SIZE,
 653        .setkey         = skcipher_aes_setkey,
 654        .encrypt        = cbc_encrypt,
 655        .decrypt        = cbc_decrypt,
 656}, {
 657        .base = {
 658                .cra_name               = "ctr(aes)",
 659                .cra_driver_name        = "ctr-aes-" MODE,
 660                .cra_priority           = PRIO,
 661                .cra_blocksize          = 1,
 662                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 663                .cra_module             = THIS_MODULE,
 664        },
 665        .min_keysize    = AES_MIN_KEY_SIZE,
 666        .max_keysize    = AES_MAX_KEY_SIZE,
 667        .ivsize         = AES_BLOCK_SIZE,
 668        .chunksize      = AES_BLOCK_SIZE,
 669        .setkey         = skcipher_aes_setkey,
 670        .encrypt        = ctr_encrypt,
 671        .decrypt        = ctr_encrypt,
 672}, {
 673        .base = {
 674                .cra_name               = "xts(aes)",
 675                .cra_driver_name        = "xts-aes-" MODE,
 676                .cra_priority           = PRIO,
 677                .cra_blocksize          = AES_BLOCK_SIZE,
 678                .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
 679                .cra_module             = THIS_MODULE,
 680        },
 681        .min_keysize    = 2 * AES_MIN_KEY_SIZE,
 682        .max_keysize    = 2 * AES_MAX_KEY_SIZE,
 683        .ivsize         = AES_BLOCK_SIZE,
 684        .walksize       = 2 * AES_BLOCK_SIZE,
 685        .setkey         = xts_set_key,
 686        .encrypt        = xts_encrypt,
 687        .decrypt        = xts_decrypt,
 688}, {
 689#endif
 690        .base = {
 691                .cra_name               = "cts(cbc(aes))",
 692                .cra_driver_name        = "cts-cbc-aes-" MODE,
 693                .cra_priority           = PRIO,
 694                .cra_blocksize          = AES_BLOCK_SIZE,
 695                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 696                .cra_module             = THIS_MODULE,
 697        },
 698        .min_keysize    = AES_MIN_KEY_SIZE,
 699        .max_keysize    = AES_MAX_KEY_SIZE,
 700        .ivsize         = AES_BLOCK_SIZE,
 701        .walksize       = 2 * AES_BLOCK_SIZE,
 702        .setkey         = skcipher_aes_setkey,
 703        .encrypt        = cts_cbc_encrypt,
 704        .decrypt        = cts_cbc_decrypt,
 705}, {
 706        .base = {
 707                .cra_name               = "essiv(cbc(aes),sha256)",
 708                .cra_driver_name        = "essiv-cbc-aes-sha256-" MODE,
 709                .cra_priority           = PRIO + 1,
 710                .cra_blocksize          = AES_BLOCK_SIZE,
 711                .cra_ctxsize            = sizeof(struct crypto_aes_essiv_cbc_ctx),
 712                .cra_module             = THIS_MODULE,
 713        },
 714        .min_keysize    = AES_MIN_KEY_SIZE,
 715        .max_keysize    = AES_MAX_KEY_SIZE,
 716        .ivsize         = AES_BLOCK_SIZE,
 717        .setkey         = essiv_cbc_set_key,
 718        .encrypt        = essiv_cbc_encrypt,
 719        .decrypt        = essiv_cbc_decrypt,
 720        .init           = essiv_cbc_init_tfm,
 721        .exit           = essiv_cbc_exit_tfm,
 722} };
 723
 724static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 725                         unsigned int key_len)
 726{
 727        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 728
 729        return aes_expandkey(&ctx->key, in_key, key_len);
 730}
 731
 732static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
 733{
 734        u64 a = be64_to_cpu(x->a);
 735        u64 b = be64_to_cpu(x->b);
 736
 737        y->a = cpu_to_be64((a << 1) | (b >> 63));
 738        y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
 739}
 740
 741static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
 742                       unsigned int key_len)
 743{
 744        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 745        be128 *consts = (be128 *)ctx->consts;
 746        int rounds = 6 + key_len / 4;
 747        int err;
 748
 749        err = cbcmac_setkey(tfm, in_key, key_len);
 750        if (err)
 751                return err;
 752
 753        /* encrypt the zero vector */
 754        kernel_neon_begin();
 755        aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
 756                        rounds, 1);
 757        kernel_neon_end();
 758
 759        cmac_gf128_mul_by_x(consts, consts);
 760        cmac_gf128_mul_by_x(consts + 1, consts);
 761
 762        return 0;
 763}
 764
 765static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
 766                       unsigned int key_len)
 767{
 768        static u8 const ks[3][AES_BLOCK_SIZE] = {
 769                { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
 770                { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
 771                { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
 772        };
 773
 774        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
 775        int rounds = 6 + key_len / 4;
 776        u8 key[AES_BLOCK_SIZE];
 777        int err;
 778
 779        err = cbcmac_setkey(tfm, in_key, key_len);
 780        if (err)
 781                return err;
 782
 783        kernel_neon_begin();
 784        aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
 785        aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
 786        kernel_neon_end();
 787
 788        return cbcmac_setkey(tfm, key, sizeof(key));
 789}
 790
 791static int mac_init(struct shash_desc *desc)
 792{
 793        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 794
 795        memset(ctx->dg, 0, AES_BLOCK_SIZE);
 796        ctx->len = 0;
 797
 798        return 0;
 799}
 800
 801static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
 802                          u8 dg[], int enc_before, int enc_after)
 803{
 804        int rounds = 6 + ctx->key_length / 4;
 805
 806        if (crypto_simd_usable()) {
 807                int rem;
 808
 809                do {
 810                        kernel_neon_begin();
 811                        rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
 812                                             dg, enc_before, enc_after);
 813                        kernel_neon_end();
 814                        in += (blocks - rem) * AES_BLOCK_SIZE;
 815                        blocks = rem;
 816                        enc_before = 0;
 817                } while (blocks);
 818        } else {
 819                if (enc_before)
 820                        aes_encrypt(ctx, dg, dg);
 821
 822                while (blocks--) {
 823                        crypto_xor(dg, in, AES_BLOCK_SIZE);
 824                        in += AES_BLOCK_SIZE;
 825
 826                        if (blocks || enc_after)
 827                                aes_encrypt(ctx, dg, dg);
 828                }
 829        }
 830}
 831
 832static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
 833{
 834        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 835        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 836
 837        while (len > 0) {
 838                unsigned int l;
 839
 840                if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
 841                    (ctx->len + len) > AES_BLOCK_SIZE) {
 842
 843                        int blocks = len / AES_BLOCK_SIZE;
 844
 845                        len %= AES_BLOCK_SIZE;
 846
 847                        mac_do_update(&tctx->key, p, blocks, ctx->dg,
 848                                      (ctx->len != 0), (len != 0));
 849
 850                        p += blocks * AES_BLOCK_SIZE;
 851
 852                        if (!len) {
 853                                ctx->len = AES_BLOCK_SIZE;
 854                                break;
 855                        }
 856                        ctx->len = 0;
 857                }
 858
 859                l = min(len, AES_BLOCK_SIZE - ctx->len);
 860
 861                if (l <= AES_BLOCK_SIZE) {
 862                        crypto_xor(ctx->dg + ctx->len, p, l);
 863                        ctx->len += l;
 864                        len -= l;
 865                        p += l;
 866                }
 867        }
 868
 869        return 0;
 870}
 871
 872static int cbcmac_final(struct shash_desc *desc, u8 *out)
 873{
 874        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 875        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 876
 877        mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
 878
 879        memcpy(out, ctx->dg, AES_BLOCK_SIZE);
 880
 881        return 0;
 882}
 883
 884static int cmac_final(struct shash_desc *desc, u8 *out)
 885{
 886        struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
 887        struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
 888        u8 *consts = tctx->consts;
 889
 890        if (ctx->len != AES_BLOCK_SIZE) {
 891                ctx->dg[ctx->len] ^= 0x80;
 892                consts += AES_BLOCK_SIZE;
 893        }
 894
 895        mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
 896
 897        memcpy(out, ctx->dg, AES_BLOCK_SIZE);
 898
 899        return 0;
 900}
 901
 902static struct shash_alg mac_algs[] = { {
 903        .base.cra_name          = "cmac(aes)",
 904        .base.cra_driver_name   = "cmac-aes-" MODE,
 905        .base.cra_priority      = PRIO,
 906        .base.cra_blocksize     = AES_BLOCK_SIZE,
 907        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
 908                                  2 * AES_BLOCK_SIZE,
 909        .base.cra_module        = THIS_MODULE,
 910
 911        .digestsize             = AES_BLOCK_SIZE,
 912        .init                   = mac_init,
 913        .update                 = mac_update,
 914        .final                  = cmac_final,
 915        .setkey                 = cmac_setkey,
 916        .descsize               = sizeof(struct mac_desc_ctx),
 917}, {
 918        .base.cra_name          = "xcbc(aes)",
 919        .base.cra_driver_name   = "xcbc-aes-" MODE,
 920        .base.cra_priority      = PRIO,
 921        .base.cra_blocksize     = AES_BLOCK_SIZE,
 922        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
 923                                  2 * AES_BLOCK_SIZE,
 924        .base.cra_module        = THIS_MODULE,
 925
 926        .digestsize             = AES_BLOCK_SIZE,
 927        .init                   = mac_init,
 928        .update                 = mac_update,
 929        .final                  = cmac_final,
 930        .setkey                 = xcbc_setkey,
 931        .descsize               = sizeof(struct mac_desc_ctx),
 932}, {
 933        .base.cra_name          = "cbcmac(aes)",
 934        .base.cra_driver_name   = "cbcmac-aes-" MODE,
 935        .base.cra_priority      = PRIO,
 936        .base.cra_blocksize     = 1,
 937        .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
 938        .base.cra_module        = THIS_MODULE,
 939
 940        .digestsize             = AES_BLOCK_SIZE,
 941        .init                   = mac_init,
 942        .update                 = mac_update,
 943        .final                  = cbcmac_final,
 944        .setkey                 = cbcmac_setkey,
 945        .descsize               = sizeof(struct mac_desc_ctx),
 946} };
 947
 948static void aes_exit(void)
 949{
 950        crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
 951        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 952}
 953
 954static int __init aes_init(void)
 955{
 956        int err;
 957
 958        err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 959        if (err)
 960                return err;
 961
 962        err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
 963        if (err)
 964                goto unregister_ciphers;
 965
 966        return 0;
 967
 968unregister_ciphers:
 969        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 970        return err;
 971}
 972
 973#ifdef USE_V8_CRYPTO_EXTENSIONS
 974module_cpu_feature_match(AES, aes_init);
 975#else
 976module_init(aes_init);
 977EXPORT_SYMBOL(neon_aes_ecb_encrypt);
 978EXPORT_SYMBOL(neon_aes_cbc_encrypt);
 979EXPORT_SYMBOL(neon_aes_ctr_encrypt);
 980EXPORT_SYMBOL(neon_aes_xts_encrypt);
 981EXPORT_SYMBOL(neon_aes_xts_decrypt);
 982#endif
 983module_exit(aes_exit);
 984