linux/arch/arm64/crypto/aes-neonbs-glue.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Bit sliced AES using NEON instructions
   4 *
   5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
   6 */
   7
   8#include <asm/neon.h>
   9#include <asm/simd.h>
  10#include <crypto/aes.h>
  11#include <crypto/internal/simd.h>
  12#include <crypto/internal/skcipher.h>
  13#include <crypto/xts.h>
  14#include <linux/module.h>
  15
  16#include "aes-ctr-fallback.h"
  17
  18MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  19MODULE_LICENSE("GPL v2");
  20
  21MODULE_ALIAS_CRYPTO("ecb(aes)");
  22MODULE_ALIAS_CRYPTO("cbc(aes)");
  23MODULE_ALIAS_CRYPTO("ctr(aes)");
  24MODULE_ALIAS_CRYPTO("xts(aes)");
  25
  26asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
  27
  28asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
  29                                  int rounds, int blocks);
  30asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
  31                                  int rounds, int blocks);
  32
  33asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
  34                                  int rounds, int blocks, u8 iv[]);
  35
  36asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
  37                                  int rounds, int blocks, u8 iv[], u8 final[]);
  38
  39asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
  40                                  int rounds, int blocks, u8 iv[]);
  41asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
  42                                  int rounds, int blocks, u8 iv[]);
  43
  44/* borrowed from aes-neon-blk.ko */
  45asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
  46                                     int rounds, int blocks);
  47asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
  48                                     int rounds, int blocks, u8 iv[]);
  49
  50struct aesbs_ctx {
  51        u8      rk[13 * (8 * AES_BLOCK_SIZE) + 32];
  52        int     rounds;
  53} __aligned(AES_BLOCK_SIZE);
  54
  55struct aesbs_cbc_ctx {
  56        struct aesbs_ctx        key;
  57        u32                     enc[AES_MAX_KEYLENGTH_U32];
  58};
  59
  60struct aesbs_ctr_ctx {
  61        struct aesbs_ctx        key;            /* must be first member */
  62        struct crypto_aes_ctx   fallback;
  63};
  64
  65struct aesbs_xts_ctx {
  66        struct aesbs_ctx        key;
  67        u32                     twkey[AES_MAX_KEYLENGTH_U32];
  68};
  69
  70static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
  71                        unsigned int key_len)
  72{
  73        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  74        struct crypto_aes_ctx rk;
  75        int err;
  76
  77        err = crypto_aes_expand_key(&rk, in_key, key_len);
  78        if (err)
  79                return err;
  80
  81        ctx->rounds = 6 + key_len / 4;
  82
  83        kernel_neon_begin();
  84        aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
  85        kernel_neon_end();
  86
  87        return 0;
  88}
  89
  90static int __ecb_crypt(struct skcipher_request *req,
  91                       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
  92                                  int rounds, int blocks))
  93{
  94        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
  95        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
  96        struct skcipher_walk walk;
  97        int err;
  98
  99        err = skcipher_walk_virt(&walk, req, false);
 100
 101        while (walk.nbytes >= AES_BLOCK_SIZE) {
 102                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 103
 104                if (walk.nbytes < walk.total)
 105                        blocks = round_down(blocks,
 106                                            walk.stride / AES_BLOCK_SIZE);
 107
 108                kernel_neon_begin();
 109                fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
 110                   ctx->rounds, blocks);
 111                kernel_neon_end();
 112                err = skcipher_walk_done(&walk,
 113                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 114        }
 115
 116        return err;
 117}
 118
 119static int ecb_encrypt(struct skcipher_request *req)
 120{
 121        return __ecb_crypt(req, aesbs_ecb_encrypt);
 122}
 123
 124static int ecb_decrypt(struct skcipher_request *req)
 125{
 126        return __ecb_crypt(req, aesbs_ecb_decrypt);
 127}
 128
 129static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 130                            unsigned int key_len)
 131{
 132        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 133        struct crypto_aes_ctx rk;
 134        int err;
 135
 136        err = crypto_aes_expand_key(&rk, in_key, key_len);
 137        if (err)
 138                return err;
 139
 140        ctx->key.rounds = 6 + key_len / 4;
 141
 142        memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc));
 143
 144        kernel_neon_begin();
 145        aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds);
 146        kernel_neon_end();
 147
 148        return 0;
 149}
 150
 151static int cbc_encrypt(struct skcipher_request *req)
 152{
 153        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 154        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 155        struct skcipher_walk walk;
 156        int err;
 157
 158        err = skcipher_walk_virt(&walk, req, false);
 159
 160        while (walk.nbytes >= AES_BLOCK_SIZE) {
 161                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 162
 163                /* fall back to the non-bitsliced NEON implementation */
 164                kernel_neon_begin();
 165                neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 166                                     ctx->enc, ctx->key.rounds, blocks,
 167                                     walk.iv);
 168                kernel_neon_end();
 169                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 170        }
 171        return err;
 172}
 173
 174static int cbc_decrypt(struct skcipher_request *req)
 175{
 176        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 177        struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
 178        struct skcipher_walk walk;
 179        int err;
 180
 181        err = skcipher_walk_virt(&walk, req, false);
 182
 183        while (walk.nbytes >= AES_BLOCK_SIZE) {
 184                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 185
 186                if (walk.nbytes < walk.total)
 187                        blocks = round_down(blocks,
 188                                            walk.stride / AES_BLOCK_SIZE);
 189
 190                kernel_neon_begin();
 191                aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 192                                  ctx->key.rk, ctx->key.rounds, blocks,
 193                                  walk.iv);
 194                kernel_neon_end();
 195                err = skcipher_walk_done(&walk,
 196                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 197        }
 198
 199        return err;
 200}
 201
 202static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key,
 203                                 unsigned int key_len)
 204{
 205        struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
 206        int err;
 207
 208        err = crypto_aes_expand_key(&ctx->fallback, in_key, key_len);
 209        if (err)
 210                return err;
 211
 212        ctx->key.rounds = 6 + key_len / 4;
 213
 214        kernel_neon_begin();
 215        aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
 216        kernel_neon_end();
 217
 218        return 0;
 219}
 220
 221static int ctr_encrypt(struct skcipher_request *req)
 222{
 223        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 224        struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
 225        struct skcipher_walk walk;
 226        u8 buf[AES_BLOCK_SIZE];
 227        int err;
 228
 229        err = skcipher_walk_virt(&walk, req, false);
 230
 231        while (walk.nbytes > 0) {
 232                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 233                u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
 234
 235                if (walk.nbytes < walk.total) {
 236                        blocks = round_down(blocks,
 237                                            walk.stride / AES_BLOCK_SIZE);
 238                        final = NULL;
 239                }
 240
 241                kernel_neon_begin();
 242                aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 243                                  ctx->rk, ctx->rounds, blocks, walk.iv, final);
 244                kernel_neon_end();
 245
 246                if (final) {
 247                        u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
 248                        u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
 249
 250                        crypto_xor_cpy(dst, src, final,
 251                                       walk.total % AES_BLOCK_SIZE);
 252
 253                        err = skcipher_walk_done(&walk, 0);
 254                        break;
 255                }
 256                err = skcipher_walk_done(&walk,
 257                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 258        }
 259        return err;
 260}
 261
 262static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 263                            unsigned int key_len)
 264{
 265        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 266        struct crypto_aes_ctx rk;
 267        int err;
 268
 269        err = xts_verify_key(tfm, in_key, key_len);
 270        if (err)
 271                return err;
 272
 273        key_len /= 2;
 274        err = crypto_aes_expand_key(&rk, in_key + key_len, key_len);
 275        if (err)
 276                return err;
 277
 278        memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey));
 279
 280        return aesbs_setkey(tfm, in_key, key_len);
 281}
 282
 283static int ctr_encrypt_sync(struct skcipher_request *req)
 284{
 285        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 286        struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm);
 287
 288        if (!crypto_simd_usable())
 289                return aes_ctr_encrypt_fallback(&ctx->fallback, req);
 290
 291        return ctr_encrypt(req);
 292}
 293
 294static int __xts_crypt(struct skcipher_request *req,
 295                       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
 296                                  int rounds, int blocks, u8 iv[]))
 297{
 298        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 299        struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 300        struct skcipher_walk walk;
 301        int err;
 302
 303        err = skcipher_walk_virt(&walk, req, false);
 304        if (err)
 305                return err;
 306
 307        kernel_neon_begin();
 308        neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey, ctx->key.rounds, 1);
 309        kernel_neon_end();
 310
 311        while (walk.nbytes >= AES_BLOCK_SIZE) {
 312                unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
 313
 314                if (walk.nbytes < walk.total)
 315                        blocks = round_down(blocks,
 316                                            walk.stride / AES_BLOCK_SIZE);
 317
 318                kernel_neon_begin();
 319                fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
 320                   ctx->key.rounds, blocks, walk.iv);
 321                kernel_neon_end();
 322                err = skcipher_walk_done(&walk,
 323                                         walk.nbytes - blocks * AES_BLOCK_SIZE);
 324        }
 325        return err;
 326}
 327
 328static int xts_encrypt(struct skcipher_request *req)
 329{
 330        return __xts_crypt(req, aesbs_xts_encrypt);
 331}
 332
 333static int xts_decrypt(struct skcipher_request *req)
 334{
 335        return __xts_crypt(req, aesbs_xts_decrypt);
 336}
 337
 338static struct skcipher_alg aes_algs[] = { {
 339        .base.cra_name          = "__ecb(aes)",
 340        .base.cra_driver_name   = "__ecb-aes-neonbs",
 341        .base.cra_priority      = 250,
 342        .base.cra_blocksize     = AES_BLOCK_SIZE,
 343        .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
 344        .base.cra_module        = THIS_MODULE,
 345        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 346
 347        .min_keysize            = AES_MIN_KEY_SIZE,
 348        .max_keysize            = AES_MAX_KEY_SIZE,
 349        .walksize               = 8 * AES_BLOCK_SIZE,
 350        .setkey                 = aesbs_setkey,
 351        .encrypt                = ecb_encrypt,
 352        .decrypt                = ecb_decrypt,
 353}, {
 354        .base.cra_name          = "__cbc(aes)",
 355        .base.cra_driver_name   = "__cbc-aes-neonbs",
 356        .base.cra_priority      = 250,
 357        .base.cra_blocksize     = AES_BLOCK_SIZE,
 358        .base.cra_ctxsize       = sizeof(struct aesbs_cbc_ctx),
 359        .base.cra_module        = THIS_MODULE,
 360        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 361
 362        .min_keysize            = AES_MIN_KEY_SIZE,
 363        .max_keysize            = AES_MAX_KEY_SIZE,
 364        .walksize               = 8 * AES_BLOCK_SIZE,
 365        .ivsize                 = AES_BLOCK_SIZE,
 366        .setkey                 = aesbs_cbc_setkey,
 367        .encrypt                = cbc_encrypt,
 368        .decrypt                = cbc_decrypt,
 369}, {
 370        .base.cra_name          = "__ctr(aes)",
 371        .base.cra_driver_name   = "__ctr-aes-neonbs",
 372        .base.cra_priority      = 250,
 373        .base.cra_blocksize     = 1,
 374        .base.cra_ctxsize       = sizeof(struct aesbs_ctx),
 375        .base.cra_module        = THIS_MODULE,
 376        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 377
 378        .min_keysize            = AES_MIN_KEY_SIZE,
 379        .max_keysize            = AES_MAX_KEY_SIZE,
 380        .chunksize              = AES_BLOCK_SIZE,
 381        .walksize               = 8 * AES_BLOCK_SIZE,
 382        .ivsize                 = AES_BLOCK_SIZE,
 383        .setkey                 = aesbs_setkey,
 384        .encrypt                = ctr_encrypt,
 385        .decrypt                = ctr_encrypt,
 386}, {
 387        .base.cra_name          = "ctr(aes)",
 388        .base.cra_driver_name   = "ctr-aes-neonbs",
 389        .base.cra_priority      = 250 - 1,
 390        .base.cra_blocksize     = 1,
 391        .base.cra_ctxsize       = sizeof(struct aesbs_ctr_ctx),
 392        .base.cra_module        = THIS_MODULE,
 393
 394        .min_keysize            = AES_MIN_KEY_SIZE,
 395        .max_keysize            = AES_MAX_KEY_SIZE,
 396        .chunksize              = AES_BLOCK_SIZE,
 397        .walksize               = 8 * AES_BLOCK_SIZE,
 398        .ivsize                 = AES_BLOCK_SIZE,
 399        .setkey                 = aesbs_ctr_setkey_sync,
 400        .encrypt                = ctr_encrypt_sync,
 401        .decrypt                = ctr_encrypt_sync,
 402}, {
 403        .base.cra_name          = "__xts(aes)",
 404        .base.cra_driver_name   = "__xts-aes-neonbs",
 405        .base.cra_priority      = 250,
 406        .base.cra_blocksize     = AES_BLOCK_SIZE,
 407        .base.cra_ctxsize       = sizeof(struct aesbs_xts_ctx),
 408        .base.cra_module        = THIS_MODULE,
 409        .base.cra_flags         = CRYPTO_ALG_INTERNAL,
 410
 411        .min_keysize            = 2 * AES_MIN_KEY_SIZE,
 412        .max_keysize            = 2 * AES_MAX_KEY_SIZE,
 413        .walksize               = 8 * AES_BLOCK_SIZE,
 414        .ivsize                 = AES_BLOCK_SIZE,
 415        .setkey                 = aesbs_xts_setkey,
 416        .encrypt                = xts_encrypt,
 417        .decrypt                = xts_decrypt,
 418} };
 419
 420static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
 421
 422static void aes_exit(void)
 423{
 424        int i;
 425
 426        for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
 427                if (aes_simd_algs[i])
 428                        simd_skcipher_free(aes_simd_algs[i]);
 429
 430        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 431}
 432
 433static int __init aes_init(void)
 434{
 435        struct simd_skcipher_alg *simd;
 436        const char *basename;
 437        const char *algname;
 438        const char *drvname;
 439        int err;
 440        int i;
 441
 442        if (!cpu_have_named_feature(ASIMD))
 443                return -ENODEV;
 444
 445        err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 446        if (err)
 447                return err;
 448
 449        for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
 450                if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
 451                        continue;
 452
 453                algname = aes_algs[i].base.cra_name + 2;
 454                drvname = aes_algs[i].base.cra_driver_name + 2;
 455                basename = aes_algs[i].base.cra_driver_name;
 456                simd = simd_skcipher_create_compat(algname, drvname, basename);
 457                err = PTR_ERR(simd);
 458                if (IS_ERR(simd))
 459                        goto unregister_simds;
 460
 461                aes_simd_algs[i] = simd;
 462        }
 463        return 0;
 464
 465unregister_simds:
 466        aes_exit();
 467        return err;
 468}
 469
 470module_init(aes_init);
 471module_exit(aes_exit);
 472