linux/net/xfrm/xfrm_ipcomp.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-or-later
   2/*
   3 * IP Payload Compression Protocol (IPComp) - RFC3173.
   4 *
   5 * Copyright (c) 2003 James Morris <jmorris@intercode.com.au>
   6 * Copyright (c) 2003-2008 Herbert Xu <herbert@gondor.apana.org.au>
   7 *
   8 * Todo:
   9 *   - Tunable compression parameters.
  10 *   - Compression stats.
  11 *   - Adaptive compression.
  12 */
  13
  14#include <linux/crypto.h>
  15#include <linux/err.h>
  16#include <linux/list.h>
  17#include <linux/module.h>
  18#include <linux/mutex.h>
  19#include <linux/percpu.h>
  20#include <linux/slab.h>
  21#include <linux/smp.h>
  22#include <linux/vmalloc.h>
  23#include <net/ip.h>
  24#include <net/ipcomp.h>
  25#include <net/xfrm.h>
  26
  27struct ipcomp_tfms {
  28        struct list_head list;
  29        struct crypto_comp * __percpu *tfms;
  30        int users;
  31};
  32
  33static DEFINE_MUTEX(ipcomp_resource_mutex);
  34static void * __percpu *ipcomp_scratches;
  35static int ipcomp_scratch_users;
  36static LIST_HEAD(ipcomp_tfms_list);
  37
  38static int ipcomp_decompress(struct xfrm_state *x, struct sk_buff *skb)
  39{
  40        struct ipcomp_data *ipcd = x->data;
  41        const int plen = skb->len;
  42        int dlen = IPCOMP_SCRATCH_SIZE;
  43        const u8 *start = skb->data;
  44        u8 *scratch = *this_cpu_ptr(ipcomp_scratches);
  45        struct crypto_comp *tfm = *this_cpu_ptr(ipcd->tfms);
  46        int err = crypto_comp_decompress(tfm, start, plen, scratch, &dlen);
  47        int len;
  48
  49        if (err)
  50                return err;
  51
  52        if (dlen < (plen + sizeof(struct ip_comp_hdr)))
  53                return -EINVAL;
  54
  55        len = dlen - plen;
  56        if (len > skb_tailroom(skb))
  57                len = skb_tailroom(skb);
  58
  59        __skb_put(skb, len);
  60
  61        len += plen;
  62        skb_copy_to_linear_data(skb, scratch, len);
  63
  64        while ((scratch += len, dlen -= len) > 0) {
  65                skb_frag_t *frag;
  66                struct page *page;
  67
  68                if (WARN_ON(skb_shinfo(skb)->nr_frags >= MAX_SKB_FRAGS))
  69                        return -EMSGSIZE;
  70
  71                frag = skb_shinfo(skb)->frags + skb_shinfo(skb)->nr_frags;
  72                page = alloc_page(GFP_ATOMIC);
  73
  74                if (!page)
  75                        return -ENOMEM;
  76
  77                __skb_frag_set_page(frag, page);
  78
  79                len = PAGE_SIZE;
  80                if (dlen < len)
  81                        len = dlen;
  82
  83                skb_frag_off_set(frag, 0);
  84                skb_frag_size_set(frag, len);
  85                memcpy(skb_frag_address(frag), scratch, len);
  86
  87                skb->truesize += len;
  88                skb->data_len += len;
  89                skb->len += len;
  90
  91                skb_shinfo(skb)->nr_frags++;
  92        }
  93
  94        return 0;
  95}
  96
  97int ipcomp_input(struct xfrm_state *x, struct sk_buff *skb)
  98{
  99        int nexthdr;
 100        int err = -ENOMEM;
 101        struct ip_comp_hdr *ipch;
 102
 103        if (skb_linearize_cow(skb))
 104                goto out;
 105
 106        skb->ip_summed = CHECKSUM_NONE;
 107
 108        /* Remove ipcomp header and decompress original payload */
 109        ipch = (void *)skb->data;
 110        nexthdr = ipch->nexthdr;
 111
 112        skb->transport_header = skb->network_header + sizeof(*ipch);
 113        __skb_pull(skb, sizeof(*ipch));
 114        err = ipcomp_decompress(x, skb);
 115        if (err)
 116                goto out;
 117
 118        err = nexthdr;
 119
 120out:
 121        return err;
 122}
 123EXPORT_SYMBOL_GPL(ipcomp_input);
 124
 125static int ipcomp_compress(struct xfrm_state *x, struct sk_buff *skb)
 126{
 127        struct ipcomp_data *ipcd = x->data;
 128        const int plen = skb->len;
 129        int dlen = IPCOMP_SCRATCH_SIZE;
 130        u8 *start = skb->data;
 131        struct crypto_comp *tfm;
 132        u8 *scratch;
 133        int err;
 134
 135        local_bh_disable();
 136        scratch = *this_cpu_ptr(ipcomp_scratches);
 137        tfm = *this_cpu_ptr(ipcd->tfms);
 138        err = crypto_comp_compress(tfm, start, plen, scratch, &dlen);
 139        if (err)
 140                goto out;
 141
 142        if ((dlen + sizeof(struct ip_comp_hdr)) >= plen) {
 143                err = -EMSGSIZE;
 144                goto out;
 145        }
 146
 147        memcpy(start + sizeof(struct ip_comp_hdr), scratch, dlen);
 148        local_bh_enable();
 149
 150        pskb_trim(skb, dlen + sizeof(struct ip_comp_hdr));
 151        return 0;
 152
 153out:
 154        local_bh_enable();
 155        return err;
 156}
 157
 158int ipcomp_output(struct xfrm_state *x, struct sk_buff *skb)
 159{
 160        int err;
 161        struct ip_comp_hdr *ipch;
 162        struct ipcomp_data *ipcd = x->data;
 163
 164        if (skb->len < ipcd->threshold) {
 165                /* Don't bother compressing */
 166                goto out_ok;
 167        }
 168
 169        if (skb_linearize_cow(skb))
 170                goto out_ok;
 171
 172        err = ipcomp_compress(x, skb);
 173
 174        if (err) {
 175                goto out_ok;
 176        }
 177
 178        /* Install ipcomp header, convert into ipcomp datagram. */
 179        ipch = ip_comp_hdr(skb);
 180        ipch->nexthdr = *skb_mac_header(skb);
 181        ipch->flags = 0;
 182        ipch->cpi = htons((u16 )ntohl(x->id.spi));
 183        *skb_mac_header(skb) = IPPROTO_COMP;
 184out_ok:
 185        skb_push(skb, -skb_network_offset(skb));
 186        return 0;
 187}
 188EXPORT_SYMBOL_GPL(ipcomp_output);
 189
 190static void ipcomp_free_scratches(void)
 191{
 192        int i;
 193        void * __percpu *scratches;
 194
 195        if (--ipcomp_scratch_users)
 196                return;
 197
 198        scratches = ipcomp_scratches;
 199        if (!scratches)
 200                return;
 201
 202        for_each_possible_cpu(i)
 203                vfree(*per_cpu_ptr(scratches, i));
 204
 205        free_percpu(scratches);
 206}
 207
 208static void * __percpu *ipcomp_alloc_scratches(void)
 209{
 210        void * __percpu *scratches;
 211        int i;
 212
 213        if (ipcomp_scratch_users++)
 214                return ipcomp_scratches;
 215
 216        scratches = alloc_percpu(void *);
 217        if (!scratches)
 218                return NULL;
 219
 220        ipcomp_scratches = scratches;
 221
 222        for_each_possible_cpu(i) {
 223                void *scratch;
 224
 225                scratch = vmalloc_node(IPCOMP_SCRATCH_SIZE, cpu_to_node(i));
 226                if (!scratch)
 227                        return NULL;
 228                *per_cpu_ptr(scratches, i) = scratch;
 229        }
 230
 231        return scratches;
 232}
 233
 234static void ipcomp_free_tfms(struct crypto_comp * __percpu *tfms)
 235{
 236        struct ipcomp_tfms *pos;
 237        int cpu;
 238
 239        list_for_each_entry(pos, &ipcomp_tfms_list, list) {
 240                if (pos->tfms == tfms)
 241                        break;
 242        }
 243
 244        WARN_ON(list_entry_is_head(pos, &ipcomp_tfms_list, list));
 245
 246        if (--pos->users)
 247                return;
 248
 249        list_del(&pos->list);
 250        kfree(pos);
 251
 252        if (!tfms)
 253                return;
 254
 255        for_each_possible_cpu(cpu) {
 256                struct crypto_comp *tfm = *per_cpu_ptr(tfms, cpu);
 257                crypto_free_comp(tfm);
 258        }
 259        free_percpu(tfms);
 260}
 261
 262static struct crypto_comp * __percpu *ipcomp_alloc_tfms(const char *alg_name)
 263{
 264        struct ipcomp_tfms *pos;
 265        struct crypto_comp * __percpu *tfms;
 266        int cpu;
 267
 268
 269        list_for_each_entry(pos, &ipcomp_tfms_list, list) {
 270                struct crypto_comp *tfm;
 271
 272                /* This can be any valid CPU ID so we don't need locking. */
 273                tfm = this_cpu_read(*pos->tfms);
 274
 275                if (!strcmp(crypto_comp_name(tfm), alg_name)) {
 276                        pos->users++;
 277                        return pos->tfms;
 278                }
 279        }
 280
 281        pos = kmalloc(sizeof(*pos), GFP_KERNEL);
 282        if (!pos)
 283                return NULL;
 284
 285        pos->users = 1;
 286        INIT_LIST_HEAD(&pos->list);
 287        list_add(&pos->list, &ipcomp_tfms_list);
 288
 289        pos->tfms = tfms = alloc_percpu(struct crypto_comp *);
 290        if (!tfms)
 291                goto error;
 292
 293        for_each_possible_cpu(cpu) {
 294                struct crypto_comp *tfm = crypto_alloc_comp(alg_name, 0,
 295                                                            CRYPTO_ALG_ASYNC);
 296                if (IS_ERR(tfm))
 297                        goto error;
 298                *per_cpu_ptr(tfms, cpu) = tfm;
 299        }
 300
 301        return tfms;
 302
 303error:
 304        ipcomp_free_tfms(tfms);
 305        return NULL;
 306}
 307
 308static void ipcomp_free_data(struct ipcomp_data *ipcd)
 309{
 310        if (ipcd->tfms)
 311                ipcomp_free_tfms(ipcd->tfms);
 312        ipcomp_free_scratches();
 313}
 314
 315void ipcomp_destroy(struct xfrm_state *x)
 316{
 317        struct ipcomp_data *ipcd = x->data;
 318        if (!ipcd)
 319                return;
 320        xfrm_state_delete_tunnel(x);
 321        mutex_lock(&ipcomp_resource_mutex);
 322        ipcomp_free_data(ipcd);
 323        mutex_unlock(&ipcomp_resource_mutex);
 324        kfree(ipcd);
 325}
 326EXPORT_SYMBOL_GPL(ipcomp_destroy);
 327
 328int ipcomp_init_state(struct xfrm_state *x)
 329{
 330        int err;
 331        struct ipcomp_data *ipcd;
 332        struct xfrm_algo_desc *calg_desc;
 333
 334        err = -EINVAL;
 335        if (!x->calg)
 336                goto out;
 337
 338        if (x->encap)
 339                goto out;
 340
 341        err = -ENOMEM;
 342        ipcd = kzalloc(sizeof(*ipcd), GFP_KERNEL);
 343        if (!ipcd)
 344                goto out;
 345
 346        mutex_lock(&ipcomp_resource_mutex);
 347        if (!ipcomp_alloc_scratches())
 348                goto error;
 349
 350        ipcd->tfms = ipcomp_alloc_tfms(x->calg->alg_name);
 351        if (!ipcd->tfms)
 352                goto error;
 353        mutex_unlock(&ipcomp_resource_mutex);
 354
 355        calg_desc = xfrm_calg_get_byname(x->calg->alg_name, 0);
 356        BUG_ON(!calg_desc);
 357        ipcd->threshold = calg_desc->uinfo.comp.threshold;
 358        x->data = ipcd;
 359        err = 0;
 360out:
 361        return err;
 362
 363error:
 364        ipcomp_free_data(ipcd);
 365        mutex_unlock(&ipcomp_resource_mutex);
 366        kfree(ipcd);
 367        goto out;
 368}
 369EXPORT_SYMBOL_GPL(ipcomp_init_state);
 370
 371MODULE_LICENSE("GPL");
 372MODULE_DESCRIPTION("IP Payload Compression Protocol (IPComp) - RFC3173");
 373MODULE_AUTHOR("James Morris <jmorris@intercode.com.au>");
 374