linux/net/mpls/af_mpls.c
<<
>>
Prefs
   1#include <linux/types.h>
   2#include <linux/skbuff.h>
   3#include <linux/socket.h>
   4#include <linux/sysctl.h>
   5#include <linux/net.h>
   6#include <linux/module.h>
   7#include <linux/if_arp.h>
   8#include <linux/ipv6.h>
   9#include <linux/mpls.h>
  10#include <linux/vmalloc.h>
  11#include <net/ip.h>
  12#include <net/dst.h>
  13#include <net/sock.h>
  14#include <net/arp.h>
  15#include <net/ip_fib.h>
  16#include <net/netevent.h>
  17#include <net/netns/generic.h>
  18#if IS_ENABLED(CONFIG_IPV6)
  19#include <net/ipv6.h>
  20#include <net/addrconf.h>
  21#endif
  22#include "internal.h"
  23
  24#define LABEL_NOT_SPECIFIED (1<<20)
  25#define MAX_NEW_LABELS 2
  26
  27/* This maximum ha length copied from the definition of struct neighbour */
  28#define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, sizeof(unsigned long)))
  29
  30enum mpls_payload_type {
  31        MPT_UNSPEC, /* IPv4 or IPv6 */
  32        MPT_IPV4 = 4,
  33        MPT_IPV6 = 6,
  34
  35        /* Other types not implemented:
  36         *  - Pseudo-wire with or without control word (RFC4385)
  37         *  - GAL (RFC5586)
  38         */
  39};
  40
  41struct mpls_route { /* next hop label forwarding entry */
  42        struct net_device __rcu *rt_dev;
  43        struct rcu_head         rt_rcu;
  44        u32                     rt_label[MAX_NEW_LABELS];
  45        u8                      rt_protocol; /* routing protocol that set this entry */
  46        u8                      rt_payload_type;
  47        u8                      rt_labels;
  48        u8                      rt_via_alen;
  49        u8                      rt_via_table;
  50        u8                      rt_via[0];
  51};
  52
  53static int zero = 0;
  54static int label_limit = (1 << 20) - 1;
  55
  56static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
  57                       struct nlmsghdr *nlh, struct net *net, u32 portid,
  58                       unsigned int nlm_flags);
  59
  60static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
  61{
  62        struct mpls_route *rt = NULL;
  63
  64        if (index < net->mpls.platform_labels) {
  65                struct mpls_route __rcu **platform_label =
  66                        rcu_dereference(net->mpls.platform_label);
  67                rt = rcu_dereference(platform_label[index]);
  68        }
  69        return rt;
  70}
  71
  72static inline struct mpls_dev *mpls_dev_get(const struct net_device *dev)
  73{
  74        return rcu_dereference_rtnl(dev->mpls_ptr);
  75}
  76
  77bool mpls_output_possible(const struct net_device *dev)
  78{
  79        return dev && (dev->flags & IFF_UP) && netif_carrier_ok(dev);
  80}
  81EXPORT_SYMBOL_GPL(mpls_output_possible);
  82
  83static unsigned int mpls_rt_header_size(const struct mpls_route *rt)
  84{
  85        /* The size of the layer 2.5 labels to be added for this route */
  86        return rt->rt_labels * sizeof(struct mpls_shim_hdr);
  87}
  88
  89unsigned int mpls_dev_mtu(const struct net_device *dev)
  90{
  91        /* The amount of data the layer 2 frame can hold */
  92        return dev->mtu;
  93}
  94EXPORT_SYMBOL_GPL(mpls_dev_mtu);
  95
  96bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
  97{
  98        if (skb->len <= mtu)
  99                return false;
 100
 101        if (skb_is_gso(skb) && skb_gso_network_seglen(skb) <= mtu)
 102                return false;
 103
 104        return true;
 105}
 106EXPORT_SYMBOL_GPL(mpls_pkt_too_big);
 107
 108static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
 109                        struct mpls_entry_decoded dec)
 110{
 111        enum mpls_payload_type payload_type;
 112        bool success = false;
 113
 114        /* The IPv4 code below accesses through the IPv4 header
 115         * checksum, which is 12 bytes into the packet.
 116         * The IPv6 code below accesses through the IPv6 hop limit
 117         * which is 8 bytes into the packet.
 118         *
 119         * For all supported cases there should always be at least 12
 120         * bytes of packet data present.  The IPv4 header is 20 bytes
 121         * without options and the IPv6 header is always 40 bytes
 122         * long.
 123         */
 124        if (!pskb_may_pull(skb, 12))
 125                return false;
 126
 127        payload_type = rt->rt_payload_type;
 128        if (payload_type == MPT_UNSPEC)
 129                payload_type = ip_hdr(skb)->version;
 130
 131        switch (payload_type) {
 132        case MPT_IPV4: {
 133                struct iphdr *hdr4 = ip_hdr(skb);
 134                skb->protocol = htons(ETH_P_IP);
 135                csum_replace2(&hdr4->check,
 136                              htons(hdr4->ttl << 8),
 137                              htons(dec.ttl << 8));
 138                hdr4->ttl = dec.ttl;
 139                success = true;
 140                break;
 141        }
 142        case MPT_IPV6: {
 143                struct ipv6hdr *hdr6 = ipv6_hdr(skb);
 144                skb->protocol = htons(ETH_P_IPV6);
 145                hdr6->hop_limit = dec.ttl;
 146                success = true;
 147                break;
 148        }
 149        case MPT_UNSPEC:
 150                break;
 151        }
 152
 153        return success;
 154}
 155
 156static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
 157                        struct packet_type *pt, struct net_device *orig_dev)
 158{
 159        struct net *net = dev_net(dev);
 160        struct mpls_shim_hdr *hdr;
 161        struct mpls_route *rt;
 162        struct mpls_entry_decoded dec;
 163        struct net_device *out_dev;
 164        struct mpls_dev *mdev;
 165        unsigned int hh_len;
 166        unsigned int new_header_size;
 167        unsigned int mtu;
 168        int err;
 169
 170        /* Careful this entire function runs inside of an rcu critical section */
 171
 172        mdev = mpls_dev_get(dev);
 173        if (!mdev || !mdev->input_enabled)
 174                goto drop;
 175
 176        if (skb->pkt_type != PACKET_HOST)
 177                goto drop;
 178
 179        if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL)
 180                goto drop;
 181
 182        if (!pskb_may_pull(skb, sizeof(*hdr)))
 183                goto drop;
 184
 185        /* Read and decode the label */
 186        hdr = mpls_hdr(skb);
 187        dec = mpls_entry_decode(hdr);
 188
 189        /* Pop the label */
 190        skb_pull(skb, sizeof(*hdr));
 191        skb_reset_network_header(skb);
 192
 193        skb_orphan(skb);
 194
 195        rt = mpls_route_input_rcu(net, dec.label);
 196        if (!rt)
 197                goto drop;
 198
 199        /* Find the output device */
 200        out_dev = rcu_dereference(rt->rt_dev);
 201        if (!mpls_output_possible(out_dev))
 202                goto drop;
 203
 204        if (skb_warn_if_lro(skb))
 205                goto drop;
 206
 207        skb_forward_csum(skb);
 208
 209        /* Verify ttl is valid */
 210        if (dec.ttl <= 1)
 211                goto drop;
 212        dec.ttl -= 1;
 213
 214        /* Verify the destination can hold the packet */
 215        new_header_size = mpls_rt_header_size(rt);
 216        mtu = mpls_dev_mtu(out_dev);
 217        if (mpls_pkt_too_big(skb, mtu - new_header_size))
 218                goto drop;
 219
 220        hh_len = LL_RESERVED_SPACE(out_dev);
 221        if (!out_dev->header_ops)
 222                hh_len = 0;
 223
 224        /* Ensure there is enough space for the headers in the skb */
 225        if (skb_cow(skb, hh_len + new_header_size))
 226                goto drop;
 227
 228        skb->dev = out_dev;
 229        skb->protocol = htons(ETH_P_MPLS_UC);
 230
 231        if (unlikely(!new_header_size && dec.bos)) {
 232                /* Penultimate hop popping */
 233                if (!mpls_egress(rt, skb, dec))
 234                        goto drop;
 235        } else {
 236                bool bos;
 237                int i;
 238                skb_push(skb, new_header_size);
 239                skb_reset_network_header(skb);
 240                /* Push the new labels */
 241                hdr = mpls_hdr(skb);
 242                bos = dec.bos;
 243                for (i = rt->rt_labels - 1; i >= 0; i--) {
 244                        hdr[i] = mpls_entry_encode(rt->rt_label[i], dec.ttl, 0, bos);
 245                        bos = false;
 246                }
 247        }
 248
 249        err = neigh_xmit(rt->rt_via_table, out_dev, rt->rt_via, skb);
 250        if (err)
 251                net_dbg_ratelimited("%s: packet transmission failed: %d\n",
 252                                    __func__, err);
 253        return 0;
 254
 255drop:
 256        kfree_skb(skb);
 257        return NET_RX_DROP;
 258}
 259
 260static struct packet_type mpls_packet_type __read_mostly = {
 261        .type = cpu_to_be16(ETH_P_MPLS_UC),
 262        .func = mpls_forward,
 263};
 264
 265static const struct nla_policy rtm_mpls_policy[RTA_MAX+1] = {
 266        [RTA_DST]               = { .type = NLA_U32 },
 267        [RTA_OIF]               = { .type = NLA_U32 },
 268};
 269
 270struct mpls_route_config {
 271        u32                     rc_protocol;
 272        u32                     rc_ifindex;
 273        u16                     rc_via_table;
 274        u16                     rc_via_alen;
 275        u8                      rc_via[MAX_VIA_ALEN];
 276        u32                     rc_label;
 277        u32                     rc_output_labels;
 278        u32                     rc_output_label[MAX_NEW_LABELS];
 279        u32                     rc_nlflags;
 280        enum mpls_payload_type  rc_payload_type;
 281        struct nl_info          rc_nlinfo;
 282};
 283
 284static struct mpls_route *mpls_rt_alloc(size_t alen)
 285{
 286        struct mpls_route *rt;
 287
 288        rt = kzalloc(sizeof(*rt) + alen, GFP_KERNEL);
 289        if (rt)
 290                rt->rt_via_alen = alen;
 291        return rt;
 292}
 293
 294static void mpls_rt_free(struct mpls_route *rt)
 295{
 296        if (rt)
 297                kfree_rcu(rt, rt_rcu);
 298}
 299
 300static void mpls_notify_route(struct net *net, unsigned index,
 301                              struct mpls_route *old, struct mpls_route *new,
 302                              const struct nl_info *info)
 303{
 304        struct nlmsghdr *nlh = info ? info->nlh : NULL;
 305        unsigned portid = info ? info->portid : 0;
 306        int event = new ? RTM_NEWROUTE : RTM_DELROUTE;
 307        struct mpls_route *rt = new ? new : old;
 308        unsigned nlm_flags = (old && new) ? NLM_F_REPLACE : 0;
 309        /* Ignore reserved labels for now */
 310        if (rt && (index >= MPLS_LABEL_FIRST_UNRESERVED))
 311                rtmsg_lfib(event, index, rt, nlh, net, portid, nlm_flags);
 312}
 313
 314static void mpls_route_update(struct net *net, unsigned index,
 315                              struct net_device *dev, struct mpls_route *new,
 316                              const struct nl_info *info)
 317{
 318        struct mpls_route __rcu **platform_label;
 319        struct mpls_route *rt, *old = NULL;
 320
 321        ASSERT_RTNL();
 322
 323        platform_label = rtnl_dereference(net->mpls.platform_label);
 324        rt = rtnl_dereference(platform_label[index]);
 325        if (!dev || (rt && (rtnl_dereference(rt->rt_dev) == dev))) {
 326                rcu_assign_pointer(platform_label[index], new);
 327                old = rt;
 328        }
 329
 330        mpls_notify_route(net, index, old, new, info);
 331
 332        /* If we removed a route free it now */
 333        mpls_rt_free(old);
 334}
 335
 336static unsigned find_free_label(struct net *net)
 337{
 338        struct mpls_route __rcu **platform_label;
 339        size_t platform_labels;
 340        unsigned index;
 341
 342        platform_label = rtnl_dereference(net->mpls.platform_label);
 343        platform_labels = net->mpls.platform_labels;
 344        for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels;
 345             index++) {
 346                if (!rtnl_dereference(platform_label[index]))
 347                        return index;
 348        }
 349        return LABEL_NOT_SPECIFIED;
 350}
 351
 352#if IS_ENABLED(CONFIG_INET)
 353static struct net_device *inet_fib_lookup_dev(struct net *net, void *addr)
 354{
 355        struct net_device *dev;
 356        struct rtable *rt;
 357        struct in_addr daddr;
 358
 359        memcpy(&daddr, addr, sizeof(struct in_addr));
 360        rt = ip_route_output(net, daddr.s_addr, 0, 0, 0);
 361        if (IS_ERR(rt))
 362                return ERR_CAST(rt);
 363
 364        dev = rt->dst.dev;
 365        dev_hold(dev);
 366
 367        ip_rt_put(rt);
 368
 369        return dev;
 370}
 371#else
 372static struct net_device *inet_fib_lookup_dev(struct net *net, void *addr)
 373{
 374        return ERR_PTR(-EAFNOSUPPORT);
 375}
 376#endif
 377
 378#if IS_ENABLED(CONFIG_IPV6)
 379static struct net_device *inet6_fib_lookup_dev(struct net *net, void *addr)
 380{
 381        struct net_device *dev;
 382        struct dst_entry *dst;
 383        struct flowi6 fl6;
 384        int err;
 385
 386        if (!ipv6_stub)
 387                return ERR_PTR(-EAFNOSUPPORT);
 388
 389        memset(&fl6, 0, sizeof(fl6));
 390        memcpy(&fl6.daddr, addr, sizeof(struct in6_addr));
 391        err = ipv6_stub->ipv6_dst_lookup(net, NULL, &dst, &fl6);
 392        if (err)
 393                return ERR_PTR(err);
 394
 395        dev = dst->dev;
 396        dev_hold(dev);
 397        dst_release(dst);
 398
 399        return dev;
 400}
 401#else
 402static struct net_device *inet6_fib_lookup_dev(struct net *net, void *addr)
 403{
 404        return ERR_PTR(-EAFNOSUPPORT);
 405}
 406#endif
 407
 408static struct net_device *find_outdev(struct net *net,
 409                                      struct mpls_route_config *cfg)
 410{
 411        struct net_device *dev = NULL;
 412
 413        if (!cfg->rc_ifindex) {
 414                switch (cfg->rc_via_table) {
 415                case NEIGH_ARP_TABLE:
 416                        dev = inet_fib_lookup_dev(net, cfg->rc_via);
 417                        break;
 418                case NEIGH_ND_TABLE:
 419                        dev = inet6_fib_lookup_dev(net, cfg->rc_via);
 420                        break;
 421                case NEIGH_LINK_TABLE:
 422                        break;
 423                }
 424        } else {
 425                dev = dev_get_by_index(net, cfg->rc_ifindex);
 426        }
 427
 428        if (!dev)
 429                return ERR_PTR(-ENODEV);
 430
 431        return dev;
 432}
 433
 434static int mpls_route_add(struct mpls_route_config *cfg)
 435{
 436        struct mpls_route __rcu **platform_label;
 437        struct net *net = cfg->rc_nlinfo.nl_net;
 438        struct net_device *dev = NULL;
 439        struct mpls_route *rt, *old;
 440        unsigned index;
 441        int i;
 442        int err = -EINVAL;
 443
 444        index = cfg->rc_label;
 445
 446        /* If a label was not specified during insert pick one */
 447        if ((index == LABEL_NOT_SPECIFIED) &&
 448            (cfg->rc_nlflags & NLM_F_CREATE)) {
 449                index = find_free_label(net);
 450        }
 451
 452        /* Reserved labels may not be set */
 453        if (index < MPLS_LABEL_FIRST_UNRESERVED)
 454                goto errout;
 455
 456        /* The full 20 bit range may not be supported. */
 457        if (index >= net->mpls.platform_labels)
 458                goto errout;
 459
 460        /* Ensure only a supported number of labels are present */
 461        if (cfg->rc_output_labels > MAX_NEW_LABELS)
 462                goto errout;
 463
 464        dev = find_outdev(net, cfg);
 465        if (IS_ERR(dev)) {
 466                err = PTR_ERR(dev);
 467                dev = NULL;
 468                goto errout;
 469        }
 470
 471        /* Ensure this is a supported device */
 472        err = -EINVAL;
 473        if (!mpls_dev_get(dev))
 474                goto errout;
 475
 476        err = -EINVAL;
 477        if ((cfg->rc_via_table == NEIGH_LINK_TABLE) &&
 478            (dev->addr_len != cfg->rc_via_alen))
 479                goto errout;
 480
 481        /* Append makes no sense with mpls */
 482        err = -EOPNOTSUPP;
 483        if (cfg->rc_nlflags & NLM_F_APPEND)
 484                goto errout;
 485
 486        err = -EEXIST;
 487        platform_label = rtnl_dereference(net->mpls.platform_label);
 488        old = rtnl_dereference(platform_label[index]);
 489        if ((cfg->rc_nlflags & NLM_F_EXCL) && old)
 490                goto errout;
 491
 492        err = -EEXIST;
 493        if (!(cfg->rc_nlflags & NLM_F_REPLACE) && old)
 494                goto errout;
 495
 496        err = -ENOENT;
 497        if (!(cfg->rc_nlflags & NLM_F_CREATE) && !old)
 498                goto errout;
 499
 500        err = -ENOMEM;
 501        rt = mpls_rt_alloc(cfg->rc_via_alen);
 502        if (!rt)
 503                goto errout;
 504
 505        rt->rt_labels = cfg->rc_output_labels;
 506        for (i = 0; i < rt->rt_labels; i++)
 507                rt->rt_label[i] = cfg->rc_output_label[i];
 508        rt->rt_protocol = cfg->rc_protocol;
 509        RCU_INIT_POINTER(rt->rt_dev, dev);
 510        rt->rt_payload_type = cfg->rc_payload_type;
 511        rt->rt_via_table = cfg->rc_via_table;
 512        memcpy(rt->rt_via, cfg->rc_via, cfg->rc_via_alen);
 513
 514        mpls_route_update(net, index, NULL, rt, &cfg->rc_nlinfo);
 515
 516        dev_put(dev);
 517        return 0;
 518
 519errout:
 520        if (dev)
 521                dev_put(dev);
 522        return err;
 523}
 524
 525static int mpls_route_del(struct mpls_route_config *cfg)
 526{
 527        struct net *net = cfg->rc_nlinfo.nl_net;
 528        unsigned index;
 529        int err = -EINVAL;
 530
 531        index = cfg->rc_label;
 532
 533        /* Reserved labels may not be removed */
 534        if (index < MPLS_LABEL_FIRST_UNRESERVED)
 535                goto errout;
 536
 537        /* The full 20 bit range may not be supported */
 538        if (index >= net->mpls.platform_labels)
 539                goto errout;
 540
 541        mpls_route_update(net, index, NULL, NULL, &cfg->rc_nlinfo);
 542
 543        err = 0;
 544errout:
 545        return err;
 546}
 547
 548#define MPLS_PERDEV_SYSCTL_OFFSET(field)        \
 549        (&((struct mpls_dev *)0)->field)
 550
 551static const struct ctl_table mpls_dev_table[] = {
 552        {
 553                .procname       = "input",
 554                .maxlen         = sizeof(int),
 555                .mode           = 0644,
 556                .proc_handler   = proc_dointvec,
 557                .data           = MPLS_PERDEV_SYSCTL_OFFSET(input_enabled),
 558        },
 559        { }
 560};
 561
 562static int mpls_dev_sysctl_register(struct net_device *dev,
 563                                    struct mpls_dev *mdev)
 564{
 565        char path[sizeof("net/mpls/conf/") + IFNAMSIZ];
 566        struct ctl_table *table;
 567        int i;
 568
 569        table = kmemdup(&mpls_dev_table, sizeof(mpls_dev_table), GFP_KERNEL);
 570        if (!table)
 571                goto out;
 572
 573        /* Table data contains only offsets relative to the base of
 574         * the mdev at this point, so make them absolute.
 575         */
 576        for (i = 0; i < ARRAY_SIZE(mpls_dev_table); i++)
 577                table[i].data = (char *)mdev + (uintptr_t)table[i].data;
 578
 579        snprintf(path, sizeof(path), "net/mpls/conf/%s", dev->name);
 580
 581        mdev->sysctl = register_net_sysctl(dev_net(dev), path, table);
 582        if (!mdev->sysctl)
 583                goto free;
 584
 585        return 0;
 586
 587free:
 588        kfree(table);
 589out:
 590        return -ENOBUFS;
 591}
 592
 593static void mpls_dev_sysctl_unregister(struct mpls_dev *mdev)
 594{
 595        struct ctl_table *table;
 596
 597        table = mdev->sysctl->ctl_table_arg;
 598        unregister_net_sysctl_table(mdev->sysctl);
 599        kfree(table);
 600}
 601
 602static struct mpls_dev *mpls_add_dev(struct net_device *dev)
 603{
 604        struct mpls_dev *mdev;
 605        int err = -ENOMEM;
 606
 607        ASSERT_RTNL();
 608
 609        mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
 610        if (!mdev)
 611                return ERR_PTR(err);
 612
 613        err = mpls_dev_sysctl_register(dev, mdev);
 614        if (err)
 615                goto free;
 616
 617        rcu_assign_pointer(dev->mpls_ptr, mdev);
 618
 619        return mdev;
 620
 621free:
 622        kfree(mdev);
 623        return ERR_PTR(err);
 624}
 625
 626static void mpls_ifdown(struct net_device *dev)
 627{
 628        struct mpls_route __rcu **platform_label;
 629        struct net *net = dev_net(dev);
 630        struct mpls_dev *mdev;
 631        unsigned index;
 632
 633        platform_label = rtnl_dereference(net->mpls.platform_label);
 634        for (index = 0; index < net->mpls.platform_labels; index++) {
 635                struct mpls_route *rt = rtnl_dereference(platform_label[index]);
 636                if (!rt)
 637                        continue;
 638                if (rtnl_dereference(rt->rt_dev) != dev)
 639                        continue;
 640                rt->rt_dev = NULL;
 641        }
 642
 643        mdev = mpls_dev_get(dev);
 644        if (!mdev)
 645                return;
 646
 647        mpls_dev_sysctl_unregister(mdev);
 648
 649        RCU_INIT_POINTER(dev->mpls_ptr, NULL);
 650
 651        kfree_rcu(mdev, rcu);
 652}
 653
 654static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
 655                           void *ptr)
 656{
 657        struct net_device *dev = netdev_notifier_info_to_dev(ptr);
 658        struct mpls_dev *mdev;
 659
 660        switch(event) {
 661        case NETDEV_REGISTER:
 662                /* For now just support ethernet devices */
 663                if ((dev->type == ARPHRD_ETHER) ||
 664                    (dev->type == ARPHRD_LOOPBACK)) {
 665                        mdev = mpls_add_dev(dev);
 666                        if (IS_ERR(mdev))
 667                                return notifier_from_errno(PTR_ERR(mdev));
 668                }
 669                break;
 670
 671        case NETDEV_UNREGISTER:
 672                mpls_ifdown(dev);
 673                break;
 674        case NETDEV_CHANGENAME:
 675                mdev = mpls_dev_get(dev);
 676                if (mdev) {
 677                        int err;
 678
 679                        mpls_dev_sysctl_unregister(mdev);
 680                        err = mpls_dev_sysctl_register(dev, mdev);
 681                        if (err)
 682                                return notifier_from_errno(err);
 683                }
 684                break;
 685        }
 686        return NOTIFY_OK;
 687}
 688
 689static struct notifier_block mpls_dev_notifier = {
 690        .notifier_call = mpls_dev_notify,
 691};
 692
 693static int nla_put_via(struct sk_buff *skb,
 694                       u8 table, const void *addr, int alen)
 695{
 696        static const int table_to_family[NEIGH_NR_TABLES + 1] = {
 697                AF_INET, AF_INET6, AF_DECnet, AF_PACKET,
 698        };
 699        struct nlattr *nla;
 700        struct rtvia *via;
 701        int family = AF_UNSPEC;
 702
 703        nla = nla_reserve(skb, RTA_VIA, alen + 2);
 704        if (!nla)
 705                return -EMSGSIZE;
 706
 707        if (table <= NEIGH_NR_TABLES)
 708                family = table_to_family[table];
 709
 710        via = nla_data(nla);
 711        via->rtvia_family = family;
 712        memcpy(via->rtvia_addr, addr, alen);
 713        return 0;
 714}
 715
 716int nla_put_labels(struct sk_buff *skb, int attrtype,
 717                   u8 labels, const u32 label[])
 718{
 719        struct nlattr *nla;
 720        struct mpls_shim_hdr *nla_label;
 721        bool bos;
 722        int i;
 723        nla = nla_reserve(skb, attrtype, labels*4);
 724        if (!nla)
 725                return -EMSGSIZE;
 726
 727        nla_label = nla_data(nla);
 728        bos = true;
 729        for (i = labels - 1; i >= 0; i--) {
 730                nla_label[i] = mpls_entry_encode(label[i], 0, 0, bos);
 731                bos = false;
 732        }
 733
 734        return 0;
 735}
 736EXPORT_SYMBOL_GPL(nla_put_labels);
 737
 738int nla_get_labels(const struct nlattr *nla,
 739                   u32 max_labels, u32 *labels, u32 label[])
 740{
 741        unsigned len = nla_len(nla);
 742        unsigned nla_labels;
 743        struct mpls_shim_hdr *nla_label;
 744        bool bos;
 745        int i;
 746
 747        /* len needs to be an even multiple of 4 (the label size) */
 748        if (len & 3)
 749                return -EINVAL;
 750
 751        /* Limit the number of new labels allowed */
 752        nla_labels = len/4;
 753        if (nla_labels > max_labels)
 754                return -EINVAL;
 755
 756        nla_label = nla_data(nla);
 757        bos = true;
 758        for (i = nla_labels - 1; i >= 0; i--, bos = false) {
 759                struct mpls_entry_decoded dec;
 760                dec = mpls_entry_decode(nla_label + i);
 761
 762                /* Ensure the bottom of stack flag is properly set
 763                 * and ttl and tc are both clear.
 764                 */
 765                if ((dec.bos != bos) || dec.ttl || dec.tc)
 766                        return -EINVAL;
 767
 768                switch (dec.label) {
 769                case MPLS_LABEL_IMPLNULL:
 770                        /* RFC3032: This is a label that an LSR may
 771                         * assign and distribute, but which never
 772                         * actually appears in the encapsulation.
 773                         */
 774                        return -EINVAL;
 775                }
 776
 777                label[i] = dec.label;
 778        }
 779        *labels = nla_labels;
 780        return 0;
 781}
 782EXPORT_SYMBOL_GPL(nla_get_labels);
 783
 784static int rtm_to_route_config(struct sk_buff *skb,  struct nlmsghdr *nlh,
 785                               struct mpls_route_config *cfg)
 786{
 787        struct rtmsg *rtm;
 788        struct nlattr *tb[RTA_MAX+1];
 789        int index;
 790        int err;
 791
 792        err = nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, rtm_mpls_policy);
 793        if (err < 0)
 794                goto errout;
 795
 796        err = -EINVAL;
 797        rtm = nlmsg_data(nlh);
 798        memset(cfg, 0, sizeof(*cfg));
 799
 800        if (rtm->rtm_family != AF_MPLS)
 801                goto errout;
 802        if (rtm->rtm_dst_len != 20)
 803                goto errout;
 804        if (rtm->rtm_src_len != 0)
 805                goto errout;
 806        if (rtm->rtm_tos != 0)
 807                goto errout;
 808        if (rtm->rtm_table != RT_TABLE_MAIN)
 809                goto errout;
 810        /* Any value is acceptable for rtm_protocol */
 811
 812        /* As mpls uses destination specific addresses
 813         * (or source specific address in the case of multicast)
 814         * all addresses have universal scope.
 815         */
 816        if (rtm->rtm_scope != RT_SCOPE_UNIVERSE)
 817                goto errout;
 818        if (rtm->rtm_type != RTN_UNICAST)
 819                goto errout;
 820        if (rtm->rtm_flags != 0)
 821                goto errout;
 822
 823        cfg->rc_label           = LABEL_NOT_SPECIFIED;
 824        cfg->rc_protocol        = rtm->rtm_protocol;
 825        cfg->rc_nlflags         = nlh->nlmsg_flags;
 826        cfg->rc_nlinfo.portid   = NETLINK_CB(skb).portid;
 827        cfg->rc_nlinfo.nlh      = nlh;
 828        cfg->rc_nlinfo.nl_net   = sock_net(skb->sk);
 829
 830        for (index = 0; index <= RTA_MAX; index++) {
 831                struct nlattr *nla = tb[index];
 832                if (!nla)
 833                        continue;
 834
 835                switch(index) {
 836                case RTA_OIF:
 837                        cfg->rc_ifindex = nla_get_u32(nla);
 838                        break;
 839                case RTA_NEWDST:
 840                        if (nla_get_labels(nla, MAX_NEW_LABELS,
 841                                           &cfg->rc_output_labels,
 842                                           cfg->rc_output_label))
 843                                goto errout;
 844                        break;
 845                case RTA_DST:
 846                {
 847                        u32 label_count;
 848                        if (nla_get_labels(nla, 1, &label_count,
 849                                           &cfg->rc_label))
 850                                goto errout;
 851
 852                        /* Reserved labels may not be set */
 853                        if (cfg->rc_label < MPLS_LABEL_FIRST_UNRESERVED)
 854                                goto errout;
 855
 856                        break;
 857                }
 858                case RTA_VIA:
 859                {
 860                        struct rtvia *via = nla_data(nla);
 861                        if (nla_len(nla) < offsetof(struct rtvia, rtvia_addr))
 862                                goto errout;
 863                        cfg->rc_via_alen   = nla_len(nla) -
 864                                offsetof(struct rtvia, rtvia_addr);
 865                        if (cfg->rc_via_alen > MAX_VIA_ALEN)
 866                                goto errout;
 867
 868                        /* Validate the address family */
 869                        switch(via->rtvia_family) {
 870                        case AF_PACKET:
 871                                cfg->rc_via_table = NEIGH_LINK_TABLE;
 872                                break;
 873                        case AF_INET:
 874                                cfg->rc_via_table = NEIGH_ARP_TABLE;
 875                                if (cfg->rc_via_alen != 4)
 876                                        goto errout;
 877                                break;
 878                        case AF_INET6:
 879                                cfg->rc_via_table = NEIGH_ND_TABLE;
 880                                if (cfg->rc_via_alen != 16)
 881                                        goto errout;
 882                                break;
 883                        default:
 884                                /* Unsupported address family */
 885                                goto errout;
 886                        }
 887
 888                        memcpy(cfg->rc_via, via->rtvia_addr, cfg->rc_via_alen);
 889                        break;
 890                }
 891                default:
 892                        /* Unsupported attribute */
 893                        goto errout;
 894                }
 895        }
 896
 897        err = 0;
 898errout:
 899        return err;
 900}
 901
 902static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh)
 903{
 904        struct mpls_route_config cfg;
 905        int err;
 906
 907        err = rtm_to_route_config(skb, nlh, &cfg);
 908        if (err < 0)
 909                return err;
 910
 911        return mpls_route_del(&cfg);
 912}
 913
 914
 915static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh)
 916{
 917        struct mpls_route_config cfg;
 918        int err;
 919
 920        err = rtm_to_route_config(skb, nlh, &cfg);
 921        if (err < 0)
 922                return err;
 923
 924        return mpls_route_add(&cfg);
 925}
 926
 927static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
 928                           u32 label, struct mpls_route *rt, int flags)
 929{
 930        struct net_device *dev;
 931        struct nlmsghdr *nlh;
 932        struct rtmsg *rtm;
 933
 934        nlh = nlmsg_put(skb, portid, seq, event, sizeof(*rtm), flags);
 935        if (nlh == NULL)
 936                return -EMSGSIZE;
 937
 938        rtm = nlmsg_data(nlh);
 939        rtm->rtm_family = AF_MPLS;
 940        rtm->rtm_dst_len = 20;
 941        rtm->rtm_src_len = 0;
 942        rtm->rtm_tos = 0;
 943        rtm->rtm_table = RT_TABLE_MAIN;
 944        rtm->rtm_protocol = rt->rt_protocol;
 945        rtm->rtm_scope = RT_SCOPE_UNIVERSE;
 946        rtm->rtm_type = RTN_UNICAST;
 947        rtm->rtm_flags = 0;
 948
 949        if (rt->rt_labels &&
 950            nla_put_labels(skb, RTA_NEWDST, rt->rt_labels, rt->rt_label))
 951                goto nla_put_failure;
 952        if (nla_put_via(skb, rt->rt_via_table, rt->rt_via, rt->rt_via_alen))
 953                goto nla_put_failure;
 954        dev = rtnl_dereference(rt->rt_dev);
 955        if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex))
 956                goto nla_put_failure;
 957        if (nla_put_labels(skb, RTA_DST, 1, &label))
 958                goto nla_put_failure;
 959
 960        nlmsg_end(skb, nlh);
 961        return 0;
 962
 963nla_put_failure:
 964        nlmsg_cancel(skb, nlh);
 965        return -EMSGSIZE;
 966}
 967
 968static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
 969{
 970        struct net *net = sock_net(skb->sk);
 971        struct mpls_route __rcu **platform_label;
 972        size_t platform_labels;
 973        unsigned int index;
 974
 975        ASSERT_RTNL();
 976
 977        index = cb->args[0];
 978        if (index < MPLS_LABEL_FIRST_UNRESERVED)
 979                index = MPLS_LABEL_FIRST_UNRESERVED;
 980
 981        platform_label = rtnl_dereference(net->mpls.platform_label);
 982        platform_labels = net->mpls.platform_labels;
 983        for (; index < platform_labels; index++) {
 984                struct mpls_route *rt;
 985                rt = rtnl_dereference(platform_label[index]);
 986                if (!rt)
 987                        continue;
 988
 989                if (mpls_dump_route(skb, NETLINK_CB(cb->skb).portid,
 990                                    cb->nlh->nlmsg_seq, RTM_NEWROUTE,
 991                                    index, rt, NLM_F_MULTI) < 0)
 992                        break;
 993        }
 994        cb->args[0] = index;
 995
 996        return skb->len;
 997}
 998
 999static inline size_t lfib_nlmsg_size(struct mpls_route *rt)
