linux/net/ipv4/inet_diag.c
<<
>>
Prefs
   1/*
   2 * inet_diag.c  Module for monitoring INET transport protocols sockets.
   3 *
   4 * Authors:     Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
   5 *
   6 *      This program is free software; you can redistribute it and/or
   7 *      modify it under the terms of the GNU General Public License
   8 *      as published by the Free Software Foundation; either version
   9 *      2 of the License, or (at your option) any later version.
  10 */
  11
  12#include <linux/kernel.h>
  13#include <linux/module.h>
  14#include <linux/types.h>
  15#include <linux/fcntl.h>
  16#include <linux/random.h>
  17#include <linux/slab.h>
  18#include <linux/cache.h>
  19#include <linux/init.h>
  20#include <linux/time.h>
  21
  22#include <net/icmp.h>
  23#include <net/tcp.h>
  24#include <net/ipv6.h>
  25#include <net/inet_common.h>
  26#include <net/inet_connection_sock.h>
  27#include <net/inet_hashtables.h>
  28#include <net/inet_timewait_sock.h>
  29#include <net/inet6_hashtables.h>
  30#include <net/bpf_sk_storage.h>
  31#include <net/netlink.h>
  32
  33#include <linux/inet.h>
  34#include <linux/stddef.h>
  35
  36#include <linux/inet_diag.h>
  37#include <linux/sock_diag.h>
  38
  39static const struct inet_diag_handler **inet_diag_table;
  40
  41struct inet_diag_entry {
  42        const __be32 *saddr;
  43        const __be32 *daddr;
  44        u16 sport;
  45        u16 dport;
  46        u16 family;
  47        u16 userlocks;
  48        u32 ifindex;
  49        u32 mark;
  50};
  51
  52static DEFINE_MUTEX(inet_diag_table_mutex);
  53
  54static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
  55{
  56        if (proto < 0 || proto >= IPPROTO_MAX) {
  57                mutex_lock(&inet_diag_table_mutex);
  58                return ERR_PTR(-ENOENT);
  59        }
  60
  61        if (!inet_diag_table[proto])
  62                sock_load_diag_module(AF_INET, proto);
  63
  64        mutex_lock(&inet_diag_table_mutex);
  65        if (!inet_diag_table[proto])
  66                return ERR_PTR(-ENOENT);
  67
  68        return inet_diag_table[proto];
  69}
  70
  71static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
  72{
  73        mutex_unlock(&inet_diag_table_mutex);
  74}
  75
  76void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
  77{
  78        r->idiag_family = sk->sk_family;
  79
  80        r->id.idiag_sport = htons(sk->sk_num);
  81        r->id.idiag_dport = sk->sk_dport;
  82        r->id.idiag_if = sk->sk_bound_dev_if;
  83        sock_diag_save_cookie(sk, r->id.idiag_cookie);
  84
  85#if IS_ENABLED(CONFIG_IPV6)
  86        if (sk->sk_family == AF_INET6) {
  87                *(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
  88                *(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
  89        } else
  90#endif
  91        {
  92        memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
  93        memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
  94
  95        r->id.idiag_src[0] = sk->sk_rcv_saddr;
  96        r->id.idiag_dst[0] = sk->sk_daddr;
  97        }
  98}
  99EXPORT_SYMBOL_GPL(inet_diag_msg_common_fill);
 100
 101static size_t inet_sk_attr_size(struct sock *sk,
 102                                const struct inet_diag_req_v2 *req,
 103                                bool net_admin)
 104{
 105        const struct inet_diag_handler *handler;
 106        size_t aux = 0;
 107
 108        handler = inet_diag_table[req->sdiag_protocol];
 109        if (handler && handler->idiag_get_aux_size)
 110                aux = handler->idiag_get_aux_size(sk, net_admin);
 111
 112        return    nla_total_size(sizeof(struct tcp_info))
 113                + nla_total_size(sizeof(struct inet_diag_msg))
 114                + inet_diag_msg_attrs_size()
 115                + nla_total_size(sizeof(struct inet_diag_meminfo))
 116                + nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
 117                + nla_total_size(TCP_CA_NAME_MAX)
 118                + nla_total_size(sizeof(struct tcpvegas_info))
 119                + aux
 120                + 64;
 121}
 122
 123int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
 124                             struct inet_diag_msg *r, int ext,
 125                             struct user_namespace *user_ns,
 126                             bool net_admin)
 127{
 128        const struct inet_sock *inet = inet_sk(sk);
 129
 130        if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
 131                goto errout;
 132
 133        /* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
 134         * hence this needs to be included regardless of socket family.
 135         */
 136        if (ext & (1 << (INET_DIAG_TOS - 1)))
 137                if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
 138                        goto errout;
 139
 140#if IS_ENABLED(CONFIG_IPV6)
 141        if (r->idiag_family == AF_INET6) {
 142                if (ext & (1 << (INET_DIAG_TCLASS - 1)))
 143                        if (nla_put_u8(skb, INET_DIAG_TCLASS,
 144                                       inet6_sk(sk)->tclass) < 0)
 145                                goto errout;
 146
 147                if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
 148                    nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk)))
 149                        goto errout;
 150        }
 151#endif
 152
 153        if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, sk->sk_mark))
 154                goto errout;
 155
 156        if (ext & (1 << (INET_DIAG_CLASS_ID - 1)) ||
 157            ext & (1 << (INET_DIAG_TCLASS - 1))) {
 158                u32 classid = 0;
 159
 160#ifdef CONFIG_SOCK_CGROUP_DATA
 161                classid = sock_cgroup_classid(&sk->sk_cgrp_data);
 162#endif
 163                /* Fallback to socket priority if class id isn't set.
 164                 * Classful qdiscs use it as direct reference to class.
 165                 * For cgroup2 classid is always zero.
 166                 */
 167                if (!classid)
 168                        classid = sk->sk_priority;
 169
 170                if (nla_put_u32(skb, INET_DIAG_CLASS_ID, classid))
 171                        goto errout;
 172        }
 173
 174        r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
 175        r->idiag_inode = sock_i_ino(sk);
 176
 177        return 0;
 178errout:
 179        return 1;
 180}
 181EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
 182
 183static int inet_diag_parse_attrs(const struct nlmsghdr *nlh, int hdrlen,
 184                                 struct nlattr **req_nlas)
 185{
 186        struct nlattr *nla;
 187        int remaining;
 188
 189        nlmsg_for_each_attr(nla, nlh, hdrlen, remaining) {
 190                int type = nla_type(nla);
 191
 192                if (type == INET_DIAG_REQ_PROTOCOL && nla_len(nla) != sizeof(u32))
 193                        return -EINVAL;
 194
 195                if (type < __INET_DIAG_REQ_MAX)
 196                        req_nlas[type] = nla;
 197        }
 198        return 0;
 199}
 200
 201static int inet_diag_get_protocol(const struct inet_diag_req_v2 *req,
 202                                  const struct inet_diag_dump_data *data)
 203{
 204        int retval;
 205
 206        if (data->req_nlas[INET_DIAG_REQ_PROTOCOL])
 207                retval = nla_get_u32(data->req_nlas[INET_DIAG_REQ_PROTOCOL]);
 208        else
 209                retval = req->sdiag_protocol;
 210        return retval == IPPROTO_MPTCP ? IPPROTO_MPTCP_KERN : retval;
 211}
 212
 213#define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)))
 214
 215int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
 216                      struct sk_buff *skb, struct netlink_callback *cb,
 217                      const struct inet_diag_req_v2 *req,
 218                      u16 nlmsg_flags, bool net_admin)
 219{
 220        const struct tcp_congestion_ops *ca_ops;
 221        const struct inet_diag_handler *handler;
 222        struct inet_diag_dump_data *cb_data;
 223        int ext = req->idiag_ext;
 224        struct inet_diag_msg *r;
 225        struct nlmsghdr  *nlh;
 226        struct nlattr *attr;
 227        void *info = NULL;
 228
 229        cb_data = cb->data;
 230        handler = inet_diag_table[inet_diag_get_protocol(req, cb_data)];
 231        BUG_ON(!handler);
 232
 233        nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
 234                        cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
 235        if (!nlh)
 236                return -EMSGSIZE;
 237
 238        r = nlmsg_data(nlh);
 239        BUG_ON(!sk_fullsock(sk));
 240
 241        inet_diag_msg_common_fill(r, sk);
 242        r->idiag_state = sk->sk_state;
 243        r->idiag_timer = 0;
 244        r->idiag_retrans = 0;
 245
 246        if (inet_diag_msg_attrs_fill(sk, skb, r, ext,
 247                                     sk_user_ns(NETLINK_CB(cb->skb).sk),
 248                                     net_admin))
 249                goto errout;
 250
 251        if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
 252                struct inet_diag_meminfo minfo = {
 253                        .idiag_rmem = sk_rmem_alloc_get(sk),
 254                        .idiag_wmem = READ_ONCE(sk->sk_wmem_queued),
 255                        .idiag_fmem = sk->sk_forward_alloc,
 256                        .idiag_tmem = sk_wmem_alloc_get(sk),
 257                };
 258
 259                if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
 260                        goto errout;
 261        }
 262
 263        if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
 264                if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
 265                        goto errout;
 266
 267        /*
 268         * RAW sockets might have user-defined protocols assigned,
 269         * so report the one supplied on socket creation.
 270         */
 271        if (sk->sk_type == SOCK_RAW) {
 272                if (nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))
 273                        goto errout;
 274        }
 275
 276        if (!icsk) {
 277                handler->idiag_get_info(sk, r, NULL);
 278                goto out;
 279        }
 280
 281        if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
 282            icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT ||
 283            icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
 284                r->idiag_timer = 1;
 285                r->idiag_retrans = icsk->icsk_retransmits;
 286                r->idiag_expires =
 287                        jiffies_to_msecs(icsk->icsk_timeout - jiffies);
 288        } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
 289                r->idiag_timer = 4;
 290                r->idiag_retrans = icsk->icsk_probes_out;
 291                r->idiag_expires =
 292                        jiffies_to_msecs(icsk->icsk_timeout - jiffies);
 293        } else if (timer_pending(&sk->sk_timer)) {
 294                r->idiag_timer = 2;
 295                r->idiag_retrans = icsk->icsk_probes_out;
 296                r->idiag_expires =
 297                        jiffies_to_msecs(sk->sk_timer.expires - jiffies);
 298        } else {
 299                r->idiag_timer = 0;
 300                r->idiag_expires = 0;
 301        }
 302
 303        if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) {
 304                attr = nla_reserve_64bit(skb, INET_DIAG_INFO,
 305                                         handler->idiag_info_size,
 306                                         INET_DIAG_PAD);
 307                if (!attr)
 308                        goto errout;
 309
 310                info = nla_data(attr);
 311        }
 312
 313        if (ext & (1 << (INET_DIAG_CONG - 1))) {
 314                int err = 0;
 315
 316                rcu_read_lock();
 317                ca_ops = READ_ONCE(icsk->icsk_ca_ops);
 318                if (ca_ops)
 319                        err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
 320                rcu_read_unlock();
 321                if (err < 0)
 322                        goto errout;
 323        }
 324
 325        handler->idiag_get_info(sk, r, info);
 326
 327        if (ext & (1 << (INET_DIAG_INFO - 1)) && handler->idiag_get_aux)
 328                if (handler->idiag_get_aux(sk, net_admin, skb) < 0)
 329                        goto errout;
 330
 331        if (sk->sk_state < TCP_TIME_WAIT) {
 332                union tcp_cc_info info;
 333                size_t sz = 0;
 334                int attr;
 335
 336                rcu_read_lock();
 337                ca_ops = READ_ONCE(icsk->icsk_ca_ops);
 338                if (ca_ops && ca_ops->get_info)
 339                        sz = ca_ops->get_info(sk, ext, &attr, &info);
 340                rcu_read_unlock();
 341                if (sz && nla_put(skb, attr, sz, &info) < 0)
 342                        goto errout;
 343        }
 344
 345        /* Keep it at the end for potential retry with a larger skb,
 346         * or else do best-effort fitting, which is only done for the
 347         * first_nlmsg.
 348         */
 349        if (cb_data->bpf_stg_diag) {
 350                bool first_nlmsg = ((unsigned char *)nlh == skb->data);
 351                unsigned int prev_min_dump_alloc;
 352                unsigned int total_nla_size = 0;
 353                unsigned int msg_len;
 354                int err;
 355
 356                msg_len = skb_tail_pointer(skb) - (unsigned char *)nlh;
 357                err = bpf_sk_storage_diag_put(cb_data->bpf_stg_diag, sk, skb,
 358                                              INET_DIAG_SK_BPF_STORAGES,
 359                                              &total_nla_size);
 360
 361                if (!err)
 362                        goto out;
 363
 364                total_nla_size += msg_len;
 365                prev_min_dump_alloc = cb->min_dump_alloc;
 366                if (total_nla_size > prev_min_dump_alloc)
 367                        cb->min_dump_alloc = min_t(u32, total_nla_size,
 368                                                   MAX_DUMP_ALLOC_SIZE);
 369
 370                if (!first_nlmsg)
 371                        goto errout;
 372
 373                if (cb->min_dump_alloc > prev_min_dump_alloc)
 374                        /* Retry with pskb_expand_head() with
 375                         * __GFP_DIRECT_RECLAIM
 376                         */
 377                        goto errout;
 378
 379                WARN_ON_ONCE(total_nla_size <= prev_min_dump_alloc);
 380
 381                /* Send what we have for this sk
 382                 * and move on to the next sk in the following
 383                 * dump()
 384                 */
 385        }
 386
 387out:
 388        nlmsg_end(skb, nlh);
 389        return 0;
 390
 391errout:
 392        nlmsg_cancel(skb, nlh);
 393        return -EMSGSIZE;
 394}
 395EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
 396
 397static int inet_twsk_diag_fill(struct sock *sk,
 398                               struct sk_buff *skb,
 399                               struct netlink_callback *cb,
 400                               u16 nlmsg_flags)
 401{
 402        struct inet_timewait_sock *tw = inet_twsk(sk);
 403        struct inet_diag_msg *r;
 404        struct nlmsghdr *nlh;
 405        long tmo;
 406
 407        nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid,
 408                        cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type,
 409                        sizeof(*r), nlmsg_flags);
 410        if (!nlh)
 411                return -EMSGSIZE;
 412
 413        r = nlmsg_data(nlh);
 414        BUG_ON(tw->tw_state != TCP_TIME_WAIT);
 415
 416        tmo = tw->tw_timer.expires - jiffies;
 417        if (tmo < 0)
 418                tmo = 0;
 419
 420        inet_diag_msg_common_fill(r, sk);
 421        r->idiag_retrans      = 0;
 422
 423        r->idiag_state        = tw->tw_substate;
 424        r->idiag_timer        = 3;
 425        r->idiag_expires      = jiffies_to_msecs(tmo);
 426        r->idiag_rqueue       = 0;
 427        r->idiag_wqueue       = 0;
 428        r->idiag_uid          = 0;
 429        r->idiag_inode        = 0;
 430
 431        nlmsg_end(skb, nlh);
 432        return 0;
 433}
 434
 435static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
 436                              struct netlink_callback *cb,
 437                              u16 nlmsg_flags, bool net_admin)
 438{
 439        struct request_sock *reqsk = inet_reqsk(sk);
 440        struct inet_diag_msg *r;
 441        struct nlmsghdr *nlh;
 442        long tmo;
 443
 444        nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
 445                        cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
 446        if (!nlh)
 447                return -EMSGSIZE;
 448
 449        r = nlmsg_data(nlh);
 450        inet_diag_msg_common_fill(r, sk);
 451        r->idiag_state = TCP_SYN_RECV;
 452        r->idiag_timer = 1;
 453        r->idiag_retrans = reqsk->num_retrans;
 454
 455        BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
 456                     offsetof(struct sock, sk_cookie));
 457
 458        tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
 459        r->idiag_expires = (tmo >= 0) ? jiffies_to_msecs(tmo) : 0;
 460        r->idiag_rqueue = 0;
 461        r->idiag_wqueue = 0;
 462        r->idiag_uid    = 0;
 463        r->idiag_inode  = 0;
 464
 465        if (net_admin && nla_put_u32(skb, INET_DIAG_MARK,
 466                                     inet_rsk(reqsk)->ir_mark)) {
 467                nlmsg_cancel(skb, nlh);
 468                return -EMSGSIZE;
 469        }
 470
 471        nlmsg_end(skb, nlh);
 472        return 0;
 473}
 474
 475static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
 476                        struct netlink_callback *cb,
 477                        const struct inet_diag_req_v2 *r,
 478                        u16 nlmsg_flags, bool net_admin)
 479{
 480        if (sk->sk_state == TCP_TIME_WAIT)
 481                return inet_twsk_diag_fill(sk, skb, cb, nlmsg_flags);
 482
 483        if (sk->sk_state == TCP_NEW_SYN_RECV)
 484                return inet_req_diag_fill(sk, skb, cb, nlmsg_flags, net_admin);
 485
 486        return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, r, nlmsg_flags,
 487                                 net_admin);
 488}
 489
 490struct sock *inet_diag_find_one_icsk(struct net *net,
 491                                     struct inet_hashinfo *hashinfo,
 492                                     const struct inet_diag_req_v2 *req)
 493{
 494        struct sock *sk;
 495
 496        rcu_read_lock();
 497        if (req->sdiag_family == AF_INET)
 498                sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[0],
 499                                 req->id.idiag_dport, req->id.idiag_src[0],
 500                                 req->id.idiag_sport, req->id.idiag_if);
 501#if IS_ENABLED(CONFIG_IPV6)
 502        else if (req->sdiag_family == AF_INET6) {
 503                if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
 504                    ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src))
 505                        sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[3],
 506                                         req->id.idiag_dport, req->id.idiag_src[3],
 507                                         req->id.idiag_sport, req->id.idiag_if);
 508                else
 509                        sk = inet6_lookup(net, hashinfo, NULL, 0,
 510                                          (struct in6_addr *)req->id.idiag_dst,
 511                                          req->id.idiag_dport,
 512                                          (struct in6_addr *)req->id.idiag_src,
 513                                          req->id.idiag_sport,
 514                                          req->id.idiag_if);
 515        }
 516#endif
 517        else {
 518                rcu_read_unlock();
 519                return ERR_PTR(-EINVAL);
 520        }
 521        rcu_read_unlock();
 522        if (!sk)
 523                return ERR_PTR(-ENOENT);
 524
 525        if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) {
 526                sock_gen_put(sk);
 527                return ERR_PTR(-ENOENT);
 528        }
 529
 530        return sk;
 531}
 532EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);
 533
 534int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
 535                            struct netlink_callback *cb,
 536                            const struct inet_diag_req_v2 *req)
 537{
 538        struct sk_buff *in_skb = cb->skb;
 539        bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN);
 540        struct net *net = sock_net(in_skb->sk);
 541        struct sk_buff *rep;
 542        struct sock *sk;
 543        int err;
 544
 545        sk = inet_diag_find_one_icsk(net, hashinfo, req);
 546        if (IS_ERR(sk))
 547                return PTR_ERR(sk);
 548
 549        rep = nlmsg_new(inet_sk_attr_size(sk, req, net_admin), GFP_KERNEL);
 550        if (!rep) {
 551                err = -ENOMEM;
 552                goto out;
 553        }
 554
 555        err = sk_diag_fill(sk, rep, cb, req, 0, net_admin);
 556        if (err < 0) {
 557                WARN_ON(err == -EMSGSIZE);
 558                nlmsg_free(rep);
 559                goto out;
 560        }
 561        err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid,
 562                              MSG_DONTWAIT);
 563        if (err > 0)
 564                err = 0;
 565
 566out:
 567        if (sk)
 568                sock_gen_put(sk);
 569
 570        return err;
 571}
 572EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
 573
 574static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
 575                               const struct nlmsghdr *nlh,
 576                               int hdrlen,
 577                               const struct inet_diag_req_v2 *req)
 578{
 579        const struct inet_diag_handler *handler;
 580        struct inet_diag_dump_data dump_data;
 581        int err, protocol;
 582
 583        memset(&dump_data, 0, sizeof(dump_data));
 584        err = inet_diag_parse_attrs(nlh, hdrlen, dump_data.req_nlas);
 585        if (err)
 586                return err;
 587
 588        protocol = inet_diag_get_protocol(req, &dump_data);
 589
 590        handler = inet_diag_lock_handler(protocol);
 591        if (IS_ERR(handler)) {
 592                err = PTR_ERR(handler);
 593        } else if (cmd == SOCK_DIAG_BY_FAMILY) {
 594                struct netlink_callback cb = {
 595                        .nlh = nlh,
 596                        .skb = in_skb,
 597                        .data = &dump_data,
 598                };
 599                err = handler->dump_one(&cb, req);
 600        } else if (cmd == SOCK_DESTROY && handler->destroy) {
 601                err = handler->destroy(in_skb, req);
 602        } else {
 603                err = -EOPNOTSUPP;
 604        }
 605        inet_diag_unlock_handler(handler);
 606
 607        return err;
 608}
 609
 610static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
 611{
 612        int words = bits >> 5;
 613
 614        bits &= 0x1f;
 615
 616        if (words) {
 617                if (memcmp(a1, a2, words << 2))
 618                        return 0;
 619        }
 620        if (bits) {
 621                __be32 w1, w2;
 622                __be32 mask;
 623
 624                w1 = a1[words];
 625                w2 = a2[words];
 626
 627                mask = htonl((0xffffffff) << (32 - bits));
 628
 629                if ((w1 ^ w2) & mask)
 630                        return 0;
 631        }
 632
 633        return 1;
 634}
 635
 636static int inet_diag_bc_run(const struct nlattr *_bc,
 637                            const struct inet_diag_entry *entry)
 638{
 639        const void *bc = nla_data(_bc);
 640        int len = nla_len(_bc);
 641
 642        while (len > 0) {
 643                int yes = 1;
 644                const struct inet_diag_bc_op *op = bc;
 645
 646                switch (op->code) {
 647                case INET_DIAG_BC_NOP:
 648                        break;
 649                case INET_DIAG_BC_JMP:
 650                        yes = 0;
 651                        break;
 652                case INET_DIAG_BC_S_EQ:
 653                        yes = entry->sport == op[1].no;
 654                        break;
 655                case INET_DIAG_BC_S_GE:
 656                        yes = entry->sport >= op[1].no;
 657                        break;
 658                case INET_DIAG_BC_S_LE:
 659                        yes = entry->sport <= op[1].no;
 660                        break;
 661                case INET_DIAG_BC_D_EQ:
 662                        yes = entry->dport == op[1].no;
 663                        break;
 664                case INET_DIAG_BC_D_GE:
 665                        yes = entry->dport >= op[1].no;
 666                        break;
 667                case INET_DIAG_BC_D_LE:
 668                        yes = entry->dport <= op[1].no;
 669                        break;
 670                case INET_DIAG_BC_AUTO:
 671                        yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
 672                        break;
 673                case INET_DIAG_BC_S_COND:
 674                case INET_DIAG_BC_D_COND: {
 675                        const struct inet_diag_hostcond *cond;
 676                        const __be32 *addr;
 677
 678                        cond = (const struct inet_diag_hostcond *)(op + 1);
 679                        if (cond->port != -1 &&
 680                            cond->port != (op->code == INET_DIAG_BC_S_COND ?
 681                                             entry->sport : entry->dport)) {
 682                                yes = 0;
 683                                break;
 684                        }
 685
 686                        if (op->code == INET_DIAG_BC_S_COND)
 687                                addr = entry->saddr;
 688                        else
 689                                addr = entry->daddr;
 690
 691                        if (cond->family != AF_UNSPEC &&
 692                            cond->family != entry->family) {
 693                                if (entry->family == AF_INET6 &&
 694                                    cond->family == AF_INET) {
 695                                        if (addr[0] == 0 && addr[1] == 0 &&
 696                                            addr[2] == htonl(0xffff) &&
 697                                            bitstring_match(addr + 3,
 698                                                            cond->addr,
 699                                                            cond->prefix_len))
 700                                                break;
 701                                }
 702                                yes = 0;
 703                                break;
 704                        }
 705
 706                        if (cond->prefix_len == 0)
 707                                break;
 708                        if (bitstring_match(addr, cond->addr,
 709                                            cond->prefix_len))
 710                                break;
 711                        yes = 0;
 712                        break;
 713                }
 714                case INET_DIAG_BC_DEV_COND: {
 715                        u32 ifindex;
 716
 717                        ifindex = *((const u32 *)(op + 1));
 718                        if (ifindex != entry->ifindex)
 719                                yes = 0;
 720                        break;
 721                }
 722                case INET_DIAG_BC_MARK_COND: {
 723                        struct inet_diag_markcond *cond;
 724
 725                        cond = (struct inet_diag_markcond *)(op + 1);
 726                        if ((entry->mark & cond->mask) != cond->mark)
 727                                yes = 0;
 728                        break;
 729                }
 730                }
 731
 732                if (yes) {
 733                        len -= op->yes;
 734                        bc += op->yes;
 735                } else {
 736                        len -= op->no;
 737                        bc += op->no;
 738                }
 739        }
 740        return len == 0;
 741}
 742
 743/* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
 744 */
 745static void entry_fill_addrs(struct inet_diag_entry *entry,
 746                             const struct sock *sk)
 747{
 748#if IS_ENABLED(CONFIG_IPV6)
 749        if (sk->sk_family == AF_INET6) {
 750                entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
 751                entry->daddr = sk->sk_v6_daddr.s6_addr32;
 752        } else
 753#endif
 754        {
 755                entry->saddr = &sk->sk_rcv_saddr;
 756                entry->daddr = &sk->sk_daddr;
 757        }
 758}
 759
 760int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
 761{
 762        struct inet_sock *inet = inet_sk(sk);
 763        struct inet_diag_entry entry;
 764
 765        if (!bc)
 766                return 1;
 767
 768        entry.family = sk->sk_family;
 769        entry_fill_addrs(&entry, sk);
 770        entry.sport = inet->inet_num;
 771        entry.dport = ntohs(inet->inet_dport);
 772        entry.ifindex = sk->sk_bound_dev_if;
 773        entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
 774        if (sk_fullsock(sk))
 775                entry.mark = sk->sk_mark;
 776        else if (sk->sk_state == TCP_NEW_SYN_RECV)
 777                entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark;
 778        else
 779                entry.mark = 0;
 780
 781        return inet_diag_bc_run(bc, &entry);
 782}
 783EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
 784
 785static int valid_cc(const void *bc, int len, int cc)
 786{
 787        while (len >= 0) {
 788                const struct inet_diag_bc_op *op = bc;
 789
 790                if (cc > len)
 791                        return 0;
 792                if (cc == len)
 793                        return 1;
 794                if (op->yes < 4 || op->yes & 3)
 795                        return 0;
 796                len -= op->yes;
 797                bc  += op->yes;
 798        }
 799        return 0;
 800}
 801
 802/* data is u32 ifindex */
 803static bool valid_devcond(const struct inet_diag_bc_op *op, int len,
 804                          int *min_len)
 805{
 806        /* Check ifindex space. */
 807        *min_len += sizeof(u32);
 808        if (len < *min_len)
 809                return false;
 810
 811        return true;
 812}
 813/* Validate an inet_diag_hostcond. */
 814static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
 815                           int *min_len)
 816{
 817        struct inet_diag_hostcond *cond;
 818        int addr_len;
 819
 820        /* Check hostcond space. */
 821        *min_len += sizeof(struct inet_diag_hostcond);
 822        if (len < *min_len)
 823                return false;
 824        cond = (struct inet_diag_hostcond *)(op + 1);
 825
 826        /* Check address family and address length. */
 827        switch (cond->family) {
 828        case AF_UNSPEC:
 829                addr_len = 0;
 830                break;
 831        case AF_INET:
 832                addr_len = sizeof(struct in_addr);
 833                break;
 834        case AF_INET6:
 835                addr_len = sizeof(struct in6_addr);
 836                break;
 837        default:
 838                return false;
 839        }
 840        *min_len += addr_len;
 841        if (len < *min_len)
 842                return false;
 843
 844        /* Check prefix length (in bits) vs address length (in bytes). */
 845        if (cond->prefix_len > 8 * addr_len)
 846                return false;
 847
 848        return true;
 849}
 850
 851/* Validate a port comparison operator. */
 852static bool valid_port_comparison(const struct inet_diag_bc_op *op,
 853                                  int len, int *min_len)
 854{
 855        /* Port comparisons put the port in a follow-on inet_diag_bc_op. */
 856        *min_len += sizeof(struct inet_diag_bc_op);
 857        if (len < *min_len)
 858                return false;
 859        return true;
 860}
 861
 862static bool valid_markcond(const struct inet_diag_bc_op *op, int len,
 863                           int *min_len)
 864{
 865        *min_len += sizeof(struct inet_diag_markcond);
 866        return len >= *min_len;
 867}
 868
 869static int inet_diag_bc_audit(const struct nlattr *attr,
 870                              const struct sk_buff *skb)
 871{
 872        bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN);
 873        const void *bytecode, *bc;
 874        int bytecode_len, len;
 875
 876        if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op))
 877                return -EINVAL;
 878
 879        bytecode = bc = nla_data(attr);
 880        len = bytecode_len = nla_len(attr);
 881
 882        while (len > 0) {
 883                int min_len = sizeof(struct inet_diag_bc_op);
 884                const struct inet_diag_bc_op *op = bc;
 885
 886                switch (op->code) {
 887                case INET_DIAG_BC_S_COND:
 888                case INET_DIAG_BC_D_COND:
 889                        if (!valid_hostcond(bc, len, &min_len))
 890                                return -EINVAL;
 891                        break;
 892                case INET_DIAG_BC_DEV_COND:
 893                        if (!valid_devcond(bc, len, &min_len))
 894                                return -EINVAL;
 895                        break;
 896                case INET_DIAG_BC_S_EQ:
 897                case INET_DIAG_BC_S_GE:
 898                case INET_DIAG_BC_S_LE:
 899                case INET_DIAG_BC_D_EQ:
 900                case INET_DIAG_BC_D_GE:
 901                case INET_DIAG_BC_D_LE:
 902                        if (!valid_port_comparison(bc, len, &min_len))
 903                                return -EINVAL;
 904                        break;
 905                case INET_DIAG_BC_MARK_COND:
 906                        if (!net_admin)
 907                                return -EPERM;
 908                        if (!valid_markcond(bc, len, &min_len))
 909                                return -EINVAL;
 910                        break;
 911                case INET_DIAG_BC_AUTO:
 912                case INET_DIAG_BC_JMP:
 913                case INET_DIAG_BC_NOP:
 914                        break;
 915                default:
 916                        return -EINVAL;
 917                }
 918
 919                if (op->code != INET_DIAG_BC_NOP) {
 920                        if (op->no < min_len || op->no > len + 4 || op->no & 3)
 921                                return -EINVAL;
 922                        if (op->no < len &&
 923                            !valid_cc(bytecode, bytecode_len, len - op->no))
 924                                return -EINVAL;
 925                }
 926
 927                if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
 928                        return -EINVAL;
 929                bc  += op->yes;
 930                len -= op->yes;
 931        }
 932        return len == 0 ? 0 : -EINVAL;
 933}
 934
 935static void twsk_build_assert(void)
 936{
 937        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
 938                     offsetof(struct sock, sk_family));
 939
 940        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) !=
 941                     offsetof(struct inet_sock, inet_num));
 942
 943        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) !=
 944                     offsetof(struct inet_sock, inet_dport));
 945
 946        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) !=
 947                     offsetof(struct inet_sock, inet_rcv_saddr));
 948
 949        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) !=
 950                     offsetof(struct inet_sock, inet_daddr));
 951
 952#if IS_ENABLED(CONFIG_IPV6)
 953        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) !=
 954                     offsetof(struct sock, sk_v6_rcv_saddr));
 955
 956        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) !=
 957                     offsetof(struct sock, sk_v6_daddr));
 958#endif
 959}
 960
 961void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
 962                         struct netlink_callback *cb,
 963                         const struct inet_diag_req_v2 *r)
 964{
 965        bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN);
 966        struct inet_diag_dump_data *cb_data = cb->data;
 967        struct net *net = sock_net(skb->sk);
 968        u32 idiag_states = r->idiag_states;
 969        int i, num, s_i, s_num;
 970        struct nlattr *bc;
 971        struct sock *sk;
 972
 973        bc = cb_data->inet_diag_nla_bc;
 974        if (idiag_states & TCPF_SYN_RECV)
 975                idiag_states |= TCPF_NEW_SYN_RECV;
 976        s_i = cb->args[1];
 977        s_num = num = cb->args[2];
 978
 979        if (cb->args[0] == 0) {
 980                if (!(idiag_states & TCPF_LISTEN) || r->id.idiag_dport)
 981                        goto skip_listen_ht;
 982
 983                for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
 984                        struct inet_listen_hashbucket *ilb;
 985
 986                        num = 0;
 987                        ilb = &hashinfo->listening_hash[i];
 988                        spin_lock(&ilb->lock);
 989                        sk_for_each(sk, &ilb->head) {
 990                                struct inet_sock *inet = inet_sk(sk);
 991
 992                                if (!net_eq(sock_net(sk), net))
 993                                        continue;
 994
 995                                if (num < s_num) {
 996                                        num++;
 997                                        continue;
 998                                }
 999
1000                                if (r->sdiag_family != AF_UNSPEC &&
1001                                    sk->sk_family != r->sdiag_family)
1002                                        goto next_listen;
1003
1004                                if (r->id.idiag_sport != inet->inet_sport &&
1005                                    r->id.idiag_sport)
1006                                        goto next_listen;
1007
1008                                if (!inet_diag_bc_sk(bc, sk))
1009                                        goto next_listen;
1010
1011                                if (inet_sk_diag_fill(sk, inet_csk(sk), skb,
1012                                                      cb, r, NLM_F_MULTI,
1013                                                      net_admin) < 0) {
1014                                        spin_unlock(&ilb->lock);
1015                                        goto done;
1016                                }
1017
1018next_listen:
1019                                ++num;
1020                        }
1021                        spin_unlock(&ilb->lock);
1022
1023                        s_num = 0;
1024                }
1025skip_listen_ht:
1026                cb->args[0] = 1;
1027                s_i = num = s_num = 0;
1028        }
1029
1030        if (!(idiag_states & ~TCPF_LISTEN))
1031                goto out;
1032
1033#define SKARR_SZ 16
1034        for (i = s_i; i <= hashinfo->ehash_mask; i++) {
1035                struct inet_ehash_bucket *head = &hashinfo->ehash[i];
1036                spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
1037                struct hlist_nulls_node *node;
1038                struct sock *sk_arr[SKARR_SZ];
1039                int num_arr[SKARR_SZ];
1040                int idx, accum, res;
1041
1042                if (hlist_nulls_empty(&head->chain))
1043                        continue;
1044
1045                if (i > s_i)
1046                        s_num = 0;
1047
1048next_chunk:
1049                num = 0;
1050                accum = 0;
1051                spin_lock_bh(lock);
1052                sk_nulls_for_each(sk, node, &head->chain) {
1053                        int state;
1054
1055                        if (!net_eq(sock_net(sk), net))
1056                                continue;
1057                        if (num < s_num)
1058                                goto next_normal;
1059                        state = (sk->sk_state == TCP_TIME_WAIT) ?
1060                                inet_twsk(sk)->tw_substate : sk->sk_state;
1061                        if (!(idiag_states & (1 << state)))
1062                                goto next_normal;
1063                        if (r->sdiag_family != AF_UNSPEC &&
1064                            sk->sk_family != r->sdiag_family)
1065                                goto next_normal;
1066                        if (r->id.idiag_sport != htons(sk->sk_num) &&
1067                            r->id.idiag_sport)
1068                                goto next_normal;
1069                        if (r->id.idiag_dport != sk->sk_dport &&
1070                            r->id.idiag_dport)
1071                                goto next_normal;
1072                        twsk_build_assert();
1073
1074                        if (!inet_diag_bc_sk(bc, sk))
1075                                goto next_normal;
1076
1077                        if (!refcount_inc_not_zero(&sk->sk_refcnt))
1078                                goto next_normal;
1079
1080                        num_arr[accum] = num;
1081                        sk_arr[accum] = sk;
1082                        if (++accum == SKARR_SZ)
1083                                break;
1084next_normal:
1085                        ++num;
1086                }
1087                spin_unlock_bh(lock);
1088                res = 0;
1089                for (idx = 0; idx < accum; idx++) {
1090                        if (res >= 0) {
1091                                res = sk_diag_fill(sk_arr[idx], skb, cb, r,
1092                                                   NLM_F_MULTI, net_admin);
1093                                if (res < 0)
1094                                        num = num_arr[idx];
1095                        }
1096                        sock_gen_put(sk_arr[idx]);
1097                }
1098                if (res < 0)
1099                        break;
1100                cond_resched();
1101                if (accum == SKARR_SZ) {
1102                        s_num = num + 1;
1103                        goto next_chunk;
1104                }
1105        }
1106
1107done:
1108        cb->args[1] = i;
1109        cb->args[2] = num;
1110out:
1111        ;
1112}
1113EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
1114
1115static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
1116                            const struct inet_diag_req_v2 *r)
1117{
1118        struct inet_diag_dump_data *cb_data = cb->data;
1119        const struct inet_diag_handler *handler;
1120        u32 prev_min_dump_alloc;
1121        int protocol, err = 0;
1122
1123        protocol = inet_diag_get_protocol(r, cb_data);
1124
1125again:
1126        prev_min_dump_alloc = cb->min_dump_alloc;
1127        handler = inet_diag_lock_handler(protocol);
1128        if (!IS_ERR(handler))
1129                handler->dump(skb, cb, r);
1130        else
1131                err = PTR_ERR(handler);
1132        inet_diag_unlock_handler(handler);
1133
1134        /* The skb is not large enough to fit one sk info and
1135         * inet_sk_diag_fill() has requested for a larger skb.
1136         */
1137        if (!skb->len && cb->min_dump_alloc > prev_min_dump_alloc) {
1138                err = pskb_expand_head(skb, 0, cb->min_dump_alloc, GFP_KERNEL);
1139                if (!err)
1140                        goto again;
1141        }
1142
1143        return err ? : skb->len;
1144}
1145
1146static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
1147{
1148        return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh));
1149}
1150
1151static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen)
1152{
1153        const struct nlmsghdr *nlh = cb->nlh;
1154        struct inet_diag_dump_data *cb_data;
1155        struct sk_buff *skb = cb->skb;
1156        struct nlattr *nla;
1157        int err;
1158
1159        cb_data = kzalloc(sizeof(*cb_data), GFP_KERNEL);
1160        if (!cb_data)
1161                return -ENOMEM;
1162
1163        err = inet_diag_parse_attrs(nlh, hdrlen, cb_data->req_nlas);
1164        if (err) {
1165                kfree(cb_data);
1166                return err;
1167        }
1168        nla = cb_data->inet_diag_nla_bc;
1169        if (nla) {
1170                err = inet_diag_bc_audit(nla, skb);
1171                if (err) {
1172                        kfree(cb_data);
1173                        return err;
1174                }
1175        }
1176
1177        nla = cb_data->inet_diag_nla_bpf_stgs;
1178        if (nla) {
1179                struct bpf_sk_storage_diag *bpf_stg_diag;
1180
1181                bpf_stg_diag = bpf_sk_storage_diag_alloc(nla);
1182                if (IS_ERR(bpf_stg_diag)) {
1183                        kfree(cb_data);
1184                        return PTR_ERR(bpf_stg_diag);
1185                }
1186                cb_data->bpf_stg_diag = bpf_stg_diag;
1187        }
1188
1189        cb->data = cb_data;
1190        return 0;
1191}
1192
1193static int inet_diag_dump_start(struct netlink_callback *cb)
1194{
1195        return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req_v2));
1196}
1197
1198static int inet_diag_dump_start_compat(struct netlink_callback *cb)
1199{
1200        return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req));
1201}
1202
1203static int inet_diag_dump_done(struct netlink_callback *cb)
1204{
1205        struct inet_diag_dump_data *cb_data = cb->data;
1206
1207        bpf_sk_storage_diag_free(cb_data->bpf_stg_diag);
1208        kfree(cb->data);
1209
1210        return 0;
1211}
1212
1213static int inet_diag_type2proto(int type)
1214{
1215        switch (type) {
1216        case TCPDIAG_GETSOCK:
1217                return IPPROTO_TCP;
1218        case DCCPDIAG_GETSOCK:
1219                return IPPROTO_DCCP;
1220        default:
1221                return 0;
1222        }
1223}
1224
1225static int inet_diag_dump_compat(struct sk_buff *skb,
1226                                 struct netlink_callback *cb)
1227{
1228        struct inet_diag_req *rc = nlmsg_data(cb->nlh);
1229        struct inet_diag_req_v2 req;
1230
1231        req.sdiag_family = AF_UNSPEC; /* compatibility */
1232        req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
1233        req.idiag_ext = rc->idiag_ext;
1234        req.idiag_states = rc->idiag_states;
1235        req.id = rc->id;
1236
1237        return __inet_diag_dump(skb, cb, &req);
1238}
1239
1240static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
1241                                      const struct nlmsghdr *nlh)
1242{
1243        struct inet_diag_req *rc = nlmsg_data(nlh);
1244        struct inet_diag_req_v2 req;
1245
1246        req.sdiag_family = rc->idiag_family;
1247        req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
1248        req.idiag_ext = rc->idiag_ext;
1249        req.idiag_states = rc->idiag_states;
1250        req.id = rc->id;
1251
1252        return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh,
1253                                   sizeof(struct inet_diag_req), &req);
1254}
1255
1256static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
1257{
1258        int hdrlen = sizeof(struct inet_diag_req);
1259        struct net *net = sock_net(skb->sk);
1260
1261        if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
1262            nlmsg_len(nlh) < hdrlen)
1263                return -EINVAL;
1264
1265        if (nlh->nlmsg_flags & NLM_F_DUMP) {
1266                struct netlink_dump_control c = {
1267                        .start = inet_diag_dump_start_compat,
1268                        .done = inet_diag_dump_done,
1269                        .dump = inet_diag_dump_compat,
1270                };
1271                return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1272        }
1273
1274        return inet_diag_get_exact_compat(skb, nlh);
1275}
1276
1277static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
1278{
1279        int hdrlen = sizeof(struct inet_diag_req_v2);
1280        struct net *net = sock_net(skb->sk);
1281
1282        if (nlmsg_len(h) < hdrlen)
1283                return -EINVAL;
1284
1285        if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
1286            h->nlmsg_flags & NLM_F_DUMP) {
1287                struct netlink_dump_control c = {
1288                        .start = inet_diag_dump_start,
1289                        .done = inet_diag_dump_done,
1290                        .dump = inet_diag_dump,
1291                };
1292                return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1293        }
1294
1295        return inet_diag_cmd_exact(h->nlmsg_type, skb, h, hdrlen,
1296                                   nlmsg_data(h));
1297}
1298
1299static
1300int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
1301{
1302        const struct inet_diag_handler *handler;
1303        struct nlmsghdr *nlh;
1304        struct nlattr *attr;
1305        struct inet_diag_msg *r;
1306        void *info = NULL;
1307        int err = 0;
1308
1309        nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0);
1310        if (!nlh)
1311                return -ENOMEM;
1312
1313        r = nlmsg_data(nlh);
1314        memset(r, 0, sizeof(*r));
1315        inet_diag_msg_common_fill(r, sk);
1316        if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM)
1317                r->id.idiag_sport = inet_sk(sk)->inet_sport;
1318        r->idiag_state = sk->sk_state;
1319
1320        if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) {
1321                nlmsg_cancel(skb, nlh);
1322                return err;
1323        }
1324
1325        handler = inet_diag_lock_handler(sk->sk_protocol);
1326        if (IS_ERR(handler)) {
1327                inet_diag_unlock_handler(handler);
1328                nlmsg_cancel(skb, nlh);
1329                return PTR_ERR(handler);
1330        }
1331
1332        attr = handler->idiag_info_size
1333                ? nla_reserve_64bit(skb, INET_DIAG_INFO,
1334                                    handler->idiag_info_size,
1335                                    INET_DIAG_PAD)
1336                : NULL;
1337        if (attr)
1338                info = nla_data(attr);
1339
1340        handler->idiag_get_info(sk, r, info);
1341        inet_diag_unlock_handler(handler);
1342
1343        nlmsg_end(skb, nlh);
1344        return 0;
1345}
1346
1347static const struct sock_diag_handler inet_diag_handler = {
1348        .family = AF_INET,
1349        .dump = inet_diag_handler_cmd,
1350        .get_info = inet_diag_handler_get_info,
1351        .destroy = inet_diag_handler_cmd,
1352};
1353
1354static const struct sock_diag_handler inet6_diag_handler = {
1355        .family = AF_INET6,
1356        .dump = inet_diag_handler_cmd,
1357        .get_info = inet_diag_handler_get_info,
1358        .destroy = inet_diag_handler_cmd,
1359};
1360
1361int inet_diag_register(const struct inet_diag_handler *h)
1362{
1363        const __u16 type = h->idiag_type;
1364        int err = -EINVAL;
1365
1366        if (type >= IPPROTO_MAX)
1367                goto out;
1368
1369        mutex_lock(&inet_diag_table_mutex);
1370        err = -EEXIST;
1371        if (!inet_diag_table[type]) {
1372                inet_diag_table[type] = h;
1373                err = 0;
1374        }
1375        mutex_unlock(&inet_diag_table_mutex);
1376out:
1377        return err;
1378}
1379EXPORT_SYMBOL_GPL(inet_diag_register);
1380
1381void inet_diag_unregister(const struct inet_diag_handler *h)
1382{
1383        const __u16 type = h->idiag_type;
1384
1385        if (type >= IPPROTO_MAX)
1386                return;
1387
1388        mutex_lock(&inet_diag_table_mutex);
1389        inet_diag_table[type] = NULL;
1390        mutex_unlock(&inet_diag_table_mutex);
1391}
1392EXPORT_SYMBOL_GPL(inet_diag_unregister);
1393
1394static int __init inet_diag_init(void)
1395{
1396        const int inet_diag_table_size = (IPPROTO_MAX *
1397                                          sizeof(struct inet_diag_handler *));
1398        int err = -ENOMEM;
1399
1400        inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1401        if (!inet_diag_table)
1402                goto out;
1403
1404        err = sock_diag_register(&inet_diag_handler);
1405        if (err)
1406                goto out_free_nl;
1407
1408        err = sock_diag_register(&inet6_diag_handler);
1409        if (err)
1410                goto out_free_inet;
1411
1412        sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1413out:
1414        return err;
1415
1416out_free_inet:
1417        sock_diag_unregister(&inet_diag_handler);
1418out_free_nl:
1419        kfree(inet_diag_table);
1420        goto out;
1421}
1422
1423static void __exit inet_diag_exit(void)
1424{
1425        sock_diag_unregister(&inet6_diag_handler);
1426        sock_diag_unregister(&inet_diag_handler);
1427        sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1428        kfree(inet_diag_table);
1429}
1430
1431module_init(inet_diag_init);
1432module_exit(inet_diag_exit);
1433MODULE_LICENSE("GPL");
1434MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1435MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);
1436