linux/net/mpls/af_mpls.c
<<
>>
Prefs
   1#include <linux/types.h>
   2#include <linux/skbuff.h>
   3#include <linux/socket.h>
   4#include <linux/sysctl.h>
   5#include <linux/net.h>
   6#include <linux/module.h>
   7#include <linux/if_arp.h>
   8#include <linux/ipv6.h>
   9#include <linux/mpls.h>
  10#include <linux/vmalloc.h>
  11#include <net/ip.h>
  12#include <net/dst.h>
  13#include <net/sock.h>
  14#include <net/arp.h>
  15#include <net/ip_fib.h>
  16#include <net/netevent.h>
  17#include <net/netns/generic.h>
  18#include "internal.h"
  19
  20#define LABEL_NOT_SPECIFIED (1<<20)
  21#define MAX_NEW_LABELS 2
  22
  23/* This maximum ha length copied from the definition of struct neighbour */
  24#define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, sizeof(unsigned long)))
  25
  26struct mpls_route { /* next hop label forwarding entry */
  27        struct net_device __rcu *rt_dev;
  28        struct rcu_head         rt_rcu;
  29        u32                     rt_label[MAX_NEW_LABELS];
  30        u8                      rt_protocol; /* routing protocol that set this entry */
  31        u8                      rt_labels;
  32        u8                      rt_via_alen;
  33        u8                      rt_via_table;
  34        u8                      rt_via[0];
  35};
  36
  37static int zero = 0;
  38static int label_limit = (1 << 20) - 1;
  39
  40static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
  41                       struct nlmsghdr *nlh, struct net *net, u32 portid,
  42                       unsigned int nlm_flags);
  43
  44static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
  45{
  46        struct mpls_route *rt = NULL;
  47
  48        if (index < net->mpls.platform_labels) {
  49                struct mpls_route __rcu **platform_label =
  50                        rcu_dereference(net->mpls.platform_label);
  51                rt = rcu_dereference(platform_label[index]);
  52        }
  53        return rt;
  54}
  55
  56static inline struct mpls_dev *mpls_dev_get(const struct net_device *dev)
  57{
  58        return rcu_dereference_rtnl(dev->mpls_ptr);
  59}
  60
  61static bool mpls_output_possible(const struct net_device *dev)
  62{
  63        return dev && (dev->flags & IFF_UP) && netif_carrier_ok(dev);
  64}
  65
  66static unsigned int mpls_rt_header_size(const struct mpls_route *rt)
  67{
  68        /* The size of the layer 2.5 labels to be added for this route */
  69        return rt->rt_labels * sizeof(struct mpls_shim_hdr);
  70}
  71
  72static unsigned int mpls_dev_mtu(const struct net_device *dev)
  73{
  74        /* The amount of data the layer 2 frame can hold */
  75        return dev->mtu;
  76}
  77
  78static bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
  79{
  80        if (skb->len <= mtu)
  81                return false;
  82
  83        if (skb_is_gso(skb) && skb_gso_network_seglen(skb) <= mtu)
  84                return false;
  85
  86        return true;
  87}
  88
  89static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
  90                        struct mpls_entry_decoded dec)
  91{
  92        /* RFC4385 and RFC5586 encode other packets in mpls such that
  93         * they don't conflict with the ip version number, making
  94         * decoding by examining the ip version correct in everything
  95         * except for the strangest cases.
  96         *
  97         * The strange cases if we choose to support them will require
  98         * manual configuration.
  99         */
 100        struct iphdr *hdr4;
 101        bool success = true;
 102
 103        /* The IPv4 code below accesses through the IPv4 header
 104         * checksum, which is 12 bytes into the packet.
 105         * The IPv6 code below accesses through the IPv6 hop limit
 106         * which is 8 bytes into the packet.
 107         *
 108         * For all supported cases there should always be at least 12
 109         * bytes of packet data present.  The IPv4 header is 20 bytes
 110         * without options and the IPv6 header is always 40 bytes
 111         * long.
 112         */
 113        if (!pskb_may_pull(skb, 12))
 114                return false;
 115
 116        /* Use ip_hdr to find the ip protocol version */
 117        hdr4 = ip_hdr(skb);
 118        if (hdr4->version == 4) {
 119                skb->protocol = htons(ETH_P_IP);
 120                csum_replace2(&hdr4->check,
 121                              htons(hdr4->ttl << 8),
 122                              htons(dec.ttl << 8));
 123                hdr4->ttl = dec.ttl;
 124        }
 125        else if (hdr4->version == 6) {
 126                struct ipv6hdr *hdr6 = ipv6_hdr(skb);
 127                skb->protocol = htons(ETH_P_IPV6);
 128                hdr6->hop_limit = dec.ttl;
 129        }
 130        else
 131                /* version 0 and version 1 are used by pseudo wires */
 132                success = false;
 133        return success;
 134}
 135
 136static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
 137                        struct packet_type *pt, struct net_device *orig_dev)
 138{
 139        struct net *net = dev_net(dev);
 140        struct mpls_shim_hdr *hdr;
 141        struct mpls_route *rt;
 142        struct mpls_entry_decoded dec;
 143        struct net_device *out_dev;
 144        struct mpls_dev *mdev;
 145        unsigned int hh_len;
 146        unsigned int new_header_size;
 147        unsigned int mtu;
 148        int err;
 149
 150        /* Careful this entire function runs inside of an rcu critical section */
 151
 152        mdev = mpls_dev_get(dev);
 153        if (!mdev || !mdev->input_enabled)
 154                goto drop;
 155
 156        if (skb->pkt_type != PACKET_HOST)
 157                goto drop;
 158
 159        if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL)
 160                goto drop;
 161
 162        if (!pskb_may_pull(skb, sizeof(*hdr)))
 163                goto drop;
 164
 165        /* Read and decode the label */
 166        hdr = mpls_hdr(skb);
 167        dec = mpls_entry_decode(hdr);
 168
 169        /* Pop the label */
 170        skb_pull(skb, sizeof(*hdr));
 171        skb_reset_network_header(skb);
 172
 173        skb_orphan(skb);
 174
 175        rt = mpls_route_input_rcu(net, dec.label);
 176        if (!rt)
 177                goto drop;
 178
 179        /* Find the output device */
 180        out_dev = rcu_dereference(rt->rt_dev);
 181        if (!mpls_output_possible(out_dev))
 182                goto drop;
 183
 184        if (skb_warn_if_lro(skb))
 185                goto drop;
 186
 187        skb_forward_csum(skb);
 188
 189        /* Verify ttl is valid */
 190        if (dec.ttl <= 1)
 191                goto drop;
 192        dec.ttl -= 1;
 193
 194        /* Verify the destination can hold the packet */
 195        new_header_size = mpls_rt_header_size(rt);
 196        mtu = mpls_dev_mtu(out_dev);
 197        if (mpls_pkt_too_big(skb, mtu - new_header_size))
 198                goto drop;
 199
 200        hh_len = LL_RESERVED_SPACE(out_dev);
 201        if (!out_dev->header_ops)
 202                hh_len = 0;
 203
 204        /* Ensure there is enough space for the headers in the skb */
 205        if (skb_cow(skb, hh_len + new_header_size))
 206                goto drop;
 207
 208        skb->dev = out_dev;
 209        skb->protocol = htons(ETH_P_MPLS_UC);
 210
 211        if (unlikely(!new_header_size && dec.bos)) {
 212                /* Penultimate hop popping */
 213                if (!mpls_egress(rt, skb, dec))
 214                        goto drop;
 215        } else {
 216                bool bos;
 217                int i;
 218                skb_push(skb, new_header_size);
 219                skb_reset_network_header(skb);
 220                /* Push the new labels */
 221                hdr = mpls_hdr(skb);
 222                bos = dec.bos;
 223                for (i = rt->rt_labels - 1; i >= 0; i--) {
 224                        hdr[i] = mpls_entry_encode(rt->rt_label[i], dec.ttl, 0, bos);
 225                        bos = false;
 226                }
 227        }
 228
 229        err = neigh_xmit(rt->rt_via_table, out_dev, rt->rt_via, skb);
 230        if (err)
 231                net_dbg_ratelimited("%s: packet transmission failed: %d\n",
 232                                    __func__, err);
 233        return 0;
 234
 235drop:
 236        kfree_skb(skb);
 237        return NET_RX_DROP;
 238}
 239
 240static struct packet_type mpls_packet_type __read_mostly = {
 241        .type = cpu_to_be16(ETH_P_MPLS_UC),
 242        .func = mpls_forward,
 243};
 244
 245static const struct nla_policy rtm_mpls_policy[RTA_MAX+1] = {
 246        [RTA_DST]               = { .type = NLA_U32 },
 247        [RTA_OIF]               = { .type = NLA_U32 },
 248};
 249
 250struct mpls_route_config {
 251        u32             rc_protocol;
 252        u32             rc_ifindex;
 253        u16             rc_via_table;
 254        u16             rc_via_alen;
 255        u8              rc_via[MAX_VIA_ALEN];
 256        u32             rc_label;
 257        u32             rc_output_labels;
 258        u32             rc_output_label[MAX_NEW_LABELS];
 259        u32             rc_nlflags;
 260        struct nl_info  rc_nlinfo;
 261};
 262
 263static struct mpls_route *mpls_rt_alloc(size_t alen)
 264{
 265        struct mpls_route *rt;
 266
 267        rt = kzalloc(sizeof(*rt) + alen, GFP_KERNEL);
 268        if (rt)
 269                rt->rt_via_alen = alen;
 270        return rt;
 271}
 272
 273static void mpls_rt_free(struct mpls_route *rt)
 274{
 275        if (rt)
 276                kfree_rcu(rt, rt_rcu);
 277}
 278
 279static void mpls_notify_route(struct net *net, unsigned index,
 280                              struct mpls_route *old, struct mpls_route *new,
 281                              const struct nl_info *info)
 282{
 283        struct nlmsghdr *nlh = info ? info->nlh : NULL;
 284        unsigned portid = info ? info->portid : 0;
 285        int event = new ? RTM_NEWROUTE : RTM_DELROUTE;
 286        struct mpls_route *rt = new ? new : old;
 287        unsigned nlm_flags = (old && new) ? NLM_F_REPLACE : 0;
 288        /* Ignore reserved labels for now */
 289        if (rt && (index >= 16))
 290                rtmsg_lfib(event, index, rt, nlh, net, portid, nlm_flags);
 291}
 292
 293static void mpls_route_update(struct net *net, unsigned index,
 294                              struct net_device *dev, struct mpls_route *new,
 295                              const struct nl_info *info)
 296{
 297        struct mpls_route __rcu **platform_label;
 298        struct mpls_route *rt, *old = NULL;
 299
 300        ASSERT_RTNL();
 301
 302        platform_label = rtnl_dereference(net->mpls.platform_label);
 303        rt = rtnl_dereference(platform_label[index]);
 304        if (!dev || (rt && (rtnl_dereference(rt->rt_dev) == dev))) {
 305                rcu_assign_pointer(platform_label[index], new);
 306                old = rt;
 307        }
 308
 309        mpls_notify_route(net, index, old, new, info);
 310
 311        /* If we removed a route free it now */
 312        mpls_rt_free(old);
 313}
 314
 315static unsigned find_free_label(struct net *net)
 316{
 317        struct mpls_route __rcu **platform_label;
 318        size_t platform_labels;
 319        unsigned index;
 320
 321        platform_label = rtnl_dereference(net->mpls.platform_label);
 322        platform_labels = net->mpls.platform_labels;
 323        for (index = 16; index < platform_labels; index++) {
 324                if (!rtnl_dereference(platform_label[index]))
 325                        return index;
 326        }
 327        return LABEL_NOT_SPECIFIED;
 328}
 329
 330static int mpls_route_add(struct mpls_route_config *cfg)
 331{
 332        struct mpls_route __rcu **platform_label;
 333        struct net *net = cfg->rc_nlinfo.nl_net;
 334        struct net_device *dev = NULL;
 335        struct mpls_route *rt, *old;
 336        unsigned index;
 337        int i;
 338        int err = -EINVAL;
 339
 340        index = cfg->rc_label;
 341
 342        /* If a label was not specified during insert pick one */
 343        if ((index == LABEL_NOT_SPECIFIED) &&
 344            (cfg->rc_nlflags & NLM_F_CREATE)) {
 345                index = find_free_label(net);
 346        }
 347
 348        /* The first 16 labels are reserved, and may not be set */
 349        if (index < 16)
 350                goto errout;
 351
 352        /* The full 20 bit range may not be supported. */
 353        if (index >= net->mpls.platform_labels)
 354                goto errout;
 355
 356        /* Ensure only a supported number of labels are present */
 357        if (cfg->rc_output_labels > MAX_NEW_LABELS)
 358                goto errout;
 359
 360        err = -ENODEV;
 361        dev = dev_get_by_index(net, cfg->rc_ifindex);
 362        if (!dev)
 363                goto errout;
 364
 365        /* Ensure this is a supported device */
 366        err = -EINVAL;
 367        if (!mpls_dev_get(dev))
 368                goto errout;
 369
 370        err = -EINVAL;
 371        if ((cfg->rc_via_table == NEIGH_LINK_TABLE) &&
 372            (dev->addr_len != cfg->rc_via_alen))
 373                goto errout;
 374
 375        /* Append makes no sense with mpls */
 376        err = -EOPNOTSUPP;
 377        if (cfg->rc_nlflags & NLM_F_APPEND)
 378                goto errout;
 379
 380        err = -EEXIST;
 381        platform_label = rtnl_dereference(net->mpls.platform_label);
 382        old = rtnl_dereference(platform_label[index]);
 383        if ((cfg->rc_nlflags & NLM_F_EXCL) && old)
 384                goto errout;
 385
 386        err = -EEXIST;
 387        if (!(cfg->rc_nlflags & NLM_F_REPLACE) && old)
 388                goto errout;
 389
 390        err = -ENOENT;
 391        if (!(cfg->rc_nlflags & NLM_F_CREATE) && !old)
 392                goto errout;
 393
 394        err = -ENOMEM;
 395        rt = mpls_rt_alloc(cfg->rc_via_alen);
 396        if (!rt)
 397                goto errout;
 398
 399        rt->rt_labels = cfg->rc_output_labels;
 400        for (i = 0; i < rt->rt_labels; i++)
 401                rt->rt_label[i] = cfg->rc_output_label[i];
 402        rt->rt_protocol = cfg->rc_protocol;
 403        RCU_INIT_POINTER(rt->rt_dev, dev);
 404        rt->rt_via_table = cfg->rc_via_table;
 405        memcpy(rt->rt_via, cfg->rc_via, cfg->rc_via_alen);
 406
 407        mpls_route_update(net, index, NULL, rt, &cfg->rc_nlinfo);
 408
 409        dev_put(dev);
 410        return 0;
 411
 412errout:
 413        if (dev)
 414                dev_put(dev);
 415        return err;
 416}
 417
 418static int mpls_route_del(struct mpls_route_config *cfg)
 419{
 420        struct net *net = cfg->rc_nlinfo.nl_net;
 421        unsigned index;
 422        int err = -EINVAL;
 423
 424        index = cfg->rc_label;
 425
 426        /* The first 16 labels are reserved, and may not be removed */
 427        if (index < 16)
 428                goto errout;
 429
 430        /* The full 20 bit range may not be supported */
 431        if (index >= net->mpls.platform_labels)
 432                goto errout;
 433
 434        mpls_route_update(net, index, NULL, NULL, &cfg->rc_nlinfo);
 435
 436        err = 0;
 437errout:
 438        return err;
 439}
 440
 441#define MPLS_PERDEV_SYSCTL_OFFSET(field)        \
 442        (&((struct mpls_dev *)0)->field)
 443
 444static const struct ctl_table mpls_dev_table[] = {
 445        {
 446                .procname       = "input",
 447                .maxlen         = sizeof(int),
 448                .mode           = 0644,
 449                .proc_handler   = proc_dointvec,
 450                .data           = MPLS_PERDEV_SYSCTL_OFFSET(input_enabled),
 451        },
 452        { }
 453};
 454
 455static int mpls_dev_sysctl_register(struct net_device *dev,
 456                                    struct mpls_dev *mdev)
 457{
 458        char path[sizeof("net/mpls/conf/") + IFNAMSIZ];
 459        struct ctl_table *table;
 460        int i;
 461
 462        table = kmemdup(&mpls_dev_table, sizeof(mpls_dev_table), GFP_KERNEL);
 463        if (!table)
 464                goto out;
 465
 466        /* Table data contains only offsets relative to the base of
 467         * the mdev at this point, so make them absolute.
 468         */
 469        for (i = 0; i < ARRAY_SIZE(mpls_dev_table); i++)
 470                table[i].data = (char *)mdev + (uintptr_t)table[i].data;
 471
 472        snprintf(path, sizeof(path), "net/mpls/conf/%s", dev->name);
 473
 474        mdev->sysctl = register_net_sysctl(dev_net(dev), path, table);
 475        if (!mdev->sysctl)
 476                goto free;
 477
 478        return 0;
 479
 480free:
 481        kfree(table);
 482out:
 483        return -ENOBUFS;
 484}
 485
 486static void mpls_dev_sysctl_unregister(struct mpls_dev *mdev)
 487{
 488        struct ctl_table *table;
 489
 490        table = mdev->sysctl->ctl_table_arg;
 491        unregister_net_sysctl_table(mdev->sysctl);
 492        kfree(table);
 493}
 494
 495static struct mpls_dev *mpls_add_dev(struct net_device *dev)
 496{
 497        struct mpls_dev *mdev;
 498        int err = -ENOMEM;
 499
 500        ASSERT_RTNL();
 501
 502        mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
 503        if (!mdev)
 504                return ERR_PTR(err);
 505
 506        err = mpls_dev_sysctl_register(dev, mdev);
 507        if (err)
 508                goto free;
 509
 510        rcu_assign_pointer(dev->mpls_ptr, mdev);
 511
 512        return mdev;
 513
 514free:
 515        kfree(mdev);
 516        return ERR_PTR(err);
 517}
 518
 519static void mpls_ifdown(struct net_device *dev)
 520{
 521        struct mpls_route __rcu **platform_label;
 522        struct net *net = dev_net(dev);
 523        struct mpls_dev *mdev;
 524        unsigned index;
 525
 526        platform_label = rtnl_dereference(net->mpls.platform_label);
 527        for (index = 0; index < net->mpls.platform_labels; index++) {
 528                struct mpls_route *rt = rtnl_dereference(platform_label[index]);
 529                if (!rt)
 530                        continue;
 531                if (rtnl_dereference(rt->rt_dev) != dev)
 532                        continue;
 533                rt->rt_dev = NULL;
 534        }
 535
 536        mdev = mpls_dev_get(dev);
 537        if (!mdev)
 538                return;
 539
 540        mpls_dev_sysctl_unregister(mdev);
 541
 542        RCU_INIT_POINTER(dev->mpls_ptr, NULL);
 543
 544        kfree_rcu(mdev, rcu);
 545}
 546
 547static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
 548                           void *ptr)
 549{
 550        struct net_device *dev = netdev_notifier_info_to_dev(ptr);
 551        struct mpls_dev *mdev;
 552
 553        switch(event) {
 554        case NETDEV_REGISTER:
 555                /* For now just support ethernet devices */
 556                if ((dev->type == ARPHRD_ETHER) ||
 557                    (dev->type == ARPHRD_LOOPBACK)) {
 558                        mdev = mpls_add_dev(dev);
 559                        if (IS_ERR(mdev))
 560                                return notifier_from_errno(PTR_ERR(mdev));
 561                }
 562                break;
 563
 564        case NETDEV_UNREGISTER:
 565                mpls_ifdown(dev);
 566                break;
 567        case NETDEV_CHANGENAME:
 568                mdev = mpls_dev_get(dev);
 569                if (mdev) {
 570                        int err;
 571
 572                        mpls_dev_sysctl_unregister(mdev);
 573                        err = mpls_dev_sysctl_register(dev, mdev);
 574                        if (err)
 575                                return notifier_from_errno(err);
 576                }
 577                break;
 578        }
 579        return NOTIFY_OK;
 580}
 581
 582static struct notifier_block mpls_dev_notifier = {
 583        .notifier_call = mpls_dev_notify,
 584};
 585
 586static int nla_put_via(struct sk_buff *skb,
 587                       u8 table, const void *addr, int alen)
 588{
 589        static const int table_to_family[NEIGH_NR_TABLES + 1] = {
 590                AF_INET, AF_INET6, AF_DECnet, AF_PACKET,
 591        };
 592        struct nlattr *nla;
 593        struct rtvia *via;
 594        int family = AF_UNSPEC;
 595
 596        nla = nla_reserve(skb, RTA_VIA, alen + 2);
 597        if (!nla)
 598                return -EMSGSIZE;
 599
 600        if (table <= NEIGH_NR_TABLES)
 601                family = table_to_family[table];
 602
 603        via = nla_data(nla);
 604        via->rtvia_family = family;
 605        memcpy(via->rtvia_addr, addr, alen);
 606        return 0;
 607}
 608
 609int nla_put_labels(struct sk_buff *skb, int attrtype,
 610                   u8 labels, const u32 label[])
 611{
 612        struct nlattr *nla;
 613        struct mpls_shim_hdr *nla_label;
 614        bool bos;
 615        int i;
 616        nla = nla_reserve(skb, attrtype, labels*4);
 617        if (!nla)
 618                return -EMSGSIZE;
 619
 620        nla_label = nla_data(nla);
 621        bos = true;
 622        for (i = labels - 1; i >= 0; i--) {
 623                nla_label[i] = mpls_entry_encode(label[i], 0, 0, bos);
 624                bos = false;
 625        }
 626
 627        return 0;
 628}
 629
 630int nla_get_labels(const struct nlattr *nla,
 631                   u32 max_labels, u32 *labels, u32 label[])
 632{
 633        unsigned len = nla_len(nla);
 634        unsigned nla_labels;
 635        struct mpls_shim_hdr *nla_label;
 636        bool bos;
 637        int i;
 638
 639        /* len needs to be an even multiple of 4 (the label size) */
 640        if (len & 3)
 641                return -EINVAL;
 642
 643        /* Limit the number of new labels allowed */
 644        nla_labels = len/4;
 645        if (nla_labels > max_labels)
 646                return -EINVAL;
 647
 648        nla_label = nla_data(nla);
 649        bos = true;
 650        for (i = nla_labels - 1; i >= 0; i--, bos = false) {
 651                struct mpls_entry_decoded dec;
 652                dec = mpls_entry_decode(nla_label + i);
 653
 654                /* Ensure the bottom of stack flag is properly set
 655                 * and ttl and tc are both clear.
 656                 */
 657                if ((dec.bos != bos) || dec.ttl || dec.tc)
 658                        return -EINVAL;
 659
 660                switch (dec.label) {
 661                case MPLS_LABEL_IMPLNULL:
 662                        /* RFC3032: This is a label that an LSR may
 663                         * assign and distribute, but which never
 664                         * actually appears in the encapsulation.
 665                         */
 666                        return -EINVAL;
 667                }
 668
 669                label[i] = dec.label;
 670        }
 671        *labels = nla_labels;
 672        return 0;
 673}
 674
 675static int rtm_to_route_config(struct sk_buff *skb,  struct nlmsghdr *nlh,
 676                               struct mpls_route_config *cfg)
 677{
 678        struct rtmsg *rtm;
 679        struct nlattr *tb[RTA_MAX+1];
 680        int index;
 681        int err;
 682
 683        err = nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, rtm_mpls_policy);
 684        if (err < 0)
 685                goto errout;
 686
 687        err = -EINVAL;
 688        rtm = nlmsg_data(nlh);
 689        memset(cfg, 0, sizeof(*cfg));
 690
 691        if (rtm->rtm_family != AF_MPLS)
 692                goto errout;
 693        if (rtm->rtm_dst_len != 20)
 694                goto errout;
 695        if (rtm->rtm_src_len != 0)
 696                goto errout;
 697        if (rtm->rtm_tos != 0)
 698                goto errout;
 699        if (rtm->rtm_table != RT_TABLE_MAIN)
 700                goto errout;
 701        /* Any value is acceptable for rtm_protocol */
 702
 703        /* As mpls uses destination specific addresses
 704         * (or source specific address in the case of multicast)
 705         * all addresses have universal scope.
 706         */
 707        if (rtm->rtm_scope != RT_SCOPE_UNIVERSE)
 708                goto errout;
 709        if (rtm->rtm_type != RTN_UNICAST)
 710                goto errout;
 711        if (rtm->rtm_flags != 0)
 712                goto errout;
 713
 714        cfg->rc_label           = LABEL_NOT_SPECIFIED;
 715        cfg->rc_protocol        = rtm->rtm_protocol;
 716        cfg->rc_nlflags         = nlh->nlmsg_flags;
 717        cfg->rc_nlinfo.portid   = NETLINK_CB(skb).portid;
 718        cfg->rc_nlinfo.nlh      = nlh;
 719        cfg->rc_nlinfo.nl_net   = sock_net(skb->sk);
 720
 721        for (index = 0; index <= RTA_MAX; index++) {
 722                struct nlattr *nla = tb[index];
 723                if (!nla)
 724                        continue;
 725
 726                switch(index) {
 727                case RTA_OIF:
 728                        cfg->rc_ifindex = nla_get_u32(nla);
 729                        break;
 730                case RTA_NEWDST:
 731                        if (nla_get_labels(nla, MAX_NEW_LABELS,
 732                                           &cfg->rc_output_labels,
 733                                           cfg->rc_output_label))
 734                                goto errout;
 735                        break;
 736                case RTA_DST:
 737                {
 738                        u32 label_count;
 739                        if (nla_get_labels(nla, 1, &label_count,
 740                                           &cfg->rc_label))
 741                                goto errout;
 742
 743                        /* The first 16 labels are reserved, and may not be set */
 744                        if (cfg->rc_label < 16)
 745                                goto errout;
 746
 747                        break;
 748                }
 749                case RTA_VIA:
 750                {
 751                        struct rtvia *via = nla_data(nla);
 752                        if (nla_len(nla) < offsetof(struct rtvia, rtvia_addr))
 753                                goto errout;
 754                        cfg->rc_via_alen   = nla_len(nla) -
 755                                offsetof(struct rtvia, rtvia_addr);
 756                        if (cfg->rc_via_alen > MAX_VIA_ALEN)
 757                                goto errout;
 758
 759                        /* Validate the address family */
 760                        switch(via->rtvia_family) {
 761                        case AF_PACKET:
 762                                cfg->rc_via_table = NEIGH_LINK_TABLE;
 763                                break;
 764                        case AF_INET:
 765                                cfg->rc_via_table = NEIGH_ARP_TABLE;
 766                                if (cfg->rc_via_alen != 4)
 767                                        goto errout;
 768                                break;
 769                        case AF_INET6:
 770                                cfg->rc_via_table = NEIGH_ND_TABLE;
 771                                if (cfg->rc_via_alen != 16)
 772                                        goto errout;
 773                                break;
 774                        default:
 775                                /* Unsupported address family */
 776                                goto errout;
 777                        }
 778
 779                        memcpy(cfg->rc_via, via->rtvia_addr, cfg->rc_via_alen);
 780                        break;
 781                }
 782                default:
 783                        /* Unsupported attribute */
 784                        goto errout;
 785                }
 786        }
 787
 788        err = 0;
 789errout:
 790        return err;
 791}
 792
 793static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh)
 794{
 795        struct mpls_route_config cfg;
 796        int err;
 797
 798        err = rtm_to_route_config(skb, nlh, &cfg);
 799        if (err < 0)
 800                return err;
 801
 802        return mpls_route_del(&cfg);
 803}
 804
 805
 806static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh)
 807{
 808        struct mpls_route_config cfg;
 809        int err;
 810
 811        err = rtm_to_route_config(skb, nlh, &cfg);
 812        if (err < 0)
 813                return err;
 814
 815        return mpls_route_add(&cfg);
 816}
 817
 818static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
 819                           u32 label, struct mpls_route *rt, int flags)
 820{
 821        struct net_device *dev;
 822        struct nlmsghdr *nlh;
 823        struct rtmsg *rtm;
 824
 825        nlh = nlmsg_put(skb, portid, seq, event, sizeof(*rtm), flags);
 826        if (nlh == NULL)
 827                return -EMSGSIZE;
 828
 829        rtm = nlmsg_data(nlh);
 830        rtm->rtm_family = AF_MPLS;
 831        rtm->rtm_dst_len = 20;
 832        rtm->rtm_src_len = 0;
 833        rtm->rtm_tos = 0;
 834        rtm->rtm_table = RT_TABLE_MAIN;
 835        rtm->rtm_protocol = rt->rt_protocol;
 836        rtm->rtm_scope = RT_SCOPE_UNIVERSE;
 837        rtm->rtm_type = RTN_UNICAST;
 838        rtm->rtm_flags = 0;
 839
 840        if (rt->rt_labels &&
 841            nla_put_labels(skb, RTA_NEWDST, rt->rt_labels, rt->rt_label))
 842                goto nla_put_failure;
 843        if (nla_put_via(skb, rt->rt_via_table, rt->rt_via, rt->rt_via_alen))
 844                goto nla_put_failure;
 845        dev = rtnl_dereference(rt->rt_dev);
 846        if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex))
 847                goto nla_put_failure;
 848        if (nla_put_labels(skb, RTA_DST, 1, &label))
 849                goto nla_put_failure;
 850
 851        nlmsg_end(skb, nlh);
 852        return 0;
 853
 854nla_put_failure:
 855        nlmsg_cancel(skb, nlh);
 856        return -EMSGSIZE;
 857}
 858
 859static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
 860{
 861        struct net *net = sock_net(skb->sk);
 862        struct mpls_route __rcu **platform_label;
 863        size_t platform_labels;
 864        unsigned int index;
 865
 866        ASSERT_RTNL();
 867
 868        index = cb->args[0];
 869        if (index < 16)
 870                index = 16;
 871
 872        platform_label = rtnl_dereference(net->mpls.platform_label);
 873        platform_labels = net->mpls.platform_labels;
 874        for (; index < platform_labels; index++) {
 875                struct mpls_route *rt;
 876                rt = rtnl_dereference(platform_label[index]);
 877                if (!rt)
 878                        continue;
 879
 880                if (mpls_dump_route(skb, NETLINK_CB(cb->skb).portid,
 881                                    cb->nlh->nlmsg_seq, RTM_NEWROUTE,
 882                                    index, rt, NLM_F_MULTI) < 0)
 883                        break;
 884        }
 885        cb->args[0] = index;
 886
 887        return skb->len;
 888}
 889
 890static inline size_t lfib_nlmsg_size(struct mpls_route *rt)
 891{
 892        size_t payload =
 893                NLMSG_ALIGN(sizeof(struct rtmsg))
 894                + nla_total_size(2 + rt->rt_via_alen)   /* RTA_VIA */
 895                + nla_total_size(4);                    /* RTA_DST */
 896        if (rt->rt_labels)                              /* RTA_NEWDST */
 897                payload += nla_total_size(rt->rt_labels * 4);
 898        if (rt->rt_dev)                                 /* RTA_OIF */
 899                payload += nla_total_size(4);
 900        return payload;
 901}
 902
 903static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
 904                       struct nlmsghdr *nlh, struct net *net, u32 portid,
 905                       unsigned int nlm_flags)
 906{
 907        struct sk_buff *skb;
 908        u32 seq = nlh ? nlh->nlmsg_seq : 0;
 909        int err = -ENOBUFS;
 910
 911        skb = nlmsg_new(lfib_nlmsg_size(rt), GFP_KERNEL);
 912        if (skb == NULL)
 913                goto errout;
 914
 915        err = mpls_dump_route(skb, portid, seq, event, label, rt, nlm_flags);
 916        if (err < 0) {
 917                /* -EMSGSIZE implies BUG in lfib_nlmsg_size */
 918                WARN_ON(err == -EMSGSIZE);
 919                kfree_skb(skb);
 920                goto errout;
 921        }
 922        rtnl_notify(skb, net, portid, RTNLGRP_MPLS_ROUTE, nlh, GFP_KERNEL);
 923
 924        return;
 925errout:
 926        if (err < 0)
 927                rtnl_set_sk_err(net, RTNLGRP_MPLS_ROUTE, err);
 928}
 929
 930static int resize_platform_label_table(struct net *net, size_t limit)
 931{
 932        size_t size = sizeof(struct mpls_route *) * limit;
 933        size_t old_limit;
 934        size_t cp_size;
 935        struct mpls_route __rcu **labels = NULL, **old;
 936        struct mpls_route *rt0 = NULL, *rt2 = NULL;
 937        unsigned index;
 938
 939        if (size) {
 940                labels = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
 941                if (!labels)
 942                        labels = vzalloc(size);
 943
 944                if (!labels)
 945                        goto nolabels;
 946        }
 947
 948        /* In case the predefined labels need to be populated */
 949        if (limit > MPLS_LABEL_IPV4NULL) {
 950                struct net_device *lo = net->loopback_dev;
 951                rt0 = mpls_rt_alloc(lo->addr_len);
 952                if (!rt0)
 953                        goto nort0;
 954                RCU_INIT_POINTER(rt0->rt_dev, lo);
 955                rt0->rt_protocol = RTPROT_KERNEL;
 956                rt0->rt_via_table = NEIGH_LINK_TABLE;
 957                memcpy(rt0->rt_via, lo->dev_addr, lo->addr_len);
 958        }
 959        if (limit > MPLS_LABEL_IPV6NULL) {
 960                struct net_device *lo = net->loopback_dev;
 961                rt2 = mpls_rt_alloc(lo->addr_len);
 962                if (!rt2)
 963                        goto nort2;
 964                RCU_INIT_POINTER(rt2->rt_dev, lo);
 965                rt2->rt_protocol = RTPROT_KERNEL;
 966                rt2->rt_via_table = NEIGH_LINK_TABLE;
 967                memcpy(rt2->rt_via, lo->dev_addr, lo->addr_len);
 968        }
 969
 970        rtnl_lock();
 971        /* Remember the original table */
 972        old = rtnl_dereference(net->mpls.platform_label);
 973        old_limit = net->mpls.platform_labels;
 974
 975        /* Free any labels beyond the new table */
 976        for (index = limit; index < old_limit; index++)
 977                mpls_route_update(net, index, NULL, NULL, NULL);
 978
 979        /* Copy over the old labels */
 980        cp_size = size;
 981        if (old_limit < limit)
 982                cp_size = old_limit * sizeof(struct mpls_route *);
 983
 984        memcpy(labels, old, cp_size);
 985
 986        /* If needed set the predefined labels */
 987        if ((old_limit <= MPLS_LABEL_IPV6NULL) &&
 988            (limit > MPLS_LABEL_IPV6NULL)) {
 989                RCU_INIT_POINTER(labels[MPLS_LABEL_IPV6NULL], rt2);
 990                rt2 = NULL;
 991        }
 992
 993        if ((old_limit <= MPLS_LABEL_IPV4NULL) &&
 994            (limit > MPLS_LABEL_IPV4NULL)) {
 995                RCU_INIT_POINTER(labels[MPLS_LABEL_IPV4NULL], rt0);
 996                rt0 = NULL;
 997        }
 998
 999        /* Update the global pointers */
1000        net->mpls.platform_labels = limit;
1001        rcu_assign_pointer(net->mpls.platform_label, labels);
1002
1003        rtnl_unlock();
1004
1005        mpls_rt_free(rt2);
1006        mpls_rt_free(rt0);
1007
1008        if (old) {
1009                synchronize_rcu();
1010                kvfree(old);
1011        }
1012        return 0;
1013
1014nort2:
1015        mpls_rt_free(rt0);
1016nort0:
1017        kvfree(labels);
1018nolabels:
1019        return -ENOMEM;
1020}
1021
1022static int mpls_platform_labels(struct ctl_table *table, int write,
1023                                void __user *buffer, size_t *lenp, loff_t *ppos)
1024{
1025        struct net *net = table->data;
1026        int platform_labels = net->mpls.platform_labels;
1027        int ret;
1028        struct ctl_table tmp = {
1029                .procname       = table->procname,
1030                .data           = &platform_labels,
1031                .maxlen         = sizeof(int),
1032                .mode           = table->mode,
1033                .extra1         = &zero,
1034                .extra2         = &label_limit,
1035        };
1036
1037        ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos);
1038
1039        if (write && ret == 0)
1040                ret = resize_platform_label_table(net, platform_labels);
1041
1042        return ret;
1043}
1044
1045static const struct ctl_table mpls_table[] = {
1046        {
1047                .procname       = "platform_labels",
1048                .data           = NULL,
1049                .maxlen         = sizeof(int),
1050                .mode           = 0644,
1051                .proc_handler   = mpls_platform_labels,
1052        },
1053        { }
1054};
1055
1056static int mpls_net_init(struct net *net)
1057{
1058        struct ctl_table *table;
1059
1060        net->mpls.platform_labels = 0;
1061        net->mpls.platform_label = NULL;
1062
1063        table = kmemdup(mpls_table, sizeof(mpls_table), GFP_KERNEL);
1064        if (table == NULL)
1065                return -ENOMEM;
1066
1067        table[0].data = net;
1068        net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
1069        if (net->mpls.ctl == NULL)
1070                return -ENOMEM;
1071
1072        return 0;
1073}
1074
1075static void mpls_net_exit(struct net *net)
1076{
1077        struct mpls_route __rcu **platform_label;
1078        size_t platform_labels;
1079        struct ctl_table *table;
1080        unsigned int index;
1081
1082        table = net->mpls.ctl->ctl_table_arg;
1083        unregister_net_sysctl_table(net->mpls.ctl);
1084        kfree(table);
1085
1086        /* An rcu grace period has passed since there was a device in
1087         * the network namespace (and thus the last in flight packet)
1088         * left this network namespace.  This is because
1089         * unregister_netdevice_many and netdev_run_todo has completed
1090         * for each network device that was in this network namespace.
1091         *
1092         * As such no additional rcu synchronization is necessary when
1093         * freeing the platform_label table.
1094         */
1095        rtnl_lock();
1096        platform_label = rtnl_dereference(net->mpls.platform_label);
1097        platform_labels = net->mpls.platform_labels;
1098        for (index = 0; index < platform_labels; index++) {
1099                struct mpls_route *rt = rtnl_dereference(platform_label[index]);
1100                RCU_INIT_POINTER(platform_label[index], NULL);
1101                mpls_rt_free(rt);
1102        }
1103        rtnl_unlock();
1104
1105        kvfree(platform_label);
1106}
1107
1108static struct pernet_operations mpls_net_ops = {
1109        .init = mpls_net_init,
1110        .exit = mpls_net_exit,
1111};
1112
1113static int __init mpls_init(void)
1114{
1115        int err;
1116
1117        BUILD_BUG_ON(sizeof(struct mpls_shim_hdr) != 4);
1118
1119        err = register_pernet_subsys(&mpls_net_ops);
1120        if (err)
1121                goto out;
1122
1123        err = register_netdevice_notifier(&mpls_dev_notifier);
1124        if (err)
1125                goto out_unregister_pernet;
1126
1127        dev_add_pack(&mpls_packet_type);
1128
1129        rtnl_register(PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL, NULL);
1130        rtnl_register(PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL, NULL);
1131        rtnl_register(PF_MPLS, RTM_GETROUTE, NULL, mpls_dump_routes, NULL);
1132        err = 0;
1133out:
1134        return err;
1135
1136out_unregister_pernet:
1137        unregister_pernet_subsys(&mpls_net_ops);
1138        goto out;
1139}
1140module_init(mpls_init);
1141
1142static void __exit mpls_exit(void)
1143{
1144        rtnl_unregister_all(PF_MPLS);
1145        dev_remove_pack(&mpls_packet_type);
1146        unregister_netdevice_notifier(&mpls_dev_notifier);
1147        unregister_pernet_subsys(&mpls_net_ops);
1148}
1149module_exit(mpls_exit);
1150
1151MODULE_DESCRIPTION("MultiProtocol Label Switching");
1152MODULE_LICENSE("GPL v2");
1153MODULE_ALIAS_NETPROTO(PF_MPLS);
1154