1000{
1001        size_t payload =
1002                NLMSG_ALIGN(sizeof(struct rtmsg))
1003                + nla_total_size(2 + rt->rt_via_alen)   /* RTA_VIA */
1004                + nla_total_size(4);                    /* RTA_DST */
1005        if (rt->rt_labels)                              /* RTA_NEWDST */
1006                payload += nla_total_size(rt->rt_labels * 4);
1007        if (rt->rt_dev)                                 /* RTA_OIF */
1008                payload += nla_total_size(4);
1009        return payload;
1010}
1011
1012static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
1013                       struct nlmsghdr *nlh, struct net *net, u32 portid,
1014                       unsigned int nlm_flags)
1015{
1016        struct sk_buff *skb;
1017        u32 seq = nlh ? nlh->nlmsg_seq : 0;
1018        int err = -ENOBUFS;
1019
1020        skb = nlmsg_new(lfib_nlmsg_size(rt), GFP_KERNEL);
1021        if (skb == NULL)
1022                goto errout;
1023
1024        err = mpls_dump_route(skb, portid, seq, event, label, rt, nlm_flags);
1025        if (err < 0) {
1026                /* -EMSGSIZE implies BUG in lfib_nlmsg_size */
1027                WARN_ON(err == -EMSGSIZE);
1028                kfree_skb(skb);
1029                goto errout;
1030        }
1031        rtnl_notify(skb, net, portid, RTNLGRP_MPLS_ROUTE, nlh, GFP_KERNEL);
1032
1033        return;
1034errout:
1035        if (err < 0)
1036                rtnl_set_sk_err(net, RTNLGRP_MPLS_ROUTE, err);
1037}
1038
1039static int resize_platform_label_table(struct net *net, size_t limit)
1040{
1041        size_t size = sizeof(struct mpls_route *) * limit;
1042        size_t old_limit;
1043        size_t cp_size;
1044        struct mpls_route __rcu **labels = NULL, **old;
1045        struct mpls_route *rt0 = NULL, *rt2 = NULL;
1046        unsigned index;
1047
1048        if (size) {
1049                labels = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
1050                if (!labels)
1051                        labels = vzalloc(size);
1052
1053                if (!labels)
1054                        goto nolabels;
1055        }
1056
1057        /* In case the predefined labels need to be populated */
1058        if (limit > MPLS_LABEL_IPV4NULL) {
1059                struct net_device *lo = net->loopback_dev;
1060                rt0 = mpls_rt_alloc(lo->addr_len);
1061                if (!rt0)
1062                        goto nort0;
1063                RCU_INIT_POINTER(rt0->rt_dev, lo);
1064                rt0->rt_protocol = RTPROT_KERNEL;
1065                rt0->rt_payload_type = MPT_IPV4;
1066                rt0->rt_via_table = NEIGH_LINK_TABLE;
1067                memcpy(rt0->rt_via, lo->dev_addr, lo->addr_len);
1068        }
1069        if (limit > MPLS_LABEL_IPV6NULL) {
1070                struct net_device *lo = net->loopback_dev;
1071                rt2 = mpls_rt_alloc(lo->addr_len);
1072                if (!rt2)
1073                        goto nort2;
1074                RCU_INIT_POINTER(rt2->rt_dev, lo);
1075                rt2->rt_protocol = RTPROT_KERNEL;
1076                rt2->rt_payload_type = MPT_IPV6;
1077                rt2->rt_via_table = NEIGH_LINK_TABLE;
1078                memcpy(rt2->rt_via, lo->dev_addr, lo->addr_len);
1079        }
1080
1081        rtnl_lock();
1082        /* Remember the original table */
1083        old = rtnl_dereference(net->mpls.platform_label);
1084        old_limit = net->mpls.platform_labels;
1085
1086        /* Free any labels beyond the new table */
1087        for (index = limit; index < old_limit; index++)
1088                mpls_route_update(net, index, NULL, NULL, NULL);
1089
1090        /* Copy over the old labels */
1091        cp_size = size;
1092        if (old_limit < limit)
1093                cp_size = old_limit * sizeof(struct mpls_route *);
1094
1095        memcpy(labels, old, cp_size);
1096
1097        /* If needed set the predefined labels */
1098        if ((old_limit <= MPLS_LABEL_IPV6NULL) &&
1099            (limit > MPLS_LABEL_IPV6NULL)) {
1100                RCU_INIT_POINTER(labels[MPLS_LABEL_IPV6NULL], rt2);
1101                rt2 = NULL;
1102        }
1103
1104        if ((old_limit <= MPLS_LABEL_IPV4NULL) &&
1105            (limit > MPLS_LABEL_IPV4NULL)) {
1106                RCU_INIT_POINTER(labels[MPLS_LABEL_IPV4NULL], rt0);
1107                rt0 = NULL;
1108        }
1109
1110        /* Update the global pointers */
1111        net->mpls.platform_labels = limit;
1112        rcu_assign_pointer(net->mpls.platform_label, labels);
1113
1114        rtnl_unlock();
1115
1116        mpls_rt_free(rt2);
1117        mpls_rt_free(rt0);
1118
1119        if (old) {
1120                synchronize_rcu();
1121                kvfree(old);
1122        }
1123        return 0;
1124
1125nort2:
1126        mpls_rt_free(rt0);
1127nort0:
1128        kvfree(labels);
1129nolabels:
1130        return -ENOMEM;
1131}
1132
1133static int mpls_platform_labels(struct ctl_table *table, int write,
1134                                void __user *buffer, size_t *lenp, loff_t *ppos)
1135{
1136        struct net *net = table->data;
1137        int platform_labels = net->mpls.platform_labels;
1138        int ret;
1139        struct ctl_table tmp = {
1140                .procname       = table->procname,
1141                .data           = &platform_labels,
1142                .maxlen         = sizeof(int),
1143                .mode           = table->mode,
1144                .extra1         = &zero,
1145                .extra2         = &label_limit,
1146        };
1147
1148        ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos);
1149
1150        if (write && ret == 0)
1151                ret = resize_platform_label_table(net, platform_labels);
1152
1153        return ret;
1154}
1155
1156static const struct ctl_table mpls_table[] = {
1157        {
1158                .procname       = "platform_labels",
1159                .data           = NULL,
1160                .maxlen         = sizeof(int),
1161                .mode           = 0644,
1162                .proc_handler   = mpls_platform_labels,
1163        },
1164        { }
1165};
1166
1167static int mpls_net_init(struct net *net)
1168{
1169        struct ctl_table *table;
1170
1171        net->mpls.platform_labels = 0;
1172        net->mpls.platform_label = NULL;
1173
1174        table = kmemdup(mpls_table, sizeof(mpls_table), GFP_KERNEL);
1175        if (table == NULL)
1176                return -ENOMEM;
1177
1178        table[0].data = net;
1179        net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
1180        if (net->mpls.ctl == NULL) {
1181                kfree(table);
1182                return -ENOMEM;
1183        }
1184
1185        return 0;
1186}
1187
1188static void mpls_net_exit(struct net *net)
1189{
1190        struct mpls_route __rcu **platform_label;
1191        size_t platform_labels;
1192        struct ctl_table *table;
1193        unsigned int index;
1194
1195        table = net->mpls.ctl->ctl_table_arg;
1196        unregister_net_sysctl_table(net->mpls.ctl);
1197        kfree(table);
1198
1199        /* An rcu grace period has passed since there was a device in
1200         * the network namespace (and thus the last in flight packet)
1201         * left this network namespace.  This is because
1202         * unregister_netdevice_many and netdev_run_todo has completed
1203         * for each network device that was in this network namespace.
1204         *
1205         * As such no additional rcu synchronization is necessary when
1206         * freeing the platform_label table.
1207         */
1208        rtnl_lock();
1209        platform_label = rtnl_dereference(net->mpls.platform_label);
1210        platform_labels = net->mpls.platform_labels;
1211        for (index = 0; index < platform_labels; index++) {
1212                struct mpls_route *rt = rtnl_dereference(platform_label[index]);
1213                RCU_INIT_POINTER(platform_label[index], NULL);
1214                mpls_rt_free(rt);
1215        }
1216        rtnl_unlock();
1217
1218        kvfree(platform_label);
1219}
1220
1221static struct pernet_operations mpls_net_ops = {
1222        .init = mpls_net_init,
1223        .exit = mpls_net_exit,
1224};
1225
1226static int __init mpls_init(void)
1227{
1228        int err;
1229
1230        BUILD_BUG_ON(sizeof(struct mpls_shim_hdr) != 4);
1231
1232        err = register_pernet_subsys(&mpls_net_ops);
1233        if (err)
1234                goto out;
1235
1236        err = register_netdevice_notifier(&mpls_dev_notifier);
1237        if (err)
1238                goto out_unregister_pernet;
1239
1240        dev_add_pack(&mpls_packet_type);
1241
1242        rtnl_register(PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL, NULL);
1243        rtnl_register(PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL, NULL);
1244        rtnl_register(PF_MPLS, RTM_GETROUTE, NULL, mpls_dump_routes, NULL);
1245        err = 0;
1246out:
1247        return err;
1248
1249out_unregister_pernet:
1250        unregister_pernet_subsys(&mpls_net_ops);
1251        goto out;
1252}
1253module_init(mpls_init);
1254
1255static void __exit mpls_exit(void)
1256{
1257        rtnl_unregister_all(PF_MPLS);
1258        dev_remove_pack(&mpls_packet_type);
1259        unregister_netdevice_notifier(&mpls_dev_notifier);
1260        unregister_pernet_subsys(&mpls_net_ops);
1261}
1262module_exit(mpls_exit);
1263
1264MODULE_DESCRIPTION("MultiProtocol Label Switching");
1265MODULE_LICENSE("GPL v2");
1266MODULE_ALIAS_NETPROTO(PF_MPLS);
1267