linux/include/crypto/sm3_base.h
<<
>>
Prefs
   1/* SPDX-License-Identifier: GPL-2.0-only */
   2/*
   3 * sm3_base.h - core logic for SM3 implementations
   4 *
   5 * Copyright (C) 2017 ARM Limited or its affiliates.
   6 * Written by Gilad Ben-Yossef <gilad@benyossef.com>
   7 */
   8
   9#ifndef _CRYPTO_SM3_BASE_H
  10#define _CRYPTO_SM3_BASE_H
  11
  12#include <crypto/internal/hash.h>
  13#include <crypto/sm3.h>
  14#include <linux/crypto.h>
  15#include <linux/module.h>
  16#include <linux/string.h>
  17#include <asm/unaligned.h>
  18
  19typedef void (sm3_block_fn)(struct sm3_state *sst, u8 const *src, int blocks);
  20
  21static inline int sm3_base_init(struct shash_desc *desc)
  22{
  23        struct sm3_state *sctx = shash_desc_ctx(desc);
  24
  25        sctx->state[0] = SM3_IVA;
  26        sctx->state[1] = SM3_IVB;
  27        sctx->state[2] = SM3_IVC;
  28        sctx->state[3] = SM3_IVD;
  29        sctx->state[4] = SM3_IVE;
  30        sctx->state[5] = SM3_IVF;
  31        sctx->state[6] = SM3_IVG;
  32        sctx->state[7] = SM3_IVH;
  33        sctx->count = 0;
  34
  35        return 0;
  36}
  37
  38static inline int sm3_base_do_update(struct shash_desc *desc,
  39                                      const u8 *data,
  40                                      unsigned int len,
  41                                      sm3_block_fn *block_fn)
  42{
  43        struct sm3_state *sctx = shash_desc_ctx(desc);
  44        unsigned int partial = sctx->count % SM3_BLOCK_SIZE;
  45
  46        sctx->count += len;
  47
  48        if (unlikely((partial + len) >= SM3_BLOCK_SIZE)) {
  49                int blocks;
  50
  51                if (partial) {
  52                        int p = SM3_BLOCK_SIZE - partial;
  53
  54                        memcpy(sctx->buffer + partial, data, p);
  55                        data += p;
  56                        len -= p;
  57
  58                        block_fn(sctx, sctx->buffer, 1);
  59                }
  60
  61                blocks = len / SM3_BLOCK_SIZE;
  62                len %= SM3_BLOCK_SIZE;
  63
  64                if (blocks) {
  65                        block_fn(sctx, data, blocks);
  66                        data += blocks * SM3_BLOCK_SIZE;
  67                }
  68                partial = 0;
  69        }
  70        if (len)
  71                memcpy(sctx->buffer + partial, data, len);
  72
  73        return 0;
  74}
  75
  76static inline int sm3_base_do_finalize(struct shash_desc *desc,
  77                                        sm3_block_fn *block_fn)
  78{
  79        const int bit_offset = SM3_BLOCK_SIZE - sizeof(__be64);
  80        struct sm3_state *sctx = shash_desc_ctx(desc);
  81        __be64 *bits = (__be64 *)(sctx->buffer + bit_offset);
  82        unsigned int partial = sctx->count % SM3_BLOCK_SIZE;
  83
  84        sctx->buffer[partial++] = 0x80;
  85        if (partial > bit_offset) {
  86                memset(sctx->buffer + partial, 0x0, SM3_BLOCK_SIZE - partial);
  87                partial = 0;
  88
  89                block_fn(sctx, sctx->buffer, 1);
  90        }
  91
  92        memset(sctx->buffer + partial, 0x0, bit_offset - partial);
  93        *bits = cpu_to_be64(sctx->count << 3);
  94        block_fn(sctx, sctx->buffer, 1);
  95
  96        return 0;
  97}
  98
  99static inline int sm3_base_finish(struct shash_desc *desc, u8 *out)
 100{
 101        struct sm3_state *sctx = shash_desc_ctx(desc);
 102        __be32 *digest = (__be32 *)out;
 103        int i;
 104
 105        for (i = 0; i < SM3_DIGEST_SIZE / sizeof(__be32); i++)
 106                put_unaligned_be32(sctx->state[i], digest++);
 107
 108        memzero_explicit(sctx, sizeof(*sctx));
 109        return 0;
 110}
 111
 112#endif /* _CRYPTO_SM3_BASE_H */
 113