linux/net/bridge/br_multicast.c
<<
>>
Prefs
   1/*
   2 * Bridge multicast support.
   3 *
   4 * Copyright (c) 2010 Herbert Xu <herbert@gondor.apana.org.au>
   5 *
   6 * This program is free software; you can redistribute it and/or modify it
   7 * under the terms of the GNU General Public License as published by the Free
   8 * Software Foundation; either version 2 of the License, or (at your option)
   9 * any later version.
  10 *
  11 */
  12
  13#include <linux/err.h>
  14#include <linux/if_ether.h>
  15#include <linux/igmp.h>
  16#include <linux/jhash.h>
  17#include <linux/kernel.h>
  18#include <linux/log2.h>
  19#include <linux/netdevice.h>
  20#include <linux/netfilter_bridge.h>
  21#include <linux/random.h>
  22#include <linux/rculist.h>
  23#include <linux/skbuff.h>
  24#include <linux/slab.h>
  25#include <linux/timer.h>
  26#include <net/ip.h>
  27#if IS_ENABLED(CONFIG_IPV6)
  28#include <net/ipv6.h>
  29#include <net/mld.h>
  30#include <net/ip6_checksum.h>
  31#endif
  32
  33#include "br_private.h"
  34
  35static void br_multicast_start_querier(struct net_bridge *br);
  36unsigned int br_mdb_rehash_seq;
  37
  38static inline int br_ip_equal(const struct br_ip *a, const struct br_ip *b)
  39{
  40        if (a->proto != b->proto)
  41                return 0;
  42        if (a->vid != b->vid)
  43                return 0;
  44        switch (a->proto) {
  45        case htons(ETH_P_IP):
  46                return a->u.ip4 == b->u.ip4;
  47#if IS_ENABLED(CONFIG_IPV6)
  48        case htons(ETH_P_IPV6):
  49                return ipv6_addr_equal(&a->u.ip6, &b->u.ip6);
  50#endif
  51        }
  52        return 0;
  53}
  54
  55static inline int __br_ip4_hash(struct net_bridge_mdb_htable *mdb, __be32 ip,
  56                                __u16 vid)
  57{
  58        return jhash_2words((__force u32)ip, vid, mdb->secret) & (mdb->max - 1);
  59}
  60
  61#if IS_ENABLED(CONFIG_IPV6)
  62static inline int __br_ip6_hash(struct net_bridge_mdb_htable *mdb,
  63                                const struct in6_addr *ip,
  64                                __u16 vid)
  65{
  66        return jhash_2words(ipv6_addr_hash(ip), vid,
  67                            mdb->secret) & (mdb->max - 1);
  68}
  69#endif
  70
  71static inline int br_ip_hash(struct net_bridge_mdb_htable *mdb,
  72                             struct br_ip *ip)
  73{
  74        switch (ip->proto) {
  75        case htons(ETH_P_IP):
  76                return __br_ip4_hash(mdb, ip->u.ip4, ip->vid);
  77#if IS_ENABLED(CONFIG_IPV6)
  78        case htons(ETH_P_IPV6):
  79                return __br_ip6_hash(mdb, &ip->u.ip6, ip->vid);
  80#endif
  81        }
  82        return 0;
  83}
  84
  85static struct net_bridge_mdb_entry *__br_mdb_ip_get(
  86        struct net_bridge_mdb_htable *mdb, struct br_ip *dst, int hash)
  87{
  88        struct net_bridge_mdb_entry *mp;
  89
  90        hlist_for_each_entry_rcu(mp, &mdb->mhash[hash], hlist[mdb->ver]) {
  91                if (br_ip_equal(&mp->addr, dst))
  92                        return mp;
  93        }
  94
  95        return NULL;
  96}
  97
  98struct net_bridge_mdb_entry *br_mdb_ip_get(struct net_bridge_mdb_htable *mdb,
  99                                           struct br_ip *dst)
 100{
 101        if (!mdb)
 102                return NULL;
 103
 104        return __br_mdb_ip_get(mdb, dst, br_ip_hash(mdb, dst));
 105}
 106
 107static struct net_bridge_mdb_entry *br_mdb_ip4_get(
 108        struct net_bridge_mdb_htable *mdb, __be32 dst, __u16 vid)
 109{
 110        struct br_ip br_dst;
 111
 112        br_dst.u.ip4 = dst;
 113        br_dst.proto = htons(ETH_P_IP);
 114        br_dst.vid = vid;
 115
 116        return br_mdb_ip_get(mdb, &br_dst);
 117}
 118
 119#if IS_ENABLED(CONFIG_IPV6)
 120static struct net_bridge_mdb_entry *br_mdb_ip6_get(
 121        struct net_bridge_mdb_htable *mdb, const struct in6_addr *dst,
 122        __u16 vid)
 123{
 124        struct br_ip br_dst;
 125
 126        br_dst.u.ip6 = *dst;
 127        br_dst.proto = htons(ETH_P_IPV6);
 128        br_dst.vid = vid;
 129
 130        return br_mdb_ip_get(mdb, &br_dst);
 131}
 132#endif
 133
 134struct net_bridge_mdb_entry *br_mdb_get(struct net_bridge *br,
 135                                        struct sk_buff *skb, u16 vid)
 136{
 137        struct net_bridge_mdb_htable *mdb = rcu_dereference(br->mdb);
 138        struct br_ip ip;
 139
 140        if (br->multicast_disabled)
 141                return NULL;
 142
 143        if (BR_INPUT_SKB_CB(skb)->igmp)
 144                return NULL;
 145
 146        ip.proto = skb->protocol;
 147        ip.vid = vid;
 148
 149        switch (skb->protocol) {
 150        case htons(ETH_P_IP):
 151                ip.u.ip4 = ip_hdr(skb)->daddr;
 152                break;
 153#if IS_ENABLED(CONFIG_IPV6)
 154        case htons(ETH_P_IPV6):
 155                ip.u.ip6 = ipv6_hdr(skb)->daddr;
 156                break;
 157#endif
 158        default:
 159                return NULL;
 160        }
 161
 162        return br_mdb_ip_get(mdb, &ip);
 163}
 164
 165static void br_mdb_free(struct rcu_head *head)
 166{
 167        struct net_bridge_mdb_htable *mdb =
 168                container_of(head, struct net_bridge_mdb_htable, rcu);
 169        struct net_bridge_mdb_htable *old = mdb->old;
 170
 171        mdb->old = NULL;
 172        kfree(old->mhash);
 173        kfree(old);
 174}
 175
 176static int br_mdb_copy(struct net_bridge_mdb_htable *new,
 177                       struct net_bridge_mdb_htable *old,
 178                       int elasticity)
 179{
 180        struct net_bridge_mdb_entry *mp;
 181        int maxlen;
 182        int len;
 183        int i;
 184
 185        for (i = 0; i < old->max; i++)
 186                hlist_for_each_entry(mp, &old->mhash[i], hlist[old->ver])
 187                        hlist_add_head(&mp->hlist[new->ver],
 188                                       &new->mhash[br_ip_hash(new, &mp->addr)]);
 189
 190        if (!elasticity)
 191                return 0;
 192
 193        maxlen = 0;
 194        for (i = 0; i < new->max; i++) {
 195                len = 0;
 196                hlist_for_each_entry(mp, &new->mhash[i], hlist[new->ver])
 197                        len++;
 198                if (len > maxlen)
 199                        maxlen = len;
 200        }
 201
 202        return maxlen > elasticity ? -EINVAL : 0;
 203}
 204
 205void br_multicast_free_pg(struct rcu_head *head)
 206{
 207        struct net_bridge_port_group *p =
 208                container_of(head, struct net_bridge_port_group, rcu);
 209
 210        kfree(p);
 211}
 212
 213static void br_multicast_free_group(struct rcu_head *head)
 214{
 215        struct net_bridge_mdb_entry *mp =
 216                container_of(head, struct net_bridge_mdb_entry, rcu);
 217
 218        kfree(mp);
 219}
 220
 221static void br_multicast_group_expired(unsigned long data)
 222{
 223        struct net_bridge_mdb_entry *mp = (void *)data;
 224        struct net_bridge *br = mp->br;
 225        struct net_bridge_mdb_htable *mdb;
 226
 227        spin_lock(&br->multicast_lock);
 228        if (!netif_running(br->dev) || timer_pending(&mp->timer))
 229                goto out;
 230
 231        mp->mglist = false;
 232
 233        if (mp->ports)
 234                goto out;
 235
 236        mdb = mlock_dereference(br->mdb, br);
 237
 238        hlist_del_rcu(&mp->hlist[mdb->ver]);
 239        mdb->size--;
 240
 241        call_rcu_bh(&mp->rcu, br_multicast_free_group);
 242
 243out:
 244        spin_unlock(&br->multicast_lock);
 245}
 246
 247static void br_multicast_del_pg(struct net_bridge *br,
 248                                struct net_bridge_port_group *pg)
 249{
 250        struct net_bridge_mdb_htable *mdb;
 251        struct net_bridge_mdb_entry *mp;
 252        struct net_bridge_port_group *p;
 253        struct net_bridge_port_group __rcu **pp;
 254
 255        mdb = mlock_dereference(br->mdb, br);
 256
 257        mp = br_mdb_ip_get(mdb, &pg->addr);
 258        if (WARN_ON(!mp))
 259                return;
 260
 261        for (pp = &mp->ports;
 262             (p = mlock_dereference(*pp, br)) != NULL;
 263             pp = &p->next) {
 264                if (p != pg)
 265                        continue;
 266
 267                rcu_assign_pointer(*pp, p->next);
 268                hlist_del_init(&p->mglist);
 269                del_timer(&p->timer);
 270                call_rcu_bh(&p->rcu, br_multicast_free_pg);
 271
 272                if (!mp->ports && !mp->mglist &&
 273                    netif_running(br->dev))
 274                        mod_timer(&mp->timer, jiffies);
 275
 276                return;
 277        }
 278
 279        WARN_ON(1);
 280}
 281
 282static void br_multicast_port_group_expired(unsigned long data)
 283{
 284        struct net_bridge_port_group *pg = (void *)data;
 285        struct net_bridge *br = pg->port->br;
 286
 287        spin_lock(&br->multicast_lock);
 288        if (!netif_running(br->dev) || timer_pending(&pg->timer) ||
 289            hlist_unhashed(&pg->mglist) || pg->state & MDB_PERMANENT)
 290                goto out;
 291
 292        br_multicast_del_pg(br, pg);
 293
 294out:
 295        spin_unlock(&br->multicast_lock);
 296}
 297
 298static int br_mdb_rehash(struct net_bridge_mdb_htable __rcu **mdbp, int max,
 299                         int elasticity)
 300{
 301        struct net_bridge_mdb_htable *old = rcu_dereference_protected(*mdbp, 1);
 302        struct net_bridge_mdb_htable *mdb;
 303        int err;
 304
 305        mdb = kmalloc(sizeof(*mdb), GFP_ATOMIC);
 306        if (!mdb)
 307                return -ENOMEM;
 308
 309        mdb->max = max;
 310        mdb->old = old;
 311
 312        mdb->mhash = kzalloc(max * sizeof(*mdb->mhash), GFP_ATOMIC);
 313        if (!mdb->mhash) {
 314                kfree(mdb);
 315                return -ENOMEM;
 316        }
 317
 318        mdb->size = old ? old->size : 0;
 319        mdb->ver = old ? old->ver ^ 1 : 0;
 320
 321        if (!old || elasticity)
 322                get_random_bytes(&mdb->secret, sizeof(mdb->secret));
 323        else
 324                mdb->secret = old->secret;
 325
 326        if (!old)
 327                goto out;
 328
 329        err = br_mdb_copy(mdb, old, elasticity);
 330        if (err) {
 331                kfree(mdb->mhash);
 332                kfree(mdb);
 333                return err;
 334        }
 335
 336        br_mdb_rehash_seq++;
 337        call_rcu_bh(&mdb->rcu, br_mdb_free);
 338
 339out:
 340        rcu_assign_pointer(*mdbp, mdb);
 341
 342        return 0;
 343}
 344
 345static struct sk_buff *br_ip4_multicast_alloc_query(struct net_bridge *br,
 346                                                    __be32 group)
 347{
 348        struct sk_buff *skb;
 349        struct igmphdr *ih;
 350        struct ethhdr *eth;
 351        struct iphdr *iph;
 352
 353        skb = netdev_alloc_skb_ip_align(br->dev, sizeof(*eth) + sizeof(*iph) +
 354                                                 sizeof(*ih) + 4);
 355        if (!skb)
 356                goto out;
 357
 358        skb->protocol = htons(ETH_P_IP);
 359
 360        skb_reset_mac_header(skb);
 361        eth = eth_hdr(skb);
 362
 363        memcpy(eth->h_source, br->dev->dev_addr, 6);
 364        eth->h_dest[0] = 1;
 365        eth->h_dest[1] = 0;
 366        eth->h_dest[2] = 0x5e;
 367        eth->h_dest[3] = 0;
 368        eth->h_dest[4] = 0;
 369        eth->h_dest[5] = 1;
 370        eth->h_proto = htons(ETH_P_IP);
 371        skb_put(skb, sizeof(*eth));
 372
 373        skb_set_network_header(skb, skb->len);
 374        iph = ip_hdr(skb);
 375
 376        iph->version = 4;
 377        iph->ihl = 6;
 378        iph->tos = 0xc0;
 379        iph->tot_len = htons(sizeof(*iph) + sizeof(*ih) + 4);
 380        iph->id = 0;
 381        iph->frag_off = htons(IP_DF);
 382        iph->ttl = 1;
 383        iph->protocol = IPPROTO_IGMP;
 384        iph->saddr = 0;
 385        iph->daddr = htonl(INADDR_ALLHOSTS_GROUP);
 386        ((u8 *)&iph[1])[0] = IPOPT_RA;
 387        ((u8 *)&iph[1])[1] = 4;
 388        ((u8 *)&iph[1])[2] = 0;
 389        ((u8 *)&iph[1])[3] = 0;
 390        ip_send_check(iph);
 391        skb_put(skb, 24);
 392
 393        skb_set_transport_header(skb, skb->len);
 394        ih = igmp_hdr(skb);
 395        ih->type = IGMP_HOST_MEMBERSHIP_QUERY;
 396        ih->code = (group ? br->multicast_last_member_interval :
 397                            br->multicast_query_response_interval) /
 398                   (HZ / IGMP_TIMER_SCALE);
 399        ih->group = group;
 400        ih->csum = 0;
 401        ih->csum = ip_compute_csum((void *)ih, sizeof(struct igmphdr));
 402        skb_put(skb, sizeof(*ih));
 403
 404        __skb_pull(skb, sizeof(*eth));
 405
 406out:
 407        return skb;
 408}
 409
 410#if IS_ENABLED(CONFIG_IPV6)
 411static struct sk_buff *br_ip6_multicast_alloc_query(struct net_bridge *br,
 412                                                    const struct in6_addr *group)
 413{
 414        struct sk_buff *skb;
 415        struct ipv6hdr *ip6h;
 416        struct mld_msg *mldq;
 417        struct ethhdr *eth;
 418        u8 *hopopt;
 419        unsigned long interval;
 420
 421        skb = netdev_alloc_skb_ip_align(br->dev, sizeof(*eth) + sizeof(*ip6h) +
 422                                                 8 + sizeof(*mldq));
 423        if (!skb)
 424                goto out;
 425
 426        skb->protocol = htons(ETH_P_IPV6);
 427
 428        /* Ethernet header */
 429        skb_reset_mac_header(skb);
 430        eth = eth_hdr(skb);
 431
 432        memcpy(eth->h_source, br->dev->dev_addr, 6);
 433        eth->h_proto = htons(ETH_P_IPV6);
 434        skb_put(skb, sizeof(*eth));
 435
 436        /* IPv6 header + HbH option */
 437        skb_set_network_header(skb, skb->len);
 438        ip6h = ipv6_hdr(skb);
 439
 440        *(__force __be32 *)ip6h = htonl(0x60000000);
 441        ip6h->payload_len = htons(8 + sizeof(*mldq));
 442        ip6h->nexthdr = IPPROTO_HOPOPTS;
 443        ip6h->hop_limit = 1;
 444        ipv6_addr_set(&ip6h->daddr, htonl(0xff020000), 0, 0, htonl(1));
 445        if (ipv6_dev_get_saddr(dev_net(br->dev), br->dev, &ip6h->daddr, 0,
 446                               &ip6h->saddr)) {
 447                kfree_skb(skb);
 448                return NULL;
 449        }
 450        ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
 451
 452        hopopt = (u8 *)(ip6h + 1);
 453        hopopt[0] = IPPROTO_ICMPV6;             /* next hdr */
 454        hopopt[1] = 0;                          /* length of HbH */
 455        hopopt[2] = IPV6_TLV_ROUTERALERT;       /* Router Alert */
 456        hopopt[3] = 2;                          /* Length of RA Option */
 457        hopopt[4] = 0;                          /* Type = 0x0000 (MLD) */
 458        hopopt[5] = 0;
 459        hopopt[6] = IPV6_TLV_PAD1;              /* Pad1 */
 460        hopopt[7] = IPV6_TLV_PAD1;              /* Pad1 */
 461
 462        skb_put(skb, sizeof(*ip6h) + 8);
 463
 464        /* ICMPv6 */
 465        skb_set_transport_header(skb, skb->len);
 466        mldq = (struct mld_msg *) icmp6_hdr(skb);
 467
 468        interval = ipv6_addr_any(group) ?
 469                        br->multicast_query_response_interval :
 470                        br->multicast_last_member_interval;
 471
 472        mldq->mld_type = ICMPV6_MGM_QUERY;
 473        mldq->mld_code = 0;
 474        mldq->mld_cksum = 0;
 475        mldq->mld_maxdelay = htons((u16)jiffies_to_msecs(interval));
 476        mldq->mld_reserved = 0;
 477        mldq->mld_mca = *group;
 478
 479        /* checksum */
 480        mldq->mld_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
 481                                          sizeof(*mldq), IPPROTO_ICMPV6,
 482                                          csum_partial(mldq,
 483                                                       sizeof(*mldq), 0));
 484        skb_put(skb, sizeof(*mldq));
 485
 486        __skb_pull(skb, sizeof(*eth));
 487
 488out:
 489        return skb;
 490}
 491#endif
 492
 493static struct sk_buff *br_multicast_alloc_query(struct net_bridge *br,
 494                                                struct br_ip *addr)
 495{
 496        switch (addr->proto) {
 497        case htons(ETH_P_IP):
 498                return br_ip4_multicast_alloc_query(br, addr->u.ip4);
 499#if IS_ENABLED(CONFIG_IPV6)
 500        case htons(ETH_P_IPV6):
 501                return br_ip6_multicast_alloc_query(br, &addr->u.ip6);
 502#endif
 503        }
 504        return NULL;
 505}
 506
 507static struct net_bridge_mdb_entry *br_multicast_get_group(
 508        struct net_bridge *br, struct net_bridge_port *port,
 509        struct br_ip *group, int hash)
 510{
 511        struct net_bridge_mdb_htable *mdb;
 512        struct net_bridge_mdb_entry *mp;
 513        unsigned int count = 0;
 514        unsigned int max;
 515        int elasticity;
 516        int err;
 517
 518        mdb = rcu_dereference_protected(br->mdb, 1);
 519        hlist_for_each_entry(mp, &mdb->mhash[hash], hlist[mdb->ver]) {
 520                count++;
 521                if (unlikely(br_ip_equal(group, &mp->addr)))
 522                        return mp;
 523        }
 524
 525        elasticity = 0;
 526        max = mdb->max;
 527
 528        if (unlikely(count > br->hash_elasticity && count)) {
 529                if (net_ratelimit())
 530                        br_info(br, "Multicast hash table "
 531                                "chain limit reached: %s\n",
 532                                port ? port->dev->name : br->dev->name);
 533
 534                elasticity = br->hash_elasticity;
 535        }
 536
 537        if (mdb->size >= max) {
 538                max *= 2;
 539                if (unlikely(max > br->hash_max)) {
 540                        br_warn(br, "Multicast hash table maximum of %d "
 541                                "reached, disabling snooping: %s\n",
 542                                br->hash_max,
 543                                port ? port->dev->name : br->dev->name);
 544                        err = -E2BIG;
 545disable:
 546                        br->multicast_disabled = 1;
 547                        goto err;
 548                }
 549        }
 550
 551        if (max > mdb->max || elasticity) {
 552                if (mdb->old) {
 553                        if (net_ratelimit())
 554                                br_info(br, "Multicast hash table "
 555                                        "on fire: %s\n",
 556                                        port ? port->dev->name : br->dev->name);
 557                        err = -EEXIST;
 558                        goto err;
 559                }
 560
 561                err = br_mdb_rehash(&br->mdb, max, elasticity);
 562                if (err) {
 563                        br_warn(br, "Cannot rehash multicast "
 564                                "hash table, disabling snooping: %s, %d, %d\n",
 565                                port ? port->dev->name : br->dev->name,
 566                                mdb->size, err);
 567                        goto disable;
 568                }
 569
 570                err = -EAGAIN;
 571                goto err;
 572        }
 573
 574        return NULL;
 575
 576err:
 577        mp = ERR_PTR(err);
 578        return mp;
 579}
 580
 581struct net_bridge_mdb_entry *br_multicast_new_group(struct net_bridge *br,
 582        struct net_bridge_port *port, struct br_ip *group)
 583{
 584        struct net_bridge_mdb_htable *mdb;
 585        struct net_bridge_mdb_entry *mp;
 586        int hash;
 587        int err;
 588
 589        mdb = rcu_dereference_protected(br->mdb, 1);
 590        if (!mdb) {
 591                err = br_mdb_rehash(&br->mdb, BR_HASH_SIZE, 0);
 592                if (err)
 593                        return ERR_PTR(err);
 594                goto rehash;
 595        }
 596
 597        hash = br_ip_hash(mdb, group);
 598        mp = br_multicast_get_group(br, port, group, hash);
 599        switch (PTR_ERR(mp)) {
 600        case 0:
 601                break;
 602
 603        case -EAGAIN:
 604rehash:
 605                mdb = rcu_dereference_protected(br->mdb, 1);
 606                hash = br_ip_hash(mdb, group);
 607                break;
 608
 609        default:
 610                goto out;
 611        }
 612
 613        mp = kzalloc(sizeof(*mp), GFP_ATOMIC);
 614        if (unlikely(!mp))
 615                return ERR_PTR(-ENOMEM);
 616
 617        mp->br = br;
 618        mp->addr = *group;
 619        setup_timer(&mp->timer, br_multicast_group_expired,
 620                    (unsigned long)mp);
 621
 622        hlist_add_head_rcu(&mp->hlist[mdb->ver], &mdb->mhash[hash]);
 623        mdb->size++;
 624
 625out:
 626        return mp;
 627}
 628
 629struct net_bridge_port_group *br_multicast_new_port_group(
 630                        struct net_bridge_port *port,
 631                        struct br_ip *group,
 632                        struct net_bridge_port_group __rcu *next,
 633                        unsigned char state)
 634{
 635        struct net_bridge_port_group *p;
 636
 637        p = kzalloc(sizeof(*p), GFP_ATOMIC);
 638        if (unlikely(!p))
 639                return NULL;
 640
 641        p->addr = *group;
 642        p->port = port;
 643        p->state = state;
 644        rcu_assign_pointer(p->next, next);
 645        hlist_add_head(&p->mglist, &port->mglist);
 646        setup_timer(&p->timer, br_multicast_port_group_expired,
 647                    (unsigned long)p);
 648        return p;
 649}
 650
 651static int br_multicast_add_group(struct net_bridge *br,
 652                                  struct net_bridge_port *port,
 653                                  struct br_ip *group)
 654{
 655        struct net_bridge_mdb_entry *mp;
 656        struct net_bridge_port_group *p;
 657        struct net_bridge_port_group __rcu **pp;
 658        unsigned long now = jiffies;
 659        int err;
 660
 661        spin_lock(&br->multicast_lock);
 662        if (!netif_running(br->dev) ||
 663            (port && port->state == BR_STATE_DISABLED))
 664                goto out;
 665
 666        mp = br_multicast_new_group(br, port, group);
 667        err = PTR_ERR(mp);
 668        if (IS_ERR(mp))
 669                goto err;
 670
 671        if (!port) {
 672                mp->mglist = true;
 673                mod_timer(&mp->timer, now + br->multicast_membership_interval);
 674                goto out;
 675        }
 676
 677        for (pp = &mp->ports;
 678             (p = mlock_dereference(*pp, br)) != NULL;
 679             pp = &p->next) {
 680                if (p->port == port)
 681                        goto found;
 682                if ((unsigned long)p->port < (unsigned long)port)
 683                        break;
 684        }
 685
 686        p = br_multicast_new_port_group(port, group, *pp, MDB_TEMPORARY);
 687        if (unlikely(!p))
 688                goto err;
 689        rcu_assign_pointer(*pp, p);
 690        br_mdb_notify(br->dev, port, group, RTM_NEWMDB);
 691
 692found:
 693        mod_timer(&p->timer, now + br->multicast_membership_interval);
 694out:
 695        err = 0;
 696
 697err:
 698        spin_unlock(&br->multicast_lock);
 699        return err;
 700}
 701
 702static int br_ip4_multicast_add_group(struct net_bridge *br,
 703                                      struct net_bridge_port *port,
 704                                      __be32 group,
 705                                      __u16 vid)
 706{
 707        struct br_ip br_group;
 708
 709        if (ipv4_is_local_multicast(group))
 710                return 0;
 711
 712        br_group.u.ip4 = group;
 713        br_group.proto = htons(ETH_P_IP);
 714        br_group.vid = vid;
 715
 716        return br_multicast_add_group(br, port, &br_group);
 717}
 718
 719#if IS_ENABLED(CONFIG_IPV6)
 720static int br_ip6_multicast_add_group(struct net_bridge *br,
 721                                      struct net_bridge_port *port,
 722                                      const struct in6_addr *group,
 723                                      __u16 vid)
 724{
 725        struct br_ip br_group;
 726
 727        if (!ipv6_is_transient_multicast(group))
 728                return 0;
 729
 730        br_group.u.ip6 = *group;
 731        br_group.proto = htons(ETH_P_IPV6);
 732        br_group.vid = vid;
 733
 734        return br_multicast_add_group(br, port, &br_group);
 735}
 736#endif
 737
 738static void br_multicast_router_expired(unsigned long data)
 739{
 740        struct net_bridge_port *port = (void *)data;
 741        struct net_bridge *br = port->br;
 742
 743        spin_lock(&br->multicast_lock);
 744        if (port->multicast_router != 1 ||
 745            timer_pending(&port->multicast_router_timer) ||
 746            hlist_unhashed(&port->rlist))
 747                goto out;
 748
 749        hlist_del_init_rcu(&port->rlist);
 750
 751out:
 752        spin_unlock(&br->multicast_lock);
 753}
 754
 755static void br_multicast_local_router_expired(unsigned long data)
 756{
 757}
 758
 759static void br_multicast_querier_expired(unsigned long data)
 760{
 761        struct net_bridge *br = (void *)data;
 762
 763        spin_lock(&br->multicast_lock);
 764        if (!netif_running(br->dev) || br->multicast_disabled)
 765                goto out;
 766
 767        br_multicast_start_querier(br);
 768
 769out:
 770        spin_unlock(&br->multicast_lock);
 771}
 772
 773static void __br_multicast_send_query(struct net_bridge *br,
 774                                      struct net_bridge_port *port,
 775                                      struct br_ip *ip)
 776{
 777        struct sk_buff *skb;
 778
 779        skb = br_multicast_alloc_query(br, ip);
 780        if (!skb)
 781                return;
 782
 783        if (port) {
 784                __skb_push(skb, sizeof(struct ethhdr));
 785                skb->dev = port->dev;
 786                NF_HOOK(NFPROTO_BRIDGE, NF_BR_LOCAL_OUT, skb, NULL, skb->dev,
 787                        dev_queue_xmit);
 788        } else
 789                netif_rx(skb);
 790}
 791
 792static void br_multicast_send_query(struct net_bridge *br,
 793                                    struct net_bridge_port *port, u32 sent)
 794{
 795        unsigned long time;
 796        struct br_ip br_group;
 797
 798        if (!netif_running(br->dev) || br->multicast_disabled ||
 799            !br->multicast_querier ||
 800            timer_pending(&br->multicast_querier_timer))
 801                return;
 802
 803        memset(&br_group.u, 0, sizeof(br_group.u));
 804
 805        br_group.proto = htons(ETH_P_IP);
 806        __br_multicast_send_query(br, port, &br_group);
 807
 808#if IS_ENABLED(CONFIG_IPV6)
 809        br_group.proto = htons(ETH_P_IPV6);
 810        __br_multicast_send_query(br, port, &br_group);
 811#endif
 812
 813        time = jiffies;
 814        time += sent < br->multicast_startup_query_count ?
 815                br->multicast_startup_query_interval :
 816                br->multicast_query_interval;
 817        mod_timer(port ? &port->multicast_query_timer :
 818                         &br->multicast_query_timer, time);
 819}
 820
 821static void br_multicast_port_query_expired(unsigned long data)
 822{
 823        struct net_bridge_port *port = (void *)data;
 824        struct net_bridge *br = port->br;
 825
 826        spin_lock(&br->multicast_lock);
 827        if (port->state == BR_STATE_DISABLED ||
 828            port->state == BR_STATE_BLOCKING)
 829                goto out;
 830
 831        if (port->multicast_startup_queries_sent <
 832            br->multicast_startup_query_count)
 833                port->multicast_startup_queries_sent++;
 834
 835        br_multicast_send_query(port->br, port,
 836                                port->multicast_startup_queries_sent);
 837
 838out:
 839        spin_unlock(&br->multicast_lock);
 840}
 841
 842void br_multicast_add_port(struct net_bridge_port *port)
 843{
 844        port->multicast_router = 1;
 845
 846        setup_timer(&port->multicast_router_timer, br_multicast_router_expired,
 847                    (unsigned long)port);
 848        setup_timer(&port->multicast_query_timer,
 849                    br_multicast_port_query_expired, (unsigned long)port);
 850}
 851
 852void br_multicast_del_port(struct net_bridge_port *port)
 853{
 854        del_timer_sync(&port->multicast_router_timer);
 855}
 856
 857static void __br_multicast_enable_port(struct net_bridge_port *port)
 858{
 859        port->multicast_startup_queries_sent = 0;
 860
 861        if (try_to_del_timer_sync(&port->multicast_query_timer) >= 0 ||
 862            del_timer(&port->multicast_query_timer))
 863                mod_timer(&port->multicast_query_timer, jiffies);
 864}
 865
 866void br_multicast_enable_port(struct net_bridge_port *port)
 867{
 868        struct net_bridge *br = port->br;
 869
 870        spin_lock(&br->multicast_lock);
 871        if (br->multicast_disabled || !netif_running(br->dev))
 872                goto out;
 873
 874        __br_multicast_enable_port(port);
 875
 876out:
 877        spin_unlock(&br->multicast_lock);
 878}
 879
 880void br_multicast_disable_port(struct net_bridge_port *port)
 881{
 882        struct net_bridge *br = port->br;
 883        struct net_bridge_port_group *pg;
 884        struct hlist_node *n;
 885
 886        spin_lock(&br->multicast_lock);
 887        hlist_for_each_entry_safe(pg, n, &port->mglist, mglist)
 888                br_multicast_del_pg(br, pg);
 889
 890        if (!hlist_unhashed(&port->rlist))
 891                hlist_del_init_rcu(&port->rlist);
 892        del_timer(&port->multicast_router_timer);
 893        del_timer(&port->multicast_query_timer);
 894        spin_unlock(&br->multicast_lock);
 895}
 896
 897static int br_ip4_multicast_igmp3_report(struct net_bridge *br,
 898                                         struct net_bridge_port *port,
 899                                         struct sk_buff *skb)
 900{
 901        struct igmpv3_report *ih;
 902        struct igmpv3_grec *grec;
 903        int i;
 904        int len;
 905        int num;
 906        int type;
 907        int err = 0;
 908        __be32 group;
 909        u16 vid = 0;
 910
 911        if (!pskb_may_pull(skb, sizeof(*ih)))
 912                return -EINVAL;
 913
 914        br_vlan_get_tag(skb, &vid);
 915        ih = igmpv3_report_hdr(skb);
 916        num = ntohs(ih->ngrec);
 917        len = sizeof(*ih);
 918
 919        for (i = 0; i < num; i++) {
 920                len += sizeof(*grec);
 921                if (!pskb_may_pull(skb, len))
 922                        return -EINVAL;
 923
 924                grec = (void *)(skb->data + len - sizeof(*grec));
 925                group = grec->grec_mca;
 926                type = grec->grec_type;
 927
 928                len += ntohs(grec->grec_nsrcs) * 4;
 929                if (!pskb_may_pull(skb, len))
 930                        return -EINVAL;
 931
 932                /* We treat this as an IGMPv2 report for now. */
 933                switch (type) {
 934                case IGMPV3_MODE_IS_INCLUDE:
 935                case IGMPV3_MODE_IS_EXCLUDE:
 936                case IGMPV3_CHANGE_TO_INCLUDE:
 937                case IGMPV3_CHANGE_TO_EXCLUDE:
 938                case IGMPV3_ALLOW_NEW_SOURCES:
 939                case IGMPV3_BLOCK_OLD_SOURCES:
 940                        break;
 941
 942                default:
 943                        continue;
 944                }
 945
 946                err = br_ip4_multicast_add_group(br, port, group, vid);
 947                if (err)
 948                        break;
 949        }
 950
 951        return err;
 952}
 953
 954#if IS_ENABLED(CONFIG_IPV6)
 955static int br_ip6_multicast_mld2_report(struct net_bridge *br,
 956                                        struct net_bridge_port *port,
 957                                        struct sk_buff *skb)
 958{
 959        struct icmp6hdr *icmp6h;
 960        struct mld2_grec *grec;
 961        int i;
 962        int len;
 963        int num;
 964        int err = 0;
 965        u16 vid = 0;
 966
 967        if (!pskb_may_pull(skb, sizeof(*icmp6h)))
 968                return -EINVAL;
 969
 970        br_vlan_get_tag(skb, &vid);
 971        icmp6h = icmp6_hdr(skb);
 972        num = ntohs(icmp6h->icmp6_dataun.un_data16[1]);
 973        len = sizeof(*icmp6h);
 974
 975        for (i = 0; i < num; i++) {
 976                __be16 *nsrcs, _nsrcs;
 977
 978                nsrcs = skb_header_pointer(skb,
 979                                           len + offsetof(struct mld2_grec,
 980                                                          grec_nsrcs),
 981                                           sizeof(_nsrcs), &_nsrcs);
 982                if (!nsrcs)
 983                        return -EINVAL;
 984
 985                if (!pskb_may_pull(skb,
 986                                   len + sizeof(*grec) +
 987                                   sizeof(struct in6_addr) * ntohs(*nsrcs)))
 988                        return -EINVAL;
 989
 990                grec = (struct mld2_grec *)(skb->data + len);
 991                len += sizeof(*grec) +
 992                       sizeof(struct in6_addr) * ntohs(*nsrcs);
 993
 994                /* We treat these as MLDv1 reports for now. */
 995                switch (grec->grec_type) {
 996                case MLD2_MODE_IS_INCLUDE:
 997                case MLD2_MODE_IS_EXCLUDE:
 998                case MLD2_CHANGE_TO_INCLUDE:
 999                case MLD2_CHANGE_TO_EXCLUDE:
1000                case MLD2_ALLOW_NEW_SOURCES:
1001                case MLD2_BLOCK_OLD_SOURCES:
1002                        break;
1003
1004                default:
1005                        continue;
1006                }
1007
1008                err = br_ip6_multicast_add_group(br, port, &grec->grec_mca,
1009                                                 vid);
1010                if (!err)
1011                        break;
1012        }
1013
1014        return err;
1015}
1016#endif
1017
1018/*
1019 * Add port to rotuer_list
1020 *  list is maintained ordered by pointer value
1021 *  and locked by br->multicast_lock and RCU
1022 */
1023static void br_multicast_add_router(struct net_bridge *br,
1024                                    struct net_bridge_port *port)
1025{
1026        struct net_bridge_port *p;
1027        struct hlist_node *slot = NULL;
1028
1029        hlist_for_each_entry(p, &br->router_list, rlist) {
1030                if ((unsigned long) port >= (unsigned long) p)
1031                        break;
1032                slot = &p->rlist;
1033        }
1034
1035        if (slot)
1036                hlist_add_after_rcu(slot, &port->rlist);
1037        else
1038                hlist_add_head_rcu(&port->rlist, &br->router_list);
1039}
1040
1041static void br_multicast_mark_router(struct net_bridge *br,
1042                                     struct net_bridge_port *port)
1043{
1044        unsigned long now = jiffies;
1045
1046        if (!port) {
1047                if (br->multicast_router == 1)
1048                        mod_timer(&br->multicast_router_timer,
1049                                  now + br->multicast_querier_interval);
1050                return;
1051        }
1052
1053        if (port->multicast_router != 1)
1054                return;
1055
1056        if (!hlist_unhashed(&port->rlist))
1057                goto timer;
1058
1059        br_multicast_add_router(br, port);
1060
1061timer:
1062        mod_timer(&port->multicast_router_timer,
1063                  now + br->multicast_querier_interval);
1064}
1065
1066static void br_multicast_query_received(struct net_bridge *br,
1067                                        struct net_bridge_port *port,
1068                                        int saddr)
1069{
1070        if (saddr)
1071                mod_timer(&br->multicast_querier_timer,
1072                          jiffies + br->multicast_querier_interval);
1073        else if (timer_pending(&br->multicast_querier_timer))
1074                return;
1075
1076        br_multicast_mark_router(br, port);
1077}
1078
1079static int br_ip4_multicast_query(struct net_bridge *br,
1080                                  struct net_bridge_port *port,
1081                                  struct sk_buff *skb)
1082{
1083        const struct iphdr *iph = ip_hdr(skb);
1084        struct igmphdr *ih = igmp_hdr(skb);
1085        struct net_bridge_mdb_entry *mp;
1086        struct igmpv3_query *ih3;
1087        struct net_bridge_port_group *p;
1088        struct net_bridge_port_group __rcu **pp;
1089        unsigned long max_delay;
1090        unsigned long now = jiffies;
1091        __be32 group;
1092        int err = 0;
1093        u16 vid = 0;
1094
1095        spin_lock(&br->multicast_lock);
1096        if (!netif_running(br->dev) ||
1097            (port && port->state == BR_STATE_DISABLED))
1098                goto out;
1099
1100        br_multicast_query_received(br, port, !!iph->saddr);
1101
1102        group = ih->group;
1103
1104        if (skb->len == sizeof(*ih)) {
1105                max_delay = ih->code * (HZ / IGMP_TIMER_SCALE);
1106
1107                if (!max_delay) {
1108                        max_delay = 10 * HZ;
1109                        group = 0;
1110                }
1111        } else {
1112                if (!pskb_may_pull(skb, sizeof(struct igmpv3_query))) {
1113                        err = -EINVAL;
1114                        goto out;
1115                }
1116
1117                ih3 = igmpv3_query_hdr(skb);
1118                if (ih3->nsrcs)
1119                        goto out;
1120
1121                max_delay = ih3->code ?
1122                            IGMPV3_MRC(ih3->code) * (HZ / IGMP_TIMER_SCALE) : 1;
1123        }
1124
1125        if (!group)
1126                goto out;
1127
1128        br_vlan_get_tag(skb, &vid);
1129        mp = br_mdb_ip4_get(mlock_dereference(br->mdb, br), group, vid);
1130        if (!mp)
1131                goto out;
1132
1133        max_delay *= br->multicast_last_member_count;
1134
1135        if (mp->mglist &&
1136            (timer_pending(&mp->timer) ?
1137             time_after(mp->timer.expires, now + max_delay) :
1138             try_to_del_timer_sync(&mp->timer) >= 0))
1139                mod_timer(&mp->timer, now + max_delay);
1140
1141        for (pp = &mp->ports;
1142             (p = mlock_dereference(*pp, br)) != NULL;
1143             pp = &p->next) {
1144                if (timer_pending(&p->timer) ?
1145                    time_after(p->timer.expires, now + max_delay) :
1146                    try_to_del_timer_sync(&p->timer) >= 0)
1147                        mod_timer(&p->timer, now + max_delay);
1148        }
1149
1150out:
1151        spin_unlock(&br->multicast_lock);
1152        return err;
1153}
1154
1155#if IS_ENABLED(CONFIG_IPV6)
1156static int br_ip6_multicast_query(struct net_bridge *br,
1157                                  struct net_bridge_port *port,
1158                                  struct sk_buff *skb)
1159{
1160        const struct ipv6hdr *ip6h = ipv6_hdr(skb);
1161        struct mld_msg *mld;
1162        struct net_bridge_mdb_entry *mp;
1163        struct mld2_query *mld2q;
1164        struct net_bridge_port_group *p;
1165        struct net_bridge_port_group __rcu **pp;
1166        unsigned long max_delay;
1167        unsigned long now = jiffies;
1168        const struct in6_addr *group = NULL;
1169        int err = 0;
1170        u16 vid = 0;
1171
1172        spin_lock(&br->multicast_lock);
1173        if (!netif_running(br->dev) ||
1174            (port && port->state == BR_STATE_DISABLED))
1175                goto out;
1176
1177        br_multicast_query_received(br, port, !ipv6_addr_any(&ip6h->saddr));
1178
1179        if (skb->len == sizeof(*mld)) {
1180                if (!pskb_may_pull(skb, sizeof(*mld))) {
1181                        err = -EINVAL;
1182                        goto out;
1183                }
1184                mld = (struct mld_msg *) icmp6_hdr(skb);
1185                max_delay = msecs_to_jiffies(ntohs(mld->mld_maxdelay));
1186                if (max_delay)
1187                        group = &mld->mld_mca;
1188        } else if (skb->len >= sizeof(*mld2q)) {
1189                if (!pskb_may_pull(skb, sizeof(*mld2q))) {
1190                        err = -EINVAL;
1191                        goto out;
1192                }
1193                mld2q = (struct mld2_query *)icmp6_hdr(skb);
1194                if (!mld2q->mld2q_nsrcs)
1195                        group = &mld2q->mld2q_mca;
1196                max_delay = mld2q->mld2q_mrc ? MLDV2_MRC(ntohs(mld2q->mld2q_mrc)) : 1;
1197        }
1198
1199        if (!group)
1200                goto out;
1201
1202        br_vlan_get_tag(skb, &vid);
1203        mp = br_mdb_ip6_get(mlock_dereference(br->mdb, br), group, vid);
1204        if (!mp)
1205                goto out;
1206
1207        max_delay *= br->multicast_last_member_count;
1208        if (mp->mglist &&
1209            (timer_pending(&mp->timer) ?
1210             time_after(mp->timer.expires, now + max_delay) :
1211             try_to_del_timer_sync(&mp->timer) >= 0))
1212                mod_timer(&mp->timer, now + max_delay);
1213
1214        for (pp = &mp->ports;
1215             (p = mlock_dereference(*pp, br)) != NULL;
1216             pp = &p->next) {
1217                if (timer_pending(&p->timer) ?
1218                    time_after(p->timer.expires, now + max_delay) :
1219                    try_to_del_timer_sync(&p->timer) >= 0)
1220                        mod_timer(&p->timer, now + max_delay);
1221        }
1222
1223out:
1224        spin_unlock(&br->multicast_lock);
1225        return err;
1226}
1227#endif
1228
1229static void br_multicast_leave_group(struct net_bridge *br,
1230                                     struct net_bridge_port *port,
1231                                     struct br_ip *group)
1232{
1233        struct net_bridge_mdb_htable *mdb;
1234        struct net_bridge_mdb_entry *mp;
1235        struct net_bridge_port_group *p;
1236        unsigned long now;
1237        unsigned long time;
1238
1239        spin_lock(&br->multicast_lock);
1240        if (!netif_running(br->dev) ||
1241            (port && port->state == BR_STATE_DISABLED) ||
1242            timer_pending(&br->multicast_querier_timer))
1243                goto out;
1244
1245        mdb = mlock_dereference(br->mdb, br);
1246        mp = br_mdb_ip_get(mdb, group);
1247        if (!mp)
1248                goto out;
1249
1250        if (port && (port->flags & BR_MULTICAST_FAST_LEAVE)) {
1251                struct net_bridge_port_group __rcu **pp;
1252
1253                for (pp = &mp->ports;
1254                     (p = mlock_dereference(*pp, br)) != NULL;
1255                     pp = &p->next) {
1256                        if (p->port != port)
1257                                continue;
1258
1259                        rcu_assign_pointer(*pp, p->next);
1260                        hlist_del_init(&p->mglist);
1261                        del_timer(&p->timer);
1262                        call_rcu_bh(&p->rcu, br_multicast_free_pg);
1263                        br_mdb_notify(br->dev, port, group, RTM_DELMDB);
1264
1265                        if (!mp->ports && !mp->mglist &&
1266                            netif_running(br->dev))
1267                                mod_timer(&mp->timer, jiffies);
1268                }
1269                goto out;
1270        }
1271
1272        now = jiffies;
1273        time = now + br->multicast_last_member_count *
1274                     br->multicast_last_member_interval;
1275
1276        if (!port) {
1277                if (mp->mglist &&
1278                    (timer_pending(&mp->timer) ?
1279                     time_after(mp->timer.expires, time) :
1280                     try_to_del_timer_sync(&mp->timer) >= 0)) {
1281                        mod_timer(&mp->timer, time);
1282                }
1283
1284                goto out;
1285        }
1286
1287        for (p = mlock_dereference(mp->ports, br);
1288             p != NULL;
1289             p = mlock_dereference(p->next, br)) {
1290                if (p->port != port)
1291                        continue;
1292
1293                if (!hlist_unhashed(&p->mglist) &&
1294                    (timer_pending(&p->timer) ?
1295                     time_after(p->timer.expires, time) :
1296                     try_to_del_timer_sync(&p->timer) >= 0)) {
1297                        mod_timer(&p->timer, time);
1298                }
1299
1300                break;
1301        }
1302
1303out:
1304        spin_unlock(&br->multicast_lock);
1305}
1306
1307static void br_ip4_multicast_leave_group(struct net_bridge *br,
1308                                         struct net_bridge_port *port,
1309                                         __be32 group,
1310                                         __u16 vid)
1311{
1312        struct br_ip br_group;
1313
1314        if (ipv4_is_local_multicast(group))
1315                return;
1316
1317        br_group.u.ip4 = group;
1318        br_group.proto = htons(ETH_P_IP);
1319        br_group.vid = vid;
1320
1321        br_multicast_leave_group(br, port, &br_group);
1322}
1323
1324#if IS_ENABLED(CONFIG_IPV6)
1325static void br_ip6_multicast_leave_group(struct net_bridge *br,
1326                                         struct net_bridge_port *port,
1327                                         const struct in6_addr *group,
1328                                         __u16 vid)
1329{
1330        struct br_ip br_group;
1331
1332        if (!ipv6_is_transient_multicast(group))
1333                return;
1334
1335        br_group.u.ip6 = *group;
1336        br_group.proto = htons(ETH_P_IPV6);
1337        br_group.vid = vid;
1338
1339        br_multicast_leave_group(br, port, &br_group);
1340}
1341#endif
1342
1343static int br_multicast_ipv4_rcv(struct net_bridge *br,
1344                                 struct net_bridge_port *port,
1345                                 struct sk_buff *skb)
1346{
1347        struct sk_buff *skb2 = skb;
1348        const struct iphdr *iph;
1349        struct igmphdr *ih;
1350        unsigned int len;
1351        unsigned int offset;
1352        int err;
1353        u16 vid = 0;
1354
1355        /* We treat OOM as packet loss for now. */
1356        if (!pskb_may_pull(skb, sizeof(*iph)))
1357                return -EINVAL;
1358
1359        iph = ip_hdr(skb);
1360
1361        if (iph->ihl < 5 || iph->version != 4)
1362                return -EINVAL;
1363
1364        if (!pskb_may_pull(skb, ip_hdrlen(skb)))
1365                return -EINVAL;
1366
1367        iph = ip_hdr(skb);
1368
1369        if (unlikely(ip_fast_csum((u8 *)iph, iph->ihl)))
1370                return -EINVAL;
1371
1372        if (iph->protocol != IPPROTO_IGMP) {
1373                if (!ipv4_is_local_multicast(iph->daddr))
1374                        BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
1375                return 0;
1376        }
1377
1378        len = ntohs(iph->tot_len);
1379        if (skb->len < len || len < ip_hdrlen(skb))
1380                return -EINVAL;
1381
1382        if (skb->len > len) {
1383                skb2 = skb_clone(skb, GFP_ATOMIC);
1384                if (!skb2)
1385                        return -ENOMEM;
1386
1387                err = pskb_trim_rcsum(skb2, len);
1388                if (err)
1389                        goto err_out;
1390        }
1391
1392        len -= ip_hdrlen(skb2);
1393        offset = skb_network_offset(skb2) + ip_hdrlen(skb2);
1394        __skb_pull(skb2, offset);
1395        skb_reset_transport_header(skb2);
1396
1397        err = -EINVAL;
1398        if (!pskb_may_pull(skb2, sizeof(*ih)))
1399                goto out;
1400
1401        switch (skb2->ip_summed) {
1402        case CHECKSUM_COMPLETE:
1403                if (!csum_fold(skb2->csum))
1404                        break;
1405                /* fall through */
1406        case CHECKSUM_NONE:
1407                skb2->csum = 0;
1408                if (skb_checksum_complete(skb2))
1409                        goto out;
1410        }
1411
1412        err = 0;
1413
1414        br_vlan_get_tag(skb2, &vid);
1415        BR_INPUT_SKB_CB(skb)->igmp = 1;
1416        ih = igmp_hdr(skb2);
1417
1418        switch (ih->type) {
1419        case IGMP_HOST_MEMBERSHIP_REPORT:
1420        case IGMPV2_HOST_MEMBERSHIP_REPORT:
1421                BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
1422                err = br_ip4_multicast_add_group(br, port, ih->group, vid);
1423                break;
1424        case IGMPV3_HOST_MEMBERSHIP_REPORT:
1425                err = br_ip4_multicast_igmp3_report(br, port, skb2);
1426                break;
1427        case IGMP_HOST_MEMBERSHIP_QUERY:
1428                err = br_ip4_multicast_query(br, port, skb2);
1429                break;
1430        case IGMP_HOST_LEAVE_MESSAGE:
1431                br_ip4_multicast_leave_group(br, port, ih->group, vid);
1432                break;
1433        }
1434
1435out:
1436        __skb_push(skb2, offset);
1437err_out:
1438        if (skb2 != skb)
1439                kfree_skb(skb2);
1440        return err;
1441}
1442
1443#if IS_ENABLED(CONFIG_IPV6)
1444static int br_multicast_ipv6_rcv(struct net_bridge *br,
1445                                 struct net_bridge_port *port,
1446                                 struct sk_buff *skb)
1447{
1448        struct sk_buff *skb2;
1449        const struct ipv6hdr *ip6h;
1450        u8 icmp6_type;
1451        u8 nexthdr;
1452        __be16 frag_off;
1453        unsigned int len;
1454        int offset;
1455        int err;
1456        u16 vid = 0;
1457
1458        if (!pskb_may_pull(skb, sizeof(*ip6h)))
1459                return -EINVAL;
1460
1461        ip6h = ipv6_hdr(skb);
1462
1463        /*
1464         * We're interested in MLD messages only.
1465         *  - Version is 6
1466         *  - MLD has always Router Alert hop-by-hop option
1467         *  - But we do not support jumbrograms.
1468         */
1469        if (ip6h->version != 6 ||
1470            ip6h->nexthdr != IPPROTO_HOPOPTS ||
1471            ip6h->payload_len == 0)
1472                return 0;
1473
1474        len = ntohs(ip6h->payload_len) + sizeof(*ip6h);
1475        if (skb->len < len)
1476                return -EINVAL;
1477
1478        nexthdr = ip6h->nexthdr;
1479        offset = ipv6_skip_exthdr(skb, sizeof(*ip6h), &nexthdr, &frag_off);
1480
1481        if (offset < 0 || nexthdr != IPPROTO_ICMPV6)
1482                return 0;
1483
1484        /* Okay, we found ICMPv6 header */
1485        skb2 = skb_clone(skb, GFP_ATOMIC);
1486        if (!skb2)
1487                return -ENOMEM;
1488
1489        err = -EINVAL;
1490        if (!pskb_may_pull(skb2, offset + sizeof(struct icmp6hdr)))
1491                goto out;
1492
1493        len -= offset - skb_network_offset(skb2);
1494
1495        __skb_pull(skb2, offset);
1496        skb_reset_transport_header(skb2);
1497        skb_postpull_rcsum(skb2, skb_network_header(skb2),
1498                           skb_network_header_len(skb2));
1499
1500        icmp6_type = icmp6_hdr(skb2)->icmp6_type;
1501
1502        switch (icmp6_type) {
1503        case ICMPV6_MGM_QUERY:
1504        case ICMPV6_MGM_REPORT:
1505        case ICMPV6_MGM_REDUCTION:
1506        case ICMPV6_MLD2_REPORT:
1507                break;
1508        default:
1509                err = 0;
1510                goto out;
1511        }
1512
1513        /* Okay, we found MLD message. Check further. */
1514        if (skb2->len > len) {
1515                err = pskb_trim_rcsum(skb2, len);
1516                if (err)
1517                        goto out;
1518                err = -EINVAL;
1519        }
1520
1521        ip6h = ipv6_hdr(skb2);
1522
1523        switch (skb2->ip_summed) {
1524        case CHECKSUM_COMPLETE:
1525                if (!csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, skb2->len,
1526                                        IPPROTO_ICMPV6, skb2->csum))
1527                        break;
1528                /*FALLTHROUGH*/
1529        case CHECKSUM_NONE:
1530                skb2->csum = ~csum_unfold(csum_ipv6_magic(&ip6h->saddr,
1531                                                        &ip6h->daddr,
1532                                                        skb2->len,
1533                                                        IPPROTO_ICMPV6, 0));
1534                if (__skb_checksum_complete(skb2))
1535                        goto out;
1536        }
1537
1538        err = 0;
1539
1540        br_vlan_get_tag(skb, &vid);
1541        BR_INPUT_SKB_CB(skb)->igmp = 1;
1542
1543        switch (icmp6_type) {
1544        case ICMPV6_MGM_REPORT:
1545            {
1546                struct mld_msg *mld;
1547                if (!pskb_may_pull(skb2, sizeof(*mld))) {
1548                        err = -EINVAL;
1549                        goto out;
1550                }
1551                mld = (struct mld_msg *)skb_transport_header(skb2);
1552                BR_INPUT_SKB_CB(skb)->mrouters_only = 1;
1553                err = br_ip6_multicast_add_group(br, port, &mld->mld_mca, vid);
1554                break;
1555            }
1556        case ICMPV6_MLD2_REPORT:
1557                err = br_ip6_multicast_mld2_report(br, port, skb2);
1558                break;
1559        case ICMPV6_MGM_QUERY:
1560                err = br_ip6_multicast_query(br, port, skb2);
1561                break;
1562        case ICMPV6_MGM_REDUCTION:
1563            {
1564                struct mld_msg *mld;
1565                if (!pskb_may_pull(skb2, sizeof(*mld))) {
1566                        err = -EINVAL;
1567                        goto out;
1568                }
1569                mld = (struct mld_msg *)skb_transport_header(skb2);
1570                br_ip6_multicast_leave_group(br, port, &mld->mld_mca, vid);
1571            }
1572        }
1573
1574out:
1575        kfree_skb(skb2);
1576        return err;
1577}
1578#endif
1579
1580int br_multicast_rcv(struct net_bridge *br, struct net_bridge_port *port,
1581                     struct sk_buff *skb)
1582{
1583        BR_INPUT_SKB_CB(skb)->igmp = 0;
1584        BR_INPUT_SKB_CB(skb)->mrouters_only = 0;
1585
1586        if (br->multicast_disabled)
1587                return 0;
1588
1589        switch (skb->protocol) {
1590        case htons(ETH_P_IP):
1591                return br_multicast_ipv4_rcv(br, port, skb);
1592#if IS_ENABLED(CONFIG_IPV6)
1593        case htons(ETH_P_IPV6):
1594                return br_multicast_ipv6_rcv(br, port, skb);
1595#endif
1596        }
1597
1598        return 0;
1599}
1600
1601static void br_multicast_query_expired(unsigned long data)
1602{
1603        struct net_bridge *br = (void *)data;
1604
1605        spin_lock(&br->multicast_lock);
1606        if (br->multicast_startup_queries_sent <
1607            br->multicast_startup_query_count)
1608                br->multicast_startup_queries_sent++;
1609
1610        br_multicast_send_query(br, NULL, br->multicast_startup_queries_sent);
1611
1612        spin_unlock(&br->multicast_lock);
1613}
1614
1615void br_multicast_init(struct net_bridge *br)
1616{
1617        br->hash_elasticity = 4;
1618        br->hash_max = 512;
1619
1620        br->multicast_router = 1;
1621        br->multicast_querier = 0;
1622        br->multicast_last_member_count = 2;
1623        br->multicast_startup_query_count = 2;
1624
1625        br->multicast_last_member_interval = HZ;
1626        br->multicast_query_response_interval = 10 * HZ;
1627        br->multicast_startup_query_interval = 125 * HZ / 4;
1628        br->multicast_query_interval = 125 * HZ;
1629        br->multicast_querier_interval = 255 * HZ;
1630        br->multicast_membership_interval = 260 * HZ;
1631
1632        spin_lock_init(&br->multicast_lock);
1633        setup_timer(&br->multicast_router_timer,
1634                    br_multicast_local_router_expired, 0);
1635        setup_timer(&br->multicast_querier_timer,
1636                    br_multicast_querier_expired, (unsigned long)br);
1637        setup_timer(&br->multicast_query_timer, br_multicast_query_expired,
1638                    (unsigned long)br);
1639}
1640
1641void br_multicast_open(struct net_bridge *br)
1642{
1643        br->multicast_startup_queries_sent = 0;
1644
1645        if (br->multicast_disabled)
1646                return;
1647
1648        mod_timer(&br->multicast_query_timer, jiffies);
1649}
1650
1651void br_multicast_stop(struct net_bridge *br)
1652{
1653        struct net_bridge_mdb_htable *mdb;
1654        struct net_bridge_mdb_entry *mp;
1655        struct hlist_node *n;
1656        u32 ver;
1657        int i;
1658
1659        del_timer_sync(&br->multicast_router_timer);
1660        del_timer_sync(&br->multicast_querier_timer);
1661        del_timer_sync(&br->multicast_query_timer);
1662
1663        spin_lock_bh(&br->multicast_lock);
1664        mdb = mlock_dereference(br->mdb, br);
1665        if (!mdb)
1666                goto out;
1667
1668        br->mdb = NULL;
1669
1670        ver = mdb->ver;
1671        for (i = 0; i < mdb->max; i++) {
1672                hlist_for_each_entry_safe(mp, n, &mdb->mhash[i],
1673                                          hlist[ver]) {
1674                        del_timer(&mp->timer);
1675                        call_rcu_bh(&mp->rcu, br_multicast_free_group);
1676                }
1677        }
1678
1679        if (mdb->old) {
1680                spin_unlock_bh(&br->multicast_lock);
1681                rcu_barrier_bh();
1682                spin_lock_bh(&br->multicast_lock);
1683                WARN_ON(mdb->old);
1684        }
1685
1686        mdb->old = mdb;
1687        call_rcu_bh(&mdb->rcu, br_mdb_free);
1688
1689out:
1690        spin_unlock_bh(&br->multicast_lock);
1691}
1692
1693int br_multicast_set_router(struct net_bridge *br, unsigned long val)
1694{
1695        int err = -ENOENT;
1696
1697        spin_lock_bh(&br->multicast_lock);
1698        if (!netif_running(br->dev))
1699                goto unlock;
1700
1701        switch (val) {
1702        case 0:
1703        case 2:
1704                del_timer(&br->multicast_router_timer);
1705                /* fall through */
1706        case 1:
1707                br->multicast_router = val;
1708                err = 0;
1709                break;
1710
1711        default:
1712                err = -EINVAL;
1713                break;
1714        }
1715
1716unlock:
1717        spin_unlock_bh(&br->multicast_lock);
1718
1719        return err;
1720}
1721
1722int br_multicast_set_port_router(struct net_bridge_port *p, unsigned long val)
1723{
1724        struct net_bridge *br = p->br;
1725        int err = -ENOENT;
1726
1727        spin_lock(&br->multicast_lock);
1728        if (!netif_running(br->dev) || p->state == BR_STATE_DISABLED)
1729                goto unlock;
1730
1731        switch (val) {
1732        case 0:
1733        case 1:
1734        case 2:
1735                p->multicast_router = val;
1736                err = 0;
1737
1738                if (val < 2 && !hlist_unhashed(&p->rlist))
1739                        hlist_del_init_rcu(&p->rlist);
1740
1741                if (val == 1)
1742                        break;
1743
1744                del_timer(&p->multicast_router_timer);
1745
1746                if (val == 0)
1747                        break;
1748
1749                br_multicast_add_router(br, p);
1750                break;
1751
1752        default:
1753                err = -EINVAL;
1754                break;
1755        }
1756
1757unlock:
1758        spin_unlock(&br->multicast_lock);
1759
1760        return err;
1761}
1762
1763static void br_multicast_start_querier(struct net_bridge *br)
1764{
1765        struct net_bridge_port *port;
1766
1767        br_multicast_open(br);
1768
1769        list_for_each_entry(port, &br->port_list, list) {
1770                if (port->state == BR_STATE_DISABLED ||
1771                    port->state == BR_STATE_BLOCKING)
1772                        continue;
1773
1774                __br_multicast_enable_port(port);
1775        }
1776}
1777
1778int br_multicast_toggle(struct net_bridge *br, unsigned long val)
1779{
1780        int err = 0;
1781        struct net_bridge_mdb_htable *mdb;
1782
1783        spin_lock_bh(&br->multicast_lock);
1784        if (br->multicast_disabled == !val)
1785                goto unlock;
1786
1787        br->multicast_disabled = !val;
1788        if (br->multicast_disabled)
1789                goto unlock;
1790
1791        if (!netif_running(br->dev))
1792                goto unlock;
1793
1794        mdb = mlock_dereference(br->mdb, br);
1795        if (mdb) {
1796                if (mdb->old) {
1797                        err = -EEXIST;
1798rollback:
1799                        br->multicast_disabled = !!val;
1800                        goto unlock;
1801                }
1802
1803                err = br_mdb_rehash(&br->mdb, mdb->max,
1804                                    br->hash_elasticity);
1805                if (err)
1806                        goto rollback;
1807        }
1808
1809        br_multicast_start_querier(br);
1810
1811unlock:
1812        spin_unlock_bh(&br->multicast_lock);
1813
1814        return err;
1815}
1816
1817int br_multicast_set_querier(struct net_bridge *br, unsigned long val)
1818{
1819        val = !!val;
1820
1821        spin_lock_bh(&br->multicast_lock);
1822        if (br->multicast_querier == val)
1823                goto unlock;
1824
1825        br->multicast_querier = val;
1826        if (val)
1827                br_multicast_start_querier(br);
1828
1829unlock:
1830        spin_unlock_bh(&br->multicast_lock);
1831
1832        return 0;
1833}
1834
1835int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val)
1836{
1837        int err = -ENOENT;
1838        u32 old;
1839        struct net_bridge_mdb_htable *mdb;
1840
1841        spin_lock(&br->multicast_lock);
1842        if (!netif_running(br->dev))
1843                goto unlock;
1844
1845        err = -EINVAL;
1846        if (!is_power_of_2(val))
1847                goto unlock;
1848
1849        mdb = mlock_dereference(br->mdb, br);
1850        if (mdb && val < mdb->size)
1851                goto unlock;
1852
1853        err = 0;
1854
1855        old = br->hash_max;
1856        br->hash_max = val;
1857
1858        if (mdb) {
1859                if (mdb->old) {
1860                        err = -EEXIST;
1861rollback:
1862                        br->hash_max = old;
1863                        goto unlock;
1864                }
1865
1866                err = br_mdb_rehash(&br->mdb, br->hash_max,
1867                                    br->hash_elasticity);
1868                if (err)
1869                        goto rollback;
1870        }
1871
1872unlock:
1873        spin_unlock(&br->multicast_lock);
1874
1875        return err;
1876}
1877