linux/net/mptcp/pm_netlink.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Multipath TCP
   3 *
   4 * Copyright (c) 2020, Red Hat, Inc.
   5 */
   6
   7#define pr_fmt(fmt) "MPTCP: " fmt
   8
   9#include <linux/inet.h>
  10#include <linux/kernel.h>
  11#include <net/tcp.h>
  12#include <net/netns/generic.h>
  13#include <net/mptcp.h>
  14#include <net/genetlink.h>
  15#include <uapi/linux/mptcp.h>
  16
  17#include "protocol.h"
  18#include "mib.h"
  19
  20/* forward declaration */
  21static struct genl_family mptcp_genl_family;
  22
  23static int pm_nl_pernet_id;
  24
  25struct mptcp_pm_addr_entry {
  26        struct list_head        list;
  27        struct mptcp_addr_info  addr;
  28        u8                      flags;
  29        int                     ifindex;
  30        struct socket           *lsk;
  31};
  32
  33struct mptcp_pm_add_entry {
  34        struct list_head        list;
  35        struct mptcp_addr_info  addr;
  36        struct timer_list       add_timer;
  37        struct mptcp_sock       *sock;
  38        u8                      retrans_times;
  39};
  40
  41#define MAX_ADDR_ID             255
  42#define BITMAP_SZ DIV_ROUND_UP(MAX_ADDR_ID + 1, BITS_PER_LONG)
  43
  44struct pm_nl_pernet {
  45        /* protects pernet updates */
  46        spinlock_t              lock;
  47        struct list_head        local_addr_list;
  48        unsigned int            addrs;
  49        unsigned int            stale_loss_cnt;
  50        unsigned int            add_addr_signal_max;
  51        unsigned int            add_addr_accept_max;
  52        unsigned int            local_addr_max;
  53        unsigned int            subflows_max;
  54        unsigned int            next_id;
  55        unsigned long           id_bitmap[BITMAP_SZ];
  56};
  57
  58#define MPTCP_PM_ADDR_MAX       8
  59#define ADD_ADDR_RETRANS_MAX    3
  60
  61static bool addresses_equal(const struct mptcp_addr_info *a,
  62                            struct mptcp_addr_info *b, bool use_port)
  63{
  64        bool addr_equals = false;
  65
  66        if (a->family == b->family) {
  67                if (a->family == AF_INET)
  68                        addr_equals = a->addr.s_addr == b->addr.s_addr;
  69#if IS_ENABLED(CONFIG_MPTCP_IPV6)
  70                else
  71                        addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
  72        } else if (a->family == AF_INET) {
  73                if (ipv6_addr_v4mapped(&b->addr6))
  74                        addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3];
  75        } else if (b->family == AF_INET) {
  76                if (ipv6_addr_v4mapped(&a->addr6))
  77                        addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr;
  78#endif
  79        }
  80
  81        if (!addr_equals)
  82                return false;
  83        if (!use_port)
  84                return true;
  85
  86        return a->port == b->port;
  87}
  88
  89static bool address_zero(const struct mptcp_addr_info *addr)
  90{
  91        struct mptcp_addr_info zero;
  92
  93        memset(&zero, 0, sizeof(zero));
  94        zero.family = addr->family;
  95
  96        return addresses_equal(addr, &zero, true);
  97}
  98
  99static void local_address(const struct sock_common *skc,
 100                          struct mptcp_addr_info *addr)
 101{
 102        addr->family = skc->skc_family;
 103        addr->port = htons(skc->skc_num);
 104        if (addr->family == AF_INET)
 105                addr->addr.s_addr = skc->skc_rcv_saddr;
 106#if IS_ENABLED(CONFIG_MPTCP_IPV6)
 107        else if (addr->family == AF_INET6)
 108                addr->addr6 = skc->skc_v6_rcv_saddr;
 109#endif
 110}
 111
 112static void remote_address(const struct sock_common *skc,
 113                           struct mptcp_addr_info *addr)
 114{
 115        addr->family = skc->skc_family;
 116        addr->port = skc->skc_dport;
 117        if (addr->family == AF_INET)
 118                addr->addr.s_addr = skc->skc_daddr;
 119#if IS_ENABLED(CONFIG_MPTCP_IPV6)
 120        else if (addr->family == AF_INET6)
 121                addr->addr6 = skc->skc_v6_daddr;
 122#endif
 123}
 124
 125static bool lookup_subflow_by_saddr(const struct list_head *list,
 126                                    struct mptcp_addr_info *saddr)
 127{
 128        struct mptcp_subflow_context *subflow;
 129        struct mptcp_addr_info cur;
 130        struct sock_common *skc;
 131
 132        list_for_each_entry(subflow, list, node) {
 133                skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
 134
 135                local_address(skc, &cur);
 136                if (addresses_equal(&cur, saddr, saddr->port))
 137                        return true;
 138        }
 139
 140        return false;
 141}
 142
 143static bool lookup_subflow_by_daddr(const struct list_head *list,
 144                                    struct mptcp_addr_info *daddr)
 145{
 146        struct mptcp_subflow_context *subflow;
 147        struct mptcp_addr_info cur;
 148        struct sock_common *skc;
 149
 150        list_for_each_entry(subflow, list, node) {
 151                skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
 152
 153                remote_address(skc, &cur);
 154                if (addresses_equal(&cur, daddr, daddr->port))
 155                        return true;
 156        }
 157
 158        return false;
 159}
 160
 161static struct mptcp_pm_addr_entry *
 162select_local_address(const struct pm_nl_pernet *pernet,
 163                     struct mptcp_sock *msk)
 164{
 165        struct mptcp_pm_addr_entry *entry, *ret = NULL;
 166        struct sock *sk = (struct sock *)msk;
 167
 168        msk_owned_by_me(msk);
 169
 170        rcu_read_lock();
 171        __mptcp_flush_join_list(msk);
 172        list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
 173                if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
 174                        continue;
 175
 176                if (entry->addr.family != sk->sk_family) {
 177#if IS_ENABLED(CONFIG_MPTCP_IPV6)
 178                        if ((entry->addr.family == AF_INET &&
 179                             !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) ||
 180                            (sk->sk_family == AF_INET &&
 181                             !ipv6_addr_v4mapped(&entry->addr.addr6)))
 182#endif
 183                                continue;
 184                }
 185
 186                /* avoid any address already in use by subflows and
 187                 * pending join
 188                 */
 189                if (!lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) {
 190                        ret = entry;
 191                        break;
 192                }
 193        }
 194        rcu_read_unlock();
 195        return ret;
 196}
 197
 198static struct mptcp_pm_addr_entry *
 199select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos)
 200{
 201        struct mptcp_pm_addr_entry *entry, *ret = NULL;
 202        int i = 0;
 203
 204        rcu_read_lock();
 205        /* do not keep any additional per socket state, just signal
 206         * the address list in order.
 207         * Note: removal from the local address list during the msk life-cycle
 208         * can lead to additional addresses not being announced.
 209         */
 210        list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
 211                if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
 212                        continue;
 213                if (i++ == pos) {
 214                        ret = entry;
 215                        break;
 216                }
 217        }
 218        rcu_read_unlock();
 219        return ret;
 220}
 221
 222unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk)
 223{
 224        struct pm_nl_pernet *pernet;
 225
 226        pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
 227        return READ_ONCE(pernet->add_addr_signal_max);
 228}
 229EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max);
 230
 231unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk)
 232{
 233        struct pm_nl_pernet *pernet;
 234
 235        pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
 236        return READ_ONCE(pernet->add_addr_accept_max);
 237}
 238EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max);
 239
 240unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk)
 241{
 242        struct pm_nl_pernet *pernet;
 243
 244        pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
 245        return READ_ONCE(pernet->subflows_max);
 246}
 247EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max);
 248
 249unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk)
 250{
 251        struct pm_nl_pernet *pernet;
 252
 253        pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
 254        return READ_ONCE(pernet->local_addr_max);
 255}
 256EXPORT_SYMBOL_GPL(mptcp_pm_get_local_addr_max);
 257
 258static void check_work_pending(struct mptcp_sock *msk)
 259{
 260        if (msk->pm.add_addr_signaled == mptcp_pm_get_add_addr_signal_max(msk) &&
 261            (msk->pm.local_addr_used == mptcp_pm_get_local_addr_max(msk) ||
 262             msk->pm.subflows == mptcp_pm_get_subflows_max(msk)))
 263                WRITE_ONCE(msk->pm.work_pending, false);
 264}
 265
 266struct mptcp_pm_add_entry *
 267mptcp_lookup_anno_list_by_saddr(struct mptcp_sock *msk,
 268                                struct mptcp_addr_info *addr)
 269{
 270        struct mptcp_pm_add_entry *entry;
 271
 272        lockdep_assert_held(&msk->pm.lock);
 273
 274        list_for_each_entry(entry, &msk->pm.anno_list, list) {
 275                if (addresses_equal(&entry->addr, addr, true))
 276                        return entry;
 277        }
 278
 279        return NULL;
 280}
 281
 282bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk)
 283{
 284        struct mptcp_pm_add_entry *entry;
 285        struct mptcp_addr_info saddr;
 286        bool ret = false;
 287
 288        local_address((struct sock_common *)sk, &saddr);
 289
 290        spin_lock_bh(&msk->pm.lock);
 291        list_for_each_entry(entry, &msk->pm.anno_list, list) {
 292                if (addresses_equal(&entry->addr, &saddr, true)) {
 293                        ret = true;
 294                        goto out;
 295                }
 296        }
 297
 298out:
 299        spin_unlock_bh(&msk->pm.lock);
 300        return ret;
 301}
 302
 303static void mptcp_pm_add_timer(struct timer_list *timer)
 304{
 305        struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer);
 306        struct mptcp_sock *msk = entry->sock;
 307        struct sock *sk = (struct sock *)msk;
 308
 309        pr_debug("msk=%p", msk);
 310
 311        if (!msk)
 312                return;
 313
 314        if (inet_sk_state_load(sk) == TCP_CLOSE)
 315                return;
 316
 317        if (!entry->addr.id)
 318                return;
 319
 320        if (mptcp_pm_should_add_signal_addr(msk)) {
 321                sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8);
 322                goto out;
 323        }
 324
 325        spin_lock_bh(&msk->pm.lock);
 326
 327        if (!mptcp_pm_should_add_signal_addr(msk)) {
 328                pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id);
 329                mptcp_pm_announce_addr(msk, &entry->addr, false);
 330                mptcp_pm_add_addr_send_ack(msk);
 331                entry->retrans_times++;
 332        }
 333
 334        if (entry->retrans_times < ADD_ADDR_RETRANS_MAX)
 335                sk_reset_timer(sk, timer,
 336                               jiffies + mptcp_get_add_addr_timeout(sock_net(sk)));
 337
 338        spin_unlock_bh(&msk->pm.lock);
 339
 340        if (entry->retrans_times == ADD_ADDR_RETRANS_MAX)
 341                mptcp_pm_subflow_established(msk);
 342
 343out:
 344        __sock_put(sk);
 345}
 346
 347struct mptcp_pm_add_entry *
 348mptcp_pm_del_add_timer(struct mptcp_sock *msk,
 349                       struct mptcp_addr_info *addr, bool check_id)
 350{
 351        struct mptcp_pm_add_entry *entry;
 352        struct sock *sk = (struct sock *)msk;
 353
 354        spin_lock_bh(&msk->pm.lock);
 355        entry = mptcp_lookup_anno_list_by_saddr(msk, addr);
 356        if (entry && (!check_id || entry->addr.id == addr->id))
 357                entry->retrans_times = ADD_ADDR_RETRANS_MAX;
 358        spin_unlock_bh(&msk->pm.lock);
 359
 360        if (entry && (!check_id || entry->addr.id == addr->id))
 361                sk_stop_timer_sync(sk, &entry->add_timer);
 362
 363        return entry;
 364}
 365
 366static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk,
 367                                     struct mptcp_pm_addr_entry *entry)
 368{
 369        struct mptcp_pm_add_entry *add_entry = NULL;
 370        struct sock *sk = (struct sock *)msk;
 371        struct net *net = sock_net(sk);
 372
 373        lockdep_assert_held(&msk->pm.lock);
 374
 375        if (mptcp_lookup_anno_list_by_saddr(msk, &entry->addr))
 376                return false;
 377
 378        add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC);
 379        if (!add_entry)
 380                return false;
 381
 382        list_add(&add_entry->list, &msk->pm.anno_list);
 383
 384        add_entry->addr = entry->addr;
 385        add_entry->sock = msk;
 386        add_entry->retrans_times = 0;
 387
 388        timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0);
 389        sk_reset_timer(sk, &add_entry->add_timer,
 390                       jiffies + mptcp_get_add_addr_timeout(net));
 391
 392        return true;
 393}
 394
 395void mptcp_pm_free_anno_list(struct mptcp_sock *msk)
 396{
 397        struct mptcp_pm_add_entry *entry, *tmp;
 398        struct sock *sk = (struct sock *)msk;
 399        LIST_HEAD(free_list);
 400
 401        pr_debug("msk=%p", msk);
 402
 403        spin_lock_bh(&msk->pm.lock);
 404        list_splice_init(&msk->pm.anno_list, &free_list);
 405        spin_unlock_bh(&msk->pm.lock);
 406
 407        list_for_each_entry_safe(entry, tmp, &free_list, list) {
 408                sk_stop_timer_sync(sk, &entry->add_timer);
 409                kfree(entry);
 410        }
 411}
 412
 413static bool lookup_address_in_vec(struct mptcp_addr_info *addrs, unsigned int nr,
 414                                  struct mptcp_addr_info *addr)
 415{
 416        int i;
 417
 418        for (i = 0; i < nr; i++) {
 419                if (addresses_equal(&addrs[i], addr, addr->port))
 420                        return true;
 421        }
 422
 423        return false;
 424}
 425
 426/* Fill all the remote addresses into the array addrs[],
 427 * and return the array size.
 428 */
 429static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk, bool fullmesh,
 430                                              struct mptcp_addr_info *addrs)
 431{
 432        struct sock *sk = (struct sock *)msk, *ssk;
 433        struct mptcp_subflow_context *subflow;
 434        struct mptcp_addr_info remote = { 0 };
 435        unsigned int subflows_max;
 436        int i = 0;
 437
 438        subflows_max = mptcp_pm_get_subflows_max(msk);
 439
 440        /* Non-fullmesh endpoint, fill in the single entry
 441         * corresponding to the primary MPC subflow remote address
 442         */
 443        if (!fullmesh) {
 444                remote_address((struct sock_common *)sk, &remote);
 445                msk->pm.subflows++;
 446                addrs[i++] = remote;
 447        } else {
 448                mptcp_for_each_subflow(msk, subflow) {
 449                        ssk = mptcp_subflow_tcp_sock(subflow);
 450                        remote_address((struct sock_common *)ssk, &remote);
 451                        if (!lookup_address_in_vec(addrs, i, &remote) &&
 452                            msk->pm.subflows < subflows_max) {
 453                                msk->pm.subflows++;
 454                                addrs[i++] = remote;
 455                        }
 456                }
 457        }
 458
 459        return i;
 460}
 461
 462static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
 463{
 464        struct sock *sk = (struct sock *)msk;
 465        struct mptcp_pm_addr_entry *local;
 466        unsigned int add_addr_signal_max;
 467        unsigned int local_addr_max;
 468        struct pm_nl_pernet *pernet;
 469        unsigned int subflows_max;
 470
 471        pernet = net_generic(sock_net(sk), pm_nl_pernet_id);
 472
 473        add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk);
 474        local_addr_max = mptcp_pm_get_local_addr_max(msk);
 475        subflows_max = mptcp_pm_get_subflows_max(msk);
 476
 477        pr_debug("local %d:%d signal %d:%d subflows %d:%d\n",
 478                 msk->pm.local_addr_used, local_addr_max,
 479                 msk->pm.add_addr_signaled, add_addr_signal_max,
 480                 msk->pm.subflows, subflows_max);
 481
 482        /* check first for announce */
 483        if (msk->pm.add_addr_signaled < add_addr_signal_max) {
 484                local = select_signal_address(pernet,
 485                                              msk->pm.add_addr_signaled);
 486
 487                if (local) {
 488                        if (mptcp_pm_alloc_anno_list(msk, local)) {
 489                                msk->pm.add_addr_signaled++;
 490                                mptcp_pm_announce_addr(msk, &local->addr, false);
 491                                mptcp_pm_nl_addr_send_ack(msk);
 492                        }
 493                } else {
 494                        /* pick failed, avoid fourther attempts later */
 495                        msk->pm.local_addr_used = add_addr_signal_max;
 496                }
 497
 498                check_work_pending(msk);
 499        }
 500
 501        /* check if should create a new subflow */
 502        if (msk->pm.local_addr_used < local_addr_max &&
 503            msk->pm.subflows < subflows_max &&
 504            !READ_ONCE(msk->pm.remote_deny_join_id0)) {
 505                local = select_local_address(pernet, msk);
 506                if (local) {
 507                        bool fullmesh = !!(local->flags & MPTCP_PM_ADDR_FLAG_FULLMESH);
 508                        struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX];
 509                        int i, nr;
 510
 511                        msk->pm.local_addr_used++;
 512                        check_work_pending(msk);
 513                        nr = fill_remote_addresses_vec(msk, fullmesh, addrs);
 514                        spin_unlock_bh(&msk->pm.lock);
 515                        for (i = 0; i < nr; i++)
 516                                __mptcp_subflow_connect(sk, &local->addr, &addrs[i]);
 517                        spin_lock_bh(&msk->pm.lock);
 518                        return;
 519                }
 520
 521                /* lookup failed, avoid fourther attempts later */
 522                msk->pm.local_addr_used = local_addr_max;
 523                check_work_pending(msk);
 524        }
 525}
 526
 527static void mptcp_pm_nl_fully_established(struct mptcp_sock *msk)
 528{
 529        mptcp_pm_create_subflow_or_signal_addr(msk);
 530}
 531
 532static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
 533{
 534        mptcp_pm_create_subflow_or_signal_addr(msk);
 535}
 536
 537/* Fill all the local addresses into the array addrs[],
 538 * and return the array size.
 539 */
 540static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk,
 541                                             struct mptcp_addr_info *addrs)
 542{
 543        struct sock *sk = (struct sock *)msk;
 544        struct mptcp_pm_addr_entry *entry;
 545        struct mptcp_addr_info local;
 546        struct pm_nl_pernet *pernet;
 547        unsigned int subflows_max;
 548        int i = 0;
 549
 550        pernet = net_generic(sock_net(sk), pm_nl_pernet_id);
 551        subflows_max = mptcp_pm_get_subflows_max(msk);
 552
 553        rcu_read_lock();
 554        __mptcp_flush_join_list(msk);
 555        list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
 556                if (!(entry->flags & MPTCP_PM_ADDR_FLAG_FULLMESH))
 557                        continue;
 558
 559                if (entry->addr.family != sk->sk_family) {
 560#if IS_ENABLED(CONFIG_MPTCP_IPV6)
 561                        if ((entry->addr.family == AF_INET &&
 562                             !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) ||
 563                            (sk->sk_family == AF_INET &&
 564                             !ipv6_addr_v4mapped(&entry->addr.addr6)))
 565#endif
 566                                continue;
 567                }
 568
 569                if (msk->pm.subflows < subflows_max) {
 570                        msk->pm.subflows++;
 571                        addrs[i++] = entry->addr;
 572                }
 573        }
 574        rcu_read_unlock();
 575
 576        /* If the array is empty, fill in the single
 577         * 'IPADDRANY' local address
 578         */
 579        if (!i) {
 580                memset(&local, 0, sizeof(local));
 581                local.family = msk->pm.remote.family;
 582
 583                msk->pm.subflows++;
 584                addrs[i++] = local;
 585        }
 586
 587        return i;
 588}
 589
 590static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
 591{
 592        struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX];
 593        struct sock *sk = (struct sock *)msk;
 594        unsigned int add_addr_accept_max;
 595        struct mptcp_addr_info remote;
 596        unsigned int subflows_max;
 597        int i, nr;
 598
 599        add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk);
 600        subflows_max = mptcp_pm_get_subflows_max(msk);
 601
 602        pr_debug("accepted %d:%d remote family %d",
 603                 msk->pm.add_addr_accepted, add_addr_accept_max,
 604                 msk->pm.remote.family);
 605
 606        if (lookup_subflow_by_daddr(&msk->conn_list, &msk->pm.remote))
 607                goto add_addr_echo;
 608
 609        /* connect to the specified remote address, using whatever
 610         * local address the routing configuration will pick.
 611         */
 612        remote = msk->pm.remote;
 613        if (!remote.port)
 614                remote.port = sk->sk_dport;
 615        nr = fill_local_addresses_vec(msk, addrs);
 616
 617        msk->pm.add_addr_accepted++;
 618        if (msk->pm.add_addr_accepted >= add_addr_accept_max ||
 619            msk->pm.subflows >= subflows_max)
 620                WRITE_ONCE(msk->pm.accept_addr, false);
 621
 622        spin_unlock_bh(&msk->pm.lock);
 623        for (i = 0; i < nr; i++)
 624                __mptcp_subflow_connect(sk, &addrs[i], &remote);
 625        spin_lock_bh(&msk->pm.lock);
 626
 627add_addr_echo:
 628        mptcp_pm_announce_addr(msk, &msk->pm.remote, true);
 629        mptcp_pm_nl_addr_send_ack(msk);
 630}
 631
 632void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk)
 633{
 634        struct mptcp_subflow_context *subflow;
 635
 636        msk_owned_by_me(msk);
 637        lockdep_assert_held(&msk->pm.lock);
 638
 639        if (!mptcp_pm_should_add_signal(msk) &&
 640            !mptcp_pm_should_rm_signal(msk))
 641                return;
 642
 643        __mptcp_flush_join_list(msk);
 644        subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node);
 645        if (subflow) {
 646                struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
 647
 648                spin_unlock_bh(&msk->pm.lock);
 649                pr_debug("send ack for %s",
 650                         mptcp_pm_should_add_signal(msk) ? "add_addr" : "rm_addr");
 651
 652                mptcp_subflow_send_ack(ssk);
 653                spin_lock_bh(&msk->pm.lock);
 654        }
 655}
 656
 657int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
 658                                 struct mptcp_addr_info *addr,
 659                                 u8 bkup)
 660{
 661        struct mptcp_subflow_context *subflow;
 662
 663        pr_debug("bkup=%d", bkup);
 664
 665        mptcp_for_each_subflow(msk, subflow) {
 666                struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
 667                struct sock *sk = (struct sock *)msk;
 668                struct mptcp_addr_info local;
 669
 670                local_address((struct sock_common *)ssk, &local);
 671                if (!addresses_equal(&local, addr, addr->port))
 672                        continue;
 673
 674                subflow->backup = bkup;
 675                subflow->send_mp_prio = 1;
 676                subflow->request_bkup = bkup;
 677                __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPPRIOTX);
 678
 679                spin_unlock_bh(&msk->pm.lock);
 680                pr_debug("send ack for mp_prio");
 681                mptcp_subflow_send_ack(ssk);
 682                spin_lock_bh(&msk->pm.lock);
 683
 684                return 0;
 685        }
 686
 687        return -EINVAL;
 688}
 689
 690static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
 691                                           const struct mptcp_rm_list *rm_list,
 692                                           enum linux_mptcp_mib_field rm_type)
 693{
 694        struct mptcp_subflow_context *subflow, *tmp;
 695        struct sock *sk = (struct sock *)msk;
 696        u8 i;
 697
 698        pr_debug("%s rm_list_nr %d",
 699                 rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow", rm_list->nr);
 700
 701        msk_owned_by_me(msk);
 702
 703        if (!rm_list->nr)
 704                return;
 705
 706        if (list_empty(&msk->conn_list))
 707                return;
 708
 709        for (i = 0; i < rm_list->nr; i++) {
 710                list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
 711                        struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
 712                        int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
 713                        u8 id = subflow->local_id;
 714
 715                        if (rm_type == MPTCP_MIB_RMADDR)
 716                                id = subflow->remote_id;
 717
 718                        if (rm_list->ids[i] != id)
 719                                continue;
 720
 721                        pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u",
 722                                 rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow",
 723                                 i, rm_list->ids[i], subflow->local_id, subflow->remote_id);
 724                        spin_unlock_bh(&msk->pm.lock);
 725                        mptcp_subflow_shutdown(sk, ssk, how);
 726                        mptcp_close_ssk(sk, ssk, subflow);
 727                        spin_lock_bh(&msk->pm.lock);
 728
 729                        if (rm_type == MPTCP_MIB_RMADDR) {
 730                                msk->pm.add_addr_accepted--;
 731                                WRITE_ONCE(msk->pm.accept_addr, true);
 732                        } else if (rm_type == MPTCP_MIB_RMSUBFLOW) {
 733                                msk->pm.local_addr_used--;
 734                        }
 735                        msk->pm.subflows--;
 736                        __MPTCP_INC_STATS(sock_net(sk), rm_type);
 737                }
 738        }
 739}
 740
 741static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
 742{
 743        mptcp_pm_nl_rm_addr_or_subflow(msk, &msk->pm.rm_list_rx, MPTCP_MIB_RMADDR);
 744}
 745
 746void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk,
 747                                     const struct mptcp_rm_list *rm_list)
 748{
 749        mptcp_pm_nl_rm_addr_or_subflow(msk, rm_list, MPTCP_MIB_RMSUBFLOW);
 750}
 751
 752void mptcp_pm_nl_work(struct mptcp_sock *msk)
 753{
 754        struct mptcp_pm_data *pm = &msk->pm;
 755
 756        msk_owned_by_me(msk);
 757
 758        spin_lock_bh(&msk->pm.lock);
 759
 760        pr_debug("msk=%p status=%x", msk, pm->status);
 761        if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) {
 762                pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED);
 763                mptcp_pm_nl_add_addr_received(msk);
 764        }
 765        if (pm->status & BIT(MPTCP_PM_ADD_ADDR_SEND_ACK)) {
 766                pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_SEND_ACK);
 767                mptcp_pm_nl_addr_send_ack(msk);
 768        }
 769        if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) {
 770                pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED);
 771                mptcp_pm_nl_rm_addr_received(msk);
 772        }
 773        if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) {
 774                pm->status &= ~BIT(MPTCP_PM_ESTABLISHED);
 775                mptcp_pm_nl_fully_established(msk);
 776        }
 777        if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) {
 778                pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED);
 779                mptcp_pm_nl_subflow_established(msk);
 780        }
 781
 782        spin_unlock_bh(&msk->pm.lock);
 783}
 784
 785static bool address_use_port(struct mptcp_pm_addr_entry *entry)
 786{
 787        return (entry->flags &
 788                (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) ==
 789                MPTCP_PM_ADDR_FLAG_SIGNAL;
 790}
 791
 792static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
 793                                             struct mptcp_pm_addr_entry *entry)
 794{
 795        struct mptcp_pm_addr_entry *cur;
 796        unsigned int addr_max;
 797        int ret = -EINVAL;
 798
 799        spin_lock_bh(&pernet->lock);
 800        /* to keep the code simple, don't do IDR-like allocation for address ID,
 801         * just bail when we exceed limits
 802         */
 803        if (pernet->next_id == MAX_ADDR_ID)
 804                pernet->next_id = 1;
 805        if (pernet->addrs >= MPTCP_PM_ADDR_MAX)
 806                goto out;
 807        if (test_bit(entry->addr.id, pernet->id_bitmap))
 808                goto out;
 809
 810        /* do not insert duplicate address, differentiate on port only
 811         * singled addresses
 812         */
 813        list_for_each_entry(cur, &pernet->local_addr_list, list) {
 814                if (addresses_equal(&cur->addr, &entry->addr,
 815                                    address_use_port(entry) &&
 816                                    address_use_port(cur)))
 817                        goto out;
 818        }
 819
 820        if (!entry->addr.id) {
 821find_next:
 822                entry->addr.id = find_next_zero_bit(pernet->id_bitmap,
 823                                                    MAX_ADDR_ID + 1,
 824                                                    pernet->next_id);
 825                if ((!entry->addr.id || entry->addr.id > MAX_ADDR_ID) &&
 826                    pernet->next_id != 1) {
 827                        pernet->next_id = 1;
 828                        goto find_next;
 829                }
 830        }
 831
 832        if (!entry->addr.id || entry->addr.id > MAX_ADDR_ID)
 833                goto out;
 834
 835        __set_bit(entry->addr.id, pernet->id_bitmap);
 836        if (entry->addr.id > pernet->next_id)
 837                pernet->next_id = entry->addr.id;
 838
 839        if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
 840                addr_max = pernet->add_addr_signal_max;
 841                WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1);
 842        }
 843        if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
 844                addr_max = pernet->local_addr_max;
 845                WRITE_ONCE(pernet->local_addr_max, addr_max + 1);
 846        }
 847
 848        pernet->addrs++;
 849        list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
 850        ret = entry->addr.id;
 851
 852out:
 853        spin_unlock_bh(&pernet->lock);
 854        return ret;
 855}
 856
 857static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
 858                                            struct mptcp_pm_addr_entry *entry)
 859{
 860        struct sockaddr_storage addr;
 861        struct mptcp_sock *msk;
 862        struct socket *ssock;
 863        int backlog = 1024;
 864        int err;
 865
 866        err = sock_create_kern(sock_net(sk), entry->addr.family,
 867                               SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk);
 868        if (err)
 869                return err;
 870
 871        msk = mptcp_sk(entry->lsk->sk);
 872        if (!msk) {
 873                err = -EINVAL;
 874                goto out;
 875        }
 876
 877        ssock = __mptcp_nmpc_socket(msk);
 878        if (!ssock) {
 879                err = -EINVAL;
 880                goto out;
 881        }
 882
 883        mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family);
 884        err = kernel_bind(ssock, (struct sockaddr *)&addr,
 885                          sizeof(struct sockaddr_in));
 886        if (err) {
 887                pr_warn("kernel_bind error, err=%d", err);
 888                goto out;
 889        }
 890
 891        err = kernel_listen(ssock, backlog);
 892        if (err) {
 893                pr_warn("kernel_listen error, err=%d", err);
 894                goto out;
 895        }
 896
 897        return 0;
 898
 899out:
 900        sock_release(entry->lsk);
 901        return err;
 902}
 903
 904int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
 905{
 906        struct mptcp_pm_addr_entry *entry;
 907        struct mptcp_addr_info skc_local;
 908        struct mptcp_addr_info msk_local;
 909        struct pm_nl_pernet *pernet;
 910        int ret = -1;
 911
 912        if (WARN_ON_ONCE(!msk))
 913                return -1;
 914
 915        /* The 0 ID mapping is defined by the first subflow, copied into the msk
 916         * addr
 917         */
 918        local_address((struct sock_common *)msk, &msk_local);
 919        local_address((struct sock_common *)skc, &skc_local);
 920        if (addresses_equal(&msk_local, &skc_local, false))
 921                return 0;
 922
 923        if (address_zero(&skc_local))
 924                return 0;
 925
 926        pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
 927
 928        rcu_read_lock();
 929        list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
 930                if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) {
 931                        ret = entry->addr.id;
 932                        break;
 933                }
 934        }
 935        rcu_read_unlock();
 936        if (ret >= 0)
 937                return ret;
 938
 939        /* address not found, add to local list */
 940        entry = kmalloc(sizeof(*entry), GFP_ATOMIC);
 941        if (!entry)
 942                return -ENOMEM;
 943
 944        entry->addr = skc_local;
 945        entry->addr.id = 0;
 946        entry->addr.port = 0;
 947        entry->ifindex = 0;
 948        entry->flags = 0;
 949        entry->lsk = NULL;
 950        ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
 951        if (ret < 0)
 952                kfree(entry);
 953
 954        return ret;
 955}
 956
 957void mptcp_pm_nl_data_init(struct mptcp_sock *msk)
 958{
 959        struct mptcp_pm_data *pm = &msk->pm;
 960        bool subflows;
 961
 962        subflows = !!mptcp_pm_get_subflows_max(msk);
 963        WRITE_ONCE(pm->work_pending, (!!mptcp_pm_get_local_addr_max(msk) && subflows) ||
 964                   !!mptcp_pm_get_add_addr_signal_max(msk));
 965        WRITE_ONCE(pm->accept_addr, !!mptcp_pm_get_add_addr_accept_max(msk) && subflows);
 966        WRITE_ONCE(pm->accept_subflow, subflows);
 967}
 968
 969#define MPTCP_PM_CMD_GRP_OFFSET       0
 970#define MPTCP_PM_EV_GRP_OFFSET        1
 971
 972static const struct genl_multicast_group mptcp_pm_mcgrps[] = {
 973        [MPTCP_PM_CMD_GRP_OFFSET]       = { .name = MPTCP_PM_CMD_GRP_NAME, },
 974        [MPTCP_PM_EV_GRP_OFFSET]        = { .name = MPTCP_PM_EV_GRP_NAME,
 975                                            .flags = GENL_UNS_ADMIN_PERM,
 976                                          },
 977};
 978
 979static const struct nla_policy
 980mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = {
 981        [MPTCP_PM_ADDR_ATTR_FAMILY]     = { .type       = NLA_U16,      },
 982        [MPTCP_PM_ADDR_ATTR_ID]         = { .type       = NLA_U8,       },
 983        [MPTCP_PM_ADDR_ATTR_ADDR4]      = { .type       = NLA_U32,      },
 984        [MPTCP_PM_ADDR_ATTR_ADDR6]      =
 985                NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
 986        [MPTCP_PM_ADDR_ATTR_PORT]       = { .type       = NLA_U16       },
 987        [MPTCP_PM_ADDR_ATTR_FLAGS]      = { .type       = NLA_U32       },
 988        [MPTCP_PM_ADDR_ATTR_IF_IDX]     = { .type       = NLA_S32       },
 989};
 990
 991static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = {
 992        [MPTCP_PM_ATTR_ADDR]            =
 993                                        NLA_POLICY_NESTED(mptcp_pm_addr_policy),
 994        [MPTCP_PM_ATTR_RCV_ADD_ADDRS]   = { .type       = NLA_U32,      },
 995        [MPTCP_PM_ATTR_SUBFLOWS]        = { .type       = NLA_U32,      },
 996};
 997
 998void mptcp_pm_nl_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ssk)
 999{
1000        struct mptcp_subflow_context *iter, *subflow = mptcp_subflow_ctx(ssk);
1001        struct sock *sk = (struct sock *)msk;
1002        unsigned int active_max_loss_cnt;
1003        struct net *net = sock_net(sk);
1004        unsigned int stale_loss_cnt;
1005        bool slow;
1006
1007        stale_loss_cnt = mptcp_stale_loss_cnt(net);
1008        if (subflow->stale || !stale_loss_cnt || subflow->stale_count <= stale_loss_cnt)
1009                return;
1010
1011        /* look for another available subflow not in loss state */
1012        active_max_loss_cnt = max_t(int, stale_loss_cnt - 1, 1);
1013        mptcp_for_each_subflow(msk, iter) {
1014                if (iter != subflow && mptcp_subflow_active(iter) &&
1015                    iter->stale_count < active_max_loss_cnt) {
1016                        /* we have some alternatives, try to mark this subflow as idle ...*/
1017                        slow = lock_sock_fast(ssk);
1018                        if (!tcp_rtx_and_write_queues_empty(ssk)) {
1019                                subflow->stale = 1;
1020                                __mptcp_retransmit_pending_data(sk);
1021                                MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_SUBFLOWSTALE);
1022                        }
1023                        unlock_sock_fast(ssk, slow);
1024
1025                        /* always try to push the pending data regarless of re-injections:
1026                         * we can possibly use backup subflows now, and subflow selection
1027                         * is cheap under the msk socket lock
1028                         */
1029                        __mptcp_push_pending(sk, 0);
1030                        return;
1031                }
1032        }
1033}
1034
1035static int mptcp_pm_family_to_addr(int family)
1036{
1037#if IS_ENABLED(CONFIG_MPTCP_IPV6)
1038        if (family == AF_INET6)
1039                return MPTCP_PM_ADDR_ATTR_ADDR6;
1040#endif
1041        return MPTCP_PM_ADDR_ATTR_ADDR4;
1042}
1043
1044static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info,
1045                               bool require_family,
1046                               struct mptcp_pm_addr_entry *entry)
1047{
1048        struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1];
1049        int err, addr_addr;
1050
1051        if (!attr) {
1052                GENL_SET_ERR_MSG(info, "missing address info");
1053                return -EINVAL;
1054        }
1055
1056        /* no validation needed - was already done via nested policy */
1057        err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr,
1058                                          mptcp_pm_addr_policy, info->extack);
1059        if (err)
1060                return err;
1061
1062        memset(entry, 0, sizeof(*entry));
1063        if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) {
1064                if (!require_family)
1065                        goto skip_family;
1066
1067                NL_SET_ERR_MSG_ATTR(info->extack, attr,
1068                                    "missing family");
1069                return -EINVAL;
1070        }
1071
1072        entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]);
1073        if (entry->addr.family != AF_INET
1074#if IS_ENABLED(CONFIG_MPTCP_IPV6)
1075            && entry->addr.family != AF_INET6
1076#endif
1077            ) {
1078                NL_SET_ERR_MSG_ATTR(info->extack, attr,
1079                                    "unknown address family");
1080                return -EINVAL;
1081        }
1082        addr_addr = mptcp_pm_family_to_addr(entry->addr.family);
1083        if (!tb[addr_addr]) {
1084                NL_SET_ERR_MSG_ATTR(info->extack, attr,
1085                                    "missing address data");
1086                return -EINVAL;
1087        }
1088
1089#if IS_ENABLED(CONFIG_MPTCP_IPV6)
1090        if (entry->addr.family == AF_INET6)
1091                entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]);
1092        else
1093#endif
1094                entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]);
1095
1096skip_family:
1097        if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) {
1098                u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]);
1099
1100                entry->ifindex = val;
1101        }
1102
1103        if (tb[MPTCP_PM_ADDR_ATTR_ID])
1104                entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]);
1105
1106        if (tb[MPTCP_PM_ADDR_ATTR_FLAGS])
1107                entry->flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
1108
1109        if (tb[MPTCP_PM_ADDR_ATTR_PORT]) {
1110                if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) {
1111                        NL_SET_ERR_MSG_ATTR(info->extack, attr,
1112                                            "flags must have signal when using port");
1113                        return -EINVAL;
1114                }
1115                entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT]));
1116        }
1117
1118        return 0;
1119}
1120
1121static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info)
1122{
1123        return net_generic(genl_info_net(info), pm_nl_pernet_id);
1124}
1125
1126static int mptcp_nl_add_subflow_or_signal_addr(struct net *net)
1127{
1128        struct mptcp_sock *msk;
1129        long s_slot = 0, s_num = 0;
1130
1131        while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1132                struct sock *sk = (struct sock *)msk;
1133
1134                if (!READ_ONCE(msk->fully_established))
1135                        goto next;
1136
1137                lock_sock(sk);
1138                spin_lock_bh(&msk->pm.lock);
1139                mptcp_pm_create_subflow_or_signal_addr(msk);
1140                spin_unlock_bh(&msk->pm.lock);
1141                release_sock(sk);
1142
1143next:
1144                sock_put(sk);
1145                cond_resched();
1146        }
1147
1148        return 0;
1149}
1150
1151static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
1152{
1153        struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1154        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1155        struct mptcp_pm_addr_entry addr, *entry;
1156        int ret;
1157
1158        ret = mptcp_pm_parse_addr(attr, info, true, &addr);
1159        if (ret < 0)
1160                return ret;
1161
1162        entry = kmalloc(sizeof(*entry), GFP_KERNEL);
1163        if (!entry) {
1164                GENL_SET_ERR_MSG(info, "can't allocate addr");
1165                return -ENOMEM;
1166        }
1167
1168        *entry = addr;
1169        if (entry->addr.port) {
1170                ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry);
1171                if (ret) {
1172                        GENL_SET_ERR_MSG(info, "create listen socket error");
1173                        kfree(entry);
1174                        return ret;
1175                }
1176        }
1177        ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
1178        if (ret < 0) {
1179                GENL_SET_ERR_MSG(info, "too many addresses or duplicate one");
1180                if (entry->lsk)
1181                        sock_release(entry->lsk);
1182                kfree(entry);
1183                return ret;
1184        }
1185
1186        mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk));
1187
1188        return 0;
1189}
1190
1191static struct mptcp_pm_addr_entry *
1192__lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
1193{
1194        struct mptcp_pm_addr_entry *entry;
1195
1196        list_for_each_entry(entry, &pernet->local_addr_list, list) {
1197                if (entry->addr.id == id)
1198                        return entry;
1199        }
1200        return NULL;
1201}
1202
1203int mptcp_pm_get_flags_and_ifindex_by_id(struct net *net, unsigned int id,
1204                                         u8 *flags, int *ifindex)
1205{
1206        struct mptcp_pm_addr_entry *entry;
1207
1208        *flags = 0;
1209        *ifindex = 0;
1210
1211        if (id) {
1212                rcu_read_lock();
1213                entry = __lookup_addr_by_id(net_generic(net, pm_nl_pernet_id), id);
1214                if (entry) {
1215                        *flags = entry->flags;
1216                        *ifindex = entry->ifindex;
1217                }
1218                rcu_read_unlock();
1219        }
1220
1221        return 0;
1222}
1223
1224static bool remove_anno_list_by_saddr(struct mptcp_sock *msk,
1225                                      struct mptcp_addr_info *addr)
1226{
1227        struct mptcp_pm_add_entry *entry;
1228
1229        entry = mptcp_pm_del_add_timer(msk, addr, false);
1230        if (entry) {
1231                list_del(&entry->list);
1232                kfree(entry);
1233                return true;
1234        }
1235
1236        return false;
1237}
1238
1239static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk,
1240                                      struct mptcp_addr_info *addr,
1241                                      bool force)
1242{
1243        struct mptcp_rm_list list = { .nr = 0 };
1244        bool ret;
1245
1246        list.ids[list.nr++] = addr->id;
1247
1248        ret = remove_anno_list_by_saddr(msk, addr);
1249        if (ret || force) {
1250                spin_lock_bh(&msk->pm.lock);
1251                mptcp_pm_remove_addr(msk, &list);
1252                spin_unlock_bh(&msk->pm.lock);
1253        }
1254        return ret;
1255}
1256
1257static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
1258                                                   struct mptcp_addr_info *addr)
1259{
1260        struct mptcp_sock *msk;
1261        long s_slot = 0, s_num = 0;
1262        struct mptcp_rm_list list = { .nr = 0 };
1263
1264        pr_debug("remove_id=%d", addr->id);
1265
1266        list.ids[list.nr++] = addr->id;
1267
1268        while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1269                struct sock *sk = (struct sock *)msk;
1270                bool remove_subflow;
1271
1272                if (list_empty(&msk->conn_list)) {
1273                        mptcp_pm_remove_anno_addr(msk, addr, false);
1274                        goto next;
1275                }
1276
1277                lock_sock(sk);
1278                remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr);
1279                mptcp_pm_remove_anno_addr(msk, addr, remove_subflow);
1280                if (remove_subflow)
1281                        mptcp_pm_remove_subflow(msk, &list);
1282                release_sock(sk);
1283
1284next:
1285                sock_put(sk);
1286                cond_resched();
1287        }
1288
1289        return 0;
1290}
1291
1292/* caller must ensure the RCU grace period is already elapsed */
1293static void __mptcp_pm_release_addr_entry(struct mptcp_pm_addr_entry *entry)
1294{
1295        if (entry->lsk)
1296                sock_release(entry->lsk);
1297        kfree(entry);
1298}
1299
1300static int mptcp_nl_remove_id_zero_address(struct net *net,
1301                                           struct mptcp_addr_info *addr)
1302{
1303        struct mptcp_rm_list list = { .nr = 0 };
1304        long s_slot = 0, s_num = 0;
1305        struct mptcp_sock *msk;
1306
1307        list.ids[list.nr++] = 0;
1308
1309        while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1310                struct sock *sk = (struct sock *)msk;
1311                struct mptcp_addr_info msk_local;
1312
1313                if (list_empty(&msk->conn_list))
1314                        goto next;
1315
1316                local_address((struct sock_common *)msk, &msk_local);
1317                if (!addresses_equal(&msk_local, addr, addr->port))
1318                        goto next;
1319
1320                lock_sock(sk);
1321                spin_lock_bh(&msk->pm.lock);
1322                mptcp_pm_remove_addr(msk, &list);
1323                mptcp_pm_nl_rm_subflow_received(msk, &list);
1324                spin_unlock_bh(&msk->pm.lock);
1325                release_sock(sk);
1326
1327next:
1328                sock_put(sk);
1329                cond_resched();
1330        }
1331
1332        return 0;
1333}
1334
1335static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
1336{
1337        struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1338        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1339        struct mptcp_pm_addr_entry addr, *entry;
1340        unsigned int addr_max;
1341        int ret;
1342
1343        ret = mptcp_pm_parse_addr(attr, info, false, &addr);
1344        if (ret < 0)
1345                return ret;
1346
1347        /* the zero id address is special: the first address used by the msk
1348         * always gets such an id, so different subflows can have different zero
1349         * id addresses. Additionally zero id is not accounted for in id_bitmap.
1350         * Let's use an 'mptcp_rm_list' instead of the common remove code.
1351         */
1352        if (addr.addr.id == 0)
1353                return mptcp_nl_remove_id_zero_address(sock_net(skb->sk), &addr.addr);
1354
1355        spin_lock_bh(&pernet->lock);
1356        entry = __lookup_addr_by_id(pernet, addr.addr.id);
1357        if (!entry) {
1358                GENL_SET_ERR_MSG(info, "address not found");
1359                spin_unlock_bh(&pernet->lock);
1360                return -EINVAL;
1361        }
1362        if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
1363                addr_max = pernet->add_addr_signal_max;
1364                WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1);
1365        }
1366        if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
1367                addr_max = pernet->local_addr_max;
1368                WRITE_ONCE(pernet->local_addr_max, addr_max - 1);
1369        }
1370
1371        pernet->addrs--;
1372        list_del_rcu(&entry->list);
1373        __clear_bit(entry->addr.id, pernet->id_bitmap);
1374        spin_unlock_bh(&pernet->lock);
1375
1376        mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr);
1377        synchronize_rcu();
1378        __mptcp_pm_release_addr_entry(entry);
1379
1380        return ret;
1381}
1382
1383static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk,
1384                                               struct list_head *rm_list)
1385{
1386        struct mptcp_rm_list alist = { .nr = 0 }, slist = { .nr = 0 };
1387        struct mptcp_pm_addr_entry *entry;
1388
1389        list_for_each_entry(entry, rm_list, list) {
1390                if (lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) &&
1391                    alist.nr < MPTCP_RM_IDS_MAX &&
1392                    slist.nr < MPTCP_RM_IDS_MAX) {
1393                        alist.ids[alist.nr++] = entry->addr.id;
1394                        slist.ids[slist.nr++] = entry->addr.id;
1395                } else if (remove_anno_list_by_saddr(msk, &entry->addr) &&
1396                         alist.nr < MPTCP_RM_IDS_MAX) {
1397                        alist.ids[alist.nr++] = entry->addr.id;
1398                }
1399        }
1400
1401        if (alist.nr) {
1402                spin_lock_bh(&msk->pm.lock);
1403                mptcp_pm_remove_addr(msk, &alist);
1404                spin_unlock_bh(&msk->pm.lock);
1405        }
1406        if (slist.nr)
1407                mptcp_pm_remove_subflow(msk, &slist);
1408}
1409
1410static void mptcp_nl_remove_addrs_list(struct net *net,
1411                                       struct list_head *rm_list)
1412{
1413        long s_slot = 0, s_num = 0;
1414        struct mptcp_sock *msk;
1415
1416        if (list_empty(rm_list))
1417                return;
1418
1419        while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1420                struct sock *sk = (struct sock *)msk;
1421
1422                lock_sock(sk);
1423                mptcp_pm_remove_addrs_and_subflows(msk, rm_list);
1424                release_sock(sk);
1425
1426                sock_put(sk);
1427                cond_resched();
1428        }
1429}
1430
1431/* caller must ensure the RCU grace period is already elapsed */
1432static void __flush_addrs(struct list_head *list)
1433{
1434        while (!list_empty(list)) {
1435                struct mptcp_pm_addr_entry *cur;
1436
1437                cur = list_entry(list->next,
1438                                 struct mptcp_pm_addr_entry, list);
1439                list_del_rcu(&cur->list);
1440                __mptcp_pm_release_addr_entry(cur);
1441        }
1442}
1443
1444static void __reset_counters(struct pm_nl_pernet *pernet)
1445{
1446        WRITE_ONCE(pernet->add_addr_signal_max, 0);
1447        WRITE_ONCE(pernet->add_addr_accept_max, 0);
1448        WRITE_ONCE(pernet->local_addr_max, 0);
1449        pernet->addrs = 0;
1450}
1451
1452static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info)
1453{
1454        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1455        LIST_HEAD(free_list);
1456
1457        spin_lock_bh(&pernet->lock);
1458        list_splice_init(&pernet->local_addr_list, &free_list);
1459        __reset_counters(pernet);
1460        pernet->next_id = 1;
1461        bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1);
1462        spin_unlock_bh(&pernet->lock);
1463        mptcp_nl_remove_addrs_list(sock_net(skb->sk), &free_list);
1464        synchronize_rcu();
1465        __flush_addrs(&free_list);
1466        return 0;
1467}
1468
1469static int mptcp_nl_fill_addr(struct sk_buff *skb,
1470                              struct mptcp_pm_addr_entry *entry)
1471{
1472        struct mptcp_addr_info *addr = &entry->addr;
1473        struct nlattr *attr;
1474
1475        attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR);
1476        if (!attr)
1477                return -EMSGSIZE;
1478
1479        if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family))
1480                goto nla_put_failure;
1481        if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port)))
1482                goto nla_put_failure;
1483        if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id))
1484                goto nla_put_failure;
1485        if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->flags))
1486                goto nla_put_failure;
1487        if (entry->ifindex &&
1488            nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->ifindex))
1489                goto nla_put_failure;
1490
1491        if (addr->family == AF_INET &&
1492            nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4,
1493                            addr->addr.s_addr))
1494                goto nla_put_failure;
1495#if IS_ENABLED(CONFIG_MPTCP_IPV6)
1496        else if (addr->family == AF_INET6 &&
1497                 nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6))
1498                goto nla_put_failure;
1499#endif
1500        nla_nest_end(skb, attr);
1501        return 0;
1502
1503nla_put_failure:
1504        nla_nest_cancel(skb, attr);
1505        return -EMSGSIZE;
1506}
1507
1508static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info)
1509{
1510        struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1511        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1512        struct mptcp_pm_addr_entry addr, *entry;
1513        struct sk_buff *msg;
1514        void *reply;
1515        int ret;
1516
1517        ret = mptcp_pm_parse_addr(attr, info, false, &addr);
1518        if (ret < 0)
1519                return ret;
1520
1521        msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1522        if (!msg)
1523                return -ENOMEM;
1524
1525        reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
1526                                  info->genlhdr->cmd);
1527        if (!reply) {
1528                GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
1529                ret = -EMSGSIZE;
1530                goto fail;
1531        }
1532
1533        spin_lock_bh(&pernet->lock);
1534        entry = __lookup_addr_by_id(pernet, addr.addr.id);
1535        if (!entry) {
1536                GENL_SET_ERR_MSG(info, "address not found");
1537                ret = -EINVAL;
1538                goto unlock_fail;
1539        }
1540
1541        ret = mptcp_nl_fill_addr(msg, entry);
1542        if (ret)
1543                goto unlock_fail;
1544
1545        genlmsg_end(msg, reply);
1546        ret = genlmsg_reply(msg, info);
1547        spin_unlock_bh(&pernet->lock);
1548        return ret;
1549
1550unlock_fail:
1551        spin_unlock_bh(&pernet->lock);
1552
1553fail:
1554        nlmsg_free(msg);
1555        return ret;
1556}
1557
1558static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg,
1559                                   struct netlink_callback *cb)
1560{
1561        struct net *net = sock_net(msg->sk);
1562        struct mptcp_pm_addr_entry *entry;
1563        struct pm_nl_pernet *pernet;
1564        int id = cb->args[0];
1565        void *hdr;
1566        int i;
1567
1568        pernet = net_generic(net, pm_nl_pernet_id);
1569
1570        spin_lock_bh(&pernet->lock);
1571        for (i = id; i < MAX_ADDR_ID + 1; i++) {
1572                if (test_bit(i, pernet->id_bitmap)) {
1573                        entry = __lookup_addr_by_id(pernet, i);
1574                        if (!entry)
1575                                break;
1576
1577                        if (entry->addr.id <= id)
1578                                continue;
1579
1580                        hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid,
1581                                          cb->nlh->nlmsg_seq, &mptcp_genl_family,
1582                                          NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR);
1583                        if (!hdr)
1584                                break;
1585
1586                        if (mptcp_nl_fill_addr(msg, entry) < 0) {
1587                                genlmsg_cancel(msg, hdr);
1588                                break;
1589                        }
1590
1591                        id = entry->addr.id;
1592                        genlmsg_end(msg, hdr);
1593                }
1594        }
1595        spin_unlock_bh(&pernet->lock);
1596
1597        cb->args[0] = id;
1598        return msg->len;
1599}
1600
1601static int parse_limit(struct genl_info *info, int id, unsigned int *limit)
1602{
1603        struct nlattr *attr = info->attrs[id];
1604
1605        if (!attr)
1606                return 0;
1607
1608        *limit = nla_get_u32(attr);
1609        if (*limit > MPTCP_PM_ADDR_MAX) {
1610                GENL_SET_ERR_MSG(info, "limit greater than maximum");
1611                return -EINVAL;
1612        }
1613        return 0;
1614}
1615
1616static int
1617mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info)
1618{
1619        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1620        unsigned int rcv_addrs, subflows;
1621        int ret;
1622
1623        spin_lock_bh(&pernet->lock);
1624        rcv_addrs = pernet->add_addr_accept_max;
1625        ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs);
1626        if (ret)
1627                goto unlock;
1628
1629        subflows = pernet->subflows_max;
1630        ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows);
1631        if (ret)
1632                goto unlock;
1633
1634        WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs);
1635        WRITE_ONCE(pernet->subflows_max, subflows);
1636
1637unlock:
1638        spin_unlock_bh(&pernet->lock);
1639        return ret;
1640}
1641
1642static int
1643mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info)
1644{
1645        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1646        struct sk_buff *msg;
1647        void *reply;
1648
1649        msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1650        if (!msg)
1651                return -ENOMEM;
1652
1653        reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
1654                                  MPTCP_PM_CMD_GET_LIMITS);
1655        if (!reply)
1656                goto fail;
1657
1658        if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS,
1659                        READ_ONCE(pernet->add_addr_accept_max)))
1660                goto fail;
1661
1662        if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS,
1663                        READ_ONCE(pernet->subflows_max)))
1664                goto fail;
1665
1666        genlmsg_end(msg, reply);
1667        return genlmsg_reply(msg, info);
1668
1669fail:
1670        GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
1671        nlmsg_free(msg);
1672        return -EMSGSIZE;
1673}
1674
1675static int mptcp_nl_addr_backup(struct net *net,
1676                                struct mptcp_addr_info *addr,
1677                                u8 bkup)
1678{
1679        long s_slot = 0, s_num = 0;
1680        struct mptcp_sock *msk;
1681        int ret = -EINVAL;
1682
1683        while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1684                struct sock *sk = (struct sock *)msk;
1685
1686                if (list_empty(&msk->conn_list))
1687                        goto next;
1688
1689                lock_sock(sk);
1690                spin_lock_bh(&msk->pm.lock);
1691                ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup);
1692                spin_unlock_bh(&msk->pm.lock);
1693                release_sock(sk);
1694
1695next:
1696                sock_put(sk);
1697                cond_resched();
1698        }
1699
1700        return ret;
1701}
1702
1703static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
1704{
1705        struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1706        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1707        struct mptcp_pm_addr_entry addr, *entry;
1708        struct net *net = sock_net(skb->sk);
1709        u8 bkup = 0;
1710        int ret;
1711
1712        ret = mptcp_pm_parse_addr(attr, info, true, &addr);
1713        if (ret < 0)
1714                return ret;
1715
1716        if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
1717                bkup = 1;
1718
1719        list_for_each_entry(entry, &pernet->local_addr_list, list) {
1720                if (addresses_equal(&entry->addr, &addr.addr, true)) {
1721                        mptcp_nl_addr_backup(net, &entry->addr, bkup);
1722
1723                        if (bkup)
1724                                entry->flags |= MPTCP_PM_ADDR_FLAG_BACKUP;
1725                        else
1726                                entry->flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP;
1727                }
1728        }
1729
1730        return 0;
1731}
1732
1733static void mptcp_nl_mcast_send(struct net *net, struct sk_buff *nlskb, gfp_t gfp)
1734{
1735        genlmsg_multicast_netns(&mptcp_genl_family, net,
1736                                nlskb, 0, MPTCP_PM_EV_GRP_OFFSET, gfp);
1737}
1738
1739static int mptcp_event_add_subflow(struct sk_buff *skb, const struct sock *ssk)
1740{
1741        const struct inet_sock *issk = inet_sk(ssk);
1742        const struct mptcp_subflow_context *sf;
1743
1744        if (nla_put_u16(skb, MPTCP_ATTR_FAMILY, ssk->sk_family))
1745                return -EMSGSIZE;
1746
1747        switch (ssk->sk_family) {
1748        case AF_INET:
1749                if (nla_put_in_addr(skb, MPTCP_ATTR_SADDR4, issk->inet_saddr))
1750                        return -EMSGSIZE;
1751                if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, issk->inet_daddr))
1752                        return -EMSGSIZE;
1753                break;
1754#if IS_ENABLED(CONFIG_MPTCP_IPV6)
1755        case AF_INET6: {
1756                const struct ipv6_pinfo *np = inet6_sk(ssk);
1757
1758                if (nla_put_in6_addr(skb, MPTCP_ATTR_SADDR6, &np->saddr))
1759                        return -EMSGSIZE;
1760                if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &ssk->sk_v6_daddr))
1761                        return -EMSGSIZE;
1762                break;
1763        }
1764#endif
1765        default:
1766                WARN_ON_ONCE(1);
1767                return -EMSGSIZE;
1768        }
1769
1770        if (nla_put_be16(skb, MPTCP_ATTR_SPORT, issk->inet_sport))
1771                return -EMSGSIZE;
1772        if (nla_put_be16(skb, MPTCP_ATTR_DPORT, issk->inet_dport))
1773                return -EMSGSIZE;
1774
1775        sf = mptcp_subflow_ctx(ssk);
1776        if (WARN_ON_ONCE(!sf))
1777                return -EINVAL;
1778
1779        if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id))
1780                return -EMSGSIZE;
1781
1782        if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id))
1783                return -EMSGSIZE;
1784
1785        return 0;
1786}
1787
1788static int mptcp_event_put_token_and_ssk(struct sk_buff *skb,
1789                                         const struct mptcp_sock *msk,
1790                                         const struct sock *ssk)
1791{
1792        const struct sock *sk = (const struct sock *)msk;
1793        const struct mptcp_subflow_context *sf;
1794        u8 sk_err;
1795
1796        if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
1797                return -EMSGSIZE;
1798
1799        if (mptcp_event_add_subflow(skb, ssk))
1800                return -EMSGSIZE;
1801
1802        sf = mptcp_subflow_ctx(ssk);
1803        if (WARN_ON_ONCE(!sf))
1804                return -EINVAL;
1805
1806        if (nla_put_u8(skb, MPTCP_ATTR_BACKUP, sf->backup))
1807                return -EMSGSIZE;
1808
1809        if (ssk->sk_bound_dev_if &&
1810            nla_put_s32(skb, MPTCP_ATTR_IF_IDX, ssk->sk_bound_dev_if))
1811                return -EMSGSIZE;
1812
1813        sk_err = ssk->sk_err;
1814        if (sk_err && sk->sk_state == TCP_ESTABLISHED &&
1815            nla_put_u8(skb, MPTCP_ATTR_ERROR, sk_err))
1816                return -EMSGSIZE;
1817
1818        return 0;
1819}
1820
1821static int mptcp_event_sub_established(struct sk_buff *skb,
1822                                       const struct mptcp_sock *msk,
1823                                       const struct sock *ssk)
1824{
1825        return mptcp_event_put_token_and_ssk(skb, msk, ssk);
1826}
1827
1828static int mptcp_event_sub_closed(struct sk_buff *skb,
1829                                  const struct mptcp_sock *msk,
1830                                  const struct sock *ssk)
1831{
1832        const struct mptcp_subflow_context *sf;
1833
1834        if (mptcp_event_put_token_and_ssk(skb, msk, ssk))
1835                return -EMSGSIZE;
1836
1837        sf = mptcp_subflow_ctx(ssk);
1838        if (!sf->reset_seen)
1839                return 0;
1840
1841        if (nla_put_u32(skb, MPTCP_ATTR_RESET_REASON, sf->reset_reason))
1842                return -EMSGSIZE;
1843
1844        if (nla_put_u32(skb, MPTCP_ATTR_RESET_FLAGS, sf->reset_transient))
1845                return -EMSGSIZE;
1846
1847        return 0;
1848}
1849
1850static int mptcp_event_created(struct sk_buff *skb,
1851                               const struct mptcp_sock *msk,
1852                               const struct sock *ssk)
1853{
1854        int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token);
1855
1856        if (err)
1857                return err;
1858
1859        return mptcp_event_add_subflow(skb, ssk);
1860}
1861
1862void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id)
1863{
1864        struct net *net = sock_net((const struct sock *)msk);
1865        struct nlmsghdr *nlh;
1866        struct sk_buff *skb;
1867
1868        if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
1869                return;
1870
1871        skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
1872        if (!skb)
1873                return;
1874
1875        nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, MPTCP_EVENT_REMOVED);
1876        if (!nlh)
1877                goto nla_put_failure;
1878
1879        if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
1880                goto nla_put_failure;
1881
1882        if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id))
1883                goto nla_put_failure;
1884
1885        genlmsg_end(skb, nlh);
1886        mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
1887        return;
1888
1889nla_put_failure:
1890        kfree_skb(skb);
1891}
1892
1893void mptcp_event_addr_announced(const struct mptcp_sock *msk,
1894                                const struct mptcp_addr_info *info)
1895{
1896        struct net *net = sock_net((const struct sock *)msk);
1897        struct nlmsghdr *nlh;
1898        struct sk_buff *skb;
1899
1900        if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
1901                return;
1902
1903        skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
1904        if (!skb)
1905                return;
1906
1907        nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0,
1908                          MPTCP_EVENT_ANNOUNCED);
1909        if (!nlh)
1910                goto nla_put_failure;
1911
1912        if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
1913                goto nla_put_failure;
1914
1915        if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id))
1916                goto nla_put_failure;
1917
1918        if (nla_put_be16(skb, MPTCP_ATTR_DPORT, info->port))
1919                goto nla_put_failure;
1920
1921        switch (info->family) {
1922        case AF_INET:
1923                if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, info->addr.s_addr))
1924                        goto nla_put_failure;
1925                break;
1926#if IS_ENABLED(CONFIG_MPTCP_IPV6)
1927        case AF_INET6:
1928                if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &info->addr6))
1929                        goto nla_put_failure;
1930                break;
1931#endif
1932        default:
1933                WARN_ON_ONCE(1);
1934                goto nla_put_failure;
1935        }
1936
1937        genlmsg_end(skb, nlh);
1938        mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
1939        return;
1940
1941nla_put_failure:
1942        kfree_skb(skb);
1943}
1944
1945void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk,
1946                 const struct sock *ssk, gfp_t gfp)
1947{
1948        struct net *net = sock_net((const struct sock *)msk);
1949        struct nlmsghdr *nlh;
1950        struct sk_buff *skb;
1951
1952        if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
1953                return;
1954
1955        skb = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp);
1956        if (!skb)
1957                return;
1958
1959        nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, type);
1960        if (!nlh)
1961                goto nla_put_failure;
1962
1963        switch (type) {
1964        case MPTCP_EVENT_UNSPEC:
1965                WARN_ON_ONCE(1);
1966                break;
1967        case MPTCP_EVENT_CREATED:
1968        case MPTCP_EVENT_ESTABLISHED:
1969                if (mptcp_event_created(skb, msk, ssk) < 0)
1970                        goto nla_put_failure;
1971                break;
1972        case MPTCP_EVENT_CLOSED:
1973                if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0)
1974                        goto nla_put_failure;
1975                break;
1976        case MPTCP_EVENT_ANNOUNCED:
1977        case MPTCP_EVENT_REMOVED:
1978                /* call mptcp_event_addr_announced()/removed instead */
1979                WARN_ON_ONCE(1);
1980                break;
1981        case MPTCP_EVENT_SUB_ESTABLISHED:
1982        case MPTCP_EVENT_SUB_PRIORITY:
1983                if (mptcp_event_sub_established(skb, msk, ssk) < 0)
1984                        goto nla_put_failure;
1985                break;
1986        case MPTCP_EVENT_SUB_CLOSED:
1987                if (mptcp_event_sub_closed(skb, msk, ssk) < 0)
1988                        goto nla_put_failure;
1989                break;
1990        }
1991
1992        genlmsg_end(skb, nlh);
1993        mptcp_nl_mcast_send(net, skb, gfp);
1994        return;
1995
1996nla_put_failure:
1997        kfree_skb(skb);
1998}
1999
2000static const struct genl_small_ops mptcp_pm_ops[] = {
2001        {
2002                .cmd    = MPTCP_PM_CMD_ADD_ADDR,
2003                .doit   = mptcp_nl_cmd_add_addr,
2004                .flags  = GENL_ADMIN_PERM,
2005        },
2006        {
2007                .cmd    = MPTCP_PM_CMD_DEL_ADDR,
2008                .doit   = mptcp_nl_cmd_del_addr,
2009                .flags  = GENL_ADMIN_PERM,
2010        },
2011        {
2012                .cmd    = MPTCP_PM_CMD_FLUSH_ADDRS,
2013                .doit   = mptcp_nl_cmd_flush_addrs,
2014                .flags  = GENL_ADMIN_PERM,
2015        },
2016        {
2017                .cmd    = MPTCP_PM_CMD_GET_ADDR,
2018                .doit   = mptcp_nl_cmd_get_addr,
2019                .dumpit   = mptcp_nl_cmd_dump_addrs,
2020        },
2021        {
2022                .cmd    = MPTCP_PM_CMD_SET_LIMITS,
2023                .doit   = mptcp_nl_cmd_set_limits,
2024                .flags  = GENL_ADMIN_PERM,
2025        },
2026        {
2027                .cmd    = MPTCP_PM_CMD_GET_LIMITS,
2028                .doit   = mptcp_nl_cmd_get_limits,
2029        },
2030        {
2031                .cmd    = MPTCP_PM_CMD_SET_FLAGS,
2032                .doit   = mptcp_nl_cmd_set_flags,
2033                .flags  = GENL_ADMIN_PERM,
2034        },
2035};
2036
2037static struct genl_family mptcp_genl_family __ro_after_init = {
2038        .name           = MPTCP_PM_NAME,
2039        .version        = MPTCP_PM_VER,
2040        .maxattr        = MPTCP_PM_ATTR_MAX,
2041        .policy         = mptcp_pm_policy,
2042        .netnsok        = true,
2043        .module         = THIS_MODULE,
2044        .small_ops      = mptcp_pm_ops,
2045        .n_small_ops    = ARRAY_SIZE(mptcp_pm_ops),
2046        .mcgrps         = mptcp_pm_mcgrps,
2047        .n_mcgrps       = ARRAY_SIZE(mptcp_pm_mcgrps),
2048};
2049
2050static int __net_init pm_nl_init_net(struct net *net)
2051{
2052        struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
2053
2054        INIT_LIST_HEAD_RCU(&pernet->local_addr_list);
2055        pernet->next_id = 1;
2056        pernet->stale_loss_cnt = 4;
2057        spin_lock_init(&pernet->lock);
2058
2059        /* No need to initialize other pernet fields, the struct is zeroed at
2060         * allocation time.
2061         */
2062
2063        return 0;
2064}
2065
2066static void __net_exit pm_nl_exit_net(struct list_head *net_list)
2067{
2068        struct net *net;
2069
2070        list_for_each_entry(net, net_list, exit_list) {
2071                struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
2072
2073                /* net is removed from namespace list, can't race with
2074                 * other modifiers, also netns core already waited for a
2075                 * RCU grace period.
2076                 */
2077                __flush_addrs(&pernet->local_addr_list);
2078        }
2079}
2080
2081static struct pernet_operations mptcp_pm_pernet_ops = {
2082        .init = pm_nl_init_net,
2083        .exit_batch = pm_nl_exit_net,
2084        .id = &pm_nl_pernet_id,
2085        .size = sizeof(struct pm_nl_pernet),
2086};
2087
2088void __init mptcp_pm_nl_init(void)
2089{
2090        if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0)
2091                panic("Failed to register MPTCP PM pernet subsystem.\n");
2092
2093        if (genl_register_family(&mptcp_genl_family))
2094                panic("Failed to register MPTCP PM netlink family\n");
2095}
2096