linux/net/mpls/af_mpls.c
<<
>>
Prefs
   1#include <linux/types.h>
   2#include <linux/skbuff.h>
   3#include <linux/socket.h>
   4#include <linux/sysctl.h>
   5#include <linux/net.h>
   6#include <linux/module.h>
   7#include <linux/if_arp.h>
   8#include <linux/ipv6.h>
   9#include <linux/mpls.h>
  10#include <linux/vmalloc.h>
  11#include <net/ip.h>
  12#include <net/dst.h>
  13#include <net/sock.h>
  14#include <net/arp.h>
  15#include <net/ip_fib.h>
  16#include <net/netevent.h>
  17#include <net/netns/generic.h>
  18#if IS_ENABLED(CONFIG_IPV6)
  19#include <net/ipv6.h>
  20#include <net/addrconf.h>
  21#endif
  22#include <net/nexthop.h>
  23#include "internal.h"
  24
  25/* Maximum number of labels to look ahead at when selecting a path of
  26 * a multipath route
  27 */
  28#define MAX_MP_SELECT_LABELS 4
  29
  30#define MPLS_NEIGH_TABLE_UNSPEC (NEIGH_LINK_TABLE + 1)
  31
  32static int zero = 0;
  33static int label_limit = (1 << 20) - 1;
  34
  35static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
  36                       struct nlmsghdr *nlh, struct net *net, u32 portid,
  37                       unsigned int nlm_flags);
  38
  39static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
  40{
  41        struct mpls_route *rt = NULL;
  42
  43        if (index < net->mpls.platform_labels) {
  44                struct mpls_route __rcu **platform_label =
  45                        rcu_dereference(net->mpls.platform_label);
  46                rt = rcu_dereference(platform_label[index]);
  47        }
  48        return rt;
  49}
  50
  51static inline struct mpls_dev *mpls_dev_get(const struct net_device *dev)
  52{
  53        return rcu_dereference_rtnl(dev->mpls_ptr);
  54}
  55
  56bool mpls_output_possible(const struct net_device *dev)
  57{
  58        return dev && (dev->flags & IFF_UP) && netif_carrier_ok(dev);
  59}
  60EXPORT_SYMBOL_GPL(mpls_output_possible);
  61
  62static u8 *__mpls_nh_via(struct mpls_route *rt, struct mpls_nh *nh)
  63{
  64        u8 *nh0_via = PTR_ALIGN((u8 *)&rt->rt_nh[rt->rt_nhn], VIA_ALEN_ALIGN);
  65        int nh_index = nh - rt->rt_nh;
  66
  67        return nh0_via + rt->rt_max_alen * nh_index;
  68}
  69
  70static const u8 *mpls_nh_via(const struct mpls_route *rt,
  71                             const struct mpls_nh *nh)
  72{
  73        return __mpls_nh_via((struct mpls_route *)rt, (struct mpls_nh *)nh);
  74}
  75
  76static unsigned int mpls_nh_header_size(const struct mpls_nh *nh)
  77{
  78        /* The size of the layer 2.5 labels to be added for this route */
  79        return nh->nh_labels * sizeof(struct mpls_shim_hdr);
  80}
  81
  82unsigned int mpls_dev_mtu(const struct net_device *dev)
  83{
  84        /* The amount of data the layer 2 frame can hold */
  85        return dev->mtu;
  86}
  87EXPORT_SYMBOL_GPL(mpls_dev_mtu);
  88
  89bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
  90{
  91        if (skb->len <= mtu)
  92                return false;
  93
  94        if (skb_is_gso(skb) && skb_gso_validate_mtu(skb, mtu))
  95                return false;
  96
  97        return true;
  98}
  99EXPORT_SYMBOL_GPL(mpls_pkt_too_big);
 100
 101static u32 mpls_multipath_hash(struct mpls_route *rt,
 102                               struct sk_buff *skb, bool bos)
 103{
 104        struct mpls_entry_decoded dec;
 105        struct mpls_shim_hdr *hdr;
 106        bool eli_seen = false;
 107        int label_index;
 108        u32 hash = 0;
 109
 110        for (label_index = 0; label_index < MAX_MP_SELECT_LABELS && !bos;
 111             label_index++) {
 112                if (!pskb_may_pull(skb, sizeof(*hdr) * label_index))
 113                        break;
 114
 115                /* Read and decode the current label */
 116                hdr = mpls_hdr(skb) + label_index;
 117                dec = mpls_entry_decode(hdr);
 118
 119                /* RFC6790 - reserved labels MUST NOT be used as keys
 120                 * for the load-balancing function
 121                 */
 122                if (likely(dec.label >= MPLS_LABEL_FIRST_UNRESERVED)) {
 123                        hash = jhash_1word(dec.label, hash);
 124
 125                        /* The entropy label follows the entropy label
 126                         * indicator, so this means that the entropy
 127                         * label was just added to the hash - no need to
 128                         * go any deeper either in the label stack or in the
 129                         * payload
 130                         */
 131                        if (eli_seen)
 132                                break;
 133                } else if (dec.label == MPLS_LABEL_ENTROPY) {
 134                        eli_seen = true;
 135                }
 136
 137                bos = dec.bos;
 138                if (bos && pskb_may_pull(skb, sizeof(*hdr) * label_index +
 139                                         sizeof(struct iphdr))) {
 140                        const struct iphdr *v4hdr;
 141
 142                        v4hdr = (const struct iphdr *)(mpls_hdr(skb) +
 143                                                       label_index);
 144                        if (v4hdr->version == 4) {
 145                                hash = jhash_3words(ntohl(v4hdr->saddr),
 146                                                    ntohl(v4hdr->daddr),
 147                                                    v4hdr->protocol, hash);
 148                        } else if (v4hdr->version == 6 &&
 149                                pskb_may_pull(skb, sizeof(*hdr) * label_index +
 150                                              sizeof(struct ipv6hdr))) {
 151                                const struct ipv6hdr *v6hdr;
 152
 153                                v6hdr = (const struct ipv6hdr *)(mpls_hdr(skb) +
 154                                                                label_index);
 155
 156                                hash = __ipv6_addr_jhash(&v6hdr->saddr, hash);
 157                                hash = __ipv6_addr_jhash(&v6hdr->daddr, hash);
 158                                hash = jhash_1word(v6hdr->nexthdr, hash);
 159                        }
 160                }
 161        }
 162
 163        return hash;
 164}
 165
 166static struct mpls_nh *mpls_select_multipath(struct mpls_route *rt,
 167                                             struct sk_buff *skb, bool bos)
 168{
 169        int alive = ACCESS_ONCE(rt->rt_nhn_alive);
 170        u32 hash = 0;
 171        int nh_index = 0;
 172        int n = 0;
 173
 174        /* No need to look further into packet if there's only
 175         * one path
 176         */
 177        if (rt->rt_nhn == 1)
 178                goto out;
 179
 180        if (alive <= 0)
 181                return NULL;
 182
 183        hash = mpls_multipath_hash(rt, skb, bos);
 184        nh_index = hash % alive;
 185        if (alive == rt->rt_nhn)
 186                goto out;
 187        for_nexthops(rt) {
 188                if (nh->nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN))
 189                        continue;
 190                if (n == nh_index)
 191                        return nh;
 192                n++;
 193        } endfor_nexthops(rt);
 194
 195out:
 196        return &rt->rt_nh[nh_index];
 197}
 198
 199static bool mpls_egress(struct mpls_route *rt, struct sk_buff *skb,
 200                        struct mpls_entry_decoded dec)
 201{
 202        enum mpls_payload_type payload_type;
 203        bool success = false;
 204
 205        /* The IPv4 code below accesses through the IPv4 header
 206         * checksum, which is 12 bytes into the packet.
 207         * The IPv6 code below accesses through the IPv6 hop limit
 208         * which is 8 bytes into the packet.
 209         *
 210         * For all supported cases there should always be at least 12
 211         * bytes of packet data present.  The IPv4 header is 20 bytes
 212         * without options and the IPv6 header is always 40 bytes
 213         * long.
 214         */
 215        if (!pskb_may_pull(skb, 12))
 216                return false;
 217
 218        payload_type = rt->rt_payload_type;
 219        if (payload_type == MPT_UNSPEC)
 220                payload_type = ip_hdr(skb)->version;
 221
 222        switch (payload_type) {
 223        case MPT_IPV4: {
 224                struct iphdr *hdr4 = ip_hdr(skb);
 225                skb->protocol = htons(ETH_P_IP);
 226                csum_replace2(&hdr4->check,
 227                              htons(hdr4->ttl << 8),
 228                              htons(dec.ttl << 8));
 229                hdr4->ttl = dec.ttl;
 230                success = true;
 231                break;
 232        }
 233        case MPT_IPV6: {
 234                struct ipv6hdr *hdr6 = ipv6_hdr(skb);
 235                skb->protocol = htons(ETH_P_IPV6);
 236                hdr6->hop_limit = dec.ttl;
 237                success = true;
 238                break;
 239        }
 240        case MPT_UNSPEC:
 241                break;
 242        }
 243
 244        return success;
 245}
 246
 247static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
 248                        struct packet_type *pt, struct net_device *orig_dev)
 249{
 250        struct net *net = dev_net(dev);
 251        struct mpls_shim_hdr *hdr;
 252        struct mpls_route *rt;
 253        struct mpls_nh *nh;
 254        struct mpls_entry_decoded dec;
 255        struct net_device *out_dev;
 256        struct mpls_dev *mdev;
 257        unsigned int hh_len;
 258        unsigned int new_header_size;
 259        unsigned int mtu;
 260        int err;
 261
 262        /* Careful this entire function runs inside of an rcu critical section */
 263
 264        mdev = mpls_dev_get(dev);
 265        if (!mdev || !mdev->input_enabled)
 266                goto drop;
 267
 268        if (skb->pkt_type != PACKET_HOST)
 269                goto drop;
 270
 271        if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL)
 272                goto drop;
 273
 274        if (!pskb_may_pull(skb, sizeof(*hdr)))
 275                goto drop;
 276
 277        /* Read and decode the label */
 278        hdr = mpls_hdr(skb);
 279        dec = mpls_entry_decode(hdr);
 280
 281        /* Pop the label */
 282        skb_pull(skb, sizeof(*hdr));
 283        skb_reset_network_header(skb);
 284
 285        skb_orphan(skb);
 286
 287        rt = mpls_route_input_rcu(net, dec.label);
 288        if (!rt)
 289                goto drop;
 290
 291        nh = mpls_select_multipath(rt, skb, dec.bos);
 292        if (!nh)
 293                goto drop;
 294
 295        /* Find the output device */
 296        out_dev = rcu_dereference(nh->nh_dev);
 297        if (!mpls_output_possible(out_dev))
 298                goto drop;
 299
 300        if (skb_warn_if_lro(skb))
 301                goto drop;
 302
 303        skb_forward_csum(skb);
 304
 305        /* Verify ttl is valid */
 306        if (dec.ttl <= 1)
 307                goto drop;
 308        dec.ttl -= 1;
 309
 310        /* Verify the destination can hold the packet */
 311        new_header_size = mpls_nh_header_size(nh);
 312        mtu = mpls_dev_mtu(out_dev);
 313        if (mpls_pkt_too_big(skb, mtu - new_header_size))
 314                goto drop;
 315
 316        hh_len = LL_RESERVED_SPACE(out_dev);
 317        if (!out_dev->header_ops)
 318                hh_len = 0;
 319
 320        /* Ensure there is enough space for the headers in the skb */
 321        if (skb_cow(skb, hh_len + new_header_size))
 322                goto drop;
 323
 324        skb->dev = out_dev;
 325        skb->protocol = htons(ETH_P_MPLS_UC);
 326
 327        if (unlikely(!new_header_size && dec.bos)) {
 328                /* Penultimate hop popping */
 329                if (!mpls_egress(rt, skb, dec))
 330                        goto drop;
 331        } else {
 332                bool bos;
 333                int i;
 334                skb_push(skb, new_header_size);
 335                skb_reset_network_header(skb);
 336                /* Push the new labels */
 337                hdr = mpls_hdr(skb);
 338                bos = dec.bos;
 339                for (i = nh->nh_labels - 1; i >= 0; i--) {
 340                        hdr[i] = mpls_entry_encode(nh->nh_label[i],
 341                                                   dec.ttl, 0, bos);
 342                        bos = false;
 343                }
 344        }
 345
 346        /* If via wasn't specified then send out using device address */
 347        if (nh->nh_via_table == MPLS_NEIGH_TABLE_UNSPEC)
 348                err = neigh_xmit(NEIGH_LINK_TABLE, out_dev,
 349                                 out_dev->dev_addr, skb);
 350        else
 351                err = neigh_xmit(nh->nh_via_table, out_dev,
 352                                 mpls_nh_via(rt, nh), skb);
 353        if (err)
 354                net_dbg_ratelimited("%s: packet transmission failed: %d\n",
 355                                    __func__, err);
 356        return 0;
 357
 358drop:
 359        kfree_skb(skb);
 360        return NET_RX_DROP;
 361}
 362
 363static struct packet_type mpls_packet_type __read_mostly = {
 364        .type = cpu_to_be16(ETH_P_MPLS_UC),
 365        .func = mpls_forward,
 366};
 367
 368static const struct nla_policy rtm_mpls_policy[RTA_MAX+1] = {
 369        [RTA_DST]               = { .type = NLA_U32 },
 370        [RTA_OIF]               = { .type = NLA_U32 },
 371};
 372
 373struct mpls_route_config {
 374        u32                     rc_protocol;
 375        u32                     rc_ifindex;
 376        u8                      rc_via_table;
 377        u8                      rc_via_alen;
 378        u8                      rc_via[MAX_VIA_ALEN];
 379        u32                     rc_label;
 380        u8                      rc_output_labels;
 381        u32                     rc_output_label[MAX_NEW_LABELS];
 382        u32                     rc_nlflags;
 383        enum mpls_payload_type  rc_payload_type;
 384        struct nl_info          rc_nlinfo;
 385        struct rtnexthop        *rc_mp;
 386        int                     rc_mp_len;
 387};
 388
 389static struct mpls_route *mpls_rt_alloc(int num_nh, u8 max_alen)
 390{
 391        u8 max_alen_aligned = ALIGN(max_alen, VIA_ALEN_ALIGN);
 392        struct mpls_route *rt;
 393
 394        rt = kzalloc(ALIGN(sizeof(*rt) + num_nh * sizeof(*rt->rt_nh),
 395                           VIA_ALEN_ALIGN) +
 396                     num_nh * max_alen_aligned,
 397                     GFP_KERNEL);
 398        if (rt) {
 399                rt->rt_nhn = num_nh;
 400                rt->rt_nhn_alive = num_nh;
 401                rt->rt_max_alen = max_alen_aligned;
 402        }
 403
 404        return rt;
 405}
 406
 407static void mpls_rt_free(struct mpls_route *rt)
 408{
 409        if (rt)
 410                kfree_rcu(rt, rt_rcu);
 411}
 412
 413static void mpls_notify_route(struct net *net, unsigned index,
 414                              struct mpls_route *old, struct mpls_route *new,
 415                              const struct nl_info *info)
 416{
 417        struct nlmsghdr *nlh = info ? info->nlh : NULL;
 418        unsigned portid = info ? info->portid : 0;
 419        int event = new ? RTM_NEWROUTE : RTM_DELROUTE;
 420        struct mpls_route *rt = new ? new : old;
 421        unsigned nlm_flags = (old && new) ? NLM_F_REPLACE : 0;
 422        /* Ignore reserved labels for now */
 423        if (rt && (index >= MPLS_LABEL_FIRST_UNRESERVED))
 424                rtmsg_lfib(event, index, rt, nlh, net, portid, nlm_flags);
 425}
 426
 427static void mpls_route_update(struct net *net, unsigned index,
 428                              struct mpls_route *new,
 429                              const struct nl_info *info)
 430{
 431        struct mpls_route __rcu **platform_label;
 432        struct mpls_route *rt;
 433
 434        ASSERT_RTNL();
 435
 436        platform_label = rtnl_dereference(net->mpls.platform_label);
 437        rt = rtnl_dereference(platform_label[index]);
 438        rcu_assign_pointer(platform_label[index], new);
 439
 440        mpls_notify_route(net, index, rt, new, info);
 441
 442        /* If we removed a route free it now */
 443        mpls_rt_free(rt);
 444}
 445
 446static unsigned find_free_label(struct net *net)
 447{
 448        struct mpls_route __rcu **platform_label;
 449        size_t platform_labels;
 450        unsigned index;
 451
 452        platform_label = rtnl_dereference(net->mpls.platform_label);
 453        platform_labels = net->mpls.platform_labels;
 454        for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels;
 455             index++) {
 456                if (!rtnl_dereference(platform_label[index]))
 457                        return index;
 458        }
 459        return LABEL_NOT_SPECIFIED;
 460}
 461
 462#if IS_ENABLED(CONFIG_INET)
 463static struct net_device *inet_fib_lookup_dev(struct net *net,
 464                                              const void *addr)
 465{
 466        struct net_device *dev;
 467        struct rtable *rt;
 468        struct in_addr daddr;
 469
 470        memcpy(&daddr, addr, sizeof(struct in_addr));
 471        rt = ip_route_output(net, daddr.s_addr, 0, 0, 0);
 472        if (IS_ERR(rt))
 473                return ERR_CAST(rt);
 474
 475        dev = rt->dst.dev;
 476        dev_hold(dev);
 477
 478        ip_rt_put(rt);
 479
 480        return dev;
 481}
 482#else
 483static struct net_device *inet_fib_lookup_dev(struct net *net,
 484                                              const void *addr)
 485{
 486        return ERR_PTR(-EAFNOSUPPORT);
 487}
 488#endif
 489
 490#if IS_ENABLED(CONFIG_IPV6)
 491static struct net_device *inet6_fib_lookup_dev(struct net *net,
 492                                               const void *addr)
 493{
 494        struct net_device *dev;
 495        struct dst_entry *dst;
 496        struct flowi6 fl6;
 497        int err;
 498
 499        if (!ipv6_stub)
 500                return ERR_PTR(-EAFNOSUPPORT);
 501
 502        memset(&fl6, 0, sizeof(fl6));
 503        memcpy(&fl6.daddr, addr, sizeof(struct in6_addr));
 504        err = ipv6_stub->ipv6_dst_lookup(net, NULL, &dst, &fl6);
 505        if (err)
 506                return ERR_PTR(err);
 507
 508        dev = dst->dev;
 509        dev_hold(dev);
 510        dst_release(dst);
 511
 512        return dev;
 513}
 514#else
 515static struct net_device *inet6_fib_lookup_dev(struct net *net,
 516                                               const void *addr)
 517{
 518        return ERR_PTR(-EAFNOSUPPORT);
 519}
 520#endif
 521
 522static struct net_device *find_outdev(struct net *net,
 523                                      struct mpls_route *rt,
 524                                      struct mpls_nh *nh, int oif)
 525{
 526        struct net_device *dev = NULL;
 527
 528        if (!oif) {
 529                switch (nh->nh_via_table) {
 530                case NEIGH_ARP_TABLE:
 531                        dev = inet_fib_lookup_dev(net, mpls_nh_via(rt, nh));
 532                        break;
 533                case NEIGH_ND_TABLE:
 534                        dev = inet6_fib_lookup_dev(net, mpls_nh_via(rt, nh));
 535                        break;
 536                case NEIGH_LINK_TABLE:
 537                        break;
 538                }
 539        } else {
 540                dev = dev_get_by_index(net, oif);
 541        }
 542
 543        if (!dev)
 544                return ERR_PTR(-ENODEV);
 545
 546        if (IS_ERR(dev))
 547                return dev;
 548
 549        /* The caller is holding rtnl anyways, so release the dev reference */
 550        dev_put(dev);
 551
 552        return dev;
 553}
 554
 555static int mpls_nh_assign_dev(struct net *net, struct mpls_route *rt,
 556                              struct mpls_nh *nh, int oif)
 557{
 558        struct net_device *dev = NULL;
 559        int err = -ENODEV;
 560
 561        dev = find_outdev(net, rt, nh, oif);
 562        if (IS_ERR(dev)) {
 563                err = PTR_ERR(dev);
 564                dev = NULL;
 565                goto errout;
 566        }
 567
 568        /* Ensure this is a supported device */
 569        err = -EINVAL;
 570        if (!mpls_dev_get(dev))
 571                goto errout;
 572
 573        if ((nh->nh_via_table == NEIGH_LINK_TABLE) &&
 574            (dev->addr_len != nh->nh_via_alen))
 575                goto errout;
 576
 577        RCU_INIT_POINTER(nh->nh_dev, dev);
 578
 579        if (!(dev->flags & IFF_UP)) {
 580                nh->nh_flags |= RTNH_F_DEAD;
 581        } else {
 582                unsigned int flags;
 583
 584                flags = dev_get_flags(dev);
 585                if (!(flags & (IFF_RUNNING | IFF_LOWER_UP)))
 586                        nh->nh_flags |= RTNH_F_LINKDOWN;
 587        }
 588
 589        return 0;
 590
 591errout:
 592        return err;
 593}
 594
 595static int mpls_nh_build_from_cfg(struct mpls_route_config *cfg,
 596                                  struct mpls_route *rt)
 597{
 598        struct net *net = cfg->rc_nlinfo.nl_net;
 599        struct mpls_nh *nh = rt->rt_nh;
 600        int err;
 601        int i;
 602
 603        if (!nh)
 604                return -ENOMEM;
 605
 606        err = -EINVAL;
 607        /* Ensure only a supported number of labels are present */
 608        if (cfg->rc_output_labels > MAX_NEW_LABELS)
 609                goto errout;
 610
 611        nh->nh_labels = cfg->rc_output_labels;
 612        for (i = 0; i < nh->nh_labels; i++)
 613                nh->nh_label[i] = cfg->rc_output_label[i];
 614
 615        nh->nh_via_table = cfg->rc_via_table;
 616        memcpy(__mpls_nh_via(rt, nh), cfg->rc_via, cfg->rc_via_alen);
 617        nh->nh_via_alen = cfg->rc_via_alen;
 618
 619        err = mpls_nh_assign_dev(net, rt, nh, cfg->rc_ifindex);
 620        if (err)
 621                goto errout;
 622
 623        if (nh->nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN))
 624                rt->rt_nhn_alive--;
 625
 626        return 0;
 627
 628errout:
 629        return err;
 630}
 631
 632static int mpls_nh_build(struct net *net, struct mpls_route *rt,
 633                         struct mpls_nh *nh, int oif, struct nlattr *via,
 634                         struct nlattr *newdst)
 635{
 636        int err = -ENOMEM;
 637
 638        if (!nh)
 639                goto errout;
 640
 641        if (newdst) {
 642                err = nla_get_labels(newdst, MAX_NEW_LABELS,
 643                                     &nh->nh_labels, nh->nh_label);
 644                if (err)
 645                        goto errout;
 646        }
 647
 648        if (via) {
 649                err = nla_get_via(via, &nh->nh_via_alen, &nh->nh_via_table,
 650                                  __mpls_nh_via(rt, nh));
 651                if (err)
 652                        goto errout;
 653        } else {
 654                nh->nh_via_table = MPLS_NEIGH_TABLE_UNSPEC;
 655        }
 656
 657        err = mpls_nh_assign_dev(net, rt, nh, oif);
 658        if (err)
 659                goto errout;
 660
 661        return 0;
 662
 663errout:
 664        return err;
 665}
 666
 667static int mpls_count_nexthops(struct rtnexthop *rtnh, int len,
 668                               u8 cfg_via_alen, u8 *max_via_alen)
 669{
 670        int nhs = 0;
 671        int remaining = len;
 672
 673        if (!rtnh) {
 674                *max_via_alen = cfg_via_alen;
 675                return 1;
 676        }
 677
 678        *max_via_alen = 0;
 679
 680        while (rtnh_ok(rtnh, remaining)) {
 681                struct nlattr *nla, *attrs = rtnh_attrs(rtnh);
 682                int attrlen;
 683
 684                attrlen = rtnh_attrlen(rtnh);
 685                nla = nla_find(attrs, attrlen, RTA_VIA);
 686                if (nla && nla_len(nla) >=
 687                    offsetof(struct rtvia, rtvia_addr)) {
 688                        int via_alen = nla_len(nla) -
 689                                offsetof(struct rtvia, rtvia_addr);
 690
 691                        if (via_alen <= MAX_VIA_ALEN)
 692                                *max_via_alen = max_t(u16, *max_via_alen,
 693                                                      via_alen);
 694                }
 695
 696                nhs++;
 697                rtnh = rtnh_next(rtnh, &remaining);
 698        }
 699
 700        /* leftover implies invalid nexthop configuration, discard it */
 701        return remaining > 0 ? 0 : nhs;
 702}
 703
 704static int mpls_nh_build_multi(struct mpls_route_config *cfg,
 705                               struct mpls_route *rt)
 706{
 707        struct rtnexthop *rtnh = cfg->rc_mp;
 708        struct nlattr *nla_via, *nla_newdst;
 709        int remaining = cfg->rc_mp_len;
 710        int nhs = 0;
 711        int err = 0;
 712
 713        change_nexthops(rt) {
 714                int attrlen;
 715
 716                nla_via = NULL;
 717                nla_newdst = NULL;
 718
 719                err = -EINVAL;
 720                if (!rtnh_ok(rtnh, remaining))
 721                        goto errout;
 722
 723                /* neither weighted multipath nor any flags
 724                 * are supported
 725                 */
 726                if (rtnh->rtnh_hops || rtnh->rtnh_flags)
 727                        goto errout;
 728
 729                attrlen = rtnh_attrlen(rtnh);
 730                if (attrlen > 0) {
 731                        struct nlattr *attrs = rtnh_attrs(rtnh);
 732
 733                        nla_via = nla_find(attrs, attrlen, RTA_VIA);
 734                        nla_newdst = nla_find(attrs, attrlen, RTA_NEWDST);
 735                }
 736
 737                err = mpls_nh_build(cfg->rc_nlinfo.nl_net, rt, nh,
 738                                    rtnh->rtnh_ifindex, nla_via, nla_newdst);
 739                if (err)
 740                        goto errout;
 741
 742                if (nh->nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN))
 743                        rt->rt_nhn_alive--;
 744
 745                rtnh = rtnh_next(rtnh, &remaining);
 746                nhs++;
 747        } endfor_nexthops(rt);
 748
 749        rt->rt_nhn = nhs;
 750
 751        return 0;
 752
 753errout:
 754        return err;
 755}
 756
 757static int mpls_route_add(struct mpls_route_config *cfg)
 758{
 759        struct mpls_route __rcu **platform_label;
 760        struct net *net = cfg->rc_nlinfo.nl_net;
 761        struct mpls_route *rt, *old;
 762        int err = -EINVAL;
 763        u8 max_via_alen;
 764        unsigned index;
 765        int nhs;
 766
 767        index = cfg->rc_label;
 768
 769        /* If a label was not specified during insert pick one */
 770        if ((index == LABEL_NOT_SPECIFIED) &&
 771            (cfg->rc_nlflags & NLM_F_CREATE)) {
 772                index = find_free_label(net);
 773        }
 774
 775        /* Reserved labels may not be set */
 776        if (index < MPLS_LABEL_FIRST_UNRESERVED)
 777                goto errout;
 778
 779        /* The full 20 bit range may not be supported. */
 780        if (index >= net->mpls.platform_labels)
 781                goto errout;
 782
 783        /* Append makes no sense with mpls */
 784        err = -EOPNOTSUPP;
 785        if (cfg->rc_nlflags & NLM_F_APPEND)
 786                goto errout;
 787
 788        err = -EEXIST;
 789        platform_label = rtnl_dereference(net->mpls.platform_label);
 790        old = rtnl_dereference(platform_label[index]);
 791        if ((cfg->rc_nlflags & NLM_F_EXCL) && old)
 792                goto errout;
 793
 794        err = -EEXIST;
 795        if (!(cfg->rc_nlflags & NLM_F_REPLACE) && old)
 796                goto errout;
 797
 798        err = -ENOENT;
 799        if (!(cfg->rc_nlflags & NLM_F_CREATE) && !old)
 800                goto errout;
 801
 802        err = -EINVAL;
 803        nhs = mpls_count_nexthops(cfg->rc_mp, cfg->rc_mp_len,
 804                                  cfg->rc_via_alen, &max_via_alen);
 805        if (nhs == 0)
 806                goto errout;
 807
 808        err = -ENOMEM;
 809        rt = mpls_rt_alloc(nhs, max_via_alen);
 810        if (!rt)
 811                goto errout;
 812
 813        rt->rt_protocol = cfg->rc_protocol;
 814        rt->rt_payload_type = cfg->rc_payload_type;
 815
 816        if (cfg->rc_mp)
 817                err = mpls_nh_build_multi(cfg, rt);
 818        else
 819                err = mpls_nh_build_from_cfg(cfg, rt);
 820        if (err)
 821                goto freert;
 822
 823        mpls_route_update(net, index, rt, &cfg->rc_nlinfo);
 824
 825        return 0;
 826
 827freert:
 828        mpls_rt_free(rt);
 829errout:
 830        return err;
 831}
 832
 833static int mpls_route_del(struct mpls_route_config *cfg)
 834{
 835        struct net *net = cfg->rc_nlinfo.nl_net;
 836        unsigned index;
 837        int err = -EINVAL;
 838
 839        index = cfg->rc_label;
 840
 841        /* Reserved labels may not be removed */
 842        if (index < MPLS_LABEL_FIRST_UNRESERVED)
 843                goto errout;
 844
 845        /* The full 20 bit range may not be supported */
 846        if (index >= net->mpls.platform_labels)
 847                goto errout;
 848
 849        mpls_route_update(net, index, NULL, &cfg->rc_nlinfo);
 850
 851        err = 0;
 852errout:
 853        return err;
 854}
 855
 856#define MPLS_PERDEV_SYSCTL_OFFSET(field)        \
 857        (&((struct mpls_dev *)0)->field)
 858
 859static const struct ctl_table mpls_dev_table[] = {
 860        {
 861                .procname       = "input",
 862                .maxlen         = sizeof(int),
 863                .mode           = 0644,
 864                .proc_handler   = proc_dointvec,
 865                .data           = MPLS_PERDEV_SYSCTL_OFFSET(input_enabled),
 866        },
 867        { }
 868};
 869
 870static int mpls_dev_sysctl_register(struct net_device *dev,
 871                                    struct mpls_dev *mdev)
 872{
 873        char path[sizeof("net/mpls/conf/") + IFNAMSIZ];
 874        struct ctl_table *table;
 875        int i;
 876
 877        table = kmemdup(&mpls_dev_table, sizeof(mpls_dev_table), GFP_KERNEL);
 878        if (!table)
 879                goto out;
 880
 881        /* Table data contains only offsets relative to the base of
 882         * the mdev at this point, so make them absolute.
 883         */
 884        for (i = 0; i < ARRAY_SIZE(mpls_dev_table); i++)
 885                table[i].data = (char *)mdev + (uintptr_t)table[i].data;
 886
 887        snprintf(path, sizeof(path), "net/mpls/conf/%s", dev->name);
 888
 889        mdev->sysctl = register_net_sysctl(dev_net(dev), path, table);
 890        if (!mdev->sysctl)
 891                goto free;
 892
 893        return 0;
 894
 895free:
 896        kfree(table);
 897out:
 898        return -ENOBUFS;
 899}
 900
 901static void mpls_dev_sysctl_unregister(struct mpls_dev *mdev)
 902{
 903        struct ctl_table *table;
 904
 905        table = mdev->sysctl->ctl_table_arg;
 906        unregister_net_sysctl_table(mdev->sysctl);
 907        kfree(table);
 908}
 909
 910static struct mpls_dev *mpls_add_dev(struct net_device *dev)
 911{
 912        struct mpls_dev *mdev;
 913        int err = -ENOMEM;
 914
 915        ASSERT_RTNL();
 916
 917        mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
 918        if (!mdev)
 919                return ERR_PTR(err);
 920
 921        err = mpls_dev_sysctl_register(dev, mdev);
 922        if (err)
 923                goto free;
 924
 925        rcu_assign_pointer(dev->mpls_ptr, mdev);
 926
 927        return mdev;
 928
 929free:
 930        kfree(mdev);
 931        return ERR_PTR(err);
 932}
 933
 934static void mpls_ifdown(struct net_device *dev, int event)
 935{
 936        struct mpls_route __rcu **platform_label;
 937        struct net *net = dev_net(dev);
 938        unsigned index;
 939
 940        platform_label = rtnl_dereference(net->mpls.platform_label);
 941        for (index = 0; index < net->mpls.platform_labels; index++) {
 942                struct mpls_route *rt = rtnl_dereference(platform_label[index]);
 943
 944                if (!rt)
 945                        continue;
 946
 947                change_nexthops(rt) {
 948                        if (rtnl_dereference(nh->nh_dev) != dev)
 949                                continue;
 950                        switch (event) {
 951                        case NETDEV_DOWN:
 952                        case NETDEV_UNREGISTER:
 953                                nh->nh_flags |= RTNH_F_DEAD;
 954                                /* fall through */
 955                        case NETDEV_CHANGE:
 956                                nh->nh_flags |= RTNH_F_LINKDOWN;
 957                                ACCESS_ONCE(rt->rt_nhn_alive) = rt->rt_nhn_alive - 1;
 958                                break;
 959                        }
 960                        if (event == NETDEV_UNREGISTER)
 961                                RCU_INIT_POINTER(nh->nh_dev, NULL);
 962                } endfor_nexthops(rt);
 963        }
 964}
 965
 966static void mpls_ifup(struct net_device *dev, unsigned int nh_flags)
 967{
 968        struct mpls_route __rcu **platform_label;
 969        struct net *net = dev_net(dev);
 970        unsigned index;
 971        int alive;
 972
 973        platform_label = rtnl_dereference(net->mpls.platform_label);
 974        for (index = 0; index < net->mpls.platform_labels; index++) {
 975                struct mpls_route *rt = rtnl_dereference(platform_label[index]);
 976
 977                if (!rt)
 978                        continue;
 979
 980                alive = 0;
 981                change_nexthops(rt) {
 982                        struct net_device *nh_dev =
 983                                rtnl_dereference(nh->nh_dev);
 984
 985                        if (!(nh->nh_flags & nh_flags)) {
 986                                alive++;
 987                                continue;
 988                        }
 989                        if (nh_dev != dev)
 990                                continue;
 991                        alive++;
 992                        nh->nh_flags &= ~nh_flags;
 993                } endfor_nexthops(rt);
 994
 995                ACCESS_ONCE(rt->rt_nhn_alive) = alive;
 996        }
 997}
 998
 999static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
