linux/fs/squashfs/zstd_wrapper.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-or-later
   2/*
   3 * Squashfs - a compressed read only filesystem for Linux
   4 *
   5 * Copyright (c) 2016-present, Facebook, Inc.
   6 * All rights reserved.
   7 *
   8 * zstd_wrapper.c
   9 */
  10
  11#include <linux/mutex.h>
  12#include <linux/bio.h>
  13#include <linux/slab.h>
  14#include <linux/zstd.h>
  15#include <linux/vmalloc.h>
  16
  17#include "squashfs_fs.h"
  18#include "squashfs_fs_sb.h"
  19#include "squashfs.h"
  20#include "decompressor.h"
  21#include "page_actor.h"
  22
  23struct workspace {
  24        void *mem;
  25        size_t mem_size;
  26        size_t window_size;
  27};
  28
  29static void *zstd_init(struct squashfs_sb_info *msblk, void *buff)
  30{
  31        struct workspace *wksp = kmalloc(sizeof(*wksp), GFP_KERNEL);
  32
  33        if (wksp == NULL)
  34                goto failed;
  35        wksp->window_size = max_t(size_t,
  36                        msblk->block_size, SQUASHFS_METADATA_SIZE);
  37        wksp->mem_size = ZSTD_DStreamWorkspaceBound(wksp->window_size);
  38        wksp->mem = vmalloc(wksp->mem_size);
  39        if (wksp->mem == NULL)
  40                goto failed;
  41
  42        return wksp;
  43
  44failed:
  45        ERROR("Failed to allocate zstd workspace\n");
  46        kfree(wksp);
  47        return ERR_PTR(-ENOMEM);
  48}
  49
  50
  51static void zstd_free(void *strm)
  52{
  53        struct workspace *wksp = strm;
  54
  55        if (wksp)
  56                vfree(wksp->mem);
  57        kfree(wksp);
  58}
  59
  60
  61static int zstd_uncompress(struct squashfs_sb_info *msblk, void *strm,
  62        struct bio *bio, int offset, int length,
  63        struct squashfs_page_actor *output)
  64{
  65        struct workspace *wksp = strm;
  66        ZSTD_DStream *stream;
  67        size_t total_out = 0;
  68        int error = 0;
  69        ZSTD_inBuffer in_buf = { NULL, 0, 0 };
  70        ZSTD_outBuffer out_buf = { NULL, 0, 0 };
  71        struct bvec_iter_all iter_all = {};
  72        struct bio_vec *bvec = bvec_init_iter_all(&iter_all);
  73
  74        stream = ZSTD_initDStream(wksp->window_size, wksp->mem, wksp->mem_size);
  75
  76        if (!stream) {
  77                ERROR("Failed to initialize zstd decompressor\n");
  78                return -EIO;
  79        }
  80
  81        out_buf.size = PAGE_SIZE;
  82        out_buf.dst = squashfs_first_page(output);
  83
  84        for (;;) {
  85                size_t zstd_err;
  86
  87                if (in_buf.pos == in_buf.size) {
  88                        const void *data;
  89                        int avail;
  90
  91                        if (!bio_next_segment(bio, &iter_all)) {
  92                                error = -EIO;
  93                                break;
  94                        }
  95
  96                        avail = min(length, ((int)bvec->bv_len) - offset);
  97                        data = bvec_virt(bvec);
  98                        length -= avail;
  99                        in_buf.src = data + offset;
 100                        in_buf.size = avail;
 101                        in_buf.pos = 0;
 102                        offset = 0;
 103                }
 104
 105                if (out_buf.pos == out_buf.size) {
 106                        out_buf.dst = squashfs_next_page(output);
 107                        if (out_buf.dst == NULL) {
 108                                /* Shouldn't run out of pages
 109                                 * before stream is done.
 110                                 */
 111                                error = -EIO;
 112                                break;
 113                        }
 114                        out_buf.pos = 0;
 115                        out_buf.size = PAGE_SIZE;
 116                }
 117
 118                total_out -= out_buf.pos;
 119                zstd_err = ZSTD_decompressStream(stream, &out_buf, &in_buf);
 120                total_out += out_buf.pos; /* add the additional data produced */
 121                if (zstd_err == 0)
 122                        break;
 123
 124                if (ZSTD_isError(zstd_err)) {
 125                        ERROR("zstd decompression error: %d\n",
 126                                        (int)ZSTD_getErrorCode(zstd_err));
 127                        error = -EIO;
 128                        break;
 129                }
 130        }
 131
 132        squashfs_finish_page(output);
 133
 134        return error ? error : total_out;
 135}
 136
 137const struct squashfs_decompressor squashfs_zstd_comp_ops = {
 138        .init = zstd_init,
 139        .free = zstd_free,
 140        .decompress = zstd_uncompress,
 141        .id = ZSTD_COMPRESSION,
 142        .name = "zstd",
 143        .supported = 1
 144};
 145