linux/net/xfrm/xfrm_ipcomp.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-or-later
   2/*
   3 * IP Payload Compression Protocol (IPComp) - RFC3173.
   4 *
   5 * Copyright (c) 2003 James Morris <jmorris@intercode.com.au>
   6 * Copyright (c) 2003-2008 Herbert Xu <herbert@gondor.apana.org.au>
   7 *
   8 * Todo:
   9 *   - Tunable compression parameters.
  10 *   - Compression stats.
  11 *   - Adaptive compression.
  12 */
  13
  14#include <linux/crypto.h>
  15#include <linux/err.h>
  16#include <linux/list.h>
  17#include <linux/module.h>
  18#include <linux/mutex.h>
  19#include <linux/percpu.h>
  20#include <linux/slab.h>
  21#include <linux/smp.h>
  22#include <linux/vmalloc.h>
  23#include <net/ip.h>
  24#include <net/ipcomp.h>
  25#include <net/xfrm.h>
  26
  27struct ipcomp_tfms {
  28        struct list_head list;
  29        struct crypto_comp * __percpu *tfms;
  30        int users;
  31};
  32
  33static DEFINE_MUTEX(ipcomp_resource_mutex);
  34static void * __percpu *ipcomp_scratches;
  35static int ipcomp_scratch_users;
  36static LIST_HEAD(ipcomp_tfms_list);
  37
  38static int ipcomp_decompress(struct xfrm_state *x, struct sk_buff *skb)
  39{
  40        struct ipcomp_data *ipcd = x->data;
  41        const int plen = skb->len;
  42        int dlen = IPCOMP_SCRATCH_SIZE;
  43        const u8 *start = skb->data;
  44        const int cpu = get_cpu();
  45        u8 *scratch = *per_cpu_ptr(ipcomp_scratches, cpu);
  46        struct crypto_comp *tfm = *per_cpu_ptr(ipcd->tfms, cpu);
  47        int err = crypto_comp_decompress(tfm, start, plen, scratch, &dlen);
  48        int len;
  49
  50        if (err)
  51                goto out;
  52
  53        if (dlen < (plen + sizeof(struct ip_comp_hdr))) {
  54                err = -EINVAL;
  55                goto out;
  56        }
  57
  58        len = dlen - plen;
  59        if (len > skb_tailroom(skb))
  60                len = skb_tailroom(skb);
  61
  62        __skb_put(skb, len);
  63
  64        len += plen;
  65        skb_copy_to_linear_data(skb, scratch, len);
  66
  67        while ((scratch += len, dlen -= len) > 0) {
  68                skb_frag_t *frag;
  69                struct page *page;
  70
  71                err = -EMSGSIZE;
  72                if (WARN_ON(skb_shinfo(skb)->nr_frags >= MAX_SKB_FRAGS))
  73                        goto out;
  74
  75                frag = skb_shinfo(skb)->frags + skb_shinfo(skb)->nr_frags;
  76                page = alloc_page(GFP_ATOMIC);
  77
  78                err = -ENOMEM;
  79                if (!page)
  80                        goto out;
  81
  82                __skb_frag_set_page(frag, page);
  83
  84                len = PAGE_SIZE;
  85                if (dlen < len)
  86                        len = dlen;
  87
  88                frag->page_offset = 0;
  89                skb_frag_size_set(frag, len);
  90                memcpy(skb_frag_address(frag), scratch, len);
  91
  92                skb->truesize += len;
  93                skb->data_len += len;
  94                skb->len += len;
  95
  96                skb_shinfo(skb)->nr_frags++;
  97        }
  98
  99        err = 0;
 100
 101out:
 102        put_cpu();
 103        return err;
 104}
 105
 106int ipcomp_input(struct xfrm_state *x, struct sk_buff *skb)
 107{
 108        int nexthdr;
 109        int err = -ENOMEM;
 110        struct ip_comp_hdr *ipch;
 111
 112        if (skb_linearize_cow(skb))
 113                goto out;
 114
 115        skb->ip_summed = CHECKSUM_NONE;
 116
 117        /* Remove ipcomp header and decompress original payload */
 118        ipch = (void *)skb->data;
 119        nexthdr = ipch->nexthdr;
 120
 121        skb->transport_header = skb->network_header + sizeof(*ipch);
 122        __skb_pull(skb, sizeof(*ipch));
 123        err = ipcomp_decompress(x, skb);
 124        if (err)
 125                goto out;
 126
 127        err = nexthdr;
 128
 129out:
 130        return err;
 131}
 132EXPORT_SYMBOL_GPL(ipcomp_input);
 133
 134static int ipcomp_compress(struct xfrm_state *x, struct sk_buff *skb)
 135{
 136        struct ipcomp_data *ipcd = x->data;
 137        const int plen = skb->len;
 138        int dlen = IPCOMP_SCRATCH_SIZE;
 139        u8 *start = skb->data;
 140        struct crypto_comp *tfm;
 141        u8 *scratch;
 142        int err;
 143
 144        local_bh_disable();
 145        scratch = *this_cpu_ptr(ipcomp_scratches);
 146        tfm = *this_cpu_ptr(ipcd->tfms);
 147        err = crypto_comp_compress(tfm, start, plen, scratch, &dlen);
 148        if (err)
 149                goto out;
 150
 151        if ((dlen + sizeof(struct ip_comp_hdr)) >= plen) {
 152                err = -EMSGSIZE;
 153                goto out;
 154        }
 155
 156        memcpy(start + sizeof(struct ip_comp_hdr), scratch, dlen);
 157        local_bh_enable();
 158
 159        pskb_trim(skb, dlen + sizeof(struct ip_comp_hdr));
 160        return 0;
 161
 162out:
 163        local_bh_enable();
 164        return err;
 165}
 166
 167int ipcomp_output(struct xfrm_state *x, struct sk_buff *skb)
 168{
 169        int err;
 170        struct ip_comp_hdr *ipch;
 171        struct ipcomp_data *ipcd = x->data;
 172
 173        if (skb->len < ipcd->threshold) {
 174                /* Don't bother compressing */
 175                goto out_ok;
 176        }
 177
 178        if (skb_linearize_cow(skb))
 179                goto out_ok;
 180
 181        err = ipcomp_compress(x, skb);
 182
 183        if (err) {
 184                goto out_ok;
 185        }
 186
 187        /* Install ipcomp header, convert into ipcomp datagram. */
 188        ipch = ip_comp_hdr(skb);
 189        ipch->nexthdr = *skb_mac_header(skb);
 190        ipch->flags = 0;
 191        ipch->cpi = htons((u16 )ntohl(x->id.spi));
 192        *skb_mac_header(skb) = IPPROTO_COMP;
 193out_ok:
 194        skb_push(skb, -skb_network_offset(skb));
 195        return 0;
 196}
 197EXPORT_SYMBOL_GPL(ipcomp_output);
 198
 199static void ipcomp_free_scratches(void)
 200{
 201        int i;
 202        void * __percpu *scratches;
 203
 204        if (--ipcomp_scratch_users)
 205                return;
 206
 207        scratches = ipcomp_scratches;
 208        if (!scratches)
 209                return;
 210
 211        for_each_possible_cpu(i)
 212                vfree(*per_cpu_ptr(scratches, i));
 213
 214        free_percpu(scratches);
 215}
 216
 217static void * __percpu *ipcomp_alloc_scratches(void)
 218{
 219        void * __percpu *scratches;
 220        int i;
 221
 222        if (ipcomp_scratch_users++)
 223                return ipcomp_scratches;
 224
 225        scratches = alloc_percpu(void *);
 226        if (!scratches)
 227                return NULL;
 228
 229        ipcomp_scratches = scratches;
 230
 231        for_each_possible_cpu(i) {
 232                void *scratch;
 233
 234                scratch = vmalloc_node(IPCOMP_SCRATCH_SIZE, cpu_to_node(i));
 235                if (!scratch)
 236                        return NULL;
 237                *per_cpu_ptr(scratches, i) = scratch;
 238        }
 239
 240        return scratches;
 241}
 242
 243static void ipcomp_free_tfms(struct crypto_comp * __percpu *tfms)
 244{
 245        struct ipcomp_tfms *pos;
 246        int cpu;
 247
 248        list_for_each_entry(pos, &ipcomp_tfms_list, list) {
 249                if (pos->tfms == tfms)
 250                        break;
 251        }
 252
 253        WARN_ON(!pos);
 254
 255        if (--pos->users)
 256                return;
 257
 258        list_del(&pos->list);
 259        kfree(pos);
 260
 261        if (!tfms)
 262                return;
 263
 264        for_each_possible_cpu(cpu) {
 265                struct crypto_comp *tfm = *per_cpu_ptr(tfms, cpu);
 266                crypto_free_comp(tfm);
 267        }
 268        free_percpu(tfms);
 269}
 270
 271static struct crypto_comp * __percpu *ipcomp_alloc_tfms(const char *alg_name)
 272{
 273        struct ipcomp_tfms *pos;
 274        struct crypto_comp * __percpu *tfms;
 275        int cpu;
 276
 277
 278        list_for_each_entry(pos, &ipcomp_tfms_list, list) {
 279                struct crypto_comp *tfm;
 280
 281                /* This can be any valid CPU ID so we don't need locking. */
 282                tfm = this_cpu_read(*pos->tfms);
 283
 284                if (!strcmp(crypto_comp_name(tfm), alg_name)) {
 285                        pos->users++;
 286                        return pos->tfms;
 287                }
 288        }
 289
 290        pos = kmalloc(sizeof(*pos), GFP_KERNEL);
 291        if (!pos)
 292                return NULL;
 293
 294        pos->users = 1;
 295        INIT_LIST_HEAD(&pos->list);
 296        list_add(&pos->list, &ipcomp_tfms_list);
 297
 298        pos->tfms = tfms = alloc_percpu(struct crypto_comp *);
 299        if (!tfms)
 300                goto error;
 301
 302        for_each_possible_cpu(cpu) {
 303                struct crypto_comp *tfm = crypto_alloc_comp(alg_name, 0,
 304                                                            CRYPTO_ALG_ASYNC);
 305                if (IS_ERR(tfm))
 306                        goto error;
 307                *per_cpu_ptr(tfms, cpu) = tfm;
 308        }
 309
 310        return tfms;
 311
 312error:
 313        ipcomp_free_tfms(tfms);
 314        return NULL;
 315}
 316
 317static void ipcomp_free_data(struct ipcomp_data *ipcd)
 318{
 319        if (ipcd->tfms)
 320                ipcomp_free_tfms(ipcd->tfms);
 321        ipcomp_free_scratches();
 322}
 323
 324void ipcomp_destroy(struct xfrm_state *x)
 325{
 326        struct ipcomp_data *ipcd = x->data;
 327        if (!ipcd)
 328                return;
 329        xfrm_state_delete_tunnel(x);
 330        mutex_lock(&ipcomp_resource_mutex);
 331        ipcomp_free_data(ipcd);
 332        mutex_unlock(&ipcomp_resource_mutex);
 333        kfree(ipcd);
 334}
 335EXPORT_SYMBOL_GPL(ipcomp_destroy);
 336
 337int ipcomp_init_state(struct xfrm_state *x)
 338{
 339        int err;
 340        struct ipcomp_data *ipcd;
 341        struct xfrm_algo_desc *calg_desc;
 342
 343        err = -EINVAL;
 344        if (!x->calg)
 345                goto out;
 346
 347        if (x->encap)
 348                goto out;
 349
 350        err = -ENOMEM;
 351        ipcd = kzalloc(sizeof(*ipcd), GFP_KERNEL);
 352        if (!ipcd)
 353                goto out;
 354
 355        mutex_lock(&ipcomp_resource_mutex);
 356        if (!ipcomp_alloc_scratches())
 357                goto error;
 358
 359        ipcd->tfms = ipcomp_alloc_tfms(x->calg->alg_name);
 360        if (!ipcd->tfms)
 361                goto error;
 362        mutex_unlock(&ipcomp_resource_mutex);
 363
 364        calg_desc = xfrm_calg_get_byname(x->calg->alg_name, 0);
 365        BUG_ON(!calg_desc);
 366        ipcd->threshold = calg_desc->uinfo.comp.threshold;
 367        x->data = ipcd;
 368        err = 0;
 369out:
 370        return err;
 371
 372error:
 373        ipcomp_free_data(ipcd);
 374        mutex_unlock(&ipcomp_resource_mutex);
 375        kfree(ipcd);
 376        goto out;
 377}
 378EXPORT_SYMBOL_GPL(ipcomp_init_state);
 379
 380MODULE_LICENSE("GPL");
 381MODULE_DESCRIPTION("IP Payload Compression Protocol (IPComp) - RFC3173");
 382MODULE_AUTHOR("James Morris <jmorris@intercode.com.au>");
 383