linux/net/ipv4/inet_diag.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-or-later
   2/*
   3 * inet_diag.c  Module for monitoring INET transport protocols sockets.
   4 *
   5 * Authors:     Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
   6 */
   7
   8#include <linux/kernel.h>
   9#include <linux/module.h>
  10#include <linux/types.h>
  11#include <linux/fcntl.h>
  12#include <linux/random.h>
  13#include <linux/slab.h>
  14#include <linux/cache.h>
  15#include <linux/init.h>
  16#include <linux/time.h>
  17
  18#include <net/icmp.h>
  19#include <net/tcp.h>
  20#include <net/ipv6.h>
  21#include <net/inet_common.h>
  22#include <net/inet_connection_sock.h>
  23#include <net/inet_hashtables.h>
  24#include <net/inet_timewait_sock.h>
  25#include <net/inet6_hashtables.h>
  26#include <net/netlink.h>
  27
  28#include <linux/inet.h>
  29#include <linux/stddef.h>
  30
  31#include <linux/inet_diag.h>
  32#include <linux/sock_diag.h>
  33
  34static const struct inet_diag_handler **inet_diag_table;
  35
  36struct inet_diag_entry {
  37        const __be32 *saddr;
  38        const __be32 *daddr;
  39        u16 sport;
  40        u16 dport;
  41        u16 family;
  42        u16 userlocks;
  43        u32 ifindex;
  44        u32 mark;
  45};
  46
  47static DEFINE_MUTEX(inet_diag_table_mutex);
  48
  49static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
  50{
  51        if (!inet_diag_table[proto])
  52                sock_load_diag_module(AF_INET, proto);
  53
  54        mutex_lock(&inet_diag_table_mutex);
  55        if (!inet_diag_table[proto])
  56                return ERR_PTR(-ENOENT);
  57
  58        return inet_diag_table[proto];
  59}
  60
  61static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
  62{
  63        mutex_unlock(&inet_diag_table_mutex);
  64}
  65
  66void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
  67{
  68        r->idiag_family = sk->sk_family;
  69
  70        r->id.idiag_sport = htons(sk->sk_num);
  71        r->id.idiag_dport = sk->sk_dport;
  72        r->id.idiag_if = sk->sk_bound_dev_if;
  73        sock_diag_save_cookie(sk, r->id.idiag_cookie);
  74
  75#if IS_ENABLED(CONFIG_IPV6)
  76        if (sk->sk_family == AF_INET6) {
  77                *(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
  78                *(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
  79        } else
  80#endif
  81        {
  82        memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
  83        memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
  84
  85        r->id.idiag_src[0] = sk->sk_rcv_saddr;
  86        r->id.idiag_dst[0] = sk->sk_daddr;
  87        }
  88}
  89EXPORT_SYMBOL_GPL(inet_diag_msg_common_fill);
  90
  91static size_t inet_sk_attr_size(struct sock *sk,
  92                                const struct inet_diag_req_v2 *req,
  93                                bool net_admin)
  94{
  95        const struct inet_diag_handler *handler;
  96        size_t aux = 0;
  97
  98        handler = inet_diag_table[req->sdiag_protocol];
  99        if (handler && handler->idiag_get_aux_size)
 100                aux = handler->idiag_get_aux_size(sk, net_admin);
 101
 102        return    nla_total_size(sizeof(struct tcp_info))
 103                + nla_total_size(1) /* INET_DIAG_SHUTDOWN */
 104                + nla_total_size(1) /* INET_DIAG_TOS */
 105                + nla_total_size(1) /* INET_DIAG_TCLASS */
 106                + nla_total_size(4) /* INET_DIAG_MARK */
 107                + nla_total_size(4) /* INET_DIAG_CLASS_ID */
 108                + nla_total_size(sizeof(struct inet_diag_meminfo))
 109                + nla_total_size(sizeof(struct inet_diag_msg))
 110                + nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
 111                + nla_total_size(TCP_CA_NAME_MAX)
 112                + nla_total_size(sizeof(struct tcpvegas_info))
 113                + aux
 114                + 64;
 115}
 116
 117int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
 118                             struct inet_diag_msg *r, int ext,
 119                             struct user_namespace *user_ns,
 120                             bool net_admin)
 121{
 122        const struct inet_sock *inet = inet_sk(sk);
 123
 124        if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
 125                goto errout;
 126
 127        /* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
 128         * hence this needs to be included regardless of socket family.
 129         */
 130        if (ext & (1 << (INET_DIAG_TOS - 1)))
 131                if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
 132                        goto errout;
 133
 134#if IS_ENABLED(CONFIG_IPV6)
 135        if (r->idiag_family == AF_INET6) {
 136                if (ext & (1 << (INET_DIAG_TCLASS - 1)))
 137                        if (nla_put_u8(skb, INET_DIAG_TCLASS,
 138                                       inet6_sk(sk)->tclass) < 0)
 139                                goto errout;
 140
 141                if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
 142                    nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk)))
 143                        goto errout;
 144        }
 145#endif
 146
 147        if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, sk->sk_mark))
 148                goto errout;
 149
 150        r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
 151        r->idiag_inode = sock_i_ino(sk);
 152
 153        return 0;
 154errout:
 155        return 1;
 156}
 157EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
 158
 159int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
 160                      struct sk_buff *skb, const struct inet_diag_req_v2 *req,
 161                      struct user_namespace *user_ns,
 162                      u32 portid, u32 seq, u16 nlmsg_flags,
 163                      const struct nlmsghdr *unlh,
 164                      bool net_admin)
 165{
 166        const struct tcp_congestion_ops *ca_ops;
 167        const struct inet_diag_handler *handler;
 168        int ext = req->idiag_ext;
 169        struct inet_diag_msg *r;
 170        struct nlmsghdr  *nlh;
 171        struct nlattr *attr;
 172        void *info = NULL;
 173
 174        handler = inet_diag_table[req->sdiag_protocol];
 175        BUG_ON(!handler);
 176
 177        nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
 178                        nlmsg_flags);
 179        if (!nlh)
 180                return -EMSGSIZE;
 181
 182        r = nlmsg_data(nlh);
 183        BUG_ON(!sk_fullsock(sk));
 184
 185        inet_diag_msg_common_fill(r, sk);
 186        r->idiag_state = sk->sk_state;
 187        r->idiag_timer = 0;
 188        r->idiag_retrans = 0;
 189
 190        if (inet_diag_msg_attrs_fill(sk, skb, r, ext, user_ns, net_admin))
 191                goto errout;
 192
 193        if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
 194                struct inet_diag_meminfo minfo = {
 195                        .idiag_rmem = sk_rmem_alloc_get(sk),
 196                        .idiag_wmem = sk->sk_wmem_queued,
 197                        .idiag_fmem = sk->sk_forward_alloc,
 198                        .idiag_tmem = sk_wmem_alloc_get(sk),
 199                };
 200
 201                if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
 202                        goto errout;
 203        }
 204
 205        if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
 206                if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
 207                        goto errout;
 208
 209        /*
 210         * RAW sockets might have user-defined protocols assigned,
 211         * so report the one supplied on socket creation.
 212         */
 213        if (sk->sk_type == SOCK_RAW) {
 214                if (nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))
 215                        goto errout;
 216        }
 217
 218        if (!icsk) {
 219                handler->idiag_get_info(sk, r, NULL);
 220                goto out;
 221        }
 222
 223        if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
 224            icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT ||
 225            icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
 226                r->idiag_timer = 1;
 227                r->idiag_retrans = icsk->icsk_retransmits;
 228                r->idiag_expires =
 229                        jiffies_to_msecs(icsk->icsk_timeout - jiffies);
 230        } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
 231                r->idiag_timer = 4;
 232                r->idiag_retrans = icsk->icsk_probes_out;
 233                r->idiag_expires =
 234                        jiffies_to_msecs(icsk->icsk_timeout - jiffies);
 235        } else if (timer_pending(&sk->sk_timer)) {
 236                r->idiag_timer = 2;
 237                r->idiag_retrans = icsk->icsk_probes_out;
 238                r->idiag_expires =
 239                        jiffies_to_msecs(sk->sk_timer.expires - jiffies);
 240        } else {
 241                r->idiag_timer = 0;
 242                r->idiag_expires = 0;
 243        }
 244
 245        if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) {
 246                attr = nla_reserve_64bit(skb, INET_DIAG_INFO,
 247                                         handler->idiag_info_size,
 248                                         INET_DIAG_PAD);
 249                if (!attr)
 250                        goto errout;
 251
 252                info = nla_data(attr);
 253        }
 254
 255        if (ext & (1 << (INET_DIAG_CONG - 1))) {
 256                int err = 0;
 257
 258                rcu_read_lock();
 259                ca_ops = READ_ONCE(icsk->icsk_ca_ops);
 260                if (ca_ops)
 261                        err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
 262                rcu_read_unlock();
 263                if (err < 0)
 264                        goto errout;
 265        }
 266
 267        handler->idiag_get_info(sk, r, info);
 268
 269        if (ext & (1 << (INET_DIAG_INFO - 1)) && handler->idiag_get_aux)
 270                if (handler->idiag_get_aux(sk, net_admin, skb) < 0)
 271                        goto errout;
 272
 273        if (sk->sk_state < TCP_TIME_WAIT) {
 274                union tcp_cc_info info;
 275                size_t sz = 0;
 276                int attr;
 277
 278                rcu_read_lock();
 279                ca_ops = READ_ONCE(icsk->icsk_ca_ops);
 280                if (ca_ops && ca_ops->get_info)
 281                        sz = ca_ops->get_info(sk, ext, &attr, &info);
 282                rcu_read_unlock();
 283                if (sz && nla_put(skb, attr, sz, &info) < 0)
 284                        goto errout;
 285        }
 286
 287        if (ext & (1 << (INET_DIAG_CLASS_ID - 1)) ||
 288            ext & (1 << (INET_DIAG_TCLASS - 1))) {
 289                u32 classid = 0;
 290
 291#ifdef CONFIG_SOCK_CGROUP_DATA
 292                classid = sock_cgroup_classid(&sk->sk_cgrp_data);
 293#endif
 294                /* Fallback to socket priority if class id isn't set.
 295                 * Classful qdiscs use it as direct reference to class.
 296                 * For cgroup2 classid is always zero.
 297                 */
 298                if (!classid)
 299                        classid = sk->sk_priority;
 300
 301                if (nla_put_u32(skb, INET_DIAG_CLASS_ID, classid))
 302                        goto errout;
 303        }
 304
 305out:
 306        nlmsg_end(skb, nlh);
 307        return 0;
 308
 309errout:
 310        nlmsg_cancel(skb, nlh);
 311        return -EMSGSIZE;
 312}
 313EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
 314
 315static int inet_csk_diag_fill(struct sock *sk,
 316                              struct sk_buff *skb,
 317                              const struct inet_diag_req_v2 *req,
 318                              struct user_namespace *user_ns,
 319                              u32 portid, u32 seq, u16 nlmsg_flags,
 320                              const struct nlmsghdr *unlh,
 321                              bool net_admin)
 322{
 323        return inet_sk_diag_fill(sk, inet_csk(sk), skb, req, user_ns,
 324                                 portid, seq, nlmsg_flags, unlh, net_admin);
 325}
 326
 327static int inet_twsk_diag_fill(struct sock *sk,
 328                               struct sk_buff *skb,
 329                               u32 portid, u32 seq, u16 nlmsg_flags,
 330                               const struct nlmsghdr *unlh)
 331{
 332        struct inet_timewait_sock *tw = inet_twsk(sk);
 333        struct inet_diag_msg *r;
 334        struct nlmsghdr *nlh;
 335        long tmo;
 336
 337        nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
 338                        nlmsg_flags);
 339        if (!nlh)
 340                return -EMSGSIZE;
 341
 342        r = nlmsg_data(nlh);
 343        BUG_ON(tw->tw_state != TCP_TIME_WAIT);
 344
 345        tmo = tw->tw_timer.expires - jiffies;
 346        if (tmo < 0)
 347                tmo = 0;
 348
 349        inet_diag_msg_common_fill(r, sk);
 350        r->idiag_retrans      = 0;
 351
 352        r->idiag_state        = tw->tw_substate;
 353        r->idiag_timer        = 3;
 354        r->idiag_expires      = jiffies_to_msecs(tmo);
 355        r->idiag_rqueue       = 0;
 356        r->idiag_wqueue       = 0;
 357        r->idiag_uid          = 0;
 358        r->idiag_inode        = 0;
 359
 360        nlmsg_end(skb, nlh);
 361        return 0;
 362}
 363
 364static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
 365                              u32 portid, u32 seq, u16 nlmsg_flags,
 366                              const struct nlmsghdr *unlh, bool net_admin)
 367{
 368        struct request_sock *reqsk = inet_reqsk(sk);
 369        struct inet_diag_msg *r;
 370        struct nlmsghdr *nlh;
 371        long tmo;
 372
 373        nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
 374                        nlmsg_flags);
 375        if (!nlh)
 376                return -EMSGSIZE;
 377
 378        r = nlmsg_data(nlh);
 379        inet_diag_msg_common_fill(r, sk);
 380        r->idiag_state = TCP_SYN_RECV;
 381        r->idiag_timer = 1;
 382        r->idiag_retrans = reqsk->num_retrans;
 383
 384        BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
 385                     offsetof(struct sock, sk_cookie));
 386
 387        tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
 388        r->idiag_expires = (tmo >= 0) ? jiffies_to_msecs(tmo) : 0;
 389        r->idiag_rqueue = 0;
 390        r->idiag_wqueue = 0;
 391        r->idiag_uid    = 0;
 392        r->idiag_inode  = 0;
 393
 394        if (net_admin && nla_put_u32(skb, INET_DIAG_MARK,
 395                                     inet_rsk(reqsk)->ir_mark))
 396                return -EMSGSIZE;
 397
 398        nlmsg_end(skb, nlh);
 399        return 0;
 400}
 401
 402static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
 403                        const struct inet_diag_req_v2 *r,
 404                        struct user_namespace *user_ns,
 405                        u32 portid, u32 seq, u16 nlmsg_flags,
 406                        const struct nlmsghdr *unlh, bool net_admin)
 407{
 408        if (sk->sk_state == TCP_TIME_WAIT)
 409                return inet_twsk_diag_fill(sk, skb, portid, seq,
 410                                           nlmsg_flags, unlh);
 411
 412        if (sk->sk_state == TCP_NEW_SYN_RECV)
 413                return inet_req_diag_fill(sk, skb, portid, seq,
 414                                          nlmsg_flags, unlh, net_admin);
 415
 416        return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
 417                                  nlmsg_flags, unlh, net_admin);
 418}
 419
 420struct sock *inet_diag_find_one_icsk(struct net *net,
 421                                     struct inet_hashinfo *hashinfo,
 422                                     const struct inet_diag_req_v2 *req)
 423{
 424        struct sock *sk;
 425
 426        rcu_read_lock();
 427        if (req->sdiag_family == AF_INET)
 428                sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[0],
 429                                 req->id.idiag_dport, req->id.idiag_src[0],
 430                                 req->id.idiag_sport, req->id.idiag_if);
 431#if IS_ENABLED(CONFIG_IPV6)
 432        else if (req->sdiag_family == AF_INET6) {
 433                if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
 434                    ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src))
 435                        sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[3],
 436                                         req->id.idiag_dport, req->id.idiag_src[3],
 437                                         req->id.idiag_sport, req->id.idiag_if);
 438                else
 439                        sk = inet6_lookup(net, hashinfo, NULL, 0,
 440                                          (struct in6_addr *)req->id.idiag_dst,
 441                                          req->id.idiag_dport,
 442                                          (struct in6_addr *)req->id.idiag_src,
 443                                          req->id.idiag_sport,
 444                                          req->id.idiag_if);
 445        }
 446#endif
 447        else {
 448                rcu_read_unlock();
 449                return ERR_PTR(-EINVAL);
 450        }
 451        rcu_read_unlock();
 452        if (!sk)
 453                return ERR_PTR(-ENOENT);
 454
 455        if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) {
 456                sock_gen_put(sk);
 457                return ERR_PTR(-ENOENT);
 458        }
 459
 460        return sk;
 461}
 462EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);
 463
 464int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
 465                            struct sk_buff *in_skb,
 466                            const struct nlmsghdr *nlh,
 467                            const struct inet_diag_req_v2 *req)
 468{
 469        bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN);
 470        struct net *net = sock_net(in_skb->sk);
 471        struct sk_buff *rep;
 472        struct sock *sk;
 473        int err;
 474
 475        sk = inet_diag_find_one_icsk(net, hashinfo, req);
 476        if (IS_ERR(sk))
 477                return PTR_ERR(sk);
 478
 479        rep = nlmsg_new(inet_sk_attr_size(sk, req, net_admin), GFP_KERNEL);
 480        if (!rep) {
 481                err = -ENOMEM;
 482                goto out;
 483        }
 484
 485        err = sk_diag_fill(sk, rep, req,
 486                           sk_user_ns(NETLINK_CB(in_skb).sk),
 487                           NETLINK_CB(in_skb).portid,
 488                           nlh->nlmsg_seq, 0, nlh, net_admin);
 489        if (err < 0) {
 490                WARN_ON(err == -EMSGSIZE);
 491                nlmsg_free(rep);
 492                goto out;
 493        }
 494        err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid,
 495                              MSG_DONTWAIT);
 496        if (err > 0)
 497                err = 0;
 498
 499out:
 500        if (sk)
 501                sock_gen_put(sk);
 502
 503        return err;
 504}
 505EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
 506
 507static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
 508                               const struct nlmsghdr *nlh,
 509                               const struct inet_diag_req_v2 *req)
 510{
 511        const struct inet_diag_handler *handler;
 512        int err;
 513
 514        handler = inet_diag_lock_handler(req->sdiag_protocol);
 515        if (IS_ERR(handler))
 516                err = PTR_ERR(handler);
 517        else if (cmd == SOCK_DIAG_BY_FAMILY)
 518                err = handler->dump_one(in_skb, nlh, req);
 519        else if (cmd == SOCK_DESTROY && handler->destroy)
 520                err = handler->destroy(in_skb, req);
 521        else
 522                err = -EOPNOTSUPP;
 523        inet_diag_unlock_handler(handler);
 524
 525        return err;
 526}
 527
 528static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
 529{
 530        int words = bits >> 5;
 531
 532        bits &= 0x1f;
 533
 534        if (words) {
 535                if (memcmp(a1, a2, words << 2))
 536                        return 0;
 537        }
 538        if (bits) {
 539                __be32 w1, w2;
 540                __be32 mask;
 541
 542                w1 = a1[words];
 543                w2 = a2[words];
 544
 545                mask = htonl((0xffffffff) << (32 - bits));
 546
 547                if ((w1 ^ w2) & mask)
 548                        return 0;
 549        }
 550
 551        return 1;
 552}
 553
 554static int inet_diag_bc_run(const struct nlattr *_bc,
 555                            const struct inet_diag_entry *entry)
 556{
 557        const void *bc = nla_data(_bc);
 558        int len = nla_len(_bc);
 559
 560        while (len > 0) {
 561                int yes = 1;
 562                const struct inet_diag_bc_op *op = bc;
 563
 564                switch (op->code) {
 565                case INET_DIAG_BC_NOP:
 566                        break;
 567                case INET_DIAG_BC_JMP:
 568                        yes = 0;
 569                        break;
 570                case INET_DIAG_BC_S_EQ:
 571                        yes = entry->sport == op[1].no;
 572                        break;
 573                case INET_DIAG_BC_S_GE:
 574                        yes = entry->sport >= op[1].no;
 575                        break;
 576                case INET_DIAG_BC_S_LE:
 577                        yes = entry->sport <= op[1].no;
 578                        break;
 579                case INET_DIAG_BC_D_EQ:
 580                        yes = entry->dport == op[1].no;
 581                        break;
 582                case INET_DIAG_BC_D_GE:
 583                        yes = entry->dport >= op[1].no;
 584                        break;
 585                case INET_DIAG_BC_D_LE:
 586                        yes = entry->dport <= op[1].no;
 587                        break;
 588                case INET_DIAG_BC_AUTO:
 589                        yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
 590                        break;
 591                case INET_DIAG_BC_S_COND:
 592                case INET_DIAG_BC_D_COND: {
 593                        const struct inet_diag_hostcond *cond;
 594                        const __be32 *addr;
 595
 596                        cond = (const struct inet_diag_hostcond *)(op + 1);
 597                        if (cond->port != -1 &&
 598                            cond->port != (op->code == INET_DIAG_BC_S_COND ?
 599                                             entry->sport : entry->dport)) {
 600                                yes = 0;
 601                                break;
 602                        }
 603
 604                        if (op->code == INET_DIAG_BC_S_COND)
 605                                addr = entry->saddr;
 606                        else
 607                                addr = entry->daddr;
 608
 609                        if (cond->family != AF_UNSPEC &&
 610                            cond->family != entry->family) {
 611                                if (entry->family == AF_INET6 &&
 612                                    cond->family == AF_INET) {
 613                                        if (addr[0] == 0 && addr[1] == 0 &&
 614                                            addr[2] == htonl(0xffff) &&
 615                                            bitstring_match(addr + 3,
 616                                                            cond->addr,
 617                                                            cond->prefix_len))
 618                                                break;
 619                                }
 620                                yes = 0;
 621                                break;
 622                        }
 623
 624                        if (cond->prefix_len == 0)
 625                                break;
 626                        if (bitstring_match(addr, cond->addr,
 627                                            cond->prefix_len))
 628                                break;
 629                        yes = 0;
 630                        break;
 631                }
 632                case INET_DIAG_BC_DEV_COND: {
 633                        u32 ifindex;
 634
 635                        ifindex = *((const u32 *)(op + 1));
 636                        if (ifindex != entry->ifindex)
 637                                yes = 0;
 638                        break;
 639                }
 640                case INET_DIAG_BC_MARK_COND: {
 641                        struct inet_diag_markcond *cond;
 642
 643                        cond = (struct inet_diag_markcond *)(op + 1);
 644                        if ((entry->mark & cond->mask) != cond->mark)
 645                                yes = 0;
 646                        break;
 647                }
 648                }
 649
 650                if (yes) {
 651                        len -= op->yes;
 652                        bc += op->yes;
 653                } else {
 654                        len -= op->no;
 655                        bc += op->no;
 656                }
 657        }
 658        return len == 0;
 659}
 660
 661/* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
 662 */
 663static void entry_fill_addrs(struct inet_diag_entry *entry,
 664                             const struct sock *sk)
 665{
 666#if IS_ENABLED(CONFIG_IPV6)
 667        if (sk->sk_family == AF_INET6) {
 668                entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
 669                entry->daddr = sk->sk_v6_daddr.s6_addr32;
 670        } else
 671#endif
 672        {
 673                entry->saddr = &sk->sk_rcv_saddr;
 674                entry->daddr = &sk->sk_daddr;
 675        }
 676}
 677
 678int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
 679{
 680        struct inet_sock *inet = inet_sk(sk);
 681        struct inet_diag_entry entry;
 682
 683        if (!bc)
 684                return 1;
 685
 686        entry.family = sk->sk_family;
 687        entry_fill_addrs(&entry, sk);
 688        entry.sport = inet->inet_num;
 689        entry.dport = ntohs(inet->inet_dport);
 690        entry.ifindex = sk->sk_bound_dev_if;
 691        entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
 692        if (sk_fullsock(sk))
 693                entry.mark = sk->sk_mark;
 694        else if (sk->sk_state == TCP_NEW_SYN_RECV)
 695                entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark;
 696        else
 697                entry.mark = 0;
 698
 699        return inet_diag_bc_run(bc, &entry);
 700}
 701EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
 702
 703static int valid_cc(const void *bc, int len, int cc)
 704{
 705        while (len >= 0) {
 706                const struct inet_diag_bc_op *op = bc;
 707
 708                if (cc > len)
 709                        return 0;
 710                if (cc == len)
 711                        return 1;
 712                if (op->yes < 4 || op->yes & 3)
 713                        return 0;
 714                len -= op->yes;
 715                bc  += op->yes;
 716        }
 717        return 0;
 718}
 719
 720/* data is u32 ifindex */
 721static bool valid_devcond(const struct inet_diag_bc_op *op, int len,
 722                          int *min_len)
 723{
 724        /* Check ifindex space. */
 725        *min_len += sizeof(u32);
 726        if (len < *min_len)
 727                return false;
 728
 729        return true;
 730}
 731/* Validate an inet_diag_hostcond. */
 732static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
 733                           int *min_len)
 734{
 735        struct inet_diag_hostcond *cond;
 736        int addr_len;
 737
 738        /* Check hostcond space. */
 739        *min_len += sizeof(struct inet_diag_hostcond);
 740        if (len < *min_len)
 741                return false;
 742        cond = (struct inet_diag_hostcond *)(op + 1);
 743
 744        /* Check address family and address length. */
 745        switch (cond->family) {
 746        case AF_UNSPEC:
 747                addr_len = 0;
 748                break;
 749        case AF_INET:
 750                addr_len = sizeof(struct in_addr);
 751                break;
 752        case AF_INET6:
 753                addr_len = sizeof(struct in6_addr);
 754                break;
 755        default:
 756                return false;
 757        }
 758        *min_len += addr_len;
 759        if (len < *min_len)
 760                return false;
 761
 762        /* Check prefix length (in bits) vs address length (in bytes). */
 763        if (cond->prefix_len > 8 * addr_len)
 764                return false;
 765
 766        return true;
 767}
 768
 769/* Validate a port comparison operator. */
 770static bool valid_port_comparison(const struct inet_diag_bc_op *op,
 771                                  int len, int *min_len)
 772{
 773        /* Port comparisons put the port in a follow-on inet_diag_bc_op. */
 774        *min_len += sizeof(struct inet_diag_bc_op);
 775        if (len < *min_len)
 776                return false;
 777        return true;
 778}
 779
 780static bool valid_markcond(const struct inet_diag_bc_op *op, int len,
 781                           int *min_len)
 782{
 783        *min_len += sizeof(struct inet_diag_markcond);
 784        return len >= *min_len;
 785}
 786
 787static int inet_diag_bc_audit(const struct nlattr *attr,
 788                              const struct sk_buff *skb)
 789{
 790        bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN);
 791        const void *bytecode, *bc;
 792        int bytecode_len, len;
 793
 794        if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op))
 795                return -EINVAL;
 796
 797        bytecode = bc = nla_data(attr);
 798        len = bytecode_len = nla_len(attr);
 799
 800        while (len > 0) {
 801                int min_len = sizeof(struct inet_diag_bc_op);
 802                const struct inet_diag_bc_op *op = bc;
 803
 804                switch (op->code) {
 805                case INET_DIAG_BC_S_COND:
 806                case INET_DIAG_BC_D_COND:
 807                        if (!valid_hostcond(bc, len, &min_len))
 808                                return -EINVAL;
 809                        break;
 810                case INET_DIAG_BC_DEV_COND:
 811                        if (!valid_devcond(bc, len, &min_len))
 812                                return -EINVAL;
 813                        break;
 814                case INET_DIAG_BC_S_EQ:
 815                case INET_DIAG_BC_S_GE:
 816                case INET_DIAG_BC_S_LE:
 817                case INET_DIAG_BC_D_EQ:
 818                case INET_DIAG_BC_D_GE:
 819                case INET_DIAG_BC_D_LE:
 820                        if (!valid_port_comparison(bc, len, &min_len))
 821                                return -EINVAL;
 822                        break;
 823                case INET_DIAG_BC_MARK_COND:
 824                        if (!net_admin)
 825                                return -EPERM;
 826                        if (!valid_markcond(bc, len, &min_len))
 827                                return -EINVAL;
 828                        break;
 829                case INET_DIAG_BC_AUTO:
 830                case INET_DIAG_BC_JMP:
 831                case INET_DIAG_BC_NOP:
 832                        break;
 833                default:
 834                        return -EINVAL;
 835                }
 836
 837                if (op->code != INET_DIAG_BC_NOP) {
 838                        if (op->no < min_len || op->no > len + 4 || op->no & 3)
 839                                return -EINVAL;
 840                        if (op->no < len &&
 841                            !valid_cc(bytecode, bytecode_len, len - op->no))
 842                                return -EINVAL;
 843                }
 844
 845                if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
 846                        return -EINVAL;
 847                bc  += op->yes;
 848                len -= op->yes;
 849        }
 850        return len == 0 ? 0 : -EINVAL;
 851}
 852
 853static int inet_csk_diag_dump(struct sock *sk,
 854                              struct sk_buff *skb,
 855                              struct netlink_callback *cb,
 856                              const struct inet_diag_req_v2 *r,
 857                              const struct nlattr *bc,
 858                              bool net_admin)
 859{
 860        if (!inet_diag_bc_sk(bc, sk))
 861                return 0;
 862
 863        return inet_csk_diag_fill(sk, skb, r,
 864                                  sk_user_ns(NETLINK_CB(cb->skb).sk),
 865                                  NETLINK_CB(cb->skb).portid,
 866                                  cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh,
 867                                  net_admin);
 868}
 869
 870static void twsk_build_assert(void)
 871{
 872        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
 873                     offsetof(struct sock, sk_family));
 874
 875        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) !=
 876                     offsetof(struct inet_sock, inet_num));
 877
 878        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) !=
 879                     offsetof(struct inet_sock, inet_dport));
 880
 881        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) !=
 882                     offsetof(struct inet_sock, inet_rcv_saddr));
 883
 884        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) !=
 885                     offsetof(struct inet_sock, inet_daddr));
 886
 887#if IS_ENABLED(CONFIG_IPV6)
 888        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) !=
 889                     offsetof(struct sock, sk_v6_rcv_saddr));
 890
 891        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) !=
 892                     offsetof(struct sock, sk_v6_daddr));
 893#endif
 894}
 895
 896void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
 897                         struct netlink_callback *cb,
 898                         const struct inet_diag_req_v2 *r, struct nlattr *bc)
 899{
 900        bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN);
 901        struct net *net = sock_net(skb->sk);
 902        u32 idiag_states = r->idiag_states;
 903        int i, num, s_i, s_num;
 904        struct sock *sk;
 905
 906        if (idiag_states & TCPF_SYN_RECV)
 907                idiag_states |= TCPF_NEW_SYN_RECV;
 908        s_i = cb->args[1];
 909        s_num = num = cb->args[2];
 910
 911        if (cb->args[0] == 0) {
 912                if (!(idiag_states & TCPF_LISTEN) || r->id.idiag_dport)
 913                        goto skip_listen_ht;
 914
 915                for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
 916                        struct inet_listen_hashbucket *ilb;
 917
 918                        num = 0;
 919                        ilb = &hashinfo->listening_hash[i];
 920                        spin_lock(&ilb->lock);
 921                        sk_for_each(sk, &ilb->head) {
 922                                struct inet_sock *inet = inet_sk(sk);
 923
 924                                if (!net_eq(sock_net(sk), net))
 925                                        continue;
 926
 927                                if (num < s_num) {
 928                                        num++;
 929                                        continue;
 930                                }
 931
 932                                if (r->sdiag_family != AF_UNSPEC &&
 933                                    sk->sk_family != r->sdiag_family)
 934                                        goto next_listen;
 935
 936                                if (r->id.idiag_sport != inet->inet_sport &&
 937                                    r->id.idiag_sport)
 938                                        goto next_listen;
 939
 940                                if (inet_csk_diag_dump(sk, skb, cb, r,
 941                                                       bc, net_admin) < 0) {
 942                                        spin_unlock(&ilb->lock);
 943                                        goto done;
 944                                }
 945
 946next_listen:
 947                                ++num;
 948                        }
 949                        spin_unlock(&ilb->lock);
 950
 951                        s_num = 0;
 952                }
 953skip_listen_ht:
 954                cb->args[0] = 1;
 955                s_i = num = s_num = 0;
 956        }
 957
 958        if (!(idiag_states & ~TCPF_LISTEN))
 959                goto out;
 960
 961#define SKARR_SZ 16
 962        for (i = s_i; i <= hashinfo->ehash_mask; i++) {
 963                struct inet_ehash_bucket *head = &hashinfo->ehash[i];
 964                spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
 965                struct hlist_nulls_node *node;
 966                struct sock *sk_arr[SKARR_SZ];
 967                int num_arr[SKARR_SZ];
 968                int idx, accum, res;
 969
 970                if (hlist_nulls_empty(&head->chain))
 971                        continue;
 972
 973                if (i > s_i)
 974                        s_num = 0;
 975
 976next_chunk:
 977                num = 0;
 978                accum = 0;
 979                spin_lock_bh(lock);
 980                sk_nulls_for_each(sk, node, &head->chain) {
 981                        int state;
 982
 983                        if (!net_eq(sock_net(sk), net))
 984                                continue;
 985                        if (num < s_num)
 986                                goto next_normal;
 987                        state = (sk->sk_state == TCP_TIME_WAIT) ?
 988                                inet_twsk(sk)->tw_substate : sk->sk_state;
 989                        if (!(idiag_states & (1 << state)))
 990                                goto next_normal;
 991                        if (r->sdiag_family != AF_UNSPEC &&
 992                            sk->sk_family != r->sdiag_family)
 993                                goto next_normal;
 994                        if (r->id.idiag_sport != htons(sk->sk_num) &&
 995                            r->id.idiag_sport)
 996                                goto next_normal;
 997                        if (r->id.idiag_dport != sk->sk_dport &&
 998                            r->id.idiag_dport)
 999                                goto next_normal;
1000                        twsk_build_assert();
1001
1002                        if (!inet_diag_bc_sk(bc, sk))
1003                                goto next_normal;
1004
1005                        if (!refcount_inc_not_zero(&sk->sk_refcnt))
1006                                goto next_normal;
1007
1008                        num_arr[accum] = num;
1009                        sk_arr[accum] = sk;
1010                        if (++accum == SKARR_SZ)
1011                                break;
1012next_normal:
1013                        ++num;
1014                }
1015                spin_unlock_bh(lock);
1016                res = 0;
1017                for (idx = 0; idx < accum; idx++) {
1018                        if (res >= 0) {
1019                                res = sk_diag_fill(sk_arr[idx], skb, r,
1020                                           sk_user_ns(NETLINK_CB(cb->skb).sk),
1021                                           NETLINK_CB(cb->skb).portid,
1022                                           cb->nlh->nlmsg_seq, NLM_F_MULTI,
1023                                           cb->nlh, net_admin);
1024                                if (res < 0)
1025                                        num = num_arr[idx];
1026                        }
1027                        sock_gen_put(sk_arr[idx]);
1028                }
1029                if (res < 0)
1030                        break;
1031                cond_resched();
1032                if (accum == SKARR_SZ) {
1033                        s_num = num + 1;
1034                        goto next_chunk;
1035                }
1036        }
1037
1038done:
1039        cb->args[1] = i;
1040        cb->args[2] = num;
1041out:
1042        ;
1043}
1044EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
1045
1046static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
1047                            const struct inet_diag_req_v2 *r,
1048                            struct nlattr *bc)
1049{
1050        const struct inet_diag_handler *handler;
1051        int err = 0;
1052
1053        handler = inet_diag_lock_handler(r->sdiag_protocol);
1054        if (!IS_ERR(handler))
1055                handler->dump(skb, cb, r, bc);
1056        else
1057                err = PTR_ERR(handler);
1058        inet_diag_unlock_handler(handler);
1059
1060        return err ? : skb->len;
1061}
1062
1063static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
1064{
1065        int hdrlen = sizeof(struct inet_diag_req_v2);
1066        struct nlattr *bc = NULL;
1067
1068        if (nlmsg_attrlen(cb->nlh, hdrlen))
1069                bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
1070
1071        return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc);
1072}
1073
1074static int inet_diag_type2proto(int type)
1075{
1076        switch (type) {
1077        case TCPDIAG_GETSOCK:
1078                return IPPROTO_TCP;
1079        case DCCPDIAG_GETSOCK:
1080                return IPPROTO_DCCP;
1081        default:
1082                return 0;
1083        }
1084}
1085
1086static int inet_diag_dump_compat(struct sk_buff *skb,
1087                                 struct netlink_callback *cb)
1088{
1089        struct inet_diag_req *rc = nlmsg_data(cb->nlh);
1090        int hdrlen = sizeof(struct inet_diag_req);
1091        struct inet_diag_req_v2 req;
1092        struct nlattr *bc = NULL;
1093
1094        req.sdiag_family = AF_UNSPEC; /* compatibility */
1095        req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
1096        req.idiag_ext = rc->idiag_ext;
1097        req.idiag_states = rc->idiag_states;
1098        req.id = rc->id;
1099
1100        if (nlmsg_attrlen(cb->nlh, hdrlen))
1101                bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
1102
1103        return __inet_diag_dump(skb, cb, &req, bc);
1104}
1105
1106static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
1107                                      const struct nlmsghdr *nlh)
1108{
1109        struct inet_diag_req *rc = nlmsg_data(nlh);
1110        struct inet_diag_req_v2 req;
1111
1112        req.sdiag_family = rc->idiag_family;
1113        req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
1114        req.idiag_ext = rc->idiag_ext;
1115        req.idiag_states = rc->idiag_states;
1116        req.id = rc->id;
1117
1118        return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh, &req);
1119}
1120
1121static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
1122{
1123        int hdrlen = sizeof(struct inet_diag_req);
1124        struct net *net = sock_net(skb->sk);
1125
1126        if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
1127            nlmsg_len(nlh) < hdrlen)
1128                return -EINVAL;
1129
1130        if (nlh->nlmsg_flags & NLM_F_DUMP) {
1131                if (nlmsg_attrlen(nlh, hdrlen)) {
1132                        struct nlattr *attr;
1133                        int err;
1134
1135                        attr = nlmsg_find_attr(nlh, hdrlen,
1136                                               INET_DIAG_REQ_BYTECODE);
1137                        err = inet_diag_bc_audit(attr, skb);
1138                        if (err)
1139                                return err;
1140                }
1141                {
1142                        struct netlink_dump_control c = {
1143                                .dump = inet_diag_dump_compat,
1144                        };
1145                        return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1146                }
1147        }
1148
1149        return inet_diag_get_exact_compat(skb, nlh);
1150}
1151
1152static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
1153{
1154        int hdrlen = sizeof(struct inet_diag_req_v2);
1155        struct net *net = sock_net(skb->sk);
1156
1157        if (nlmsg_len(h) < hdrlen)
1158                return -EINVAL;
1159
1160        if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
1161            h->nlmsg_flags & NLM_F_DUMP) {
1162                if (nlmsg_attrlen(h, hdrlen)) {
1163                        struct nlattr *attr;
1164                        int err;
1165
1166                        attr = nlmsg_find_attr(h, hdrlen,
1167                                               INET_DIAG_REQ_BYTECODE);
1168                        err = inet_diag_bc_audit(attr, skb);
1169                        if (err)
1170                                return err;
1171                }
1172                {
1173                        struct netlink_dump_control c = {
1174                                .dump = inet_diag_dump,
1175                        };
1176                        return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1177                }
1178        }
1179
1180        return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h));
1181}
1182
1183static
1184int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
1185{
1186        const struct inet_diag_handler *handler;
1187        struct nlmsghdr *nlh;
1188        struct nlattr *attr;
1189        struct inet_diag_msg *r;
1190        void *info = NULL;
1191        int err = 0;
1192
1193        nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0);
1194        if (!nlh)
1195                return -ENOMEM;
1196
1197        r = nlmsg_data(nlh);
1198        memset(r, 0, sizeof(*r));
1199        inet_diag_msg_common_fill(r, sk);
1200        if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM)
1201                r->id.idiag_sport = inet_sk(sk)->inet_sport;
1202        r->idiag_state = sk->sk_state;
1203
1204        if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) {
1205                nlmsg_cancel(skb, nlh);
1206                return err;
1207        }
1208
1209        handler = inet_diag_lock_handler(sk->sk_protocol);
1210        if (IS_ERR(handler)) {
1211                inet_diag_unlock_handler(handler);
1212                nlmsg_cancel(skb, nlh);
1213                return PTR_ERR(handler);
1214        }
1215
1216        attr = handler->idiag_info_size
1217                ? nla_reserve_64bit(skb, INET_DIAG_INFO,
1218                                    handler->idiag_info_size,
1219                                    INET_DIAG_PAD)
1220                : NULL;
1221        if (attr)
1222                info = nla_data(attr);
1223
1224        handler->idiag_get_info(sk, r, info);
1225        inet_diag_unlock_handler(handler);
1226
1227        nlmsg_end(skb, nlh);
1228        return 0;
1229}
1230
1231static const struct sock_diag_handler inet_diag_handler = {
1232        .family = AF_INET,
1233        .dump = inet_diag_handler_cmd,
1234        .get_info = inet_diag_handler_get_info,
1235        .destroy = inet_diag_handler_cmd,
1236};
1237
1238static const struct sock_diag_handler inet6_diag_handler = {
1239        .family = AF_INET6,
1240        .dump = inet_diag_handler_cmd,
1241        .get_info = inet_diag_handler_get_info,
1242        .destroy = inet_diag_handler_cmd,
1243};
1244
1245int inet_diag_register(const struct inet_diag_handler *h)
1246{
1247        const __u16 type = h->idiag_type;
1248        int err = -EINVAL;
1249
1250        if (type >= IPPROTO_MAX)
1251                goto out;
1252
1253        mutex_lock(&inet_diag_table_mutex);
1254        err = -EEXIST;
1255        if (!inet_diag_table[type]) {
1256                inet_diag_table[type] = h;
1257                err = 0;
1258        }
1259        mutex_unlock(&inet_diag_table_mutex);
1260out:
1261        return err;
1262}
1263EXPORT_SYMBOL_GPL(inet_diag_register);
1264
1265void inet_diag_unregister(const struct inet_diag_handler *h)
1266{
1267        const __u16 type = h->idiag_type;
1268
1269        if (type >= IPPROTO_MAX)
1270                return;
1271
1272        mutex_lock(&inet_diag_table_mutex);
1273        inet_diag_table[type] = NULL;
1274        mutex_unlock(&inet_diag_table_mutex);
1275}
1276EXPORT_SYMBOL_GPL(inet_diag_unregister);
1277
1278static int __init inet_diag_init(void)
1279{
1280        const int inet_diag_table_size = (IPPROTO_MAX *
1281                                          sizeof(struct inet_diag_handler *));
1282        int err = -ENOMEM;
1283
1284        inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1285        if (!inet_diag_table)
1286                goto out;
1287
1288        err = sock_diag_register(&inet_diag_handler);
1289        if (err)
1290                goto out_free_nl;
1291
1292        err = sock_diag_register(&inet6_diag_handler);
1293        if (err)
1294                goto out_free_inet;
1295
1296        sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1297out:
1298        return err;
1299
1300out_free_inet:
1301        sock_diag_unregister(&inet_diag_handler);
1302out_free_nl:
1303        kfree(inet_diag_table);
1304        goto out;
1305}
1306
1307static void __exit inet_diag_exit(void)
1308{
1309        sock_diag_unregister(&inet6_diag_handler);
1310        sock_diag_unregister(&inet_diag_handler);
1311        sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1312        kfree(inet_diag_table);
1313}
1314
1315module_init(inet_diag_init);
1316module_exit(inet_diag_exit);
1317MODULE_LICENSE("GPL");
1318MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1319MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);
1320