linux/crypto/cbc.c
<<
>>
Prefs
   1/*
   2 * CBC: Cipher Block Chaining mode
   3 *
   4 * Copyright (c) 2006 Herbert Xu <herbert@gondor.apana.org.au>
   5 *
   6 * This program is free software; you can redistribute it and/or modify it
   7 * under the terms of the GNU General Public License as published by the Free
   8 * Software Foundation; either version 2 of the License, or (at your option)
   9 * any later version.
  10 *
  11 */
  12
  13#include <crypto/algapi.h>
  14#include <linux/err.h>
  15#include <linux/init.h>
  16#include <linux/kernel.h>
  17#include <linux/module.h>
  18#include <linux/scatterlist.h>
  19#include <linux/slab.h>
  20
  21struct crypto_cbc_ctx {
  22        struct crypto_cipher *child;
  23        void (*xor)(u8 *dst, const u8 *src, unsigned int bs);
  24};
  25
  26static int crypto_cbc_setkey(struct crypto_tfm *parent, const u8 *key,
  27                             unsigned int keylen)
  28{
  29        struct crypto_cbc_ctx *ctx = crypto_tfm_ctx(parent);
  30        struct crypto_cipher *child = ctx->child;
  31        int err;
  32
  33        crypto_cipher_clear_flags(child, CRYPTO_TFM_REQ_MASK);
  34        crypto_cipher_set_flags(child, crypto_tfm_get_flags(parent) &
  35                                       CRYPTO_TFM_REQ_MASK);
  36        err = crypto_cipher_setkey(child, key, keylen);
  37        crypto_tfm_set_flags(parent, crypto_cipher_get_flags(child) &
  38                                     CRYPTO_TFM_RES_MASK);
  39        return err;
  40}
  41
  42static int crypto_cbc_encrypt_segment(struct blkcipher_desc *desc,
  43                                      struct blkcipher_walk *walk,
  44                                      struct crypto_cipher *tfm,
  45                                      void (*xor)(u8 *, const u8 *,
  46                                                  unsigned int))
  47{
  48        void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
  49                crypto_cipher_alg(tfm)->cia_encrypt;
  50        int bsize = crypto_cipher_blocksize(tfm);
  51        unsigned int nbytes = walk->nbytes;
  52        u8 *src = walk->src.virt.addr;
  53        u8 *dst = walk->dst.virt.addr;
  54        u8 *iv = walk->iv;
  55
  56        do {
  57                xor(iv, src, bsize);
  58                fn(crypto_cipher_tfm(tfm), dst, iv);
  59                memcpy(iv, dst, bsize);
  60
  61                src += bsize;
  62                dst += bsize;
  63        } while ((nbytes -= bsize) >= bsize);
  64
  65        return nbytes;
  66}
  67
  68static int crypto_cbc_encrypt_inplace(struct blkcipher_desc *desc,
  69                                      struct blkcipher_walk *walk,
  70                                      struct crypto_cipher *tfm,
  71                                      void (*xor)(u8 *, const u8 *,
  72                                                  unsigned int))
  73{
  74        void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
  75                crypto_cipher_alg(tfm)->cia_encrypt;
  76        int bsize = crypto_cipher_blocksize(tfm);
  77        unsigned int nbytes = walk->nbytes;
  78        u8 *src = walk->src.virt.addr;
  79        u8 *iv = walk->iv;
  80
  81        do {
  82                xor(src, iv, bsize);
  83                fn(crypto_cipher_tfm(tfm), src, src);
  84                iv = src;
  85
  86                src += bsize;
  87        } while ((nbytes -= bsize) >= bsize);
  88
  89        memcpy(walk->iv, iv, bsize);
  90
  91        return nbytes;
  92}
  93
  94static int crypto_cbc_encrypt(struct blkcipher_desc *desc,
  95                              struct scatterlist *dst, struct scatterlist *src,
  96                              unsigned int nbytes)
  97{
  98        struct blkcipher_walk walk;
  99        struct crypto_blkcipher *tfm = desc->tfm;
 100        struct crypto_cbc_ctx *ctx = crypto_blkcipher_ctx(tfm);
 101        struct crypto_cipher *child = ctx->child;
 102        void (*xor)(u8 *, const u8 *, unsigned int bs) = ctx->xor;
 103        int err;
 104
 105        blkcipher_walk_init(&walk, dst, src, nbytes);
 106        err = blkcipher_walk_virt(desc, &walk);
 107
 108        while ((nbytes = walk.nbytes)) {
 109                if (walk.src.virt.addr == walk.dst.virt.addr)
 110                        nbytes = crypto_cbc_encrypt_inplace(desc, &walk, child,
 111                                                            xor);
 112                else
 113                        nbytes = crypto_cbc_encrypt_segment(desc, &walk, child,
 114                                                            xor);
 115                err = blkcipher_walk_done(desc, &walk, nbytes);
 116        }
 117
 118        return err;
 119}
 120
 121static int crypto_cbc_decrypt_segment(struct blkcipher_desc *desc,
 122                                      struct blkcipher_walk *walk,
 123                                      struct crypto_cipher *tfm,
 124                                      void (*xor)(u8 *, const u8 *,
 125                                                  unsigned int))
 126{
 127        void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
 128                crypto_cipher_alg(tfm)->cia_decrypt;
 129        int bsize = crypto_cipher_blocksize(tfm);
 130        unsigned int nbytes = walk->nbytes;
 131        u8 *src = walk->src.virt.addr;
 132        u8 *dst = walk->dst.virt.addr;
 133        u8 *iv = walk->iv;
 134
 135        do {
 136                fn(crypto_cipher_tfm(tfm), dst, src);
 137                xor(dst, iv, bsize);
 138                iv = src;
 139
 140                src += bsize;
 141                dst += bsize;
 142        } while ((nbytes -= bsize) >= bsize);
 143
 144        memcpy(walk->iv, iv, bsize);
 145
 146        return nbytes;
 147}
 148
 149static int crypto_cbc_decrypt_inplace(struct blkcipher_desc *desc,
 150                                      struct blkcipher_walk *walk,
 151                                      struct crypto_cipher *tfm,
 152                                      void (*xor)(u8 *, const u8 *,
 153                                                  unsigned int))
 154{
 155        void (*fn)(struct crypto_tfm *, u8 *, const u8 *) =
 156                crypto_cipher_alg(tfm)->cia_decrypt;
 157        int bsize = crypto_cipher_blocksize(tfm);
 158        unsigned long alignmask = crypto_cipher_alignmask(tfm);
 159        unsigned int nbytes = walk->nbytes;
 160        u8 *src = walk->src.virt.addr;
 161        u8 stack[bsize + alignmask];
 162        u8 *first_iv = (u8 *)ALIGN((unsigned long)stack, alignmask + 1);
 163
 164        memcpy(first_iv, walk->iv, bsize);
 165
 166        /* Start of the last block. */
 167        src += nbytes - nbytes % bsize - bsize;
 168        memcpy(walk->iv, src, bsize);
 169
 170        for (;;) {
 171                fn(crypto_cipher_tfm(tfm), src, src);
 172                if ((nbytes -= bsize) < bsize)
 173                        break;
 174                xor(src, src - bsize, bsize);
 175                src -= bsize;
 176        }
 177
 178        xor(src, first_iv, bsize);
 179
 180        return nbytes;
 181}
 182
 183static int crypto_cbc_decrypt(struct blkcipher_desc *desc,
 184                              struct scatterlist *dst, struct scatterlist *src,
 185                              unsigned int nbytes)
 186{
 187        struct blkcipher_walk walk;
 188        struct crypto_blkcipher *tfm = desc->tfm;
 189        struct crypto_cbc_ctx *ctx = crypto_blkcipher_ctx(tfm);
 190        struct crypto_cipher *child = ctx->child;
 191        void (*xor)(u8 *, const u8 *, unsigned int bs) = ctx->xor;
 192        int err;
 193
 194        blkcipher_walk_init(&walk, dst, src, nbytes);
 195        err = blkcipher_walk_virt(desc, &walk);
 196
 197        while ((nbytes = walk.nbytes)) {
 198                if (walk.src.virt.addr == walk.dst.virt.addr)
 199                        nbytes = crypto_cbc_decrypt_inplace(desc, &walk, child,
 200                                                            xor);
 201                else
 202                        nbytes = crypto_cbc_decrypt_segment(desc, &walk, child,
 203                                                            xor);
 204                err = blkcipher_walk_done(desc, &walk, nbytes);
 205        }
 206
 207        return err;
 208}
 209
 210static void xor_byte(u8 *a, const u8 *b, unsigned int bs)
 211{
 212        do {
 213                *a++ ^= *b++;
 214        } while (--bs);
 215}
 216
 217static void xor_quad(u8 *dst, const u8 *src, unsigned int bs)
 218{
 219        u32 *a = (u32 *)dst;
 220        u32 *b = (u32 *)src;
 221
 222        do {
 223                *a++ ^= *b++;
 224        } while ((bs -= 4));
 225}
 226
 227static void xor_64(u8 *a, const u8 *b, unsigned int bs)
 228{
 229        ((u32 *)a)[0] ^= ((u32 *)b)[0];
 230        ((u32 *)a)[1] ^= ((u32 *)b)[1];
 231}
 232
 233static void xor_128(u8 *a, const u8 *b, unsigned int bs)
 234{
 235        ((u32 *)a)[0] ^= ((u32 *)b)[0];
 236        ((u32 *)a)[1] ^= ((u32 *)b)[1];
 237        ((u32 *)a)[2] ^= ((u32 *)b)[2];
 238        ((u32 *)a)[3] ^= ((u32 *)b)[3];
 239}
 240
 241static int crypto_cbc_init_tfm(struct crypto_tfm *tfm)
 242{
 243        struct crypto_instance *inst = (void *)tfm->__crt_alg;
 244        struct crypto_spawn *spawn = crypto_instance_ctx(inst);
 245        struct crypto_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
 246        struct crypto_cipher *cipher;
 247
 248        switch (crypto_tfm_alg_blocksize(tfm)) {
 249        case 8:
 250                ctx->xor = xor_64;
 251                break;
 252
 253        case 16:
 254                ctx->xor = xor_128;
 255                break;
 256
 257        default:
 258                if (crypto_tfm_alg_blocksize(tfm) % 4)
 259                        ctx->xor = xor_byte;
 260                else
 261                        ctx->xor = xor_quad;
 262        }
 263
 264        cipher = crypto_spawn_cipher(spawn);
 265        if (IS_ERR(cipher))
 266                return PTR_ERR(cipher);
 267
 268        ctx->child = cipher;
 269        return 0;
 270}
 271
 272static void crypto_cbc_exit_tfm(struct crypto_tfm *tfm)
 273{
 274        struct crypto_cbc_ctx *ctx = crypto_tfm_ctx(tfm);
 275        crypto_free_cipher(ctx->child);
 276}
 277
 278static struct crypto_instance *crypto_cbc_alloc(struct rtattr **tb)
 279{
 280        struct crypto_instance *inst;
 281        struct crypto_alg *alg;
 282        int err;
 283
 284        err = crypto_check_attr_type(tb, CRYPTO_ALG_TYPE_BLKCIPHER);
 285        if (err)
 286                return ERR_PTR(err);
 287
 288        alg = crypto_get_attr_alg(tb, CRYPTO_ALG_TYPE_CIPHER,
 289                                  CRYPTO_ALG_TYPE_MASK);
 290        if (IS_ERR(alg))
 291                return ERR_PTR(PTR_ERR(alg));
 292
 293        inst = crypto_alloc_instance("cbc", alg);
 294        if (IS_ERR(inst))
 295                goto out_put_alg;
 296
 297        inst->alg.cra_flags = CRYPTO_ALG_TYPE_BLKCIPHER;
 298        inst->alg.cra_priority = alg->cra_priority;
 299        inst->alg.cra_blocksize = alg->cra_blocksize;
 300        inst->alg.cra_alignmask = alg->cra_alignmask;
 301        inst->alg.cra_type = &crypto_blkcipher_type;
 302
 303        if (!(alg->cra_blocksize % 4))
 304                inst->alg.cra_alignmask |= 3;
 305        inst->alg.cra_blkcipher.ivsize = alg->cra_blocksize;
 306        inst->alg.cra_blkcipher.min_keysize = alg->cra_cipher.cia_min_keysize;
 307        inst->alg.cra_blkcipher.max_keysize = alg->cra_cipher.cia_max_keysize;
 308
 309        inst->alg.cra_ctxsize = sizeof(struct crypto_cbc_ctx);
 310
 311        inst->alg.cra_init = crypto_cbc_init_tfm;
 312        inst->alg.cra_exit = crypto_cbc_exit_tfm;
 313
 314        inst->alg.cra_blkcipher.setkey = crypto_cbc_setkey;
 315        inst->alg.cra_blkcipher.encrypt = crypto_cbc_encrypt;
 316        inst->alg.cra_blkcipher.decrypt = crypto_cbc_decrypt;
 317
 318out_put_alg:
 319        crypto_mod_put(alg);
 320        return inst;
 321}
 322
 323static void crypto_cbc_free(struct crypto_instance *inst)
 324{
 325        crypto_drop_spawn(crypto_instance_ctx(inst));
 326        kfree(inst);
 327}
 328
 329static struct crypto_template crypto_cbc_tmpl = {
 330        .name = "cbc",
 331        .alloc = crypto_cbc_alloc,
 332        .free = crypto_cbc_free,
 333        .module = THIS_MODULE,
 334};
 335
 336static int __init crypto_cbc_module_init(void)
 337{
 338        return crypto_register_template(&crypto_cbc_tmpl);
 339}
 340
 341static void __exit crypto_cbc_module_exit(void)
 342{
 343        crypto_unregister_template(&crypto_cbc_tmpl);
 344}
 345
 346module_init(crypto_cbc_module_init);
 347module_exit(crypto_cbc_module_exit);
 348
 349MODULE_LICENSE("GPL");
 350MODULE_DESCRIPTION("CBC block cipher algorithm");
 351