linux/net/mptcp/pm_netlink.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Multipath TCP
   3 *
   4 * Copyright (c) 2020, Red Hat, Inc.
   5 */
   6
   7#define pr_fmt(fmt) "MPTCP: " fmt
   8
   9#include <linux/inet.h>
  10#include <linux/kernel.h>
  11#include <net/tcp.h>
  12#include <net/netns/generic.h>
  13#include <net/mptcp.h>
  14#include <net/genetlink.h>
  15#include <uapi/linux/mptcp.h>
  16
  17#include "protocol.h"
  18
  19/* forward declaration */
  20static struct genl_family mptcp_genl_family;
  21
  22static int pm_nl_pernet_id;
  23
  24struct mptcp_pm_addr_entry {
  25        struct list_head        list;
  26        unsigned int            flags;
  27        int                     ifindex;
  28        struct mptcp_addr_info  addr;
  29        struct rcu_head         rcu;
  30};
  31
  32struct pm_nl_pernet {
  33        /* protects pernet updates */
  34        spinlock_t              lock;
  35        struct list_head        local_addr_list;
  36        unsigned int            addrs;
  37        unsigned int            add_addr_signal_max;
  38        unsigned int            add_addr_accept_max;
  39        unsigned int            local_addr_max;
  40        unsigned int            subflows_max;
  41        unsigned int            next_id;
  42};
  43
  44#define MPTCP_PM_ADDR_MAX       8
  45
  46static bool addresses_equal(const struct mptcp_addr_info *a,
  47                            struct mptcp_addr_info *b, bool use_port)
  48{
  49        bool addr_equals = false;
  50
  51        if (a->family != b->family)
  52                return false;
  53
  54        if (a->family == AF_INET)
  55                addr_equals = a->addr.s_addr == b->addr.s_addr;
  56#if IS_ENABLED(CONFIG_MPTCP_IPV6)
  57        else
  58                addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
  59#endif
  60
  61        if (!addr_equals)
  62                return false;
  63        if (!use_port)
  64                return true;
  65
  66        return a->port == b->port;
  67}
  68
  69static void local_address(const struct sock_common *skc,
  70                          struct mptcp_addr_info *addr)
  71{
  72        addr->port = 0;
  73        addr->family = skc->skc_family;
  74        if (addr->family == AF_INET)
  75                addr->addr.s_addr = skc->skc_rcv_saddr;
  76#if IS_ENABLED(CONFIG_MPTCP_IPV6)
  77        else if (addr->family == AF_INET6)
  78                addr->addr6 = skc->skc_v6_rcv_saddr;
  79#endif
  80}
  81
  82static void remote_address(const struct sock_common *skc,
  83                           struct mptcp_addr_info *addr)
  84{
  85        addr->family = skc->skc_family;
  86        addr->port = skc->skc_dport;
  87        if (addr->family == AF_INET)
  88                addr->addr.s_addr = skc->skc_daddr;
  89#if IS_ENABLED(CONFIG_MPTCP_IPV6)
  90        else if (addr->family == AF_INET6)
  91                addr->addr6 = skc->skc_v6_daddr;
  92#endif
  93}
  94
  95static bool lookup_subflow_by_saddr(const struct list_head *list,
  96                                    struct mptcp_addr_info *saddr)
  97{
  98        struct mptcp_subflow_context *subflow;
  99        struct mptcp_addr_info cur;
 100        struct sock_common *skc;
 101
 102        list_for_each_entry(subflow, list, node) {
 103                skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
 104
 105                local_address(skc, &cur);
 106                if (addresses_equal(&cur, saddr, false))
 107                        return true;
 108        }
 109
 110        return false;
 111}
 112
 113static struct mptcp_pm_addr_entry *
 114select_local_address(const struct pm_nl_pernet *pernet,
 115                     struct mptcp_sock *msk)
 116{
 117        struct mptcp_pm_addr_entry *entry, *ret = NULL;
 118
 119        rcu_read_lock();
 120        spin_lock_bh(&msk->join_list_lock);
 121        list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
 122                if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
 123                        continue;
 124
 125                /* avoid any address already in use by subflows and
 126                 * pending join
 127                 */
 128                if (entry->addr.family == ((struct sock *)msk)->sk_family &&
 129                    !lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) &&
 130                    !lookup_subflow_by_saddr(&msk->join_list, &entry->addr)) {
 131                        ret = entry;
 132                        break;
 133                }
 134        }
 135        spin_unlock_bh(&msk->join_list_lock);
 136        rcu_read_unlock();
 137        return ret;
 138}
 139
 140static struct mptcp_pm_addr_entry *
 141select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos)
 142{
 143        struct mptcp_pm_addr_entry *entry, *ret = NULL;
 144        int i = 0;
 145
 146        rcu_read_lock();
 147        /* do not keep any additional per socket state, just signal
 148         * the address list in order.
 149         * Note: removal from the local address list during the msk life-cycle
 150         * can lead to additional addresses not being announced.
 151         */
 152        list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
 153                if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
 154                        continue;
 155                if (i++ == pos) {
 156                        ret = entry;
 157                        break;
 158                }
 159        }
 160        rcu_read_unlock();
 161        return ret;
 162}
 163
 164static void check_work_pending(struct mptcp_sock *msk)
 165{
 166        if (msk->pm.add_addr_signaled == msk->pm.add_addr_signal_max &&
 167            (msk->pm.local_addr_used == msk->pm.local_addr_max ||
 168             msk->pm.subflows == msk->pm.subflows_max))
 169                WRITE_ONCE(msk->pm.work_pending, false);
 170}
 171
 172static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
 173{
 174        struct sock *sk = (struct sock *)msk;
 175        struct mptcp_pm_addr_entry *local;
 176        struct mptcp_addr_info remote;
 177        struct pm_nl_pernet *pernet;
 178
 179        pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
 180
 181        pr_debug("local %d:%d signal %d:%d subflows %d:%d\n",
 182                 msk->pm.local_addr_used, msk->pm.local_addr_max,
 183                 msk->pm.add_addr_signaled, msk->pm.add_addr_signal_max,
 184                 msk->pm.subflows, msk->pm.subflows_max);
 185
 186        /* check first for announce */
 187        if (msk->pm.add_addr_signaled < msk->pm.add_addr_signal_max) {
 188                local = select_signal_address(pernet,
 189                                              msk->pm.add_addr_signaled);
 190
 191                if (local) {
 192                        msk->pm.add_addr_signaled++;
 193                        mptcp_pm_announce_addr(msk, &local->addr);
 194                } else {
 195                        /* pick failed, avoid fourther attempts later */
 196                        msk->pm.local_addr_used = msk->pm.add_addr_signal_max;
 197                }
 198
 199                check_work_pending(msk);
 200        }
 201
 202        /* check if should create a new subflow */
 203        if (msk->pm.local_addr_used < msk->pm.local_addr_max &&
 204            msk->pm.subflows < msk->pm.subflows_max) {
 205                remote_address((struct sock_common *)sk, &remote);
 206
 207                local = select_local_address(pernet, msk);
 208                if (local) {
 209                        msk->pm.local_addr_used++;
 210                        msk->pm.subflows++;
 211                        check_work_pending(msk);
 212                        spin_unlock_bh(&msk->pm.lock);
 213                        __mptcp_subflow_connect(sk, local->ifindex,
 214                                                &local->addr, &remote);
 215                        spin_lock_bh(&msk->pm.lock);
 216                        return;
 217                }
 218
 219                /* lookup failed, avoid fourther attempts later */
 220                msk->pm.local_addr_used = msk->pm.local_addr_max;
 221                check_work_pending(msk);
 222        }
 223}
 224
 225void mptcp_pm_nl_fully_established(struct mptcp_sock *msk)
 226{
 227        mptcp_pm_create_subflow_or_signal_addr(msk);
 228}
 229
 230void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
 231{
 232        mptcp_pm_create_subflow_or_signal_addr(msk);
 233}
 234
 235void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
 236{
 237        struct sock *sk = (struct sock *)msk;
 238        struct mptcp_addr_info remote;
 239        struct mptcp_addr_info local;
 240
 241        pr_debug("accepted %d:%d remote family %d",
 242                 msk->pm.add_addr_accepted, msk->pm.add_addr_accept_max,
 243                 msk->pm.remote.family);
 244        msk->pm.add_addr_accepted++;
 245        msk->pm.subflows++;
 246        if (msk->pm.add_addr_accepted >= msk->pm.add_addr_accept_max ||
 247            msk->pm.subflows >= msk->pm.subflows_max)
 248                WRITE_ONCE(msk->pm.accept_addr, false);
 249
 250        /* connect to the specified remote address, using whatever
 251         * local address the routing configuration will pick.
 252         */
 253        remote = msk->pm.remote;
 254        if (!remote.port)
 255                remote.port = sk->sk_dport;
 256        memset(&local, 0, sizeof(local));
 257        local.family = remote.family;
 258
 259        spin_unlock_bh(&msk->pm.lock);
 260        __mptcp_subflow_connect((struct sock *)msk, 0, &local, &remote);
 261        spin_lock_bh(&msk->pm.lock);
 262}
 263
 264static bool address_use_port(struct mptcp_pm_addr_entry *entry)
 265{
 266        return (entry->flags &
 267                (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) ==
 268                MPTCP_PM_ADDR_FLAG_SIGNAL;
 269}
 270
 271static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
 272                                             struct mptcp_pm_addr_entry *entry)
 273{
 274        struct mptcp_pm_addr_entry *cur;
 275        int ret = -EINVAL;
 276
 277        spin_lock_bh(&pernet->lock);
 278        /* to keep the code simple, don't do IDR-like allocation for address ID,
 279         * just bail when we exceed limits
 280         */
 281        if (pernet->next_id > 255)
 282                goto out;
 283        if (pernet->addrs >= MPTCP_PM_ADDR_MAX)
 284                goto out;
 285
 286        /* do not insert duplicate address, differentiate on port only
 287         * singled addresses
 288         */
 289        list_for_each_entry(cur, &pernet->local_addr_list, list) {
 290                if (addresses_equal(&cur->addr, &entry->addr,
 291                                    address_use_port(entry) &&
 292                                    address_use_port(cur)))
 293                        goto out;
 294        }
 295
 296        if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)
 297                pernet->add_addr_signal_max++;
 298        if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)
 299                pernet->local_addr_max++;
 300
 301        entry->addr.id = pernet->next_id++;
 302        pernet->addrs++;
 303        list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
 304        ret = entry->addr.id;
 305
 306out:
 307        spin_unlock_bh(&pernet->lock);
 308        return ret;
 309}
 310
 311int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
 312{
 313        struct mptcp_pm_addr_entry *entry;
 314        struct mptcp_addr_info skc_local;
 315        struct mptcp_addr_info msk_local;
 316        struct pm_nl_pernet *pernet;
 317        int ret = -1;
 318
 319        if (WARN_ON_ONCE(!msk))
 320                return -1;
 321
 322        /* The 0 ID mapping is defined by the first subflow, copied into the msk
 323         * addr
 324         */
 325        local_address((struct sock_common *)msk, &msk_local);
 326        local_address((struct sock_common *)msk, &skc_local);
 327        if (addresses_equal(&msk_local, &skc_local, false))
 328                return 0;
 329
 330        pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
 331
 332        rcu_read_lock();
 333        list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
 334                if (addresses_equal(&entry->addr, &skc_local, false)) {
 335                        ret = entry->addr.id;
 336                        break;
 337                }
 338        }
 339        rcu_read_unlock();
 340        if (ret >= 0)
 341                return ret;
 342
 343        /* address not found, add to local list */
 344        entry = kmalloc(sizeof(*entry), GFP_KERNEL);
 345        if (!entry)
 346                return -ENOMEM;
 347
 348        entry->flags = 0;
 349        entry->addr = skc_local;
 350        ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
 351        if (ret < 0)
 352                kfree(entry);
 353
 354        return ret;
 355}
 356
 357void mptcp_pm_nl_data_init(struct mptcp_sock *msk)
 358{
 359        struct mptcp_pm_data *pm = &msk->pm;
 360        struct pm_nl_pernet *pernet;
 361        bool subflows;
 362
 363        pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
 364
 365        pm->add_addr_signal_max = READ_ONCE(pernet->add_addr_signal_max);
 366        pm->add_addr_accept_max = READ_ONCE(pernet->add_addr_accept_max);
 367        pm->local_addr_max = READ_ONCE(pernet->local_addr_max);
 368        pm->subflows_max = READ_ONCE(pernet->subflows_max);
 369        subflows = !!pm->subflows_max;
 370        WRITE_ONCE(pm->work_pending, (!!pm->local_addr_max && subflows) ||
 371                   !!pm->add_addr_signal_max);
 372        WRITE_ONCE(pm->accept_addr, !!pm->add_addr_accept_max && subflows);
 373        WRITE_ONCE(pm->accept_subflow, subflows);
 374}
 375
 376#define MPTCP_PM_CMD_GRP_OFFSET 0
 377
 378static const struct genl_multicast_group mptcp_pm_mcgrps[] = {
 379        [MPTCP_PM_CMD_GRP_OFFSET]       = { .name = MPTCP_PM_CMD_GRP_NAME, },
 380};
 381
 382static const struct nla_policy
 383mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = {
 384        [MPTCP_PM_ADDR_ATTR_FAMILY]     = { .type       = NLA_U16,      },
 385        [MPTCP_PM_ADDR_ATTR_ID]         = { .type       = NLA_U8,       },
 386        [MPTCP_PM_ADDR_ATTR_ADDR4]      = { .type       = NLA_U32,      },
 387        [MPTCP_PM_ADDR_ATTR_ADDR6]      = { .type       = NLA_EXACT_LEN,
 388                                            .len   = sizeof(struct in6_addr), },
 389        [MPTCP_PM_ADDR_ATTR_PORT]       = { .type       = NLA_U16       },
 390        [MPTCP_PM_ADDR_ATTR_FLAGS]      = { .type       = NLA_U32       },
 391        [MPTCP_PM_ADDR_ATTR_IF_IDX]     = { .type       = NLA_S32       },
 392};
 393
 394static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = {
 395        [MPTCP_PM_ATTR_ADDR]            =
 396                                        NLA_POLICY_NESTED(mptcp_pm_addr_policy),
 397        [MPTCP_PM_ATTR_RCV_ADD_ADDRS]   = { .type       = NLA_U32,      },
 398        [MPTCP_PM_ATTR_SUBFLOWS]        = { .type       = NLA_U32,      },
 399};
 400
 401static int mptcp_pm_family_to_addr(int family)
 402{
 403#if IS_ENABLED(CONFIG_MPTCP_IPV6)
 404        if (family == AF_INET6)
 405                return MPTCP_PM_ADDR_ATTR_ADDR6;
 406#endif
 407        return MPTCP_PM_ADDR_ATTR_ADDR4;
 408}
 409
 410static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info,
 411                               bool require_family,
 412                               struct mptcp_pm_addr_entry *entry)
 413{
 414        struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1];
 415        int err, addr_addr;
 416
 417        if (!attr) {
 418                GENL_SET_ERR_MSG(info, "missing address info");
 419                return -EINVAL;
 420        }
 421
 422        /* no validation needed - was already done via nested policy */
 423        err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr,
 424                                          mptcp_pm_addr_policy, info->extack);
 425        if (err)
 426                return err;
 427
 428        memset(entry, 0, sizeof(*entry));
 429        if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) {
 430                if (!require_family)
 431                        goto skip_family;
 432
 433                NL_SET_ERR_MSG_ATTR(info->extack, attr,
 434                                    "missing family");
 435                return -EINVAL;
 436        }
 437
 438        entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]);
 439        if (entry->addr.family != AF_INET
 440#if IS_ENABLED(CONFIG_MPTCP_IPV6)
 441            && entry->addr.family != AF_INET6
 442#endif
 443            ) {
 444                NL_SET_ERR_MSG_ATTR(info->extack, attr,
 445                                    "unknown address family");
 446                return -EINVAL;
 447        }
 448        addr_addr = mptcp_pm_family_to_addr(entry->addr.family);
 449        if (!tb[addr_addr]) {
 450                NL_SET_ERR_MSG_ATTR(info->extack, attr,
 451                                    "missing address data");
 452                return -EINVAL;
 453        }
 454
 455#if IS_ENABLED(CONFIG_MPTCP_IPV6)
 456        if (entry->addr.family == AF_INET6)
 457                entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]);
 458        else
 459#endif
 460                entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]);
 461
 462skip_family:
 463        if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX])
 464                entry->ifindex = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]);
 465
 466        if (tb[MPTCP_PM_ADDR_ATTR_ID])
 467                entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]);
 468
 469        if (tb[MPTCP_PM_ADDR_ATTR_FLAGS])
 470                entry->flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
 471
 472        return 0;
 473}
 474
 475static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info)
 476{
 477        return net_generic(genl_info_net(info), pm_nl_pernet_id);
 478}
 479
 480static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
 481{
 482        struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
 483        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
 484        struct mptcp_pm_addr_entry addr, *entry;
 485        int ret;
 486
 487        ret = mptcp_pm_parse_addr(attr, info, true, &addr);
 488        if (ret < 0)
 489                return ret;
 490
 491        entry = kmalloc(sizeof(*entry), GFP_KERNEL);
 492        if (!entry) {
 493                GENL_SET_ERR_MSG(info, "can't allocate addr");
 494                return -ENOMEM;
 495        }
 496
 497        *entry = addr;
 498        ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
 499        if (ret < 0) {
 500                GENL_SET_ERR_MSG(info, "too many addresses or duplicate one");
 501                kfree(entry);
 502                return ret;
 503        }
 504
 505        return 0;
 506}
 507
 508static struct mptcp_pm_addr_entry *
 509__lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
 510{
 511        struct mptcp_pm_addr_entry *entry;
 512
 513        list_for_each_entry(entry, &pernet->local_addr_list, list) {
 514                if (entry->addr.id == id)
 515                        return entry;
 516        }
 517        return NULL;
 518}
 519
 520static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
 521{
 522        struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
 523        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
 524        struct mptcp_pm_addr_entry addr, *entry;
 525        int ret;
 526
 527        ret = mptcp_pm_parse_addr(attr, info, false, &addr);
 528        if (ret < 0)
 529                return ret;
 530
 531        spin_lock_bh(&pernet->lock);
 532        entry = __lookup_addr_by_id(pernet, addr.addr.id);
 533        if (!entry) {
 534                GENL_SET_ERR_MSG(info, "address not found");
 535                ret = -EINVAL;
 536                goto out;
 537        }
 538        if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)
 539                pernet->add_addr_signal_max--;
 540        if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)
 541                pernet->local_addr_max--;
 542
 543        pernet->addrs--;
 544        list_del_rcu(&entry->list);
 545        kfree_rcu(entry, rcu);
 546out:
 547        spin_unlock_bh(&pernet->lock);
 548        return ret;
 549}
 550
 551static void __flush_addrs(struct pm_nl_pernet *pernet)
 552{
 553        while (!list_empty(&pernet->local_addr_list)) {
 554                struct mptcp_pm_addr_entry *cur;
 555
 556                cur = list_entry(pernet->local_addr_list.next,
 557                                 struct mptcp_pm_addr_entry, list);
 558                list_del_rcu(&cur->list);
 559                kfree_rcu(cur, rcu);
 560        }
 561}
 562
 563static void __reset_counters(struct pm_nl_pernet *pernet)
 564{
 565        pernet->add_addr_signal_max = 0;
 566        pernet->add_addr_accept_max = 0;
 567        pernet->local_addr_max = 0;
 568        pernet->addrs = 0;
 569}
 570
 571static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info)
 572{
 573        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
 574
 575        spin_lock_bh(&pernet->lock);
 576        __flush_addrs(pernet);
 577        __reset_counters(pernet);
 578        spin_unlock_bh(&pernet->lock);
 579        return 0;
 580}
 581
 582static int mptcp_nl_fill_addr(struct sk_buff *skb,
 583                              struct mptcp_pm_addr_entry *entry)
 584{
 585        struct mptcp_addr_info *addr = &entry->addr;
 586        struct nlattr *attr;
 587
 588        attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR);
 589        if (!attr)
 590                return -EMSGSIZE;
 591
 592        if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family))
 593                goto nla_put_failure;
 594        if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id))
 595                goto nla_put_failure;
 596        if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->flags))
 597                goto nla_put_failure;
 598        if (entry->ifindex &&
 599            nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->ifindex))
 600                goto nla_put_failure;
 601
 602        if (addr->family == AF_INET &&
 603            nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4,
 604                            addr->addr.s_addr))
 605                goto nla_put_failure;
 606#if IS_ENABLED(CONFIG_MPTCP_IPV6)
 607        else if (addr->family == AF_INET6 &&
 608                 nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6))
 609                goto nla_put_failure;
 610#endif
 611        nla_nest_end(skb, attr);
 612        return 0;
 613
 614nla_put_failure:
 615        nla_nest_cancel(skb, attr);
 616        return -EMSGSIZE;
 617}
 618
 619static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info)
 620{
 621        struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
 622        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
 623        struct mptcp_pm_addr_entry addr, *entry;
 624        struct sk_buff *msg;
 625        void *reply;
 626        int ret;
 627
 628        ret = mptcp_pm_parse_addr(attr, info, false, &addr);
 629        if (ret < 0)
 630                return ret;
 631
 632        msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
 633        if (!msg)
 634                return -ENOMEM;
 635
 636        reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
 637                                  info->genlhdr->cmd);
 638        if (!reply) {
 639                GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
 640                ret = -EMSGSIZE;
 641                goto fail;
 642        }
 643
 644        spin_lock_bh(&pernet->lock);
 645        entry = __lookup_addr_by_id(pernet, addr.addr.id);
 646        if (!entry) {
 647                GENL_SET_ERR_MSG(info, "address not found");
 648                ret = -EINVAL;
 649                goto unlock_fail;
 650        }
 651
 652        ret = mptcp_nl_fill_addr(msg, entry);
 653        if (ret)
 654                goto unlock_fail;
 655
 656        genlmsg_end(msg, reply);
 657        ret = genlmsg_reply(msg, info);
 658        spin_unlock_bh(&pernet->lock);
 659        return ret;
 660
 661unlock_fail:
 662        spin_unlock_bh(&pernet->lock);
 663
 664fail:
 665        nlmsg_free(msg);
 666        return ret;
 667}
 668
 669static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg,
 670                                   struct netlink_callback *cb)
 671{
 672        struct net *net = sock_net(msg->sk);
 673        struct mptcp_pm_addr_entry *entry;
 674        struct pm_nl_pernet *pernet;
 675        int id = cb->args[0];
 676        void *hdr;
 677
 678        pernet = net_generic(net, pm_nl_pernet_id);
 679
 680        spin_lock_bh(&pernet->lock);
 681        list_for_each_entry(entry, &pernet->local_addr_list, list) {
 682                if (entry->addr.id <= id)
 683                        continue;
 684
 685                hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid,
 686                                  cb->nlh->nlmsg_seq, &mptcp_genl_family,
 687                                  NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR);
 688                if (!hdr)
 689                        break;
 690
 691                if (mptcp_nl_fill_addr(msg, entry) < 0) {
 692                        genlmsg_cancel(msg, hdr);
 693                        break;
 694                }
 695
 696                id = entry->addr.id;
 697                genlmsg_end(msg, hdr);
 698        }
 699        spin_unlock_bh(&pernet->lock);
 700
 701        cb->args[0] = id;
 702        return msg->len;
 703}
 704
 705static int parse_limit(struct genl_info *info, int id, unsigned int *limit)
 706{
 707        struct nlattr *attr = info->attrs[id];
 708
 709        if (!attr)
 710                return 0;
 711
 712        *limit = nla_get_u32(attr);
 713        if (*limit > MPTCP_PM_ADDR_MAX) {
 714                GENL_SET_ERR_MSG(info, "limit greater than maximum");
 715                return -EINVAL;
 716        }
 717        return 0;
 718}
 719
 720static int
 721mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info)
 722{
 723        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
 724        unsigned int rcv_addrs, subflows;
 725        int ret;
 726
 727        spin_lock_bh(&pernet->lock);
 728        rcv_addrs = pernet->add_addr_accept_max;
 729        ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs);
 730        if (ret)
 731                goto unlock;
 732
 733        subflows = pernet->subflows_max;
 734        ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows);
 735        if (ret)
 736                goto unlock;
 737
 738        WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs);
 739        WRITE_ONCE(pernet->subflows_max, subflows);
 740
 741unlock:
 742        spin_unlock_bh(&pernet->lock);
 743        return ret;
 744}
 745
 746static int
 747mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info)
 748{
 749        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
 750        struct sk_buff *msg;
 751        void *reply;
 752
 753        msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
 754        if (!msg)
 755                return -ENOMEM;
 756
 757        reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
 758                                  MPTCP_PM_CMD_GET_LIMITS);
 759        if (!reply)
 760                goto fail;
 761
 762        if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS,
 763                        READ_ONCE(pernet->add_addr_accept_max)))
 764                goto fail;
 765
 766        if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS,
 767                        READ_ONCE(pernet->subflows_max)))
 768                goto fail;
 769
 770        genlmsg_end(msg, reply);
 771        return genlmsg_reply(msg, info);
 772
 773fail:
 774        GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
 775        nlmsg_free(msg);
 776        return -EMSGSIZE;
 777}
 778
 779static struct genl_ops mptcp_pm_ops[] = {
 780        {
 781                .cmd    = MPTCP_PM_CMD_ADD_ADDR,
 782                .doit   = mptcp_nl_cmd_add_addr,
 783                .flags  = GENL_ADMIN_PERM,
 784        },
 785        {
 786                .cmd    = MPTCP_PM_CMD_DEL_ADDR,
 787                .doit   = mptcp_nl_cmd_del_addr,
 788                .flags  = GENL_ADMIN_PERM,
 789        },
 790        {
 791                .cmd    = MPTCP_PM_CMD_FLUSH_ADDRS,
 792                .doit   = mptcp_nl_cmd_flush_addrs,
 793                .flags  = GENL_ADMIN_PERM,
 794        },
 795        {
 796                .cmd    = MPTCP_PM_CMD_GET_ADDR,
 797                .doit   = mptcp_nl_cmd_get_addr,
 798                .dumpit   = mptcp_nl_cmd_dump_addrs,
 799        },
 800        {
 801                .cmd    = MPTCP_PM_CMD_SET_LIMITS,
 802                .doit   = mptcp_nl_cmd_set_limits,
 803                .flags  = GENL_ADMIN_PERM,
 804        },
 805        {
 806                .cmd    = MPTCP_PM_CMD_GET_LIMITS,
 807                .doit   = mptcp_nl_cmd_get_limits,
 808        },
 809};
 810
 811static struct genl_family mptcp_genl_family __ro_after_init = {
 812        .name           = MPTCP_PM_NAME,
 813        .version        = MPTCP_PM_VER,
 814        .maxattr        = MPTCP_PM_ATTR_MAX,
 815        .policy         = mptcp_pm_policy,
 816        .netnsok        = true,
 817        .module         = THIS_MODULE,
 818        .ops            = mptcp_pm_ops,
 819        .n_ops          = ARRAY_SIZE(mptcp_pm_ops),
 820        .mcgrps         = mptcp_pm_mcgrps,
 821        .n_mcgrps       = ARRAY_SIZE(mptcp_pm_mcgrps),
 822};
 823
 824static int __net_init pm_nl_init_net(struct net *net)
 825{
 826        struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
 827
 828        INIT_LIST_HEAD_RCU(&pernet->local_addr_list);
 829        __reset_counters(pernet);
 830        pernet->next_id = 1;
 831        spin_lock_init(&pernet->lock);
 832        return 0;
 833}
 834
 835static void __net_exit pm_nl_exit_net(struct list_head *net_list)
 836{
 837        struct net *net;
 838
 839        list_for_each_entry(net, net_list, exit_list) {
 840                /* net is removed from namespace list, can't race with
 841                 * other modifiers
 842                 */
 843                __flush_addrs(net_generic(net, pm_nl_pernet_id));
 844        }
 845}
 846
 847static struct pernet_operations mptcp_pm_pernet_ops = {
 848        .init = pm_nl_init_net,
 849        .exit_batch = pm_nl_exit_net,
 850        .id = &pm_nl_pernet_id,
 851        .size = sizeof(struct pm_nl_pernet),
 852};
 853
 854void mptcp_pm_nl_init(void)
 855{
 856        if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0)
 857                panic("Failed to register MPTCP PM pernet subsystem.\n");
 858
 859        if (genl_register_family(&mptcp_genl_family))
 860                panic("Failed to register MPTCP PM netlink family\n");
 861}
 862