linux/net/netfilter/ipset/ip_set_core.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
   3 *                         Patrick Schaaf <bof@bof.de>
   4 * Copyright (C) 2003-2013 Jozsef Kadlecsik <kadlec@netfilter.org>
   5 */
   6
   7/* Kernel module for IP set management */
   8
   9#include <linux/init.h>
  10#include <linux/module.h>
  11#include <linux/moduleparam.h>
  12#include <linux/ip.h>
  13#include <linux/skbuff.h>
  14#include <linux/spinlock.h>
  15#include <linux/rculist.h>
  16#include <net/netlink.h>
  17#include <net/net_namespace.h>
  18#include <net/netns/generic.h>
  19
  20#include <linux/netfilter.h>
  21#include <linux/netfilter/x_tables.h>
  22#include <linux/netfilter/nfnetlink.h>
  23#include <linux/netfilter/ipset/ip_set.h>
  24
  25static LIST_HEAD(ip_set_type_list);             /* all registered set types */
  26static DEFINE_MUTEX(ip_set_type_mutex);         /* protects ip_set_type_list */
  27static DEFINE_RWLOCK(ip_set_ref_lock);          /* protects the set refs */
  28
  29struct ip_set_net {
  30        struct ip_set * __rcu *ip_set_list;     /* all individual sets */
  31        ip_set_id_t     ip_set_max;     /* max number of sets */
  32        bool            is_deleted;     /* deleted by ip_set_net_exit */
  33        bool            is_destroyed;   /* all sets are destroyed */
  34};
  35
  36static unsigned int ip_set_net_id __read_mostly;
  37
  38static inline struct ip_set_net *ip_set_pernet(struct net *net)
  39{
  40        return net_generic(net, ip_set_net_id);
  41}
  42
  43#define IP_SET_INC      64
  44#define STRNCMP(a, b)   (strncmp(a, b, IPSET_MAXNAMELEN) == 0)
  45
  46static unsigned int max_sets;
  47
  48module_param(max_sets, int, 0600);
  49MODULE_PARM_DESC(max_sets, "maximal number of sets");
  50MODULE_LICENSE("GPL");
  51MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@netfilter.org>");
  52MODULE_DESCRIPTION("core IP set support");
  53MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET);
  54
  55/* When the nfnl mutex or ip_set_ref_lock is held: */
  56#define ip_set_dereference(p)           \
  57        rcu_dereference_protected(p,    \
  58                lockdep_nfnl_is_held(NFNL_SUBSYS_IPSET) || \
  59                lockdep_is_held(&ip_set_ref_lock))
  60#define ip_set(inst, id)                \
  61        ip_set_dereference((inst)->ip_set_list)[id]
  62#define ip_set_ref_netlink(inst,id)     \
  63        rcu_dereference_raw((inst)->ip_set_list)[id]
  64
  65/* The set types are implemented in modules and registered set types
  66 * can be found in ip_set_type_list. Adding/deleting types is
  67 * serialized by ip_set_type_mutex.
  68 */
  69
  70static inline void
  71ip_set_type_lock(void)
  72{
  73        mutex_lock(&ip_set_type_mutex);
  74}
  75
  76static inline void
  77ip_set_type_unlock(void)
  78{
  79        mutex_unlock(&ip_set_type_mutex);
  80}
  81
  82/* Register and deregister settype */
  83
  84static struct ip_set_type *
  85find_set_type(const char *name, u8 family, u8 revision)
  86{
  87        struct ip_set_type *type;
  88
  89        list_for_each_entry_rcu(type, &ip_set_type_list, list)
  90                if (STRNCMP(type->name, name) &&
  91                    (type->family == family ||
  92                     type->family == NFPROTO_UNSPEC) &&
  93                    revision >= type->revision_min &&
  94                    revision <= type->revision_max)
  95                        return type;
  96        return NULL;
  97}
  98
  99/* Unlock, try to load a set type module and lock again */
 100static bool
 101load_settype(const char *name)
 102{
 103        nfnl_unlock(NFNL_SUBSYS_IPSET);
 104        pr_debug("try to load ip_set_%s\n", name);
 105        if (request_module("ip_set_%s", name) < 0) {
 106                pr_warn("Can't find ip_set type %s\n", name);
 107                nfnl_lock(NFNL_SUBSYS_IPSET);
 108                return false;
 109        }
 110        nfnl_lock(NFNL_SUBSYS_IPSET);
 111        return true;
 112}
 113
 114/* Find a set type and reference it */
 115#define find_set_type_get(name, family, revision, found)        \
 116        __find_set_type_get(name, family, revision, found, false)
 117
 118static int
 119__find_set_type_get(const char *name, u8 family, u8 revision,
 120                    struct ip_set_type **found, bool retry)
 121{
 122        struct ip_set_type *type;
 123        int err;
 124
 125        if (retry && !load_settype(name))
 126                return -IPSET_ERR_FIND_TYPE;
 127
 128        rcu_read_lock();
 129        *found = find_set_type(name, family, revision);
 130        if (*found) {
 131                err = !try_module_get((*found)->me) ? -EFAULT : 0;
 132                goto unlock;
 133        }
 134        /* Make sure the type is already loaded
 135         * but we don't support the revision
 136         */
 137        list_for_each_entry_rcu(type, &ip_set_type_list, list)
 138                if (STRNCMP(type->name, name)) {
 139                        err = -IPSET_ERR_FIND_TYPE;
 140                        goto unlock;
 141                }
 142        rcu_read_unlock();
 143
 144        return retry ? -IPSET_ERR_FIND_TYPE :
 145                __find_set_type_get(name, family, revision, found, true);
 146
 147unlock:
 148        rcu_read_unlock();
 149        return err;
 150}
 151
 152/* Find a given set type by name and family.
 153 * If we succeeded, the supported minimal and maximum revisions are
 154 * filled out.
 155 */
 156#define find_set_type_minmax(name, family, min, max) \
 157        __find_set_type_minmax(name, family, min, max, false)
 158
 159static int
 160__find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max,
 161                       bool retry)
 162{
 163        struct ip_set_type *type;
 164        bool found = false;
 165
 166        if (retry && !load_settype(name))
 167                return -IPSET_ERR_FIND_TYPE;
 168
 169        *min = 255; *max = 0;
 170        rcu_read_lock();
 171        list_for_each_entry_rcu(type, &ip_set_type_list, list)
 172                if (STRNCMP(type->name, name) &&
 173                    (type->family == family ||
 174                     type->family == NFPROTO_UNSPEC)) {
 175                        found = true;
 176                        if (type->revision_min < *min)
 177                                *min = type->revision_min;
 178                        if (type->revision_max > *max)
 179                                *max = type->revision_max;
 180                }
 181        rcu_read_unlock();
 182        if (found)
 183                return 0;
 184
 185        return retry ? -IPSET_ERR_FIND_TYPE :
 186                __find_set_type_minmax(name, family, min, max, true);
 187}
 188
 189#define family_name(f)  ((f) == NFPROTO_IPV4 ? "inet" : \
 190                         (f) == NFPROTO_IPV6 ? "inet6" : "any")
 191
 192/* Register a set type structure. The type is identified by
 193 * the unique triple of name, family and revision.
 194 */
 195int
 196ip_set_type_register(struct ip_set_type *type)
 197{
 198        int ret = 0;
 199
 200        if (type->protocol != IPSET_PROTOCOL) {
 201                pr_warn("ip_set type %s, family %s, revision %u:%u uses wrong protocol version %u (want %u)\n",
 202                        type->name, family_name(type->family),
 203                        type->revision_min, type->revision_max,
 204                        type->protocol, IPSET_PROTOCOL);
 205                return -EINVAL;
 206        }
 207
 208        ip_set_type_lock();
 209        if (find_set_type(type->name, type->family, type->revision_min)) {
 210                /* Duplicate! */
 211                pr_warn("ip_set type %s, family %s with revision min %u already registered!\n",
 212                        type->name, family_name(type->family),
 213                        type->revision_min);
 214                ip_set_type_unlock();
 215                return -EINVAL;
 216        }
 217        list_add_rcu(&type->list, &ip_set_type_list);
 218        pr_debug("type %s, family %s, revision %u:%u registered.\n",
 219                 type->name, family_name(type->family),
 220                 type->revision_min, type->revision_max);
 221        ip_set_type_unlock();
 222
 223        return ret;
 224}
 225EXPORT_SYMBOL_GPL(ip_set_type_register);
 226
 227/* Unregister a set type. There's a small race with ip_set_create */
 228void
 229ip_set_type_unregister(struct ip_set_type *type)
 230{
 231        ip_set_type_lock();
 232        if (!find_set_type(type->name, type->family, type->revision_min)) {
 233                pr_warn("ip_set type %s, family %s with revision min %u not registered\n",
 234                        type->name, family_name(type->family),
 235                        type->revision_min);
 236                ip_set_type_unlock();
 237                return;
 238        }
 239        list_del_rcu(&type->list);
 240        pr_debug("type %s, family %s with revision min %u unregistered.\n",
 241                 type->name, family_name(type->family), type->revision_min);
 242        ip_set_type_unlock();
 243
 244        synchronize_rcu();
 245}
 246EXPORT_SYMBOL_GPL(ip_set_type_unregister);
 247
 248/* Utility functions */
 249void *
 250ip_set_alloc(size_t size)
 251{
 252        void *members = NULL;
 253
 254        if (size < KMALLOC_MAX_SIZE)
 255                members = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
 256
 257        if (members) {
 258                pr_debug("%p: allocated with kmalloc\n", members);
 259                return members;
 260        }
 261
 262        members = vzalloc(size);
 263        if (!members)
 264                return NULL;
 265        pr_debug("%p: allocated with vmalloc\n", members);
 266
 267        return members;
 268}
 269EXPORT_SYMBOL_GPL(ip_set_alloc);
 270
 271void
 272ip_set_free(void *members)
 273{
 274        pr_debug("%p: free with %s\n", members,
 275                 is_vmalloc_addr(members) ? "vfree" : "kfree");
 276        kvfree(members);
 277}
 278EXPORT_SYMBOL_GPL(ip_set_free);
 279
 280static inline bool
 281flag_nested(const struct nlattr *nla)
 282{
 283        return nla->nla_type & NLA_F_NESTED;
 284}
 285
 286static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = {
 287        [IPSET_ATTR_IPADDR_IPV4]        = { .type = NLA_U32 },
 288        [IPSET_ATTR_IPADDR_IPV6]        = { .type = NLA_BINARY,
 289                                            .len = sizeof(struct in6_addr) },
 290};
 291
 292int
 293ip_set_get_ipaddr4(struct nlattr *nla,  __be32 *ipaddr)
 294{
 295        struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1];
 296
 297        if (unlikely(!flag_nested(nla)))
 298                return -IPSET_ERR_PROTOCOL;
 299        if (nla_parse_nested_deprecated(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy, NULL))
 300                return -IPSET_ERR_PROTOCOL;
 301        if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4)))
 302                return -IPSET_ERR_PROTOCOL;
 303
 304        *ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]);
 305        return 0;
 306}
 307EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4);
 308
 309int
 310ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr)
 311{
 312        struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1];
 313
 314        if (unlikely(!flag_nested(nla)))
 315                return -IPSET_ERR_PROTOCOL;
 316
 317        if (nla_parse_nested_deprecated(tb, IPSET_ATTR_IPADDR_MAX, nla, ipaddr_policy, NULL))
 318                return -IPSET_ERR_PROTOCOL;
 319        if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6)))
 320                return -IPSET_ERR_PROTOCOL;
 321
 322        memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]),
 323               sizeof(struct in6_addr));
 324        return 0;
 325}
 326EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6);
 327
 328typedef void (*destroyer)(struct ip_set *, void *);
 329/* ipset data extension types, in size order */
 330
 331const struct ip_set_ext_type ip_set_extensions[] = {
 332        [IPSET_EXT_ID_COUNTER] = {
 333                .type   = IPSET_EXT_COUNTER,
 334                .flag   = IPSET_FLAG_WITH_COUNTERS,
 335                .len    = sizeof(struct ip_set_counter),
 336                .align  = __alignof__(struct ip_set_counter),
 337        },
 338        [IPSET_EXT_ID_TIMEOUT] = {
 339                .type   = IPSET_EXT_TIMEOUT,
 340                .len    = sizeof(unsigned long),
 341                .align  = __alignof__(unsigned long),
 342        },
 343        [IPSET_EXT_ID_SKBINFO] = {
 344                .type   = IPSET_EXT_SKBINFO,
 345                .flag   = IPSET_FLAG_WITH_SKBINFO,
 346                .len    = sizeof(struct ip_set_skbinfo),
 347                .align  = __alignof__(struct ip_set_skbinfo),
 348        },
 349        [IPSET_EXT_ID_COMMENT] = {
 350                .type    = IPSET_EXT_COMMENT | IPSET_EXT_DESTROY,
 351                .flag    = IPSET_FLAG_WITH_COMMENT,
 352                .len     = sizeof(struct ip_set_comment),
 353                .align   = __alignof__(struct ip_set_comment),
 354                .destroy = (destroyer) ip_set_comment_free,
 355        },
 356};
 357EXPORT_SYMBOL_GPL(ip_set_extensions);
 358
 359static inline bool
 360add_extension(enum ip_set_ext_id id, u32 flags, struct nlattr *tb[])
 361{
 362        return ip_set_extensions[id].flag ?
 363                (flags & ip_set_extensions[id].flag) :
 364                !!tb[IPSET_ATTR_TIMEOUT];
 365}
 366
 367size_t
 368ip_set_elem_len(struct ip_set *set, struct nlattr *tb[], size_t len,
 369                size_t align)
 370{
 371        enum ip_set_ext_id id;
 372        u32 cadt_flags = 0;
 373
 374        if (tb[IPSET_ATTR_CADT_FLAGS])
 375                cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
 376        if (cadt_flags & IPSET_FLAG_WITH_FORCEADD)
 377                set->flags |= IPSET_CREATE_FLAG_FORCEADD;
 378        if (!align)
 379                align = 1;
 380        for (id = 0; id < IPSET_EXT_ID_MAX; id++) {
 381                if (!add_extension(id, cadt_flags, tb))
 382                        continue;
 383                len = ALIGN(len, ip_set_extensions[id].align);
 384                set->offset[id] = len;
 385                set->extensions |= ip_set_extensions[id].type;
 386                len += ip_set_extensions[id].len;
 387        }
 388        return ALIGN(len, align);
 389}
 390EXPORT_SYMBOL_GPL(ip_set_elem_len);
 391
 392int
 393ip_set_get_extensions(struct ip_set *set, struct nlattr *tb[],
 394                      struct ip_set_ext *ext)
 395{
 396        u64 fullmark;
 397
 398        if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) ||
 399                     !ip_set_optattr_netorder(tb, IPSET_ATTR_PACKETS) ||
 400                     !ip_set_optattr_netorder(tb, IPSET_ATTR_BYTES) ||
 401                     !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBMARK) ||
 402                     !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBPRIO) ||
 403                     !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBQUEUE)))
 404                return -IPSET_ERR_PROTOCOL;
 405
 406        if (tb[IPSET_ATTR_TIMEOUT]) {
 407                if (!SET_WITH_TIMEOUT(set))
 408                        return -IPSET_ERR_TIMEOUT;
 409                ext->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]);
 410        }
 411        if (tb[IPSET_ATTR_BYTES] || tb[IPSET_ATTR_PACKETS]) {
 412                if (!SET_WITH_COUNTER(set))
 413                        return -IPSET_ERR_COUNTER;
 414                if (tb[IPSET_ATTR_BYTES])
 415                        ext->bytes = be64_to_cpu(nla_get_be64(
 416                                                 tb[IPSET_ATTR_BYTES]));
 417                if (tb[IPSET_ATTR_PACKETS])
 418                        ext->packets = be64_to_cpu(nla_get_be64(
 419                                                   tb[IPSET_ATTR_PACKETS]));
 420        }
 421        if (tb[IPSET_ATTR_COMMENT]) {
 422                if (!SET_WITH_COMMENT(set))
 423                        return -IPSET_ERR_COMMENT;
 424                ext->comment = ip_set_comment_uget(tb[IPSET_ATTR_COMMENT]);
 425        }
 426        if (tb[IPSET_ATTR_SKBMARK]) {
 427                if (!SET_WITH_SKBINFO(set))
 428                        return -IPSET_ERR_SKBINFO;
 429                fullmark = be64_to_cpu(nla_get_be64(tb[IPSET_ATTR_SKBMARK]));
 430                ext->skbinfo.skbmark = fullmark >> 32;
 431                ext->skbinfo.skbmarkmask = fullmark & 0xffffffff;
 432        }
 433        if (tb[IPSET_ATTR_SKBPRIO]) {
 434                if (!SET_WITH_SKBINFO(set))
 435                        return -IPSET_ERR_SKBINFO;
 436                ext->skbinfo.skbprio =
 437                        be32_to_cpu(nla_get_be32(tb[IPSET_ATTR_SKBPRIO]));
 438        }
 439        if (tb[IPSET_ATTR_SKBQUEUE]) {
 440                if (!SET_WITH_SKBINFO(set))
 441                        return -IPSET_ERR_SKBINFO;
 442                ext->skbinfo.skbqueue =
 443                        be16_to_cpu(nla_get_be16(tb[IPSET_ATTR_SKBQUEUE]));
 444        }
 445        return 0;
 446}
 447EXPORT_SYMBOL_GPL(ip_set_get_extensions);
 448
 449int
 450ip_set_put_extensions(struct sk_buff *skb, const struct ip_set *set,
 451                      const void *e, bool active)
 452{
 453        if (SET_WITH_TIMEOUT(set)) {
 454                unsigned long *timeout = ext_timeout(e, set);
 455
 456                if (nla_put_net32(skb, IPSET_ATTR_TIMEOUT,
 457                        htonl(active ? ip_set_timeout_get(timeout)
 458                                : *timeout)))
 459                        return -EMSGSIZE;
 460        }
 461        if (SET_WITH_COUNTER(set) &&
 462            ip_set_put_counter(skb, ext_counter(e, set)))
 463                return -EMSGSIZE;
 464        if (SET_WITH_COMMENT(set) &&
 465            ip_set_put_comment(skb, ext_comment(e, set)))
 466                return -EMSGSIZE;
 467        if (SET_WITH_SKBINFO(set) &&
 468            ip_set_put_skbinfo(skb, ext_skbinfo(e, set)))
 469                return -EMSGSIZE;
 470        return 0;
 471}
 472EXPORT_SYMBOL_GPL(ip_set_put_extensions);
 473
 474bool
 475ip_set_match_extensions(struct ip_set *set, const struct ip_set_ext *ext,
 476                        struct ip_set_ext *mext, u32 flags, void *data)
 477{
 478        if (SET_WITH_TIMEOUT(set) &&
 479            ip_set_timeout_expired(ext_timeout(data, set)))
 480                return false;
 481        if (SET_WITH_COUNTER(set)) {
 482                struct ip_set_counter *counter = ext_counter(data, set);
 483
 484                if (flags & IPSET_FLAG_MATCH_COUNTERS &&
 485                    !(ip_set_match_counter(ip_set_get_packets(counter),
 486                                mext->packets, mext->packets_op) &&
 487                      ip_set_match_counter(ip_set_get_bytes(counter),
 488                                mext->bytes, mext->bytes_op)))
 489                        return false;
 490                ip_set_update_counter(counter, ext, flags);
 491        }
 492        if (SET_WITH_SKBINFO(set))
 493                ip_set_get_skbinfo(ext_skbinfo(data, set),
 494                                   ext, mext, flags);
 495        return true;
 496}
 497EXPORT_SYMBOL_GPL(ip_set_match_extensions);
 498
 499/* Creating/destroying/renaming/swapping affect the existence and
 500 * the properties of a set. All of these can be executed from userspace
 501 * only and serialized by the nfnl mutex indirectly from nfnetlink.
 502 *
 503 * Sets are identified by their index in ip_set_list and the index
 504 * is used by the external references (set/SET netfilter modules).
 505 *
 506 * The set behind an index may change by swapping only, from userspace.
 507 */
 508
 509static inline void
 510__ip_set_get(struct ip_set *set)
 511{
 512        write_lock_bh(&ip_set_ref_lock);
 513        set->ref++;
 514        write_unlock_bh(&ip_set_ref_lock);
 515}
 516
 517static inline void
 518__ip_set_put(struct ip_set *set)
 519{
 520        write_lock_bh(&ip_set_ref_lock);
 521        BUG_ON(set->ref == 0);
 522        set->ref--;
 523        write_unlock_bh(&ip_set_ref_lock);
 524}
 525
 526/* set->ref can be swapped out by ip_set_swap, netlink events (like dump) need
 527 * a separate reference counter
 528 */
 529static inline void
 530__ip_set_put_netlink(struct ip_set *set)
 531{
 532        write_lock_bh(&ip_set_ref_lock);
 533        BUG_ON(set->ref_netlink == 0);
 534        set->ref_netlink--;
 535        write_unlock_bh(&ip_set_ref_lock);
 536}
 537
 538/* Add, del and test set entries from kernel.
 539 *
 540 * The set behind the index must exist and must be referenced
 541 * so it can't be destroyed (or changed) under our foot.
 542 */
 543
 544static inline struct ip_set *
 545ip_set_rcu_get(struct net *net, ip_set_id_t index)
 546{
 547        struct ip_set *set;
 548        struct ip_set_net *inst = ip_set_pernet(net);
 549
 550        rcu_read_lock();
 551        /* ip_set_list itself needs to be protected */
 552        set = rcu_dereference(inst->ip_set_list)[index];
 553        rcu_read_unlock();
 554
 555        return set;
 556}
 557
 558int
 559ip_set_test(ip_set_id_t index, const struct sk_buff *skb,
 560            const struct xt_action_param *par, struct ip_set_adt_opt *opt)
 561{
 562        struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
 563        int ret = 0;
 564
 565        BUG_ON(!set);
 566        pr_debug("set %s, index %u\n", set->name, index);
 567
 568        if (opt->dim < set->type->dimension ||
 569            !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
 570                return 0;
 571
 572        rcu_read_lock_bh();
 573        ret = set->variant->kadt(set, skb, par, IPSET_TEST, opt);
 574        rcu_read_unlock_bh();
 575
 576        if (ret == -EAGAIN) {
 577                /* Type requests element to be completed */
 578                pr_debug("element must be completed, ADD is triggered\n");
 579                spin_lock_bh(&set->lock);
 580                set->variant->kadt(set, skb, par, IPSET_ADD, opt);
 581                spin_unlock_bh(&set->lock);
 582                ret = 1;
 583        } else {
 584                /* --return-nomatch: invert matched element */
 585                if ((opt->cmdflags & IPSET_FLAG_RETURN_NOMATCH) &&
 586                    (set->type->features & IPSET_TYPE_NOMATCH) &&
 587                    (ret > 0 || ret == -ENOTEMPTY))
 588                        ret = -ret;
 589        }
 590
 591        /* Convert error codes to nomatch */
 592        return (ret < 0 ? 0 : ret);
 593}
 594EXPORT_SYMBOL_GPL(ip_set_test);
 595
 596int
 597ip_set_add(ip_set_id_t index, const struct sk_buff *skb,
 598           const struct xt_action_param *par, struct ip_set_adt_opt *opt)
 599{
 600        struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
 601        int ret;
 602
 603        BUG_ON(!set);
 604        pr_debug("set %s, index %u\n", set->name, index);
 605
 606        if (opt->dim < set->type->dimension ||
 607            !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
 608                return -IPSET_ERR_TYPE_MISMATCH;
 609
 610        spin_lock_bh(&set->lock);
 611        ret = set->variant->kadt(set, skb, par, IPSET_ADD, opt);
 612        spin_unlock_bh(&set->lock);
 613
 614        return ret;
 615}
 616EXPORT_SYMBOL_GPL(ip_set_add);
 617
 618int
 619ip_set_del(ip_set_id_t index, const struct sk_buff *skb,
 620           const struct xt_action_param *par, struct ip_set_adt_opt *opt)
 621{
 622        struct ip_set *set = ip_set_rcu_get(xt_net(par), index);
 623        int ret = 0;
 624
 625        BUG_ON(!set);
 626        pr_debug("set %s, index %u\n", set->name, index);
 627
 628        if (opt->dim < set->type->dimension ||
 629            !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
 630                return -IPSET_ERR_TYPE_MISMATCH;
 631
 632        spin_lock_bh(&set->lock);
 633        ret = set->variant->kadt(set, skb, par, IPSET_DEL, opt);
 634        spin_unlock_bh(&set->lock);
 635
 636        return ret;
 637}
 638EXPORT_SYMBOL_GPL(ip_set_del);
 639
 640/* Find set by name, reference it once. The reference makes sure the
 641 * thing pointed to, does not go away under our feet.
 642 *
 643 */
 644ip_set_id_t
 645ip_set_get_byname(struct net *net, const char *name, struct ip_set **set)
 646{
 647        ip_set_id_t i, index = IPSET_INVALID_ID;
 648        struct ip_set *s;
 649        struct ip_set_net *inst = ip_set_pernet(net);
 650
 651        rcu_read_lock();
 652        for (i = 0; i < inst->ip_set_max; i++) {
 653                s = rcu_dereference(inst->ip_set_list)[i];
 654                if (s && STRNCMP(s->name, name)) {
 655                        __ip_set_get(s);
 656                        index = i;
 657                        *set = s;
 658                        break;
 659                }
 660        }
 661        rcu_read_unlock();
 662
 663        return index;
 664}
 665EXPORT_SYMBOL_GPL(ip_set_get_byname);
 666
 667/* If the given set pointer points to a valid set, decrement
 668 * reference count by 1. The caller shall not assume the index
 669 * to be valid, after calling this function.
 670 *
 671 */
 672
 673static inline void
 674__ip_set_put_byindex(struct ip_set_net *inst, ip_set_id_t index)
 675{
 676        struct ip_set *set;
 677
 678        rcu_read_lock();
 679        set = rcu_dereference(inst->ip_set_list)[index];
 680        if (set)
 681                __ip_set_put(set);
 682        rcu_read_unlock();
 683}
 684
 685void
 686ip_set_put_byindex(struct net *net, ip_set_id_t index)
 687{
 688        struct ip_set_net *inst = ip_set_pernet(net);
 689
 690        __ip_set_put_byindex(inst, index);
 691}
 692EXPORT_SYMBOL_GPL(ip_set_put_byindex);
 693
 694/* Get the name of a set behind a set index.
 695 * Set itself is protected by RCU, but its name isn't: to protect against
 696 * renaming, grab ip_set_ref_lock as reader (see ip_set_rename()) and copy the
 697 * name.
 698 */
 699void
 700ip_set_name_byindex(struct net *net, ip_set_id_t index, char *name)
 701{
 702        struct ip_set *set = ip_set_rcu_get(net, index);
 703
 704        BUG_ON(!set);
 705
 706        read_lock_bh(&ip_set_ref_lock);
 707        strncpy(name, set->name, IPSET_MAXNAMELEN);
 708        read_unlock_bh(&ip_set_ref_lock);
 709}
 710EXPORT_SYMBOL_GPL(ip_set_name_byindex);
 711
 712/* Routines to call by external subsystems, which do not
 713 * call nfnl_lock for us.
 714 */
 715
 716/* Find set by index, reference it once. The reference makes sure the
 717 * thing pointed to, does not go away under our feet.
 718 *
 719 * The nfnl mutex is used in the function.
 720 */
 721ip_set_id_t
 722ip_set_nfnl_get_byindex(struct net *net, ip_set_id_t index)
 723{
 724        struct ip_set *set;
 725        struct ip_set_net *inst = ip_set_pernet(net);
 726
 727        if (index >= inst->ip_set_max)
 728                return IPSET_INVALID_ID;
 729
 730        nfnl_lock(NFNL_SUBSYS_IPSET);
 731        set = ip_set(inst, index);
 732        if (set)
 733                __ip_set_get(set);
 734        else
 735                index = IPSET_INVALID_ID;
 736        nfnl_unlock(NFNL_SUBSYS_IPSET);
 737
 738        return index;
 739}
 740EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex);
 741
 742/* If the given set pointer points to a valid set, decrement
 743 * reference count by 1. The caller shall not assume the index
 744 * to be valid, after calling this function.
 745 *
 746 * The nfnl mutex is used in the function.
 747 */
 748void
 749ip_set_nfnl_put(struct net *net, ip_set_id_t index)
 750{
 751        struct ip_set *set;
 752        struct ip_set_net *inst = ip_set_pernet(net);
 753
 754        nfnl_lock(NFNL_SUBSYS_IPSET);
 755        if (!inst->is_deleted) { /* already deleted from ip_set_net_exit() */
 756                set = ip_set(inst, index);
 757                if (set)
 758                        __ip_set_put(set);
 759        }
 760        nfnl_unlock(NFNL_SUBSYS_IPSET);
 761}
 762EXPORT_SYMBOL_GPL(ip_set_nfnl_put);
 763
 764/* Communication protocol with userspace over netlink.
 765 *
 766 * The commands are serialized by the nfnl mutex.
 767 */
 768
 769static inline u8 protocol(const struct nlattr * const tb[])
 770{
 771        return nla_get_u8(tb[IPSET_ATTR_PROTOCOL]);
 772}
 773
 774static inline bool
 775protocol_failed(const struct nlattr * const tb[])
 776{
 777        return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) != IPSET_PROTOCOL;
 778}
 779
 780static inline bool
 781protocol_min_failed(const struct nlattr * const tb[])
 782{
 783        return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) < IPSET_PROTOCOL_MIN;
 784}
 785
 786static inline u32
 787flag_exist(const struct nlmsghdr *nlh)
 788{
 789        return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST;
 790}
 791
 792static struct nlmsghdr *
 793start_msg(struct sk_buff *skb, u32 portid, u32 seq, unsigned int flags,
 794          enum ipset_cmd cmd)
 795{
 796        struct nlmsghdr *nlh;
 797        struct nfgenmsg *nfmsg;
 798
 799        nlh = nlmsg_put(skb, portid, seq, nfnl_msg_type(NFNL_SUBSYS_IPSET, cmd),
 800                        sizeof(*nfmsg), flags);
 801        if (!nlh)
 802                return NULL;
 803
 804        nfmsg = nlmsg_data(nlh);
 805        nfmsg->nfgen_family = NFPROTO_IPV4;
 806        nfmsg->version = NFNETLINK_V0;
 807        nfmsg->res_id = 0;
 808
 809        return nlh;
 810}
 811
 812/* Create a set */
 813
 814static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = {
 815        [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
 816        [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
 817                                    .len = IPSET_MAXNAMELEN - 1 },
 818        [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
 819                                    .len = IPSET_MAXNAMELEN - 1},
 820        [IPSET_ATTR_REVISION]   = { .type = NLA_U8 },
 821        [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
 822        [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
 823};
 824
 825static struct ip_set *
 826find_set_and_id(struct ip_set_net *inst, const char *name, ip_set_id_t *id)
 827{
 828        struct ip_set *set = NULL;
 829        ip_set_id_t i;
 830
 831        *id = IPSET_INVALID_ID;
 832        for (i = 0; i < inst->ip_set_max; i++) {
 833                set = ip_set(inst, i);
 834                if (set && STRNCMP(set->name, name)) {
 835                        *id = i;
 836                        break;
 837                }
 838        }
 839        return (*id == IPSET_INVALID_ID ? NULL : set);
 840}
 841
 842static inline struct ip_set *
 843find_set(struct ip_set_net *inst, const char *name)
 844{
 845        ip_set_id_t id;
 846
 847        return find_set_and_id(inst, name, &id);
 848}
 849
 850static int
 851find_free_id(struct ip_set_net *inst, const char *name, ip_set_id_t *index,
 852             struct ip_set **set)
 853{
 854        struct ip_set *s;
 855        ip_set_id_t i;
 856
 857        *index = IPSET_INVALID_ID;
 858        for (i = 0;  i < inst->ip_set_max; i++) {
 859                s = ip_set(inst, i);
 860                if (!s) {
 861                        if (*index == IPSET_INVALID_ID)
 862                                *index = i;
 863                } else if (STRNCMP(name, s->name)) {
 864                        /* Name clash */
 865                        *set = s;
 866                        return -EEXIST;
 867                }
 868        }
 869        if (*index == IPSET_INVALID_ID)
 870                /* No free slot remained */
 871                return -IPSET_ERR_MAX_SETS;
 872        return 0;
 873}
 874
 875static int ip_set_none(struct net *net, struct sock *ctnl, struct sk_buff *skb,
 876                       const struct nlmsghdr *nlh,
 877                       const struct nlattr * const attr[],
 878                       struct netlink_ext_ack *extack)
 879{
 880        return -EOPNOTSUPP;
 881}
 882
 883static int ip_set_create(struct net *net, struct sock *ctnl,
 884                         struct sk_buff *skb, const struct nlmsghdr *nlh,
 885                         const struct nlattr * const attr[],
 886                         struct netlink_ext_ack *extack)
 887{
 888        struct ip_set_net *inst = ip_set_pernet(net);
 889        struct ip_set *set, *clash = NULL;
 890        ip_set_id_t index = IPSET_INVALID_ID;
 891        struct nlattr *tb[IPSET_ATTR_CREATE_MAX + 1] = {};
 892        const char *name, *typename;
 893        u8 family, revision;
 894        u32 flags = flag_exist(nlh);
 895        int ret = 0;
 896
 897        if (unlikely(protocol_min_failed(attr) ||
 898                     !attr[IPSET_ATTR_SETNAME] ||
 899                     !attr[IPSET_ATTR_TYPENAME] ||
 900                     !attr[IPSET_ATTR_REVISION] ||
 901                     !attr[IPSET_ATTR_FAMILY] ||
 902                     (attr[IPSET_ATTR_DATA] &&
 903                      !flag_nested(attr[IPSET_ATTR_DATA]))))
 904                return -IPSET_ERR_PROTOCOL;
 905
 906        name = nla_data(attr[IPSET_ATTR_SETNAME]);
 907        typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
 908        family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
 909        revision = nla_get_u8(attr[IPSET_ATTR_REVISION]);
 910        pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n",
 911                 name, typename, family_name(family), revision);
 912
 913        /* First, and without any locks, allocate and initialize
 914         * a normal base set structure.
 915         */
 916        set = kzalloc(sizeof(*set), GFP_KERNEL);
 917        if (!set)
 918                return -ENOMEM;
 919        spin_lock_init(&set->lock);
 920        strlcpy(set->name, name, IPSET_MAXNAMELEN);
 921        set->family = family;
 922        set->revision = revision;
 923
 924        /* Next, check that we know the type, and take
 925         * a reference on the type, to make sure it stays available
 926         * while constructing our new set.
 927         *
 928         * After referencing the type, we try to create the type
 929         * specific part of the set without holding any locks.
 930         */
 931        ret = find_set_type_get(typename, family, revision, &set->type);
 932        if (ret)
 933                goto out;
 934
 935        /* Without holding any locks, create private part. */
 936        if (attr[IPSET_ATTR_DATA] &&
 937            nla_parse_nested_deprecated(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA], set->type->create_policy, NULL)) {
 938                ret = -IPSET_ERR_PROTOCOL;
 939                goto put_out;
 940        }
 941
 942        ret = set->type->create(net, set, tb, flags);
 943        if (ret != 0)
 944                goto put_out;
 945
 946        /* BTW, ret==0 here. */
 947
 948        /* Here, we have a valid, constructed set and we are protected
 949         * by the nfnl mutex. Find the first free index in ip_set_list
 950         * and check clashing.
 951         */
 952        ret = find_free_id(inst, set->name, &index, &clash);
 953        if (ret == -EEXIST) {
 954                /* If this is the same set and requested, ignore error */
 955                if ((flags & IPSET_FLAG_EXIST) &&
 956                    STRNCMP(set->type->name, clash->type->name) &&
 957                    set->type->family == clash->type->family &&
 958                    set->type->revision_min == clash->type->revision_min &&
 959                    set->type->revision_max == clash->type->revision_max &&
 960                    set->variant->same_set(set, clash))
 961                        ret = 0;
 962                goto cleanup;
 963        } else if (ret == -IPSET_ERR_MAX_SETS) {
 964                struct ip_set **list, **tmp;
 965                ip_set_id_t i = inst->ip_set_max + IP_SET_INC;
 966
 967                if (i < inst->ip_set_max || i == IPSET_INVALID_ID)
 968                        /* Wraparound */
 969                        goto cleanup;
 970
 971                list = kvcalloc(i, sizeof(struct ip_set *), GFP_KERNEL);
 972                if (!list)
 973                        goto cleanup;
 974                /* nfnl mutex is held, both lists are valid */
 975                tmp = ip_set_dereference(inst->ip_set_list);
 976                memcpy(list, tmp, sizeof(struct ip_set *) * inst->ip_set_max);
 977                rcu_assign_pointer(inst->ip_set_list, list);
 978                /* Make sure all current packets have passed through */
 979                synchronize_net();
 980                /* Use new list */
 981                index = inst->ip_set_max;
 982                inst->ip_set_max = i;
 983                kvfree(tmp);
 984                ret = 0;
 985        } else if (ret) {
 986                goto cleanup;
 987        }
 988
 989        /* Finally! Add our shiny new set to the list, and be done. */
 990        pr_debug("create: '%s' created with index %u!\n", set->name, index);
 991        ip_set(inst, index) = set;
 992
 993        return ret;
 994
 995cleanup:
 996        set->variant->destroy(set);
 997put_out:
 998        module_put(set->type->me);
 999out:
1000        kfree(set);
1001        return ret;
1002}
1003
1004/* Destroy sets */
1005
1006static const struct nla_policy
1007ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = {
1008        [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1009        [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1010                                    .len = IPSET_MAXNAMELEN - 1 },
1011};
1012
1013static void
1014ip_set_destroy_set(struct ip_set *set)
1015{
1016        pr_debug("set: %s\n",  set->name);
1017
1018        /* Must call it without holding any lock */
1019        set->variant->destroy(set);
1020        module_put(set->type->me);
1021        kfree(set);
1022}
1023
1024static int ip_set_destroy(struct net *net, struct sock *ctnl,
1025                          struct sk_buff *skb, const struct nlmsghdr *nlh,
1026                          const struct nlattr * const attr[],
1027                          struct netlink_ext_ack *extack)
1028{
1029        struct ip_set_net *inst = ip_set_pernet(net);
1030        struct ip_set *s;
1031        ip_set_id_t i;
1032        int ret = 0;
1033
1034        if (unlikely(protocol_min_failed(attr)))
1035                return -IPSET_ERR_PROTOCOL;
1036
1037        /* Must wait for flush to be really finished in list:set */
1038        rcu_barrier();
1039
1040        /* Commands are serialized and references are
1041         * protected by the ip_set_ref_lock.
1042         * External systems (i.e. xt_set) must call
1043         * ip_set_put|get_nfnl_* functions, that way we
1044         * can safely check references here.
1045         *
1046         * list:set timer can only decrement the reference
1047         * counter, so if it's already zero, we can proceed
1048         * without holding the lock.
1049         */
1050        read_lock_bh(&ip_set_ref_lock);
1051        if (!attr[IPSET_ATTR_SETNAME]) {
1052                for (i = 0; i < inst->ip_set_max; i++) {
1053                        s = ip_set(inst, i);
1054                        if (s && (s->ref || s->ref_netlink)) {
1055                                ret = -IPSET_ERR_BUSY;
1056                                goto out;
1057                        }
1058                }
1059                inst->is_destroyed = true;
1060                read_unlock_bh(&ip_set_ref_lock);
1061                for (i = 0; i < inst->ip_set_max; i++) {
1062                        s = ip_set(inst, i);
1063                        if (s) {
1064                                ip_set(inst, i) = NULL;
1065                                ip_set_destroy_set(s);
1066                        }
1067                }
1068                /* Modified by ip_set_destroy() only, which is serialized */
1069                inst->is_destroyed = false;
1070        } else {
1071                s = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]),
1072                                    &i);
1073                if (!s) {
1074                        ret = -ENOENT;
1075                        goto out;
1076                } else if (s->ref || s->ref_netlink) {
1077                        ret = -IPSET_ERR_BUSY;
1078                        goto out;
1079                }
1080                ip_set(inst, i) = NULL;
1081                read_unlock_bh(&ip_set_ref_lock);
1082
1083                ip_set_destroy_set(s);
1084        }
1085        return 0;
1086out:
1087        read_unlock_bh(&ip_set_ref_lock);
1088        return ret;
1089}
1090
1091/* Flush sets */
1092
1093static void
1094ip_set_flush_set(struct ip_set *set)
1095{
1096        pr_debug("set: %s\n",  set->name);
1097
1098        spin_lock_bh(&set->lock);
1099        set->variant->flush(set);
1100        spin_unlock_bh(&set->lock);
1101}
1102
1103static int ip_set_flush(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1104                        const struct nlmsghdr *nlh,
1105                        const struct nlattr * const attr[],
1106                        struct netlink_ext_ack *extack)
1107{
1108        struct ip_set_net *inst = ip_set_pernet(net);
1109        struct ip_set *s;
1110        ip_set_id_t i;
1111
1112        if (unlikely(protocol_min_failed(attr)))
1113                return -IPSET_ERR_PROTOCOL;
1114
1115        if (!attr[IPSET_ATTR_SETNAME]) {
1116                for (i = 0; i < inst->ip_set_max; i++) {
1117                        s = ip_set(inst, i);
1118                        if (s)
1119                                ip_set_flush_set(s);
1120                }
1121        } else {
1122                s = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1123                if (!s)
1124                        return -ENOENT;
1125
1126                ip_set_flush_set(s);
1127        }
1128
1129        return 0;
1130}
1131
1132/* Rename a set */
1133
1134static const struct nla_policy
1135ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = {
1136        [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1137        [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1138                                    .len = IPSET_MAXNAMELEN - 1 },
1139        [IPSET_ATTR_SETNAME2]   = { .type = NLA_NUL_STRING,
1140                                    .len = IPSET_MAXNAMELEN - 1 },
1141};
1142
1143static int ip_set_rename(struct net *net, struct sock *ctnl,
1144                         struct sk_buff *skb, const struct nlmsghdr *nlh,
1145                         const struct nlattr * const attr[],
1146                         struct netlink_ext_ack *extack)
1147{
1148        struct ip_set_net *inst = ip_set_pernet(net);
1149        struct ip_set *set, *s;
1150        const char *name2;
1151        ip_set_id_t i;
1152        int ret = 0;
1153
1154        if (unlikely(protocol_min_failed(attr) ||
1155                     !attr[IPSET_ATTR_SETNAME] ||
1156                     !attr[IPSET_ATTR_SETNAME2]))
1157                return -IPSET_ERR_PROTOCOL;
1158
1159        set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1160        if (!set)
1161                return -ENOENT;
1162
1163        write_lock_bh(&ip_set_ref_lock);
1164        if (set->ref != 0 || set->ref_netlink != 0) {
1165                ret = -IPSET_ERR_REFERENCED;
1166                goto out;
1167        }
1168
1169        name2 = nla_data(attr[IPSET_ATTR_SETNAME2]);
1170        for (i = 0; i < inst->ip_set_max; i++) {
1171                s = ip_set(inst, i);
1172                if (s && STRNCMP(s->name, name2)) {
1173                        ret = -IPSET_ERR_EXIST_SETNAME2;
1174                        goto out;
1175                }
1176        }
1177        strncpy(set->name, name2, IPSET_MAXNAMELEN);
1178
1179out:
1180        write_unlock_bh(&ip_set_ref_lock);
1181        return ret;
1182}
1183
1184/* Swap two sets so that name/index points to the other.
1185 * References and set names are also swapped.
1186 *
1187 * The commands are serialized by the nfnl mutex and references are
1188 * protected by the ip_set_ref_lock. The kernel interfaces
1189 * do not hold the mutex but the pointer settings are atomic
1190 * so the ip_set_list always contains valid pointers to the sets.
1191 */
1192
1193static int ip_set_swap(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1194                       const struct nlmsghdr *nlh,
1195                       const struct nlattr * const attr[],
1196                       struct netlink_ext_ack *extack)
1197{
1198        struct ip_set_net *inst = ip_set_pernet(net);
1199        struct ip_set *from, *to;
1200        ip_set_id_t from_id, to_id;
1201        char from_name[IPSET_MAXNAMELEN];
1202
1203        if (unlikely(protocol_min_failed(attr) ||
1204                     !attr[IPSET_ATTR_SETNAME] ||
1205                     !attr[IPSET_ATTR_SETNAME2]))
1206                return -IPSET_ERR_PROTOCOL;
1207
1208        from = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]),
1209                               &from_id);
1210        if (!from)
1211                return -ENOENT;
1212
1213        to = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME2]),
1214                             &to_id);
1215        if (!to)
1216                return -IPSET_ERR_EXIST_SETNAME2;
1217
1218        /* Features must not change.
1219         * Not an artifical restriction anymore, as we must prevent
1220         * possible loops created by swapping in setlist type of sets.
1221         */
1222        if (!(from->type->features == to->type->features &&
1223              from->family == to->family))
1224                return -IPSET_ERR_TYPE_MISMATCH;
1225
1226        write_lock_bh(&ip_set_ref_lock);
1227
1228        if (from->ref_netlink || to->ref_netlink) {
1229                write_unlock_bh(&ip_set_ref_lock);
1230                return -EBUSY;
1231        }
1232
1233        strncpy(from_name, from->name, IPSET_MAXNAMELEN);
1234        strncpy(from->name, to->name, IPSET_MAXNAMELEN);
1235        strncpy(to->name, from_name, IPSET_MAXNAMELEN);
1236
1237        swap(from->ref, to->ref);
1238        ip_set(inst, from_id) = to;
1239        ip_set(inst, to_id) = from;
1240        write_unlock_bh(&ip_set_ref_lock);
1241
1242        return 0;
1243}
1244
1245/* List/save set data */
1246
1247#define DUMP_INIT       0
1248#define DUMP_ALL        1
1249#define DUMP_ONE        2
1250#define DUMP_LAST       3
1251
1252#define DUMP_TYPE(arg)          (((u32)(arg)) & 0x0000FFFF)
1253#define DUMP_FLAGS(arg)         (((u32)(arg)) >> 16)
1254
1255static int
1256ip_set_dump_done(struct netlink_callback *cb)
1257{
1258        if (cb->args[IPSET_CB_ARG0]) {
1259                struct ip_set_net *inst =
1260                        (struct ip_set_net *)cb->args[IPSET_CB_NET];
1261                ip_set_id_t index = (ip_set_id_t)cb->args[IPSET_CB_INDEX];
1262                struct ip_set *set = ip_set_ref_netlink(inst, index);
1263
1264                if (set->variant->uref)
1265                        set->variant->uref(set, cb, false);
1266                pr_debug("release set %s\n", set->name);
1267                __ip_set_put_netlink(set);
1268        }
1269        return 0;
1270}
1271
1272static inline void
1273dump_attrs(struct nlmsghdr *nlh)
1274{
1275        const struct nlattr *attr;
1276        int rem;
1277
1278        pr_debug("dump nlmsg\n");
1279        nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) {
1280                pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len);
1281        }
1282}
1283
1284static int
1285dump_init(struct netlink_callback *cb, struct ip_set_net *inst)
1286{
1287        struct nlmsghdr *nlh = nlmsg_hdr(cb->skb);
1288        int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1289        struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1];
1290        struct nlattr *attr = (void *)nlh + min_len;
1291        u32 dump_type;
1292        ip_set_id_t index;
1293        int ret;
1294
1295        ret = nla_parse_deprecated(cda, IPSET_ATTR_CMD_MAX, attr,
1296                                   nlh->nlmsg_len - min_len,
1297                                   ip_set_setname_policy, NULL);
1298        if (ret)
1299                return ret;
1300
1301        cb->args[IPSET_CB_PROTO] = nla_get_u8(cda[IPSET_ATTR_PROTOCOL]);
1302        if (cda[IPSET_ATTR_SETNAME]) {
1303                struct ip_set *set;
1304
1305                set = find_set_and_id(inst, nla_data(cda[IPSET_ATTR_SETNAME]),
1306                                      &index);
1307                if (!set)
1308                        return -ENOENT;
1309
1310                dump_type = DUMP_ONE;
1311                cb->args[IPSET_CB_INDEX] = index;
1312        } else {
1313                dump_type = DUMP_ALL;
1314        }
1315
1316        if (cda[IPSET_ATTR_FLAGS]) {
1317                u32 f = ip_set_get_h32(cda[IPSET_ATTR_FLAGS]);
1318
1319                dump_type |= (f << 16);
1320        }
1321        cb->args[IPSET_CB_NET] = (unsigned long)inst;
1322        cb->args[IPSET_CB_DUMP] = dump_type;
1323
1324        return 0;
1325}
1326
1327static int
1328ip_set_dump_start(struct sk_buff *skb, struct netlink_callback *cb)
1329{
1330        ip_set_id_t index = IPSET_INVALID_ID, max;
1331        struct ip_set *set = NULL;
1332        struct nlmsghdr *nlh = NULL;
1333        unsigned int flags = NETLINK_CB(cb->skb).portid ? NLM_F_MULTI : 0;
1334        struct ip_set_net *inst = ip_set_pernet(sock_net(skb->sk));
1335        u32 dump_type, dump_flags;
1336        bool is_destroyed;
1337        int ret = 0;
1338
1339        if (!cb->args[IPSET_CB_DUMP]) {
1340                ret = dump_init(cb, inst);
1341                if (ret < 0) {
1342                        nlh = nlmsg_hdr(cb->skb);
1343                        /* We have to create and send the error message
1344                         * manually :-(
1345                         */
1346                        if (nlh->nlmsg_flags & NLM_F_ACK)
1347                                netlink_ack(cb->skb, nlh, ret, NULL);
1348                        return ret;
1349                }
1350        }
1351
1352        if (cb->args[IPSET_CB_INDEX] >= inst->ip_set_max)
1353                goto out;
1354
1355        dump_type = DUMP_TYPE(cb->args[IPSET_CB_DUMP]);
1356        dump_flags = DUMP_FLAGS(cb->args[IPSET_CB_DUMP]);
1357        max = dump_type == DUMP_ONE ? cb->args[IPSET_CB_INDEX] + 1
1358                                    : inst->ip_set_max;
1359dump_last:
1360        pr_debug("dump type, flag: %u %u index: %ld\n",
1361                 dump_type, dump_flags, cb->args[IPSET_CB_INDEX]);
1362        for (; cb->args[IPSET_CB_INDEX] < max; cb->args[IPSET_CB_INDEX]++) {
1363                index = (ip_set_id_t)cb->args[IPSET_CB_INDEX];
1364                write_lock_bh(&ip_set_ref_lock);
1365                set = ip_set(inst, index);
1366                is_destroyed = inst->is_destroyed;
1367                if (!set || is_destroyed) {
1368                        write_unlock_bh(&ip_set_ref_lock);
1369                        if (dump_type == DUMP_ONE) {
1370                                ret = -ENOENT;
1371                                goto out;
1372                        }
1373                        if (is_destroyed) {
1374                                /* All sets are just being destroyed */
1375                                ret = 0;
1376                                goto out;
1377                        }
1378                        continue;
1379                }
1380                /* When dumping all sets, we must dump "sorted"
1381                 * so that lists (unions of sets) are dumped last.
1382                 */
1383                if (dump_type != DUMP_ONE &&
1384                    ((dump_type == DUMP_ALL) ==
1385                     !!(set->type->features & IPSET_DUMP_LAST))) {
1386                        write_unlock_bh(&ip_set_ref_lock);
1387                        continue;
1388                }
1389                pr_debug("List set: %s\n", set->name);
1390                if (!cb->args[IPSET_CB_ARG0]) {
1391                        /* Start listing: make sure set won't be destroyed */
1392                        pr_debug("reference set\n");
1393                        set->ref_netlink++;
1394                }
1395                write_unlock_bh(&ip_set_ref_lock);
1396                nlh = start_msg(skb, NETLINK_CB(cb->skb).portid,
1397                                cb->nlh->nlmsg_seq, flags,
1398                                IPSET_CMD_LIST);
1399                if (!nlh) {
1400                        ret = -EMSGSIZE;
1401                        goto release_refcount;
1402                }
1403                if (nla_put_u8(skb, IPSET_ATTR_PROTOCOL,
1404                               cb->args[IPSET_CB_PROTO]) ||
1405                    nla_put_string(skb, IPSET_ATTR_SETNAME, set->name))
1406                        goto nla_put_failure;
1407                if (dump_flags & IPSET_FLAG_LIST_SETNAME)
1408                        goto next_set;
1409                switch (cb->args[IPSET_CB_ARG0]) {
1410                case 0:
1411                        /* Core header data */
1412                        if (nla_put_string(skb, IPSET_ATTR_TYPENAME,
1413                                           set->type->name) ||
1414                            nla_put_u8(skb, IPSET_ATTR_FAMILY,
1415                                       set->family) ||
1416                            nla_put_u8(skb, IPSET_ATTR_REVISION,
1417                                       set->revision))
1418                                goto nla_put_failure;
1419                        if (cb->args[IPSET_CB_PROTO] > IPSET_PROTOCOL_MIN &&
1420                            nla_put_net16(skb, IPSET_ATTR_INDEX, htons(index)))
1421                                goto nla_put_failure;
1422                        ret = set->variant->head(set, skb);
1423                        if (ret < 0)
1424                                goto release_refcount;
1425                        if (dump_flags & IPSET_FLAG_LIST_HEADER)
1426                                goto next_set;
1427                        if (set->variant->uref)
1428                                set->variant->uref(set, cb, true);
1429                        /* fall through */
1430                default:
1431                        ret = set->variant->list(set, skb, cb);
1432                        if (!cb->args[IPSET_CB_ARG0])
1433                                /* Set is done, proceed with next one */
1434                                goto next_set;
1435                        goto release_refcount;
1436                }
1437        }
1438        /* If we dump all sets, continue with dumping last ones */
1439        if (dump_type == DUMP_ALL) {
1440                dump_type = DUMP_LAST;
1441                cb->args[IPSET_CB_DUMP] = dump_type | (dump_flags << 16);
1442                cb->args[IPSET_CB_INDEX] = 0;
1443                if (set && set->variant->uref)
1444                        set->variant->uref(set, cb, false);
1445                goto dump_last;
1446        }
1447        goto out;
1448
1449nla_put_failure:
1450        ret = -EFAULT;
1451next_set:
1452        if (dump_type == DUMP_ONE)
1453                cb->args[IPSET_CB_INDEX] = IPSET_INVALID_ID;
1454        else
1455                cb->args[IPSET_CB_INDEX]++;
1456release_refcount:
1457        /* If there was an error or set is done, release set */
1458        if (ret || !cb->args[IPSET_CB_ARG0]) {
1459                set = ip_set_ref_netlink(inst, index);
1460                if (set->variant->uref)
1461                        set->variant->uref(set, cb, false);
1462                pr_debug("release set %s\n", set->name);
1463                __ip_set_put_netlink(set);
1464                cb->args[IPSET_CB_ARG0] = 0;
1465        }
1466out:
1467        if (nlh) {
1468                nlmsg_end(skb, nlh);
1469                pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len);
1470                dump_attrs(nlh);
1471        }
1472
1473        return ret < 0 ? ret : skb->len;
1474}
1475
1476static int ip_set_dump(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1477                       const struct nlmsghdr *nlh,
1478                       const struct nlattr * const attr[],
1479                       struct netlink_ext_ack *extack)
1480{
1481        if (unlikely(protocol_min_failed(attr)))
1482                return -IPSET_ERR_PROTOCOL;
1483
1484        {
1485                struct netlink_dump_control c = {
1486                        .dump = ip_set_dump_start,
1487                        .done = ip_set_dump_done,
1488                };
1489                return netlink_dump_start(ctnl, skb, nlh, &c);
1490        }
1491}
1492
1493/* Add, del and test */
1494
1495static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = {
1496        [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1497        [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1498                                    .len = IPSET_MAXNAMELEN - 1 },
1499        [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
1500        [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
1501        [IPSET_ATTR_ADT]        = { .type = NLA_NESTED },
1502};
1503
1504static int
1505call_ad(struct sock *ctnl, struct sk_buff *skb, struct ip_set *set,
1506        struct nlattr *tb[], enum ipset_adt adt,
1507        u32 flags, bool use_lineno)
1508{
1509        int ret;
1510        u32 lineno = 0;
1511        bool eexist = flags & IPSET_FLAG_EXIST, retried = false;
1512
1513        do {
1514                spin_lock_bh(&set->lock);
1515                ret = set->variant->uadt(set, tb, adt, &lineno, flags, retried);
1516                spin_unlock_bh(&set->lock);
1517                retried = true;
1518        } while (ret == -EAGAIN &&
1519                 set->variant->resize &&
1520                 (ret = set->variant->resize(set, retried)) == 0);
1521
1522        if (!ret || (ret == -IPSET_ERR_EXIST && eexist))
1523                return 0;
1524        if (lineno && use_lineno) {
1525                /* Error in restore/batch mode: send back lineno */
1526                struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb);
1527                struct sk_buff *skb2;
1528                struct nlmsgerr *errmsg;
1529                size_t payload = min(SIZE_MAX,
1530                                     sizeof(*errmsg) + nlmsg_len(nlh));
1531                int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1532                struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1];
1533                struct nlattr *cmdattr;
1534                u32 *errline;
1535
1536                skb2 = nlmsg_new(payload, GFP_KERNEL);
1537                if (!skb2)
1538                        return -ENOMEM;
1539                rep = __nlmsg_put(skb2, NETLINK_CB(skb).portid,
1540                                  nlh->nlmsg_seq, NLMSG_ERROR, payload, 0);
1541                errmsg = nlmsg_data(rep);
1542                errmsg->error = ret;
1543                memcpy(&errmsg->msg, nlh, nlh->nlmsg_len);
1544                cmdattr = (void *)&errmsg->msg + min_len;
1545
1546                ret = nla_parse_deprecated(cda, IPSET_ATTR_CMD_MAX, cmdattr,
1547                                           nlh->nlmsg_len - min_len,
1548                                           ip_set_adt_policy, NULL);
1549
1550                if (ret) {
1551                        nlmsg_free(skb2);
1552                        return ret;
1553                }
1554                errline = nla_data(cda[IPSET_ATTR_LINENO]);
1555
1556                *errline = lineno;
1557
1558                netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid,
1559                                MSG_DONTWAIT);
1560                /* Signal netlink not to send its ACK/errmsg.  */
1561                return -EINTR;
1562        }
1563
1564        return ret;
1565}
1566
1567static int ip_set_ad(struct net *net, struct sock *ctnl,
1568                     struct sk_buff *skb,
1569                     enum ipset_adt adt,
1570                     const struct nlmsghdr *nlh,
1571                     const struct nlattr * const attr[],
1572                     struct netlink_ext_ack *extack)
1573{
1574        struct ip_set_net *inst = ip_set_pernet(net);
1575        struct ip_set *set;
1576        struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1577        const struct nlattr *nla;
1578        u32 flags = flag_exist(nlh);
1579        bool use_lineno;
1580        int ret = 0;
1581
1582        if (unlikely(protocol_min_failed(attr) ||
1583                     !attr[IPSET_ATTR_SETNAME] ||
1584                     !((attr[IPSET_ATTR_DATA] != NULL) ^
1585                       (attr[IPSET_ATTR_ADT] != NULL)) ||
1586                     (attr[IPSET_ATTR_DATA] &&
1587                      !flag_nested(attr[IPSET_ATTR_DATA])) ||
1588                     (attr[IPSET_ATTR_ADT] &&
1589                      (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1590                       !attr[IPSET_ATTR_LINENO]))))
1591                return -IPSET_ERR_PROTOCOL;
1592
1593        set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1594        if (!set)
1595                return -ENOENT;
1596
1597        use_lineno = !!attr[IPSET_ATTR_LINENO];
1598        if (attr[IPSET_ATTR_DATA]) {
1599                if (nla_parse_nested_deprecated(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA], set->type->adt_policy, NULL))
1600                        return -IPSET_ERR_PROTOCOL;
1601                ret = call_ad(ctnl, skb, set, tb, adt, flags,
1602                              use_lineno);
1603        } else {
1604                int nla_rem;
1605
1606                nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1607                        if (nla_type(nla) != IPSET_ATTR_DATA ||
1608                            !flag_nested(nla) ||
1609                            nla_parse_nested_deprecated(tb, IPSET_ATTR_ADT_MAX, nla, set->type->adt_policy, NULL))
1610                                return -IPSET_ERR_PROTOCOL;
1611                        ret = call_ad(ctnl, skb, set, tb, adt,
1612                                      flags, use_lineno);
1613                        if (ret < 0)
1614                                return ret;
1615                }
1616        }
1617        return ret;
1618}
1619
1620static int ip_set_uadd(struct net *net, struct sock *ctnl,
1621                       struct sk_buff *skb, const struct nlmsghdr *nlh,
1622                       const struct nlattr * const attr[],
1623                       struct netlink_ext_ack *extack)
1624{
1625        return ip_set_ad(net, ctnl, skb,
1626                         IPSET_ADD, nlh, attr, extack);
1627}
1628
1629static int ip_set_udel(struct net *net, struct sock *ctnl,
1630                       struct sk_buff *skb, const struct nlmsghdr *nlh,
1631                       const struct nlattr * const attr[],
1632                       struct netlink_ext_ack *extack)
1633{
1634        return ip_set_ad(net, ctnl, skb,
1635                         IPSET_DEL, nlh, attr, extack);
1636}
1637
1638static int ip_set_utest(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1639                        const struct nlmsghdr *nlh,
1640                        const struct nlattr * const attr[],
1641                        struct netlink_ext_ack *extack)
1642{
1643        struct ip_set_net *inst = ip_set_pernet(net);
1644        struct ip_set *set;
1645        struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1646        int ret = 0;
1647
1648        if (unlikely(protocol_min_failed(attr) ||
1649                     !attr[IPSET_ATTR_SETNAME] ||
1650                     !attr[IPSET_ATTR_DATA] ||
1651                     !flag_nested(attr[IPSET_ATTR_DATA])))
1652                return -IPSET_ERR_PROTOCOL;
1653
1654        set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1655        if (!set)
1656                return -ENOENT;
1657
1658        if (nla_parse_nested_deprecated(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA], set->type->adt_policy, NULL))
1659                return -IPSET_ERR_PROTOCOL;
1660
1661        rcu_read_lock_bh();
1662        ret = set->variant->uadt(set, tb, IPSET_TEST, NULL, 0, 0);
1663        rcu_read_unlock_bh();
1664        /* Userspace can't trigger element to be re-added */
1665        if (ret == -EAGAIN)
1666                ret = 1;
1667
1668        return ret > 0 ? 0 : -IPSET_ERR_EXIST;
1669}
1670
1671/* Get headed data of a set */
1672
1673static int ip_set_header(struct net *net, struct sock *ctnl,
1674                         struct sk_buff *skb, const struct nlmsghdr *nlh,
1675                         const struct nlattr * const attr[],
1676                         struct netlink_ext_ack *extack)
1677{
1678        struct ip_set_net *inst = ip_set_pernet(net);
1679        const struct ip_set *set;
1680        struct sk_buff *skb2;
1681        struct nlmsghdr *nlh2;
1682        int ret = 0;
1683
1684        if (unlikely(protocol_min_failed(attr) ||
1685                     !attr[IPSET_ATTR_SETNAME]))
1686                return -IPSET_ERR_PROTOCOL;
1687
1688        set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1689        if (!set)
1690                return -ENOENT;
1691
1692        skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1693        if (!skb2)
1694                return -ENOMEM;
1695
1696        nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1697                         IPSET_CMD_HEADER);
1698        if (!nlh2)
1699                goto nlmsg_failure;
1700        if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1701            nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name) ||
1702            nla_put_string(skb2, IPSET_ATTR_TYPENAME, set->type->name) ||
1703            nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1704            nla_put_u8(skb2, IPSET_ATTR_REVISION, set->revision))
1705                goto nla_put_failure;
1706        nlmsg_end(skb2, nlh2);
1707
1708        ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1709        if (ret < 0)
1710                return ret;
1711
1712        return 0;
1713
1714nla_put_failure:
1715        nlmsg_cancel(skb2, nlh2);
1716nlmsg_failure:
1717        kfree_skb(skb2);
1718        return -EMSGSIZE;
1719}
1720
1721/* Get type data */
1722
1723static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = {
1724        [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1725        [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
1726                                    .len = IPSET_MAXNAMELEN - 1 },
1727        [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
1728};
1729
1730static int ip_set_type(struct net *net, struct sock *ctnl, struct sk_buff *skb,
1731                       const struct nlmsghdr *nlh,
1732                       const struct nlattr * const attr[],
1733                       struct netlink_ext_ack *extack)
1734{
1735        struct sk_buff *skb2;
1736        struct nlmsghdr *nlh2;
1737        u8 family, min, max;
1738        const char *typename;
1739        int ret = 0;
1740
1741        if (unlikely(protocol_min_failed(attr) ||
1742                     !attr[IPSET_ATTR_TYPENAME] ||
1743                     !attr[IPSET_ATTR_FAMILY]))
1744                return -IPSET_ERR_PROTOCOL;
1745
1746        family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1747        typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1748        ret = find_set_type_minmax(typename, family, &min, &max);
1749        if (ret)
1750                return ret;
1751
1752        skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1753        if (!skb2)
1754                return -ENOMEM;
1755
1756        nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1757                         IPSET_CMD_TYPE);
1758        if (!nlh2)
1759                goto nlmsg_failure;
1760        if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1761            nla_put_string(skb2, IPSET_ATTR_TYPENAME, typename) ||
1762            nla_put_u8(skb2, IPSET_ATTR_FAMILY, family) ||
1763            nla_put_u8(skb2, IPSET_ATTR_REVISION, max) ||
1764            nla_put_u8(skb2, IPSET_ATTR_REVISION_MIN, min))
1765                goto nla_put_failure;
1766        nlmsg_end(skb2, nlh2);
1767
1768        pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len);
1769        ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1770        if (ret < 0)
1771                return ret;
1772
1773        return 0;
1774
1775nla_put_failure:
1776        nlmsg_cancel(skb2, nlh2);
1777nlmsg_failure:
1778        kfree_skb(skb2);
1779        return -EMSGSIZE;
1780}
1781
1782/* Get protocol version */
1783
1784static const struct nla_policy
1785ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = {
1786        [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1787};
1788
1789static int ip_set_protocol(struct net *net, struct sock *ctnl,
1790                           struct sk_buff *skb, const struct nlmsghdr *nlh,
1791                           const struct nlattr * const attr[],
1792                           struct netlink_ext_ack *extack)
1793{
1794        struct sk_buff *skb2;
1795        struct nlmsghdr *nlh2;
1796        int ret = 0;
1797
1798        if (unlikely(!attr[IPSET_ATTR_PROTOCOL]))
1799                return -IPSET_ERR_PROTOCOL;
1800
1801        skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1802        if (!skb2)
1803                return -ENOMEM;
1804
1805        nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1806                         IPSET_CMD_PROTOCOL);
1807        if (!nlh2)
1808                goto nlmsg_failure;
1809        if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL))
1810                goto nla_put_failure;
1811        if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL_MIN, IPSET_PROTOCOL_MIN))
1812                goto nla_put_failure;
1813        nlmsg_end(skb2, nlh2);
1814
1815        ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1816        if (ret < 0)
1817                return ret;
1818
1819        return 0;
1820
1821nla_put_failure:
1822        nlmsg_cancel(skb2, nlh2);
1823nlmsg_failure:
1824        kfree_skb(skb2);
1825        return -EMSGSIZE;
1826}
1827
1828/* Get set by name or index, from userspace */
1829
1830static int ip_set_byname(struct net *net, struct sock *ctnl,
1831                         struct sk_buff *skb, const struct nlmsghdr *nlh,
1832                         const struct nlattr * const attr[],
1833                         struct netlink_ext_ack *extack)
1834{
1835        struct ip_set_net *inst = ip_set_pernet(net);
1836        struct sk_buff *skb2;
1837        struct nlmsghdr *nlh2;
1838        ip_set_id_t id = IPSET_INVALID_ID;
1839        const struct ip_set *set;
1840        int ret = 0;
1841
1842        if (unlikely(protocol_failed(attr) ||
1843                     !attr[IPSET_ATTR_SETNAME]))
1844                return -IPSET_ERR_PROTOCOL;
1845
1846        set = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]), &id);
1847        if (id == IPSET_INVALID_ID)
1848                return -ENOENT;
1849
1850        skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1851        if (!skb2)
1852                return -ENOMEM;
1853
1854        nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1855                         IPSET_CMD_GET_BYNAME);
1856        if (!nlh2)
1857                goto nlmsg_failure;
1858        if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1859            nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1860            nla_put_net16(skb2, IPSET_ATTR_INDEX, htons(id)))
1861                goto nla_put_failure;
1862        nlmsg_end(skb2, nlh2);
1863
1864        ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1865        if (ret < 0)
1866                return ret;
1867
1868        return 0;
1869
1870nla_put_failure:
1871        nlmsg_cancel(skb2, nlh2);
1872nlmsg_failure:
1873        kfree_skb(skb2);
1874        return -EMSGSIZE;
1875}
1876
1877static const struct nla_policy ip_set_index_policy[IPSET_ATTR_CMD_MAX + 1] = {
1878        [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1879        [IPSET_ATTR_INDEX]      = { .type = NLA_U16 },
1880};
1881
1882static int ip_set_byindex(struct net *net, struct sock *ctnl,
1883                          struct sk_buff *skb, const struct nlmsghdr *nlh,
1884                          const struct nlattr * const attr[],
1885                          struct netlink_ext_ack *extack)
1886{
1887        struct ip_set_net *inst = ip_set_pernet(net);
1888        struct sk_buff *skb2;
1889        struct nlmsghdr *nlh2;
1890        ip_set_id_t id = IPSET_INVALID_ID;
1891        const struct ip_set *set;
1892        int ret = 0;
1893
1894        if (unlikely(protocol_failed(attr) ||
1895                     !attr[IPSET_ATTR_INDEX]))
1896                return -IPSET_ERR_PROTOCOL;
1897
1898        id = ip_set_get_h16(attr[IPSET_ATTR_INDEX]);
1899        if (id >= inst->ip_set_max)
1900                return -ENOENT;
1901        set = ip_set(inst, id);
1902        if (set == NULL)
1903                return -ENOENT;
1904
1905        skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1906        if (!skb2)
1907                return -ENOMEM;
1908
1909        nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0,
1910                         IPSET_CMD_GET_BYINDEX);
1911        if (!nlh2)
1912                goto nlmsg_failure;
1913        if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1914            nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name))
1915                goto nla_put_failure;
1916        nlmsg_end(skb2, nlh2);
1917
1918        ret = netlink_unicast(ctnl, skb2, NETLINK_CB(skb).portid, MSG_DONTWAIT);
1919        if (ret < 0)
1920                return ret;
1921
1922        return 0;
1923
1924nla_put_failure:
1925        nlmsg_cancel(skb2, nlh2);
1926nlmsg_failure:
1927        kfree_skb(skb2);
1928        return -EMSGSIZE;
1929}
1930
1931static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = {
1932        [IPSET_CMD_NONE]        = {
1933                .call           = ip_set_none,
1934                .attr_count     = IPSET_ATTR_CMD_MAX,
1935        },
1936        [IPSET_CMD_CREATE]      = {
1937                .call           = ip_set_create,
1938                .attr_count     = IPSET_ATTR_CMD_MAX,
1939                .policy         = ip_set_create_policy,
1940        },
1941        [IPSET_CMD_DESTROY]     = {
1942                .call           = ip_set_destroy,
1943                .attr_count     = IPSET_ATTR_CMD_MAX,
1944                .policy         = ip_set_setname_policy,
1945        },
1946        [IPSET_CMD_FLUSH]       = {
1947                .call           = ip_set_flush,
1948                .attr_count     = IPSET_ATTR_CMD_MAX,
1949                .policy         = ip_set_setname_policy,
1950        },
1951        [IPSET_CMD_RENAME]      = {
1952                .call           = ip_set_rename,
1953                .attr_count     = IPSET_ATTR_CMD_MAX,
1954                .policy         = ip_set_setname2_policy,
1955        },
1956        [IPSET_CMD_SWAP]        = {
1957                .call           = ip_set_swap,
1958                .attr_count     = IPSET_ATTR_CMD_MAX,
1959                .policy         = ip_set_setname2_policy,
1960        },
1961        [IPSET_CMD_LIST]        = {
1962                .call           = ip_set_dump,
1963                .attr_count     = IPSET_ATTR_CMD_MAX,
1964                .policy         = ip_set_setname_policy,
1965        },
1966        [IPSET_CMD_SAVE]        = {
1967                .call           = ip_set_dump,
1968                .attr_count     = IPSET_ATTR_CMD_MAX,
1969                .policy         = ip_set_setname_policy,
1970        },
1971        [IPSET_CMD_ADD] = {
1972                .call           = ip_set_uadd,
1973                .attr_count     = IPSET_ATTR_CMD_MAX,
1974                .policy         = ip_set_adt_policy,
1975        },
1976        [IPSET_CMD_DEL] = {
1977                .call           = ip_set_udel,
1978                .attr_count     = IPSET_ATTR_CMD_MAX,
1979                .policy         = ip_set_adt_policy,
1980        },
1981        [IPSET_CMD_TEST]        = {
1982                .call           = ip_set_utest,
1983                .attr_count     = IPSET_ATTR_CMD_MAX,
1984                .policy         = ip_set_adt_policy,
1985        },
1986        [IPSET_CMD_HEADER]      = {
1987                .call           = ip_set_header,
1988                .attr_count     = IPSET_ATTR_CMD_MAX,
1989                .policy         = ip_set_setname_policy,
1990        },
1991        [IPSET_CMD_TYPE]        = {
1992                .call           = ip_set_type,
1993                .attr_count     = IPSET_ATTR_CMD_MAX,
1994                .policy         = ip_set_type_policy,
1995        },
1996        [IPSET_CMD_PROTOCOL]    = {
1997                .call           = ip_set_protocol,
1998                .attr_count     = IPSET_ATTR_CMD_MAX,
1999                .policy         = ip_set_protocol_policy,
2000        },
2001        [IPSET_CMD_GET_BYNAME]  = {
2002                .call           = ip_set_byname,
2003                .attr_count     = IPSET_ATTR_CMD_MAX,
2004                .policy         = ip_set_setname_policy,
2005        },
2006        [IPSET_CMD_GET_BYINDEX] = {
2007                .call           = ip_set_byindex,
2008                .attr_count     = IPSET_ATTR_CMD_MAX,
2009                .policy         = ip_set_index_policy,
2010        },
2011};
2012
2013static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = {
2014        .name           = "ip_set",
2015        .subsys_id      = NFNL_SUBSYS_IPSET,
2016        .cb_count       = IPSET_MSG_MAX,
2017        .cb             = ip_set_netlink_subsys_cb,
2018};
2019
2020/* Interface to iptables/ip6tables */
2021
2022static int
2023ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len)
2024{
2025        unsigned int *op;
2026        void *data;
2027        int copylen = *len, ret = 0;
2028        struct net *net = sock_net(sk);
2029        struct ip_set_net *inst = ip_set_pernet(net);
2030
2031        if (!ns_capable(net->user_ns, CAP_NET_ADMIN))
2032                return -EPERM;
2033        if (optval != SO_IP_SET)
2034                return -EBADF;
2035        if (*len < sizeof(unsigned int))
2036                return -EINVAL;
2037
2038        data = vmalloc(*len);
2039        if (!data)
2040                return -ENOMEM;
2041        if (copy_from_user(data, user, *len) != 0) {
2042                ret = -EFAULT;
2043                goto done;
2044        }
2045        op = data;
2046
2047        if (*op < IP_SET_OP_VERSION) {
2048                /* Check the version at the beginning of operations */
2049                struct ip_set_req_version *req_version = data;
2050
2051                if (*len < sizeof(struct ip_set_req_version)) {
2052                        ret = -EINVAL;
2053                        goto done;
2054                }
2055
2056                if (req_version->version < IPSET_PROTOCOL_MIN) {
2057                        ret = -EPROTO;
2058                        goto done;
2059                }
2060        }
2061
2062        switch (*op) {
2063        case IP_SET_OP_VERSION: {
2064                struct ip_set_req_version *req_version = data;
2065
2066                if (*len != sizeof(struct ip_set_req_version)) {
2067                        ret = -EINVAL;
2068                        goto done;
2069                }
2070
2071                req_version->version = IPSET_PROTOCOL;
2072                ret = copy_to_user(user, req_version,
2073                                   sizeof(struct ip_set_req_version));
2074                goto done;
2075        }
2076        case IP_SET_OP_GET_BYNAME: {
2077                struct ip_set_req_get_set *req_get = data;
2078                ip_set_id_t id;
2079
2080                if (*len != sizeof(struct ip_set_req_get_set)) {
2081                        ret = -EINVAL;
2082                        goto done;
2083                }
2084                req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
2085                nfnl_lock(NFNL_SUBSYS_IPSET);
2086                find_set_and_id(inst, req_get->set.name, &id);
2087                req_get->set.index = id;
2088                nfnl_unlock(NFNL_SUBSYS_IPSET);
2089                goto copy;
2090        }
2091        case IP_SET_OP_GET_FNAME: {
2092                struct ip_set_req_get_set_family *req_get = data;
2093                ip_set_id_t id;
2094
2095                if (*len != sizeof(struct ip_set_req_get_set_family)) {
2096                        ret = -EINVAL;
2097                        goto done;
2098                }
2099                req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
2100                nfnl_lock(NFNL_SUBSYS_IPSET);
2101                find_set_and_id(inst, req_get->set.name, &id);
2102                req_get->set.index = id;
2103                if (id != IPSET_INVALID_ID)
2104                        req_get->family = ip_set(inst, id)->family;
2105                nfnl_unlock(NFNL_SUBSYS_IPSET);
2106                goto copy;
2107        }
2108        case IP_SET_OP_GET_BYINDEX: {
2109                struct ip_set_req_get_set *req_get = data;
2110                struct ip_set *set;
2111
2112                if (*len != sizeof(struct ip_set_req_get_set) ||
2113                    req_get->set.index >= inst->ip_set_max) {
2114                        ret = -EINVAL;
2115                        goto done;
2116                }
2117                nfnl_lock(NFNL_SUBSYS_IPSET);
2118                set = ip_set(inst, req_get->set.index);
2119                ret = strscpy(req_get->set.name, set ? set->name : "",
2120                              IPSET_MAXNAMELEN);
2121                nfnl_unlock(NFNL_SUBSYS_IPSET);
2122                if (ret < 0)
2123                        goto done;
2124                goto copy;
2125        }
2126        default:
2127                ret = -EBADMSG;
2128                goto done;
2129        }       /* end of switch(op) */
2130
2131copy:
2132        ret = copy_to_user(user, data, copylen);
2133
2134done:
2135        vfree(data);
2136        if (ret > 0)
2137                ret = 0;
2138        return ret;
2139}
2140
2141static struct nf_sockopt_ops so_set __read_mostly = {
2142        .pf             = PF_INET,
2143        .get_optmin     = SO_IP_SET,
2144        .get_optmax     = SO_IP_SET + 1,
2145        .get            = ip_set_sockfn_get,
2146        .owner          = THIS_MODULE,
2147};
2148
2149static int __net_init
2150ip_set_net_init(struct net *net)
2151{
2152        struct ip_set_net *inst = ip_set_pernet(net);
2153        struct ip_set **list;
2154
2155        inst->ip_set_max = max_sets ? max_sets : CONFIG_IP_SET_MAX;
2156        if (inst->ip_set_max >= IPSET_INVALID_ID)
2157                inst->ip_set_max = IPSET_INVALID_ID - 1;
2158
2159        list = kvcalloc(inst->ip_set_max, sizeof(struct ip_set *), GFP_KERNEL);
2160        if (!list)
2161                return -ENOMEM;
2162        inst->is_deleted = false;
2163        inst->is_destroyed = false;
2164        rcu_assign_pointer(inst->ip_set_list, list);
2165        return 0;
2166}
2167
2168static void __net_exit
2169ip_set_net_exit(struct net *net)
2170{
2171        struct ip_set_net *inst = ip_set_pernet(net);
2172
2173        struct ip_set *set = NULL;
2174        ip_set_id_t i;
2175
2176        inst->is_deleted = true; /* flag for ip_set_nfnl_put */
2177
2178        nfnl_lock(NFNL_SUBSYS_IPSET);
2179        for (i = 0; i < inst->ip_set_max; i++) {
2180                set = ip_set(inst, i);
2181                if (set) {
2182                        ip_set(inst, i) = NULL;
2183                        ip_set_destroy_set(set);
2184                }
2185        }
2186        nfnl_unlock(NFNL_SUBSYS_IPSET);
2187        kvfree(rcu_dereference_protected(inst->ip_set_list, 1));
2188}
2189
2190static struct pernet_operations ip_set_net_ops = {
2191        .init   = ip_set_net_init,
2192        .exit   = ip_set_net_exit,
2193        .id     = &ip_set_net_id,
2194        .size   = sizeof(struct ip_set_net),
2195};
2196
2197static int __init
2198ip_set_init(void)
2199{
2200        int ret = register_pernet_subsys(&ip_set_net_ops);
2201
2202        if (ret) {
2203                pr_err("ip_set: cannot register pernet_subsys.\n");
2204                return ret;
2205        }
2206
2207        ret = nfnetlink_subsys_register(&ip_set_netlink_subsys);
2208        if (ret != 0) {
2209                pr_err("ip_set: cannot register with nfnetlink.\n");
2210                unregister_pernet_subsys(&ip_set_net_ops);
2211                return ret;
2212        }
2213
2214        ret = nf_register_sockopt(&so_set);
2215        if (ret != 0) {
2216                pr_err("SO_SET registry failed: %d\n", ret);
2217                nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
2218                unregister_pernet_subsys(&ip_set_net_ops);
2219                return ret;
2220        }
2221
2222        return 0;
2223}
2224
2225static void __exit
2226ip_set_fini(void)
2227{
2228        nf_unregister_sockopt(&so_set);
2229        nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
2230
2231        unregister_pernet_subsys(&ip_set_net_ops);
2232        pr_debug("these are the famous last words\n");
2233}
2234
2235module_init(ip_set_init);
2236module_exit(ip_set_fini);
2237
2238MODULE_DESCRIPTION("ip_set: protocol " __stringify(IPSET_PROTOCOL));
2239