linux/crypto/zstd.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Cryptographic API.
   4 *
   5 * Copyright (c) 2017-present, Facebook, Inc.
   6 */
   7#include <linux/crypto.h>
   8#include <linux/init.h>
   9#include <linux/interrupt.h>
  10#include <linux/mm.h>
  11#include <linux/module.h>
  12#include <linux/net.h>
  13#include <linux/vmalloc.h>
  14#include <linux/zstd.h>
  15#include <crypto/internal/scompress.h>
  16
  17
  18#define ZSTD_DEF_LEVEL  3
  19
  20struct zstd_ctx {
  21        ZSTD_CCtx *cctx;
  22        ZSTD_DCtx *dctx;
  23        void *cwksp;
  24        void *dwksp;
  25};
  26
  27static ZSTD_parameters zstd_params(void)
  28{
  29        return ZSTD_getParams(ZSTD_DEF_LEVEL, 0, 0);
  30}
  31
  32static int zstd_comp_init(struct zstd_ctx *ctx)
  33{
  34        int ret = 0;
  35        const ZSTD_parameters params = zstd_params();
  36        const size_t wksp_size = ZSTD_CCtxWorkspaceBound(params.cParams);
  37
  38        ctx->cwksp = vzalloc(wksp_size);
  39        if (!ctx->cwksp) {
  40                ret = -ENOMEM;
  41                goto out;
  42        }
  43
  44        ctx->cctx = ZSTD_initCCtx(ctx->cwksp, wksp_size);
  45        if (!ctx->cctx) {
  46                ret = -EINVAL;
  47                goto out_free;
  48        }
  49out:
  50        return ret;
  51out_free:
  52        vfree(ctx->cwksp);
  53        goto out;
  54}
  55
  56static int zstd_decomp_init(struct zstd_ctx *ctx)
  57{
  58        int ret = 0;
  59        const size_t wksp_size = ZSTD_DCtxWorkspaceBound();
  60
  61        ctx->dwksp = vzalloc(wksp_size);
  62        if (!ctx->dwksp) {
  63                ret = -ENOMEM;
  64                goto out;
  65        }
  66
  67        ctx->dctx = ZSTD_initDCtx(ctx->dwksp, wksp_size);
  68        if (!ctx->dctx) {
  69                ret = -EINVAL;
  70                goto out_free;
  71        }
  72out:
  73        return ret;
  74out_free:
  75        vfree(ctx->dwksp);
  76        goto out;
  77}
  78
  79static void zstd_comp_exit(struct zstd_ctx *ctx)
  80{
  81        vfree(ctx->cwksp);
  82        ctx->cwksp = NULL;
  83        ctx->cctx = NULL;
  84}
  85
  86static void zstd_decomp_exit(struct zstd_ctx *ctx)
  87{
  88        vfree(ctx->dwksp);
  89        ctx->dwksp = NULL;
  90        ctx->dctx = NULL;
  91}
  92
  93static int __zstd_init(void *ctx)
  94{
  95        int ret;
  96
  97        ret = zstd_comp_init(ctx);
  98        if (ret)
  99                return ret;
 100        ret = zstd_decomp_init(ctx);
 101        if (ret)
 102                zstd_comp_exit(ctx);
 103        return ret;
 104}
 105
 106static void *zstd_alloc_ctx(struct crypto_scomp *tfm)
 107{
 108        int ret;
 109        struct zstd_ctx *ctx;
 110
 111        ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
 112        if (!ctx)
 113                return ERR_PTR(-ENOMEM);
 114
 115        ret = __zstd_init(ctx);
 116        if (ret) {
 117                kfree(ctx);
 118                return ERR_PTR(ret);
 119        }
 120
 121        return ctx;
 122}
 123
 124static int zstd_init(struct crypto_tfm *tfm)
 125{
 126        struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
 127
 128        return __zstd_init(ctx);
 129}
 130
 131static void __zstd_exit(void *ctx)
 132{
 133        zstd_comp_exit(ctx);
 134        zstd_decomp_exit(ctx);
 135}
 136
 137static void zstd_free_ctx(struct crypto_scomp *tfm, void *ctx)
 138{
 139        __zstd_exit(ctx);
 140        kfree_sensitive(ctx);
 141}
 142
 143static void zstd_exit(struct crypto_tfm *tfm)
 144{
 145        struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
 146
 147        __zstd_exit(ctx);
 148}
 149
 150static int __zstd_compress(const u8 *src, unsigned int slen,
 151                           u8 *dst, unsigned int *dlen, void *ctx)
 152{
 153        size_t out_len;
 154        struct zstd_ctx *zctx = ctx;
 155        const ZSTD_parameters params = zstd_params();
 156
 157        out_len = ZSTD_compressCCtx(zctx->cctx, dst, *dlen, src, slen, params);
 158        if (ZSTD_isError(out_len))
 159                return -EINVAL;
 160        *dlen = out_len;
 161        return 0;
 162}
 163
 164static int zstd_compress(struct crypto_tfm *tfm, const u8 *src,
 165                         unsigned int slen, u8 *dst, unsigned int *dlen)
 166{
 167        struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
 168
 169        return __zstd_compress(src, slen, dst, dlen, ctx);
 170}
 171
 172static int zstd_scompress(struct crypto_scomp *tfm, const u8 *src,
 173                          unsigned int slen, u8 *dst, unsigned int *dlen,
 174                          void *ctx)
 175{
 176        return __zstd_compress(src, slen, dst, dlen, ctx);
 177}
 178
 179static int __zstd_decompress(const u8 *src, unsigned int slen,
 180                             u8 *dst, unsigned int *dlen, void *ctx)
 181{
 182        size_t out_len;
 183        struct zstd_ctx *zctx = ctx;
 184
 185        out_len = ZSTD_decompressDCtx(zctx->dctx, dst, *dlen, src, slen);
 186        if (ZSTD_isError(out_len))
 187                return -EINVAL;
 188        *dlen = out_len;
 189        return 0;
 190}
 191
 192static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src,
 193                           unsigned int slen, u8 *dst, unsigned int *dlen)
 194{
 195        struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
 196
 197        return __zstd_decompress(src, slen, dst, dlen, ctx);
 198}
 199
 200static int zstd_sdecompress(struct crypto_scomp *tfm, const u8 *src,
 201                            unsigned int slen, u8 *dst, unsigned int *dlen,
 202                            void *ctx)
 203{
 204        return __zstd_decompress(src, slen, dst, dlen, ctx);
 205}
 206
 207static struct crypto_alg alg = {
 208        .cra_name               = "zstd",
 209        .cra_driver_name        = "zstd-generic",
 210        .cra_flags              = CRYPTO_ALG_TYPE_COMPRESS,
 211        .cra_ctxsize            = sizeof(struct zstd_ctx),
 212        .cra_module             = THIS_MODULE,
 213        .cra_init               = zstd_init,
 214        .cra_exit               = zstd_exit,
 215        .cra_u                  = { .compress = {
 216        .coa_compress           = zstd_compress,
 217        .coa_decompress         = zstd_decompress } }
 218};
 219
 220static struct scomp_alg scomp = {
 221        .alloc_ctx              = zstd_alloc_ctx,
 222        .free_ctx               = zstd_free_ctx,
 223        .compress               = zstd_scompress,
 224        .decompress             = zstd_sdecompress,
 225        .base                   = {
 226                .cra_name       = "zstd",
 227                .cra_driver_name = "zstd-scomp",
 228                .cra_module      = THIS_MODULE,
 229        }
 230};
 231
 232static int __init zstd_mod_init(void)
 233{
 234        int ret;
 235
 236        ret = crypto_register_alg(&alg);
 237        if (ret)
 238                return ret;
 239
 240        ret = crypto_register_scomp(&scomp);
 241        if (ret)
 242                crypto_unregister_alg(&alg);
 243
 244        return ret;
 245}
 246
 247static void __exit zstd_mod_fini(void)
 248{
 249        crypto_unregister_alg(&alg);
 250        crypto_unregister_scomp(&scomp);
 251}
 252
 253subsys_initcall(zstd_mod_init);
 254module_exit(zstd_mod_fini);
 255
 256MODULE_LICENSE("GPL");
 257MODULE_DESCRIPTION("Zstd Compression Algorithm");
 258MODULE_ALIAS_CRYPTO("zstd");
 259