linux/arch/arm/crypto/aes-ce-glue.c
<<
>>
Prefs
   1/*
   2 * aes-ce-glue.c - wrapper code for ARMv8 AES
   3 *
   4 * Copyright (C) 2015 Linaro Ltd <ard.biesheuvel@linaro.org>
   5 *
   6 * This program is free software; you can redistribute it and/or modify
   7 * it under the terms of the GNU General Public License version 2 as
   8 * published by the Free Software Foundation.
   9 */
  10
  11#include <asm/hwcap.h>
  12#include <asm/neon.h>
  13#include <asm/hwcap.h>
  14#include <crypto/aes.h>
  15#include <crypto/internal/simd.h>
  16#include <crypto/internal/skcipher.h>
  17#include <linux/cpufeature.h>
  18#include <linux/module.h>
  19#include <crypto/xts.h>
  20
  21MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
  22MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
  23MODULE_LICENSE("GPL v2");
  24
  25/* defined in aes-ce-core.S */
  26asmlinkage u32 ce_aes_sub(u32 input);
  27asmlinkage void ce_aes_invert(void *dst, void *src);
  28
  29asmlinkage void ce_aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
  30                                   int rounds, int blocks);
  31asmlinkage void ce_aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
  32                                   int rounds, int blocks);
  33
  34asmlinkage void ce_aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
  35                                   int rounds, int blocks, u8 iv[]);
  36asmlinkage void ce_aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
  37                                   int rounds, int blocks, u8 iv[]);
  38
  39asmlinkage void ce_aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
  40                                   int rounds, int blocks, u8 ctr[]);
  41
  42asmlinkage void ce_aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[],
  43                                   int rounds, int blocks, u8 iv[],
  44                                   u8 const rk2[], int first);
  45asmlinkage void ce_aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[],
  46                                   int rounds, int blocks, u8 iv[],
  47                                   u8 const rk2[], int first);
  48
  49struct aes_block {
  50        u8 b[AES_BLOCK_SIZE];
  51};
  52
  53static int num_rounds(struct crypto_aes_ctx *ctx)
  54{
  55        /*
  56         * # of rounds specified by AES:
  57         * 128 bit key          10 rounds
  58         * 192 bit key          12 rounds
  59         * 256 bit key          14 rounds
  60         * => n byte key        => 6 + (n/4) rounds
  61         */
  62        return 6 + ctx->key_length / 4;
  63}
  64
  65static int ce_aes_expandkey(struct crypto_aes_ctx *ctx, const u8 *in_key,
  66                            unsigned int key_len)
  67{
  68        /*
  69         * The AES key schedule round constants
  70         */
  71        static u8 const rcon[] = {
  72                0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36,
  73        };
  74
  75        u32 kwords = key_len / sizeof(u32);
  76        struct aes_block *key_enc, *key_dec;
  77        int i, j;
  78
  79        if (key_len != AES_KEYSIZE_128 &&
  80            key_len != AES_KEYSIZE_192 &&
  81            key_len != AES_KEYSIZE_256)
  82                return -EINVAL;
  83
  84        memcpy(ctx->key_enc, in_key, key_len);
  85        ctx->key_length = key_len;
  86
  87        kernel_neon_begin();
  88        for (i = 0; i < sizeof(rcon); i++) {
  89                u32 *rki = ctx->key_enc + (i * kwords);
  90                u32 *rko = rki + kwords;
  91
  92#ifndef CONFIG_CPU_BIG_ENDIAN
  93                rko[0] = ror32(ce_aes_sub(rki[kwords - 1]), 8);
  94                rko[0] = rko[0] ^ rki[0] ^ rcon[i];
  95#else
  96                rko[0] = rol32(ce_aes_sub(rki[kwords - 1]), 8);
  97                rko[0] = rko[0] ^ rki[0] ^ (rcon[i] << 24);
  98#endif
  99                rko[1] = rko[0] ^ rki[1];
 100                rko[2] = rko[1] ^ rki[2];
 101                rko[3] = rko[2] ^ rki[3];
 102
 103                if (key_len == AES_KEYSIZE_192) {
 104                        if (i >= 7)
 105                                break;
 106                        rko[4] = rko[3] ^ rki[4];
 107                        rko[5] = rko[4] ^ rki[5];
 108                } else if (key_len == AES_KEYSIZE_256) {
 109                        if (i >= 6)
 110                                break;
 111                        rko[4] = ce_aes_sub(rko[3]) ^ rki[4];
 112                        rko[5] = rko[4] ^ rki[5];
 113                        rko[6] = rko[5] ^ rki[6];
 114                        rko[7] = rko[6] ^ rki[7];
 115                }
 116        }
 117
 118        /*
 119         * Generate the decryption keys for the Equivalent Inverse Cipher.
 120         * This involves reversing the order of the round keys, and applying
 121         * the Inverse Mix Columns transformation on all but the first and
 122         * the last one.
 123         */
 124        key_enc = (struct aes_block *)ctx->key_enc;
 125        key_dec = (struct aes_block *)ctx->key_dec;
 126        j = num_rounds(ctx);
 127
 128        key_dec[0] = key_enc[j];
 129        for (i = 1, j--; j > 0; i++, j--)
 130                ce_aes_invert(key_dec + i, key_enc + j);
 131        key_dec[i] = key_enc[0];
 132
 133        kernel_neon_end();
 134        return 0;
 135}
 136
 137static int ce_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
 138                         unsigned int key_len)
 139{
 140        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 141        int ret;
 142
 143        ret = ce_aes_expandkey(ctx, in_key, key_len);
 144        if (!ret)
 145                return 0;
 146
 147        crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
 148        return -EINVAL;
 149}
 150
 151struct crypto_aes_xts_ctx {
 152        struct crypto_aes_ctx key1;
 153        struct crypto_aes_ctx __aligned(8) key2;
 154};
 155
 156static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
 157                       unsigned int key_len)
 158{
 159        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 160        int ret;
 161
 162        ret = xts_verify_key(tfm, in_key, key_len);
 163        if (ret)
 164                return ret;
 165
 166        ret = ce_aes_expandkey(&ctx->key1, in_key, key_len / 2);
 167        if (!ret)
 168                ret = ce_aes_expandkey(&ctx->key2, &in_key[key_len / 2],
 169                                       key_len / 2);
 170        if (!ret)
 171                return 0;
 172
 173        crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
 174        return -EINVAL;
 175}
 176
 177static int ecb_encrypt(struct skcipher_request *req)
 178{
 179        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 180        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 181        struct skcipher_walk walk;
 182        unsigned int blocks;
 183        int err;
 184
 185        err = skcipher_walk_virt(&walk, req, true);
 186
 187        kernel_neon_begin();
 188        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 189                ce_aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 190                                   (u8 *)ctx->key_enc, num_rounds(ctx), blocks);
 191                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 192        }
 193        kernel_neon_end();
 194        return err;
 195}
 196
 197static int ecb_decrypt(struct skcipher_request *req)
 198{
 199        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 200        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 201        struct skcipher_walk walk;
 202        unsigned int blocks;
 203        int err;
 204
 205        err = skcipher_walk_virt(&walk, req, true);
 206
 207        kernel_neon_begin();
 208        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 209                ce_aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 210                                   (u8 *)ctx->key_dec, num_rounds(ctx), blocks);
 211                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 212        }
 213        kernel_neon_end();
 214        return err;
 215}
 216
 217static int cbc_encrypt(struct skcipher_request *req)
 218{
 219        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 220        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 221        struct skcipher_walk walk;
 222        unsigned int blocks;
 223        int err;
 224
 225        err = skcipher_walk_virt(&walk, req, true);
 226
 227        kernel_neon_begin();
 228        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 229                ce_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 230                                   (u8 *)ctx->key_enc, num_rounds(ctx), blocks,
 231                                   walk.iv);
 232                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 233        }
 234        kernel_neon_end();
 235        return err;
 236}
 237
 238static int cbc_decrypt(struct skcipher_request *req)
 239{
 240        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 241        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 242        struct skcipher_walk walk;
 243        unsigned int blocks;
 244        int err;
 245
 246        err = skcipher_walk_virt(&walk, req, true);
 247
 248        kernel_neon_begin();
 249        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 250                ce_aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 251                                   (u8 *)ctx->key_dec, num_rounds(ctx), blocks,
 252                                   walk.iv);
 253                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 254        }
 255        kernel_neon_end();
 256        return err;
 257}
 258
 259static int ctr_encrypt(struct skcipher_request *req)
 260{
 261        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 262        struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
 263        struct skcipher_walk walk;
 264        int err, blocks;
 265
 266        err = skcipher_walk_virt(&walk, req, true);
 267
 268        kernel_neon_begin();
 269        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
 270                ce_aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 271                                   (u8 *)ctx->key_enc, num_rounds(ctx), blocks,
 272                                   walk.iv);
 273                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 274        }
 275        if (walk.nbytes) {
 276                u8 __aligned(8) tail[AES_BLOCK_SIZE];
 277                unsigned int nbytes = walk.nbytes;
 278                u8 *tdst = walk.dst.virt.addr;
 279                u8 *tsrc = walk.src.virt.addr;
 280
 281                /*
 282                 * Tell aes_ctr_encrypt() to process a tail block.
 283                 */
 284                blocks = -1;
 285
 286                ce_aes_ctr_encrypt(tail, NULL, (u8 *)ctx->key_enc,
 287                                   num_rounds(ctx), blocks, walk.iv);
 288                crypto_xor_cpy(tdst, tsrc, tail, nbytes);
 289                err = skcipher_walk_done(&walk, 0);
 290        }
 291        kernel_neon_end();
 292
 293        return err;
 294}
 295
 296static int xts_encrypt(struct skcipher_request *req)
 297{
 298        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 299        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 300        int err, first, rounds = num_rounds(&ctx->key1);
 301        struct skcipher_walk walk;
 302        unsigned int blocks;
 303
 304        err = skcipher_walk_virt(&walk, req, true);
 305
 306        kernel_neon_begin();
 307        for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
 308                ce_aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 309                                   (u8 *)ctx->key1.key_enc, rounds, blocks,
 310                                   walk.iv, (u8 *)ctx->key2.key_enc, first);
 311                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 312        }
 313        kernel_neon_end();
 314
 315        return err;
 316}
 317
 318static int xts_decrypt(struct skcipher_request *req)
 319{
 320        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
 321        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
 322        int err, first, rounds = num_rounds(&ctx->key1);
 323        struct skcipher_walk walk;
 324        unsigned int blocks;
 325
 326        err = skcipher_walk_virt(&walk, req, true);
 327
 328        kernel_neon_begin();
 329        for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
 330                ce_aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 331                                   (u8 *)ctx->key1.key_dec, rounds, blocks,
 332                                   walk.iv, (u8 *)ctx->key2.key_enc, first);
 333                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
 334        }
 335        kernel_neon_end();
 336
 337        return err;
 338}
 339
 340static struct skcipher_alg aes_algs[] = { {
 341        .base = {
 342                .cra_name               = "__ecb(aes)",
 343                .cra_driver_name        = "__ecb-aes-ce",
 344                .cra_priority           = 300,
 345                .cra_flags              = CRYPTO_ALG_INTERNAL,
 346                .cra_blocksize          = AES_BLOCK_SIZE,
 347                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 348                .cra_module             = THIS_MODULE,
 349        },
 350        .min_keysize    = AES_MIN_KEY_SIZE,
 351        .max_keysize    = AES_MAX_KEY_SIZE,
 352        .setkey         = ce_aes_setkey,
 353        .encrypt        = ecb_encrypt,
 354        .decrypt        = ecb_decrypt,
 355}, {
 356        .base = {
 357                .cra_name               = "__cbc(aes)",
 358                .cra_driver_name        = "__cbc-aes-ce",
 359                .cra_priority           = 300,
 360                .cra_flags              = CRYPTO_ALG_INTERNAL,
 361                .cra_blocksize          = AES_BLOCK_SIZE,
 362                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 363                .cra_module             = THIS_MODULE,
 364        },
 365        .min_keysize    = AES_MIN_KEY_SIZE,
 366        .max_keysize    = AES_MAX_KEY_SIZE,
 367        .ivsize         = AES_BLOCK_SIZE,
 368        .setkey         = ce_aes_setkey,
 369        .encrypt        = cbc_encrypt,
 370        .decrypt        = cbc_decrypt,
 371}, {
 372        .base = {
 373                .cra_name               = "__ctr(aes)",
 374                .cra_driver_name        = "__ctr-aes-ce",
 375                .cra_priority           = 300,
 376                .cra_flags              = CRYPTO_ALG_INTERNAL,
 377                .cra_blocksize          = 1,
 378                .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
 379                .cra_module             = THIS_MODULE,
 380        },
 381        .min_keysize    = AES_MIN_KEY_SIZE,
 382        .max_keysize    = AES_MAX_KEY_SIZE,
 383        .ivsize         = AES_BLOCK_SIZE,
 384        .chunksize      = AES_BLOCK_SIZE,
 385        .setkey         = ce_aes_setkey,
 386        .encrypt        = ctr_encrypt,
 387        .decrypt        = ctr_encrypt,
 388}, {
 389        .base = {
 390                .cra_name               = "__xts(aes)",
 391                .cra_driver_name        = "__xts-aes-ce",
 392                .cra_priority           = 300,
 393                .cra_flags              = CRYPTO_ALG_INTERNAL,
 394                .cra_blocksize          = AES_BLOCK_SIZE,
 395                .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
 396                .cra_module             = THIS_MODULE,
 397        },
 398        .min_keysize    = 2 * AES_MIN_KEY_SIZE,
 399        .max_keysize    = 2 * AES_MAX_KEY_SIZE,
 400        .ivsize         = AES_BLOCK_SIZE,
 401        .setkey         = xts_set_key,
 402        .encrypt        = xts_encrypt,
 403        .decrypt        = xts_decrypt,
 404} };
 405
 406static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
 407
 408static void aes_exit(void)
 409{
 410        int i;
 411
 412        for (i = 0; i < ARRAY_SIZE(aes_simd_algs) && aes_simd_algs[i]; i++)
 413                simd_skcipher_free(aes_simd_algs[i]);
 414
 415        crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 416}
 417
 418static int __init aes_init(void)
 419{
 420        struct simd_skcipher_alg *simd;
 421        const char *basename;
 422        const char *algname;
 423        const char *drvname;
 424        int err;
 425        int i;
 426
 427        err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
 428        if (err)
 429                return err;
 430
 431        for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
 432                algname = aes_algs[i].base.cra_name + 2;
 433                drvname = aes_algs[i].base.cra_driver_name + 2;
 434                basename = aes_algs[i].base.cra_driver_name;
 435                simd = simd_skcipher_create_compat(algname, drvname, basename);
 436                err = PTR_ERR(simd);
 437                if (IS_ERR(simd))
 438                        goto unregister_simds;
 439
 440                aes_simd_algs[i] = simd;
 441        }
 442
 443        return 0;
 444
 445unregister_simds:
 446        aes_exit();
 447        return err;
 448}
 449
 450module_cpu_feature_match(AES, aes_init);
 451module_exit(aes_exit);
 452