1000                           void *ptr)
1001{
1002        struct net_device *dev = netdev_notifier_info_to_dev(ptr);
1003        struct mpls_dev *mdev;
1004        unsigned int flags;
1005
1006        if (event == NETDEV_REGISTER) {
1007                /* For now just support Ethernet, IPGRE, SIT and IPIP devices */
1008                if (dev->type == ARPHRD_ETHER ||
1009                    dev->type == ARPHRD_LOOPBACK ||
1010                    dev->type == ARPHRD_IPGRE ||
1011                    dev->type == ARPHRD_SIT ||
1012                    dev->type == ARPHRD_TUNNEL) {
1013                        mdev = mpls_add_dev(dev);
1014                        if (IS_ERR(mdev))
1015                                return notifier_from_errno(PTR_ERR(mdev));
1016                }
1017                return NOTIFY_OK;
1018        }
1019
1020        mdev = mpls_dev_get(dev);
1021        if (!mdev)
1022                return NOTIFY_OK;
1023
1024        switch (event) {
1025        case NETDEV_DOWN:
1026                mpls_ifdown(dev, event);
1027                break;
1028        case NETDEV_UP:
1029                flags = dev_get_flags(dev);
1030                if (flags & (IFF_RUNNING | IFF_LOWER_UP))
1031                        mpls_ifup(dev, RTNH_F_DEAD | RTNH_F_LINKDOWN);
1032                else
1033                        mpls_ifup(dev, RTNH_F_DEAD);
1034                break;
1035        case NETDEV_CHANGE:
1036                flags = dev_get_flags(dev);
1037                if (flags & (IFF_RUNNING | IFF_LOWER_UP))
1038                        mpls_ifup(dev, RTNH_F_DEAD | RTNH_F_LINKDOWN);
1039                else
1040                        mpls_ifdown(dev, event);
1041                break;
1042        case NETDEV_UNREGISTER:
1043                mpls_ifdown(dev, event);
1044                mdev = mpls_dev_get(dev);
1045                if (mdev) {
1046                        mpls_dev_sysctl_unregister(mdev);
1047                        RCU_INIT_POINTER(dev->mpls_ptr, NULL);
1048                        kfree_rcu(mdev, rcu);
1049                }
1050                break;
1051        case NETDEV_CHANGENAME:
1052                mdev = mpls_dev_get(dev);
1053                if (mdev) {
1054                        int err;
1055
1056                        mpls_dev_sysctl_unregister(mdev);
1057                        err = mpls_dev_sysctl_register(dev, mdev);
1058                        if (err)
1059                                return notifier_from_errno(err);
1060                }
1061                break;
1062        }
1063        return NOTIFY_OK;
1064}
1065
1066static struct notifier_block mpls_dev_notifier = {
1067        .notifier_call = mpls_dev_notify,
1068};
1069
1070static int nla_put_via(struct sk_buff *skb,
1071                       u8 table, const void *addr, int alen)
1072{
1073        static const int table_to_family[NEIGH_NR_TABLES + 1] = {
1074                AF_INET, AF_INET6, AF_DECnet, AF_PACKET,
1075        };
1076        struct nlattr *nla;
1077        struct rtvia *via;
1078        int family = AF_UNSPEC;
1079
1080        nla = nla_reserve(skb, RTA_VIA, alen + 2);
1081        if (!nla)
1082                return -EMSGSIZE;
1083
1084        if (table <= NEIGH_NR_TABLES)
1085                family = table_to_family[table];
1086
1087        via = nla_data(nla);
1088        via->rtvia_family = family;
1089        memcpy(via->rtvia_addr, addr, alen);
1090        return 0;
1091}
1092
1093int nla_put_labels(struct sk_buff *skb, int attrtype,
1094                   u8 labels, const u32 label[])
1095{
1096        struct nlattr *nla;
1097        struct mpls_shim_hdr *nla_label;
1098        bool bos;
1099        int i;
1100        nla = nla_reserve(skb, attrtype, labels*4);
1101        if (!nla)
1102                return -EMSGSIZE;
1103
1104        nla_label = nla_data(nla);
1105        bos = true;
1106        for (i = labels - 1; i >= 0; i--) {
1107                nla_label[i] = mpls_entry_encode(label[i], 0, 0, bos);
1108                bos = false;
1109        }
1110
1111        return 0;
1112}
1113EXPORT_SYMBOL_GPL(nla_put_labels);
1114
1115int nla_get_labels(const struct nlattr *nla,
1116                   u32 max_labels, u8 *labels, u32 label[])
1117{
1118        unsigned len = nla_len(nla);
1119        unsigned nla_labels;
1120        struct mpls_shim_hdr *nla_label;
1121        bool bos;
1122        int i;
1123
1124        /* len needs to be an even multiple of 4 (the label size) */
1125        if (len & 3)
1126                return -EINVAL;
1127
1128        /* Limit the number of new labels allowed */
1129        nla_labels = len/4;
1130        if (nla_labels > max_labels)
1131                return -EINVAL;
1132
1133        nla_label = nla_data(nla);
1134        bos = true;
1135        for (i = nla_labels - 1; i >= 0; i--, bos = false) {
1136                struct mpls_entry_decoded dec;
1137                dec = mpls_entry_decode(nla_label + i);
1138
1139                /* Ensure the bottom of stack flag is properly set
1140                 * and ttl and tc are both clear.
1141                 */
1142                if ((dec.bos != bos) || dec.ttl || dec.tc)
1143                        return -EINVAL;
1144
1145                switch (dec.label) {
1146                case MPLS_LABEL_IMPLNULL:
1147                        /* RFC3032: This is a label that an LSR may
1148                         * assign and distribute, but which never
1149                         * actually appears in the encapsulation.
1150                         */
1151                        return -EINVAL;
1152                }
1153
1154                label[i] = dec.label;
1155        }
1156        *labels = nla_labels;
1157        return 0;
1158}
1159EXPORT_SYMBOL_GPL(nla_get_labels);
1160
1161int nla_get_via(const struct nlattr *nla, u8 *via_alen,
1162                u8 *via_table, u8 via_addr[])
1163{
1164        struct rtvia *via = nla_data(nla);
1165        int err = -EINVAL;
1166        int alen;
1167
1168        if (nla_len(nla) < offsetof(struct rtvia, rtvia_addr))
1169                goto errout;
1170        alen = nla_len(nla) -
1171                        offsetof(struct rtvia, rtvia_addr);
1172        if (alen > MAX_VIA_ALEN)
1173                goto errout;
1174
1175        /* Validate the address family */
1176        switch (via->rtvia_family) {
1177        case AF_PACKET:
1178                *via_table = NEIGH_LINK_TABLE;
1179                break;
1180        case AF_INET:
1181                *via_table = NEIGH_ARP_TABLE;
1182                if (alen != 4)
1183                        goto errout;
1184                break;
1185        case AF_INET6:
1186                *via_table = NEIGH_ND_TABLE;
1187                if (alen != 16)
1188                        goto errout;
1189                break;
1190        default:
1191                /* Unsupported address family */
1192                goto errout;
1193        }
1194
1195        memcpy(via_addr, via->rtvia_addr, alen);
1196        *via_alen = alen;
1197        err = 0;
1198
1199errout:
1200        return err;
1201}
1202
1203static int rtm_to_route_config(struct sk_buff *skb,  struct nlmsghdr *nlh,
1204                               struct mpls_route_config *cfg)
1205{
1206        struct rtmsg *rtm;
1207        struct nlattr *tb[RTA_MAX+1];
1208        int index;
1209        int err;
1210
1211        err = nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, rtm_mpls_policy);
1212        if (err < 0)
1213                goto errout;
1214
1215        err = -EINVAL;
1216        rtm = nlmsg_data(nlh);
1217        memset(cfg, 0, sizeof(*cfg));
1218
1219        if (rtm->rtm_family != AF_MPLS)
1220                goto errout;
1221        if (rtm->rtm_dst_len != 20)
1222                goto errout;
1223        if (rtm->rtm_src_len != 0)
1224                goto errout;
1225        if (rtm->rtm_tos != 0)
1226                goto errout;
1227        if (rtm->rtm_table != RT_TABLE_MAIN)
1228                goto errout;
1229        /* Any value is acceptable for rtm_protocol */
1230
1231        /* As mpls uses destination specific addresses
1232         * (or source specific address in the case of multicast)
1233         * all addresses have universal scope.
1234         */
1235        if (rtm->rtm_scope != RT_SCOPE_UNIVERSE)
1236                goto errout;
1237        if (rtm->rtm_type != RTN_UNICAST)
1238                goto errout;
1239        if (rtm->rtm_flags != 0)
1240                goto errout;
1241
1242        cfg->rc_label           = LABEL_NOT_SPECIFIED;
1243        cfg->rc_protocol        = rtm->rtm_protocol;
1244        cfg->rc_via_table       = MPLS_NEIGH_TABLE_UNSPEC;
1245        cfg->rc_nlflags         = nlh->nlmsg_flags;
1246        cfg->rc_nlinfo.portid   = NETLINK_CB(skb).portid;
1247        cfg->rc_nlinfo.nlh      = nlh;
1248        cfg->rc_nlinfo.nl_net   = sock_net(skb->sk);
1249
1250        for (index = 0; index <= RTA_MAX; index++) {
1251                struct nlattr *nla = tb[index];
1252                if (!nla)
1253                        continue;
1254
1255                switch (index) {
1256                case RTA_OIF:
1257                        cfg->rc_ifindex = nla_get_u32(nla);
1258                        break;
1259                case RTA_NEWDST:
1260                        if (nla_get_labels(nla, MAX_NEW_LABELS,
1261                                           &cfg->rc_output_labels,
1262                                           cfg->rc_output_label))
1263                                goto errout;
1264                        break;
1265                case RTA_DST:
1266                {
1267                        u8 label_count;
1268                        if (nla_get_labels(nla, 1, &label_count,
1269                                           &cfg->rc_label))
1270                                goto errout;
1271
1272                        /* Reserved labels may not be set */
1273                        if (cfg->rc_label < MPLS_LABEL_FIRST_UNRESERVED)
1274                                goto errout;
1275
1276                        break;
1277                }
1278                case RTA_VIA:
1279                {
1280                        if (nla_get_via(nla, &cfg->rc_via_alen,
1281                                        &cfg->rc_via_table, cfg->rc_via))
1282                                goto errout;
1283                        break;
1284                }
1285                case RTA_MULTIPATH:
1286                {
1287                        cfg->rc_mp = nla_data(nla);
1288                        cfg->rc_mp_len = nla_len(nla);
1289                        break;
1290                }
1291                default:
1292                        /* Unsupported attribute */
1293                        goto errout;
1294                }
1295        }
1296
1297        err = 0;
1298errout:
1299        return err;
1300}
1301
1302static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh)
1303{
1304        struct mpls_route_config cfg;
1305        int err;
1306
1307        err = rtm_to_route_config(skb, nlh, &cfg);
1308        if (err < 0)
1309                return err;
1310
1311        return mpls_route_del(&cfg);
1312}
1313
1314
1315static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh)
1316{
1317        struct mpls_route_config cfg;
1318        int err;
1319
1320        err = rtm_to_route_config(skb, nlh, &cfg);
1321        if (err < 0)
1322                return err;
1323
1324        return mpls_route_add(&cfg);
1325}
1326
1327static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event,
1328                           u32 label, struct mpls_route *rt, int flags)
1329{
1330        struct net_device *dev;
1331        struct nlmsghdr *nlh;
1332        struct rtmsg *rtm;
1333
1334        nlh = nlmsg_put(skb, portid, seq, event, sizeof(*rtm), flags);
1335        if (nlh == NULL)
1336                return -EMSGSIZE;
1337
1338        rtm = nlmsg_data(nlh);
1339        rtm->rtm_family = AF_MPLS;
1340        rtm->rtm_dst_len = 20;
1341        rtm->rtm_src_len = 0;
1342        rtm->rtm_tos = 0;
1343        rtm->rtm_table = RT_TABLE_MAIN;
1344        rtm->rtm_protocol = rt->rt_protocol;
1345        rtm->rtm_scope = RT_SCOPE_UNIVERSE;
1346        rtm->rtm_type = RTN_UNICAST;
1347        rtm->rtm_flags = 0;
1348
1349        if (nla_put_labels(skb, RTA_DST, 1, &label))
1350                goto nla_put_failure;
1351        if (rt->rt_nhn == 1) {
1352                const struct mpls_nh *nh = rt->rt_nh;
1353
1354                if (nh->nh_labels &&
1355                    nla_put_labels(skb, RTA_NEWDST, nh->nh_labels,
1356                                   nh->nh_label))
1357                        goto nla_put_failure;
1358                if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC &&
1359                    nla_put_via(skb, nh->nh_via_table, mpls_nh_via(rt, nh),
1360                                nh->nh_via_alen))
1361                        goto nla_put_failure;
1362                dev = rtnl_dereference(nh->nh_dev);
1363                if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex))
1364                        goto nla_put_failure;
1365                if (nh->nh_flags & RTNH_F_LINKDOWN)
1366                        rtm->rtm_flags |= RTNH_F_LINKDOWN;
1367                if (nh->nh_flags & RTNH_F_DEAD)
1368                        rtm->rtm_flags |= RTNH_F_DEAD;
1369        } else {
1370                struct rtnexthop *rtnh;
1371                struct nlattr *mp;
1372                int dead = 0;
1373                int linkdown = 0;
1374
1375                mp = nla_nest_start(skb, RTA_MULTIPATH);
1376                if (!mp)
1377                        goto nla_put_failure;
1378
1379                for_nexthops(rt) {
1380                        rtnh = nla_reserve_nohdr(skb, sizeof(*rtnh));
1381                        if (!rtnh)
1382                                goto nla_put_failure;
1383
1384                        dev = rtnl_dereference(nh->nh_dev);
1385                        if (dev)
1386                                rtnh->rtnh_ifindex = dev->ifindex;
1387                        if (nh->nh_flags & RTNH_F_LINKDOWN) {
1388                                rtnh->rtnh_flags |= RTNH_F_LINKDOWN;
1389                                linkdown++;
1390                        }
1391                        if (nh->nh_flags & RTNH_F_DEAD) {
1392                                rtnh->rtnh_flags |= RTNH_F_DEAD;
1393                                dead++;
1394                        }
1395
1396                        if (nh->nh_labels && nla_put_labels(skb, RTA_NEWDST,
1397                                                            nh->nh_labels,
1398                                                            nh->nh_label))
1399                                goto nla_put_failure;
1400                        if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC &&
1401                            nla_put_via(skb, nh->nh_via_table,
1402                                        mpls_nh_via(rt, nh),
1403                                        nh->nh_via_alen))
1404                                goto nla_put_failure;
1405
1406                        /* length of rtnetlink header + attributes */
1407                        rtnh->rtnh_len = nlmsg_get_pos(skb) - (void *)rtnh;
1408                } endfor_nexthops(rt);
1409
1410                if (linkdown == rt->rt_nhn)
1411                        rtm->rtm_flags |= RTNH_F_LINKDOWN;
1412                if (dead == rt->rt_nhn)
1413                        rtm->rtm_flags |= RTNH_F_DEAD;
1414
1415                nla_nest_end(skb, mp);
1416        }
1417
1418        nlmsg_end(skb, nlh);
1419        return 0;
1420
1421nla_put_failure:
1422        nlmsg_cancel(skb, nlh);
1423        return -EMSGSIZE;
1424}
1425
1426static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
1427{
1428        struct net *net = sock_net(skb->sk);
1429        struct mpls_route __rcu **platform_label;
1430        size_t platform_labels;
1431        unsigned int index;
1432
1433        ASSERT_RTNL();
1434
1435        index = cb->args[0];
1436        if (index < MPLS_LABEL_FIRST_UNRESERVED)
1437                index = MPLS_LABEL_FIRST_UNRESERVED;
1438
1439        platform_label = rtnl_dereference(net->mpls.platform_label);
1440        platform_labels = net->mpls.platform_labels;
1441        for (; index < platform_labels; index++) {
1442                struct mpls_route *rt;
1443                rt = rtnl_dereference(platform_label[index]);
1444                if (!rt)
1445                        continue;
1446
1447                if (mpls_dump_route(skb, NETLINK_CB(cb->skb).portid,
1448                                    cb->nlh->nlmsg_seq, RTM_NEWROUTE,
1449                                    index, rt, NLM_F_MULTI) < 0)
1450                        break;
1451        }
1452        cb->args[0] = index;
1453
1454        return skb->len;
1455}
1456
1457static inline size_t lfib_nlmsg_size(struct mpls_route *rt)
1458{
1459        size_t payload =
1460                NLMSG_ALIGN(sizeof(struct rtmsg))
1461                + nla_total_size(4);                    /* RTA_DST */
1462
1463        if (rt->rt_nhn == 1) {
1464                struct mpls_nh *nh = rt->rt_nh;
1465
1466                if (nh->nh_dev)
1467                        payload += nla_total_size(4); /* RTA_OIF */
1468                if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC) /* RTA_VIA */
1469                        payload += nla_total_size(2 + nh->nh_via_alen);
1470                if (nh->nh_labels) /* RTA_NEWDST */
1471                        payload += nla_total_size(nh->nh_labels * 4);
1472        } else {
1473                /* each nexthop is packed in an attribute */
1474                size_t nhsize = 0;
1475
1476                for_nexthops(rt) {
1477                        nhsize += nla_total_size(sizeof(struct rtnexthop));
1478                        /* RTA_VIA */
1479                        if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC)
1480                                nhsize += nla_total_size(2 + nh->nh_via_alen);
1481                        if (nh->nh_labels)
1482                                nhsize += nla_total_size(nh->nh_labels * 4);
1483                } endfor_nexthops(rt);
1484                /* nested attribute */
1485                payload += nla_total_size(nhsize);
1486        }
1487
1488        return payload;
1489}
1490
1491static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
1492                       struct nlmsghdr *nlh, struct net *net, u32 portid,
1493                       unsigned int nlm_flags)
1494{
1495        struct sk_buff *skb;
1496        u32 seq = nlh ? nlh->nlmsg_seq : 0;
1497        int err = -ENOBUFS;
1498
1499        skb = nlmsg_new(lfib_nlmsg_size(rt), GFP_KERNEL);
1500        if (skb == NULL)
1501                goto errout;
1502
1503        err = mpls_dump_route(skb, portid, seq, event, label, rt, nlm_flags);
1504        if (err < 0) {
1505                /* -EMSGSIZE implies BUG in lfib_nlmsg_size */
1506                WARN_ON(err == -EMSGSIZE);
1507                kfree_skb(skb);
1508                goto errout;
1509        }
1510        rtnl_notify(skb, net, portid, RTNLGRP_MPLS_ROUTE, nlh, GFP_KERNEL);
1511
1512        return;
1513errout:
1514        if (err < 0)
1515                rtnl_set_sk_err(net, RTNLGRP_MPLS_ROUTE, err);
1516}
1517
1518static int resize_platform_label_table(struct net *net, size_t limit)
1519{
1520        size_t size = sizeof(struct mpls_route *) * limit;
1521        size_t old_limit;
1522        size_t cp_size;
1523        struct mpls_route __rcu **labels = NULL, **old;
1524        struct mpls_route *rt0 = NULL, *rt2 = NULL;
1525        unsigned index;
1526
1527        if (size) {
1528                labels = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
1529                if (!labels)
1530                        labels = vzalloc(size);
1531
1532                if (!labels)
1533                        goto nolabels;
1534        }
1535
1536        /* In case the predefined labels need to be populated */
1537        if (limit > MPLS_LABEL_IPV4NULL) {
1538                struct net_device *lo = net->loopback_dev;
1539                rt0 = mpls_rt_alloc(1, lo->addr_len);
1540                if (!rt0)
1541                        goto nort0;
1542                RCU_INIT_POINTER(rt0->rt_nh->nh_dev, lo);
1543                rt0->rt_protocol = RTPROT_KERNEL;
1544                rt0->rt_payload_type = MPT_IPV4;
1545                rt0->rt_nh->nh_via_table = NEIGH_LINK_TABLE;
1546                rt0->rt_nh->nh_via_alen = lo->addr_len;
1547                memcpy(__mpls_nh_via(rt0, rt0->rt_nh), lo->dev_addr,
1548                       lo->addr_len);
1549        }
1550        if (limit > MPLS_LABEL_IPV6NULL) {
1551                struct net_device *lo = net->loopback_dev;
1552                rt2 = mpls_rt_alloc(1, lo->addr_len);
1553                if (!rt2)
1554                        goto nort2;
1555                RCU_INIT_POINTER(rt2->rt_nh->nh_dev, lo);
1556                rt2->rt_protocol = RTPROT_KERNEL;
1557                rt2->rt_payload_type = MPT_IPV6;
1558                rt2->rt_nh->nh_via_table = NEIGH_LINK_TABLE;
1559                rt2->rt_nh->nh_via_alen = lo->addr_len;
1560                memcpy(__mpls_nh_via(rt2, rt2->rt_nh), lo->dev_addr,
1561                       lo->addr_len);
1562        }
1563
1564        rtnl_lock();
1565        /* Remember the original table */
1566        old = rtnl_dereference(net->mpls.platform_label);
1567        old_limit = net->mpls.platform_labels;
1568
1569        /* Free any labels beyond the new table */
1570        for (index = limit; index < old_limit; index++)
1571                mpls_route_update(net, index, NULL, NULL);
1572
1573        /* Copy over the old labels */
1574        cp_size = size;
1575        if (old_limit < limit)
1576                cp_size = old_limit * sizeof(struct mpls_route *);
1577
1578        memcpy(labels, old, cp_size);
1579
1580        /* If needed set the predefined labels */
1581        if ((old_limit <= MPLS_LABEL_IPV6NULL) &&
1582            (limit > MPLS_LABEL_IPV6NULL)) {
1583                RCU_INIT_POINTER(labels[MPLS_LABEL_IPV6NULL], rt2);
1584                rt2 = NULL;
1585        }
1586
1587        if ((old_limit <= MPLS_LABEL_IPV4NULL) &&
1588            (limit > MPLS_LABEL_IPV4NULL)) {
1589                RCU_INIT_POINTER(labels[MPLS_LABEL_IPV4NULL], rt0);
1590                rt0 = NULL;
1591        }
1592
1593        /* Update the global pointers */
1594        net->mpls.platform_labels = limit;
1595        rcu_assign_pointer(net->mpls.platform_label, labels);
1596
1597        rtnl_unlock();
1598
1599        mpls_rt_free(rt2);
1600        mpls_rt_free(rt0);
1601
1602        if (old) {
1603                synchronize_rcu();
1604                kvfree(old);
1605        }
1606        return 0;
1607
1608nort2:
1609        mpls_rt_free(rt0);
1610nort0:
1611        kvfree(labels);
1612nolabels:
1613        return -ENOMEM;
1614}
1615
1616static int mpls_platform_labels(struct ctl_table *table, int write,
1617                                void __user *buffer, size_t *lenp, loff_t *ppos)
1618{
1619        struct net *net = table->data;
1620        int platform_labels = net->mpls.platform_labels;
1621        int ret;
1622        struct ctl_table tmp = {
1623                .procname       = table->procname,
1624                .data           = &platform_labels,
1625                .maxlen         = sizeof(int),
1626                .mode           = table->mode,
1627                .extra1         = &zero,
1628                .extra2         = &label_limit,
1629        };
1630
1631        ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos);
1632
1633        if (write && ret == 0)
1634                ret = resize_platform_label_table(net, platform_labels);
1635
1636        return ret;
1637}
1638
1639static const struct ctl_table mpls_table[] = {
1640        {
1641                .procname       = "platform_labels",
1642                .data           = NULL,
1643                .maxlen         = sizeof(int),
1644                .mode           = 0644,
1645                .proc_handler   = mpls_platform_labels,
1646        },
1647        { }
1648};
1649
1650static int mpls_net_init(struct net *net)
1651{
1652        struct ctl_table *table;
1653
1654        net->mpls.platform_labels = 0;
1655        net->mpls.platform_label = NULL;
1656
1657        table = kmemdup(mpls_table, sizeof(mpls_table), GFP_KERNEL);
1658        if (table == NULL)
1659                return -ENOMEM;
1660
1661        table[0].data = net;
1662        net->mpls.ctl = register_net_sysctl(net, "net/mpls", table);
1663        if (net->mpls.ctl == NULL) {
1664                kfree(table);
1665                return -ENOMEM;
1666        }
1667
1668        return 0;
1669}
1670
1671static void mpls_net_exit(struct net *net)
1672{
1673        struct mpls_route __rcu **platform_label;
1674        size_t platform_labels;
1675        struct ctl_table *table;
1676        unsigned int index;
1677
1678        table = net->mpls.ctl->ctl_table_arg;
1679        unregister_net_sysctl_table(net->mpls.ctl);
1680        kfree(table);
1681
1682        /* An rcu grace period has passed since there was a device in
1683         * the network namespace (and thus the last in flight packet)
1684         * left this network namespace.  This is because
1685         * unregister_netdevice_many and netdev_run_todo has completed
1686         * for each network device that was in this network namespace.
1687         *
1688         * As such no additional rcu synchronization is necessary when
1689         * freeing the platform_label table.
1690         */
1691        rtnl_lock();
1692        platform_label = rtnl_dereference(net->mpls.platform_label);
1693        platform_labels = net->mpls.platform_labels;
1694        for (index = 0; index < platform_labels; index++) {
1695                struct mpls_route *rt = rtnl_dereference(platform_label[index]);
1696                RCU_INIT_POINTER(platform_label[index], NULL);
1697                mpls_rt_free(rt);
1698        }
1699        rtnl_unlock();
1700
1701        kvfree(platform_label);
1702}
1703
1704static struct pernet_operations mpls_net_ops = {
1705        .init = mpls_net_init,
1706        .exit = mpls_net_exit,
1707};
1708
1709static int __init mpls_init(void)
1710{
1711        int err;
1712
1713        BUILD_BUG_ON(sizeof(struct mpls_shim_hdr) != 4);
1714
1715        err = register_pernet_subsys(&mpls_net_ops);
1716        if (err)
1717                goto out;
1718
1719        err = register_netdevice_notifier(&mpls_dev_notifier);
1720        if (err)
1721                goto out_unregister_pernet;
1722
1723        dev_add_pack(&mpls_packet_type);
1724
1725        rtnl_register(PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL, NULL);
1726        rtnl_register(PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL, NULL);
1727        rtnl_register(PF_MPLS, RTM_GETROUTE, NULL, mpls_dump_routes, NULL);
1728        err = 0;
1729out:
1730        return err;
1731
1732out_unregister_pernet:
1733        unregister_pernet_subsys(&mpls_net_ops);
1734        goto out;
1735}
1736module_init(mpls_init);
1737
1738static void __exit mpls_exit(void)
1739{
1740        rtnl_unregister_all(PF_MPLS);
1741        dev_remove_pack(&mpls_packet_type);
1742        unregister_netdevice_notifier(&mpls_dev_notifier);
1743        unregister_pernet_subsys(&mpls_net_ops);
1744}
1745module_exit(mpls_exit);
1746
1747MODULE_DESCRIPTION("MultiProtocol Label Switching");
1748MODULE_LICENSE("GPL v2");
1749MODULE_ALIAS_NETPROTO(PF_MPLS);
1750