linux/net/ipv4/inet_diag.c
<<
>>
Prefs
   1/*
   2 * inet_diag.c  Module for monitoring INET transport protocols sockets.
   3 *
   4 * Authors:     Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
   5 *
   6 *      This program is free software; you can redistribute it and/or
   7 *      modify it under the terms of the GNU General Public License
   8 *      as published by the Free Software Foundation; either version
   9 *      2 of the License, or (at your option) any later version.
  10 */
  11
  12#include <linux/kernel.h>
  13#include <linux/module.h>
  14#include <linux/types.h>
  15#include <linux/fcntl.h>
  16#include <linux/random.h>
  17#include <linux/slab.h>
  18#include <linux/cache.h>
  19#include <linux/init.h>
  20#include <linux/time.h>
  21
  22#include <net/icmp.h>
  23#include <net/tcp.h>
  24#include <net/ipv6.h>
  25#include <net/inet_common.h>
  26#include <net/inet_connection_sock.h>
  27#include <net/inet_hashtables.h>
  28#include <net/inet_timewait_sock.h>
  29#include <net/inet6_hashtables.h>
  30#include <net/bpf_sk_storage.h>
  31#include <net/netlink.h>
  32
  33#include <linux/inet.h>
  34#include <linux/stddef.h>
  35
  36#include <linux/inet_diag.h>
  37#include <linux/sock_diag.h>
  38
  39static const struct inet_diag_handler **inet_diag_table;
  40
  41struct inet_diag_entry {
  42        const __be32 *saddr;
  43        const __be32 *daddr;
  44        u16 sport;
  45        u16 dport;
  46        u16 family;
  47        u16 userlocks;
  48        u32 ifindex;
  49        u32 mark;
  50};
  51
  52static DEFINE_MUTEX(inet_diag_table_mutex);
  53
  54static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
  55{
  56        if (proto < 0 || proto >= IPPROTO_MAX) {
  57                mutex_lock(&inet_diag_table_mutex);
  58                return ERR_PTR(-ENOENT);
  59        }
  60
  61        if (!inet_diag_table[proto])
  62                sock_load_diag_module(AF_INET, proto);
  63
  64        mutex_lock(&inet_diag_table_mutex);
  65        if (!inet_diag_table[proto])
  66                return ERR_PTR(-ENOENT);
  67
  68        return inet_diag_table[proto];
  69}
  70
  71static void inet_diag_unlock_handler(const struct inet_diag_handler *handler)
  72{
  73        mutex_unlock(&inet_diag_table_mutex);
  74}
  75
  76void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk)
  77{
  78        r->idiag_family = sk->sk_family;
  79
  80        r->id.idiag_sport = htons(sk->sk_num);
  81        r->id.idiag_dport = sk->sk_dport;
  82        r->id.idiag_if = sk->sk_bound_dev_if;
  83        sock_diag_save_cookie(sk, r->id.idiag_cookie);
  84
  85#if IS_ENABLED(CONFIG_IPV6)
  86        if (sk->sk_family == AF_INET6) {
  87                *(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr;
  88                *(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr;
  89        } else
  90#endif
  91        {
  92        memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src));
  93        memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst));
  94
  95        r->id.idiag_src[0] = sk->sk_rcv_saddr;
  96        r->id.idiag_dst[0] = sk->sk_daddr;
  97        }
  98}
  99EXPORT_SYMBOL_GPL(inet_diag_msg_common_fill);
 100
 101static size_t inet_sk_attr_size(struct sock *sk,
 102                                const struct inet_diag_req_v2 *req,
 103                                bool net_admin)
 104{
 105        const struct inet_diag_handler *handler;
 106        size_t aux = 0;
 107
 108        handler = inet_diag_table[req->sdiag_protocol];
 109        if (handler && handler->idiag_get_aux_size)
 110                aux = handler->idiag_get_aux_size(sk, net_admin);
 111
 112        return    nla_total_size(sizeof(struct tcp_info))
 113                + nla_total_size(sizeof(struct inet_diag_msg))
 114                + inet_diag_msg_attrs_size()
 115                + nla_total_size(sizeof(struct inet_diag_meminfo))
 116                + nla_total_size(SK_MEMINFO_VARS * sizeof(u32))
 117                + nla_total_size(TCP_CA_NAME_MAX)
 118                + nla_total_size(sizeof(struct tcpvegas_info))
 119                + aux
 120                + 64;
 121}
 122
 123int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
 124                             struct inet_diag_msg *r, int ext,
 125                             struct user_namespace *user_ns,
 126                             bool net_admin)
 127{
 128        const struct inet_sock *inet = inet_sk(sk);
 129
 130        if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
 131                goto errout;
 132
 133        /* IPv6 dual-stack sockets use inet->tos for IPv4 connections,
 134         * hence this needs to be included regardless of socket family.
 135         */
 136        if (ext & (1 << (INET_DIAG_TOS - 1)))
 137                if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0)
 138                        goto errout;
 139
 140#if IS_ENABLED(CONFIG_IPV6)
 141        if (r->idiag_family == AF_INET6) {
 142                if (ext & (1 << (INET_DIAG_TCLASS - 1)))
 143                        if (nla_put_u8(skb, INET_DIAG_TCLASS,
 144                                       inet6_sk(sk)->tclass) < 0)
 145                                goto errout;
 146
 147                if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) &&
 148                    nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk)))
 149                        goto errout;
 150        }
 151#endif
 152
 153        if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, sk->sk_mark))
 154                goto errout;
 155
 156        if (ext & (1 << (INET_DIAG_CLASS_ID - 1)) ||
 157            ext & (1 << (INET_DIAG_TCLASS - 1))) {
 158                u32 classid = 0;
 159
 160#ifdef CONFIG_SOCK_CGROUP_DATA
 161                classid = sock_cgroup_classid(&sk->sk_cgrp_data);
 162#endif
 163                /* Fallback to socket priority if class id isn't set.
 164                 * Classful qdiscs use it as direct reference to class.
 165                 * For cgroup2 classid is always zero.
 166                 */
 167                if (!classid)
 168                        classid = sk->sk_priority;
 169
 170                if (nla_put_u32(skb, INET_DIAG_CLASS_ID, classid))
 171                        goto errout;
 172        }
 173
 174        r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
 175        r->idiag_inode = sock_i_ino(sk);
 176
 177        return 0;
 178errout:
 179        return 1;
 180}
 181EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
 182
 183static int inet_diag_parse_attrs(const struct nlmsghdr *nlh, int hdrlen,
 184                                 struct nlattr **req_nlas)
 185{
 186        struct nlattr *nla;
 187        int remaining;
 188
 189        nlmsg_for_each_attr(nla, nlh, hdrlen, remaining) {
 190                int type = nla_type(nla);
 191
 192                if (type == INET_DIAG_REQ_PROTOCOL && nla_len(nla) != sizeof(u32))
 193                        return -EINVAL;
 194
 195                if (type < __INET_DIAG_REQ_MAX)
 196                        req_nlas[type] = nla;
 197        }
 198        return 0;
 199}
 200
 201static int inet_diag_get_protocol(const struct inet_diag_req_v2 *req,
 202                                  const struct inet_diag_dump_data *data)
 203{
 204        int retval;
 205
 206        if (data->req_nlas[INET_DIAG_REQ_PROTOCOL])
 207                retval = nla_get_u32(data->req_nlas[INET_DIAG_REQ_PROTOCOL]);
 208        else
 209                retval = req->sdiag_protocol;
 210        return retval == IPPROTO_MPTCP ? IPPROTO_MPTCP_KERN : retval;
 211}
 212
 213#define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)))
 214
 215int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
 216                      struct sk_buff *skb, struct netlink_callback *cb,
 217                      const struct inet_diag_req_v2 *req,
 218                      u16 nlmsg_flags, bool net_admin)
 219{
 220        const struct tcp_congestion_ops *ca_ops;
 221        const struct inet_diag_handler *handler;
 222        struct inet_diag_dump_data *cb_data;
 223        int ext = req->idiag_ext;
 224        struct inet_diag_msg *r;
 225        struct nlmsghdr  *nlh;
 226        struct nlattr *attr;
 227        void *info = NULL;
 228
 229        cb_data = cb->data;
 230        handler = inet_diag_table[inet_diag_get_protocol(req, cb_data)];
 231        BUG_ON(!handler);
 232
 233        nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
 234                        cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
 235        if (!nlh)
 236                return -EMSGSIZE;
 237
 238        r = nlmsg_data(nlh);
 239        BUG_ON(!sk_fullsock(sk));
 240
 241        inet_diag_msg_common_fill(r, sk);
 242        r->idiag_state = sk->sk_state;
 243        r->idiag_timer = 0;
 244        r->idiag_retrans = 0;
 245        r->idiag_expires = 0;
 246
 247        if (inet_diag_msg_attrs_fill(sk, skb, r, ext,
 248                                     sk_user_ns(NETLINK_CB(cb->skb).sk),
 249                                     net_admin))
 250                goto errout;
 251
 252        if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
 253                struct inet_diag_meminfo minfo = {
 254                        .idiag_rmem = sk_rmem_alloc_get(sk),
 255                        .idiag_wmem = READ_ONCE(sk->sk_wmem_queued),
 256                        .idiag_fmem = sk->sk_forward_alloc,
 257                        .idiag_tmem = sk_wmem_alloc_get(sk),
 258                };
 259
 260                if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0)
 261                        goto errout;
 262        }
 263
 264        if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))
 265                if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO))
 266                        goto errout;
 267
 268        /*
 269         * RAW sockets might have user-defined protocols assigned,
 270         * so report the one supplied on socket creation.
 271         */
 272        if (sk->sk_type == SOCK_RAW) {
 273                if (nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))
 274                        goto errout;
 275        }
 276
 277        if (!icsk) {
 278                handler->idiag_get_info(sk, r, NULL);
 279                goto out;
 280        }
 281
 282        if (icsk->icsk_pending == ICSK_TIME_RETRANS ||
 283            icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT ||
 284            icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) {
 285                r->idiag_timer = 1;
 286                r->idiag_retrans = icsk->icsk_retransmits;
 287                r->idiag_expires =
 288                        jiffies_to_msecs(icsk->icsk_timeout - jiffies);
 289        } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) {
 290                r->idiag_timer = 4;
 291                r->idiag_retrans = icsk->icsk_probes_out;
 292                r->idiag_expires =
 293                        jiffies_to_msecs(icsk->icsk_timeout - jiffies);
 294        } else if (timer_pending(&sk->sk_timer)) {
 295                r->idiag_timer = 2;
 296                r->idiag_retrans = icsk->icsk_probes_out;
 297                r->idiag_expires =
 298                        jiffies_to_msecs(sk->sk_timer.expires - jiffies);
 299        }
 300
 301        if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) {
 302                attr = nla_reserve_64bit(skb, INET_DIAG_INFO,
 303                                         handler->idiag_info_size,
 304                                         INET_DIAG_PAD);
 305                if (!attr)
 306                        goto errout;
 307
 308                info = nla_data(attr);
 309        }
 310
 311        if (ext & (1 << (INET_DIAG_CONG - 1))) {
 312                int err = 0;
 313
 314                rcu_read_lock();
 315                ca_ops = READ_ONCE(icsk->icsk_ca_ops);
 316                if (ca_ops)
 317                        err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name);
 318                rcu_read_unlock();
 319                if (err < 0)
 320                        goto errout;
 321        }
 322
 323        handler->idiag_get_info(sk, r, info);
 324
 325        if (ext & (1 << (INET_DIAG_INFO - 1)) && handler->idiag_get_aux)
 326                if (handler->idiag_get_aux(sk, net_admin, skb) < 0)
 327                        goto errout;
 328
 329        if (sk->sk_state < TCP_TIME_WAIT) {
 330                union tcp_cc_info info;
 331                size_t sz = 0;
 332                int attr;
 333
 334                rcu_read_lock();
 335                ca_ops = READ_ONCE(icsk->icsk_ca_ops);
 336                if (ca_ops && ca_ops->get_info)
 337                        sz = ca_ops->get_info(sk, ext, &attr, &info);
 338                rcu_read_unlock();
 339                if (sz && nla_put(skb, attr, sz, &info) < 0)
 340                        goto errout;
 341        }
 342
 343        /* Keep it at the end for potential retry with a larger skb,
 344         * or else do best-effort fitting, which is only done for the
 345         * first_nlmsg.
 346         */
 347        if (cb_data->bpf_stg_diag) {
 348                bool first_nlmsg = ((unsigned char *)nlh == skb->data);
 349                unsigned int prev_min_dump_alloc;
 350                unsigned int total_nla_size = 0;
 351                unsigned int msg_len;
 352                int err;
 353
 354                msg_len = skb_tail_pointer(skb) - (unsigned char *)nlh;
 355                err = bpf_sk_storage_diag_put(cb_data->bpf_stg_diag, sk, skb,
 356                                              INET_DIAG_SK_BPF_STORAGES,
 357                                              &total_nla_size);
 358
 359                if (!err)
 360                        goto out;
 361
 362                total_nla_size += msg_len;
 363                prev_min_dump_alloc = cb->min_dump_alloc;
 364                if (total_nla_size > prev_min_dump_alloc)
 365                        cb->min_dump_alloc = min_t(u32, total_nla_size,
 366                                                   MAX_DUMP_ALLOC_SIZE);
 367
 368                if (!first_nlmsg)
 369                        goto errout;
 370
 371                if (cb->min_dump_alloc > prev_min_dump_alloc)
 372                        /* Retry with pskb_expand_head() with
 373                         * __GFP_DIRECT_RECLAIM
 374                         */
 375                        goto errout;
 376
 377                WARN_ON_ONCE(total_nla_size <= prev_min_dump_alloc);
 378
 379                /* Send what we have for this sk
 380                 * and move on to the next sk in the following
 381                 * dump()
 382                 */
 383        }
 384
 385out:
 386        nlmsg_end(skb, nlh);
 387        return 0;
 388
 389errout:
 390        nlmsg_cancel(skb, nlh);
 391        return -EMSGSIZE;
 392}
 393EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
 394
 395static int inet_twsk_diag_fill(struct sock *sk,
 396                               struct sk_buff *skb,
 397                               struct netlink_callback *cb,
 398                               u16 nlmsg_flags, bool net_admin)
 399{
 400        struct inet_timewait_sock *tw = inet_twsk(sk);
 401        struct inet_diag_msg *r;
 402        struct nlmsghdr *nlh;
 403        long tmo;
 404
 405        nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid,
 406                        cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type,
 407                        sizeof(*r), nlmsg_flags);
 408        if (!nlh)
 409                return -EMSGSIZE;
 410
 411        r = nlmsg_data(nlh);
 412        BUG_ON(tw->tw_state != TCP_TIME_WAIT);
 413
 414        tmo = tw->tw_timer.expires - jiffies;
 415        if (tmo < 0)
 416                tmo = 0;
 417
 418        inet_diag_msg_common_fill(r, sk);
 419        r->idiag_retrans      = 0;
 420
 421        r->idiag_state        = tw->tw_substate;
 422        r->idiag_timer        = 3;
 423        r->idiag_expires      = jiffies_to_msecs(tmo);
 424        r->idiag_rqueue       = 0;
 425        r->idiag_wqueue       = 0;
 426        r->idiag_uid          = 0;
 427        r->idiag_inode        = 0;
 428
 429        if (net_admin && nla_put_u32(skb, INET_DIAG_MARK,
 430                                     tw->tw_mark)) {
 431                nlmsg_cancel(skb, nlh);
 432                return -EMSGSIZE;
 433        }
 434
 435        nlmsg_end(skb, nlh);
 436        return 0;
 437}
 438
 439static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
 440                              struct netlink_callback *cb,
 441                              u16 nlmsg_flags, bool net_admin)
 442{
 443        struct request_sock *reqsk = inet_reqsk(sk);
 444        struct inet_diag_msg *r;
 445        struct nlmsghdr *nlh;
 446        long tmo;
 447
 448        nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
 449                        cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
 450        if (!nlh)
 451                return -EMSGSIZE;
 452
 453        r = nlmsg_data(nlh);
 454        inet_diag_msg_common_fill(r, sk);
 455        r->idiag_state = TCP_SYN_RECV;
 456        r->idiag_timer = 1;
 457        r->idiag_retrans = reqsk->num_retrans;
 458
 459        BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
 460                     offsetof(struct sock, sk_cookie));
 461
 462        tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies;
 463        r->idiag_expires = (tmo >= 0) ? jiffies_to_msecs(tmo) : 0;
 464        r->idiag_rqueue = 0;
 465        r->idiag_wqueue = 0;
 466        r->idiag_uid    = 0;
 467        r->idiag_inode  = 0;
 468
 469        if (net_admin && nla_put_u32(skb, INET_DIAG_MARK,
 470                                     inet_rsk(reqsk)->ir_mark)) {
 471                nlmsg_cancel(skb, nlh);
 472                return -EMSGSIZE;
 473        }
 474
 475        nlmsg_end(skb, nlh);
 476        return 0;
 477}
 478
 479static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
 480                        struct netlink_callback *cb,
 481                        const struct inet_diag_req_v2 *r,
 482                        u16 nlmsg_flags, bool net_admin)
 483{
 484        if (sk->sk_state == TCP_TIME_WAIT)
 485                return inet_twsk_diag_fill(sk, skb, cb, nlmsg_flags, net_admin);
 486
 487        if (sk->sk_state == TCP_NEW_SYN_RECV)
 488                return inet_req_diag_fill(sk, skb, cb, nlmsg_flags, net_admin);
 489
 490        return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, r, nlmsg_flags,
 491                                 net_admin);
 492}
 493
 494struct sock *inet_diag_find_one_icsk(struct net *net,
 495                                     struct inet_hashinfo *hashinfo,
 496                                     const struct inet_diag_req_v2 *req)
 497{
 498        struct sock *sk;
 499
 500        rcu_read_lock();
 501        if (req->sdiag_family == AF_INET)
 502                sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[0],
 503                                 req->id.idiag_dport, req->id.idiag_src[0],
 504                                 req->id.idiag_sport, req->id.idiag_if);
 505#if IS_ENABLED(CONFIG_IPV6)
 506        else if (req->sdiag_family == AF_INET6) {
 507                if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) &&
 508                    ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src))
 509                        sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[3],
 510                                         req->id.idiag_dport, req->id.idiag_src[3],
 511                                         req->id.idiag_sport, req->id.idiag_if);
 512                else
 513                        sk = inet6_lookup(net, hashinfo, NULL, 0,
 514                                          (struct in6_addr *)req->id.idiag_dst,
 515                                          req->id.idiag_dport,
 516                                          (struct in6_addr *)req->id.idiag_src,
 517                                          req->id.idiag_sport,
 518                                          req->id.idiag_if);
 519        }
 520#endif
 521        else {
 522                rcu_read_unlock();
 523                return ERR_PTR(-EINVAL);
 524        }
 525        rcu_read_unlock();
 526        if (!sk)
 527                return ERR_PTR(-ENOENT);
 528
 529        if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) {
 530                sock_gen_put(sk);
 531                return ERR_PTR(-ENOENT);
 532        }
 533
 534        return sk;
 535}
 536EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);
 537
 538int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
 539                            struct netlink_callback *cb,
 540                            const struct inet_diag_req_v2 *req)
 541{
 542        struct sk_buff *in_skb = cb->skb;
 543        bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN);
 544        struct net *net = sock_net(in_skb->sk);
 545        struct sk_buff *rep;
 546        struct sock *sk;
 547        int err;
 548
 549        sk = inet_diag_find_one_icsk(net, hashinfo, req);
 550        if (IS_ERR(sk))
 551                return PTR_ERR(sk);
 552
 553        rep = nlmsg_new(inet_sk_attr_size(sk, req, net_admin), GFP_KERNEL);
 554        if (!rep) {
 555                err = -ENOMEM;
 556                goto out;
 557        }
 558
 559        err = sk_diag_fill(sk, rep, cb, req, 0, net_admin);
 560        if (err < 0) {
 561                WARN_ON(err == -EMSGSIZE);
 562                nlmsg_free(rep);
 563                goto out;
 564        }
 565        err = nlmsg_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid);
 566
 567out:
 568        if (sk)
 569                sock_gen_put(sk);
 570
 571        return err;
 572}
 573EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk);
 574
 575static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
 576                               const struct nlmsghdr *nlh,
 577                               int hdrlen,
 578                               const struct inet_diag_req_v2 *req)
 579{
 580        const struct inet_diag_handler *handler;
 581        struct inet_diag_dump_data dump_data;
 582        int err, protocol;
 583
 584        memset(&dump_data, 0, sizeof(dump_data));
 585        err = inet_diag_parse_attrs(nlh, hdrlen, dump_data.req_nlas);
 586        if (err)
 587                return err;
 588
 589        protocol = inet_diag_get_protocol(req, &dump_data);
 590
 591        handler = inet_diag_lock_handler(protocol);
 592        if (IS_ERR(handler)) {
 593                err = PTR_ERR(handler);
 594        } else if (cmd == SOCK_DIAG_BY_FAMILY) {
 595                struct netlink_callback cb = {
 596                        .nlh = nlh,
 597                        .skb = in_skb,
 598                        .data = &dump_data,
 599                };
 600                err = handler->dump_one(&cb, req);
 601        } else if (cmd == SOCK_DESTROY && handler->destroy) {
 602                err = handler->destroy(in_skb, req);
 603        } else {
 604                err = -EOPNOTSUPP;
 605        }
 606        inet_diag_unlock_handler(handler);
 607
 608        return err;
 609}
 610
 611static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits)
 612{
 613        int words = bits >> 5;
 614
 615        bits &= 0x1f;
 616
 617        if (words) {
 618                if (memcmp(a1, a2, words << 2))
 619                        return 0;
 620        }
 621        if (bits) {
 622                __be32 w1, w2;
 623                __be32 mask;
 624
 625                w1 = a1[words];
 626                w2 = a2[words];
 627
 628                mask = htonl((0xffffffff) << (32 - bits));
 629
 630                if ((w1 ^ w2) & mask)
 631                        return 0;
 632        }
 633
 634        return 1;
 635}
 636
 637static int inet_diag_bc_run(const struct nlattr *_bc,
 638                            const struct inet_diag_entry *entry)
 639{
 640        const void *bc = nla_data(_bc);
 641        int len = nla_len(_bc);
 642
 643        while (len > 0) {
 644                int yes = 1;
 645                const struct inet_diag_bc_op *op = bc;
 646
 647                switch (op->code) {
 648                case INET_DIAG_BC_NOP:
 649                        break;
 650                case INET_DIAG_BC_JMP:
 651                        yes = 0;
 652                        break;
 653                case INET_DIAG_BC_S_EQ:
 654                        yes = entry->sport == op[1].no;
 655                        break;
 656                case INET_DIAG_BC_S_GE:
 657                        yes = entry->sport >= op[1].no;
 658                        break;
 659                case INET_DIAG_BC_S_LE:
 660                        yes = entry->sport <= op[1].no;
 661                        break;
 662                case INET_DIAG_BC_D_EQ:
 663                        yes = entry->dport == op[1].no;
 664                        break;
 665                case INET_DIAG_BC_D_GE:
 666                        yes = entry->dport >= op[1].no;
 667                        break;
 668                case INET_DIAG_BC_D_LE:
 669                        yes = entry->dport <= op[1].no;
 670                        break;
 671                case INET_DIAG_BC_AUTO:
 672                        yes = !(entry->userlocks & SOCK_BINDPORT_LOCK);
 673                        break;
 674                case INET_DIAG_BC_S_COND:
 675                case INET_DIAG_BC_D_COND: {
 676                        const struct inet_diag_hostcond *cond;
 677                        const __be32 *addr;
 678
 679                        cond = (const struct inet_diag_hostcond *)(op + 1);
 680                        if (cond->port != -1 &&
 681                            cond->port != (op->code == INET_DIAG_BC_S_COND ?
 682                                             entry->sport : entry->dport)) {
 683                                yes = 0;
 684                                break;
 685                        }
 686
 687                        if (op->code == INET_DIAG_BC_S_COND)
 688                                addr = entry->saddr;
 689                        else
 690                                addr = entry->daddr;
 691
 692                        if (cond->family != AF_UNSPEC &&
 693                            cond->family != entry->family) {
 694                                if (entry->family == AF_INET6 &&
 695                                    cond->family == AF_INET) {
 696                                        if (addr[0] == 0 && addr[1] == 0 &&
 697                                            addr[2] == htonl(0xffff) &&
 698                                            bitstring_match(addr + 3,
 699                                                            cond->addr,
 700                                                            cond->prefix_len))
 701                                                break;
 702                                }
 703                                yes = 0;
 704                                break;
 705                        }
 706
 707                        if (cond->prefix_len == 0)
 708                                break;
 709                        if (bitstring_match(addr, cond->addr,
 710                                            cond->prefix_len))
 711                                break;
 712                        yes = 0;
 713                        break;
 714                }
 715                case INET_DIAG_BC_DEV_COND: {
 716                        u32 ifindex;
 717
 718                        ifindex = *((const u32 *)(op + 1));
 719                        if (ifindex != entry->ifindex)
 720                                yes = 0;
 721                        break;
 722                }
 723                case INET_DIAG_BC_MARK_COND: {
 724                        struct inet_diag_markcond *cond;
 725
 726                        cond = (struct inet_diag_markcond *)(op + 1);
 727                        if ((entry->mark & cond->mask) != cond->mark)
 728                                yes = 0;
 729                        break;
 730                }
 731                }
 732
 733                if (yes) {
 734                        len -= op->yes;
 735                        bc += op->yes;
 736                } else {
 737                        len -= op->no;
 738                        bc += op->no;
 739                }
 740        }
 741        return len == 0;
 742}
 743
 744/* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV)
 745 */
 746static void entry_fill_addrs(struct inet_diag_entry *entry,
 747                             const struct sock *sk)
 748{
 749#if IS_ENABLED(CONFIG_IPV6)
 750        if (sk->sk_family == AF_INET6) {
 751                entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32;
 752                entry->daddr = sk->sk_v6_daddr.s6_addr32;
 753        } else
 754#endif
 755        {
 756                entry->saddr = &sk->sk_rcv_saddr;
 757                entry->daddr = &sk->sk_daddr;
 758        }
 759}
 760
 761int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
 762{
 763        struct inet_sock *inet = inet_sk(sk);
 764        struct inet_diag_entry entry;
 765
 766        if (!bc)
 767                return 1;
 768
 769        entry.family = sk->sk_family;
 770        entry_fill_addrs(&entry, sk);
 771        entry.sport = inet->inet_num;
 772        entry.dport = ntohs(inet->inet_dport);
 773        entry.ifindex = sk->sk_bound_dev_if;
 774        entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
 775        if (sk_fullsock(sk))
 776                entry.mark = sk->sk_mark;
 777        else if (sk->sk_state == TCP_NEW_SYN_RECV)
 778                entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark;
 779        else if (sk->sk_state == TCP_TIME_WAIT)
 780                entry.mark = inet_twsk(sk)->tw_mark;
 781        else
 782                entry.mark = 0;
 783
 784        return inet_diag_bc_run(bc, &entry);
 785}
 786EXPORT_SYMBOL_GPL(inet_diag_bc_sk);
 787
 788static int valid_cc(const void *bc, int len, int cc)
 789{
 790        while (len >= 0) {
 791                const struct inet_diag_bc_op *op = bc;
 792
 793                if (cc > len)
 794                        return 0;
 795                if (cc == len)
 796                        return 1;
 797                if (op->yes < 4 || op->yes & 3)
 798                        return 0;
 799                len -= op->yes;
 800                bc  += op->yes;
 801        }
 802        return 0;
 803}
 804
 805/* data is u32 ifindex */
 806static bool valid_devcond(const struct inet_diag_bc_op *op, int len,
 807                          int *min_len)
 808{
 809        /* Check ifindex space. */
 810        *min_len += sizeof(u32);
 811        if (len < *min_len)
 812                return false;
 813
 814        return true;
 815}
 816/* Validate an inet_diag_hostcond. */
 817static bool valid_hostcond(const struct inet_diag_bc_op *op, int len,
 818                           int *min_len)
 819{
 820        struct inet_diag_hostcond *cond;
 821        int addr_len;
 822
 823        /* Check hostcond space. */
 824        *min_len += sizeof(struct inet_diag_hostcond);
 825        if (len < *min_len)
 826                return false;
 827        cond = (struct inet_diag_hostcond *)(op + 1);
 828
 829        /* Check address family and address length. */
 830        switch (cond->family) {
 831        case AF_UNSPEC:
 832                addr_len = 0;
 833                break;
 834        case AF_INET:
 835                addr_len = sizeof(struct in_addr);
 836                break;
 837        case AF_INET6:
 838                addr_len = sizeof(struct in6_addr);
 839                break;
 840        default:
 841                return false;
 842        }
 843        *min_len += addr_len;
 844        if (len < *min_len)
 845                return false;
 846
 847        /* Check prefix length (in bits) vs address length (in bytes). */
 848        if (cond->prefix_len > 8 * addr_len)
 849                return false;
 850
 851        return true;
 852}
 853
 854/* Validate a port comparison operator. */
 855static bool valid_port_comparison(const struct inet_diag_bc_op *op,
 856                                  int len, int *min_len)
 857{
 858        /* Port comparisons put the port in a follow-on inet_diag_bc_op. */
 859        *min_len += sizeof(struct inet_diag_bc_op);
 860        if (len < *min_len)
 861                return false;
 862        return true;
 863}
 864
 865static bool valid_markcond(const struct inet_diag_bc_op *op, int len,
 866                           int *min_len)
 867{
 868        *min_len += sizeof(struct inet_diag_markcond);
 869        return len >= *min_len;
 870}
 871
 872static int inet_diag_bc_audit(const struct nlattr *attr,
 873                              const struct sk_buff *skb)
 874{
 875        bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN);
 876        const void *bytecode, *bc;
 877        int bytecode_len, len;
 878
 879        if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op))
 880                return -EINVAL;
 881
 882        bytecode = bc = nla_data(attr);
 883        len = bytecode_len = nla_len(attr);
 884
 885        while (len > 0) {
 886                int min_len = sizeof(struct inet_diag_bc_op);
 887                const struct inet_diag_bc_op *op = bc;
 888
 889                switch (op->code) {
 890                case INET_DIAG_BC_S_COND:
 891                case INET_DIAG_BC_D_COND:
 892                        if (!valid_hostcond(bc, len, &min_len))
 893                                return -EINVAL;
 894                        break;
 895                case INET_DIAG_BC_DEV_COND:
 896                        if (!valid_devcond(bc, len, &min_len))
 897                                return -EINVAL;
 898                        break;
 899                case INET_DIAG_BC_S_EQ:
 900                case INET_DIAG_BC_S_GE:
 901                case INET_DIAG_BC_S_LE:
 902                case INET_DIAG_BC_D_EQ:
 903                case INET_DIAG_BC_D_GE:
 904                case INET_DIAG_BC_D_LE:
 905                        if (!valid_port_comparison(bc, len, &min_len))
 906                                return -EINVAL;
 907                        break;
 908                case INET_DIAG_BC_MARK_COND:
 909                        if (!net_admin)
 910                                return -EPERM;
 911                        if (!valid_markcond(bc, len, &min_len))
 912                                return -EINVAL;
 913                        break;
 914                case INET_DIAG_BC_AUTO:
 915                case INET_DIAG_BC_JMP:
 916                case INET_DIAG_BC_NOP:
 917                        break;
 918                default:
 919                        return -EINVAL;
 920                }
 921
 922                if (op->code != INET_DIAG_BC_NOP) {
 923                        if (op->no < min_len || op->no > len + 4 || op->no & 3)
 924                                return -EINVAL;
 925                        if (op->no < len &&
 926                            !valid_cc(bytecode, bytecode_len, len - op->no))
 927                                return -EINVAL;
 928                }
 929
 930                if (op->yes < min_len || op->yes > len + 4 || op->yes & 3)
 931                        return -EINVAL;
 932                bc  += op->yes;
 933                len -= op->yes;
 934        }
 935        return len == 0 ? 0 : -EINVAL;
 936}
 937
 938static void twsk_build_assert(void)
 939{
 940        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
 941                     offsetof(struct sock, sk_family));
 942
 943        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) !=
 944                     offsetof(struct inet_sock, inet_num));
 945
 946        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) !=
 947                     offsetof(struct inet_sock, inet_dport));
 948
 949        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) !=
 950                     offsetof(struct inet_sock, inet_rcv_saddr));
 951
 952        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) !=
 953                     offsetof(struct inet_sock, inet_daddr));
 954
 955#if IS_ENABLED(CONFIG_IPV6)
 956        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) !=
 957                     offsetof(struct sock, sk_v6_rcv_saddr));
 958
 959        BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) !=
 960                     offsetof(struct sock, sk_v6_daddr));
 961#endif
 962}
 963
 964void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
 965                         struct netlink_callback *cb,
 966                         const struct inet_diag_req_v2 *r)
 967{
 968        bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN);
 969        struct inet_diag_dump_data *cb_data = cb->data;
 970        struct net *net = sock_net(skb->sk);
 971        u32 idiag_states = r->idiag_states;
 972        int i, num, s_i, s_num;
 973        struct nlattr *bc;
 974        struct sock *sk;
 975
 976        bc = cb_data->inet_diag_nla_bc;
 977        if (idiag_states & TCPF_SYN_RECV)
 978                idiag_states |= TCPF_NEW_SYN_RECV;
 979        s_i = cb->args[1];
 980        s_num = num = cb->args[2];
 981
 982        if (cb->args[0] == 0) {
 983                if (!(idiag_states & TCPF_LISTEN) || r->id.idiag_dport)
 984                        goto skip_listen_ht;
 985
 986                for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
 987                        struct inet_listen_hashbucket *ilb;
 988
 989                        num = 0;
 990                        ilb = &hashinfo->listening_hash[i];
 991                        spin_lock(&ilb->lock);
 992                        sk_for_each(sk, &ilb->head) {
 993                                struct inet_sock *inet = inet_sk(sk);
 994
 995                                if (!net_eq(sock_net(sk), net))
 996                                        continue;
 997
 998                                if (num < s_num) {
 999                                        num++;
1000                                        continue;
1001                                }
1002
1003                                if (r->sdiag_family != AF_UNSPEC &&
1004                                    sk->sk_family != r->sdiag_family)
1005                                        goto next_listen;
1006
1007                                if (r->id.idiag_sport != inet->inet_sport &&
1008                                    r->id.idiag_sport)
1009                                        goto next_listen;
1010
1011                                if (!inet_diag_bc_sk(bc, sk))
1012                                        goto next_listen;
1013
1014                                if (inet_sk_diag_fill(sk, inet_csk(sk), skb,
1015                                                      cb, r, NLM_F_MULTI,
1016                                                      net_admin) < 0) {
1017                                        spin_unlock(&ilb->lock);
1018                                        goto done;
1019                                }
1020
1021next_listen:
1022                                ++num;
1023                        }
1024                        spin_unlock(&ilb->lock);
1025
1026                        s_num = 0;
1027                }
1028skip_listen_ht:
1029                cb->args[0] = 1;
1030                s_i = num = s_num = 0;
1031        }
1032
1033        if (!(idiag_states & ~TCPF_LISTEN))
1034                goto out;
1035
1036#define SKARR_SZ 16
1037        for (i = s_i; i <= hashinfo->ehash_mask; i++) {
1038                struct inet_ehash_bucket *head = &hashinfo->ehash[i];
1039                spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
1040                struct hlist_nulls_node *node;
1041                struct sock *sk_arr[SKARR_SZ];
1042                int num_arr[SKARR_SZ];
1043                int idx, accum, res;
1044
1045                if (hlist_nulls_empty(&head->chain))
1046                        continue;
1047
1048                if (i > s_i)
1049                        s_num = 0;
1050
1051next_chunk:
1052                num = 0;
1053                accum = 0;
1054                spin_lock_bh(lock);
1055                sk_nulls_for_each(sk, node, &head->chain) {
1056                        int state;
1057
1058                        if (!net_eq(sock_net(sk), net))
1059                                continue;
1060                        if (num < s_num)
1061                                goto next_normal;
1062                        state = (sk->sk_state == TCP_TIME_WAIT) ?
1063                                inet_twsk(sk)->tw_substate : sk->sk_state;
1064                        if (!(idiag_states & (1 << state)))
1065                                goto next_normal;
1066                        if (r->sdiag_family != AF_UNSPEC &&
1067                            sk->sk_family != r->sdiag_family)
1068                                goto next_normal;
1069                        if (r->id.idiag_sport != htons(sk->sk_num) &&
1070                            r->id.idiag_sport)
1071                                goto next_normal;
1072                        if (r->id.idiag_dport != sk->sk_dport &&
1073                            r->id.idiag_dport)
1074                                goto next_normal;
1075                        twsk_build_assert();
1076
1077                        if (!inet_diag_bc_sk(bc, sk))
1078                                goto next_normal;
1079
1080                        if (!refcount_inc_not_zero(&sk->sk_refcnt))
1081                                goto next_normal;
1082
1083                        num_arr[accum] = num;
1084                        sk_arr[accum] = sk;
1085                        if (++accum == SKARR_SZ)
1086                                break;
1087next_normal:
1088                        ++num;
1089                }
1090                spin_unlock_bh(lock);
1091                res = 0;
1092                for (idx = 0; idx < accum; idx++) {
1093                        if (res >= 0) {
1094                                res = sk_diag_fill(sk_arr[idx], skb, cb, r,
1095                                                   NLM_F_MULTI, net_admin);
1096                                if (res < 0)
1097                                        num = num_arr[idx];
1098                        }
1099                        sock_gen_put(sk_arr[idx]);
1100                }
1101                if (res < 0)
1102                        break;
1103                cond_resched();
1104                if (accum == SKARR_SZ) {
1105                        s_num = num + 1;
1106                        goto next_chunk;
1107                }
1108        }
1109
1110done:
1111        cb->args[1] = i;
1112        cb->args[2] = num;
1113out:
1114        ;
1115}
1116EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
1117
1118static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
1119                            const struct inet_diag_req_v2 *r)
1120{
1121        struct inet_diag_dump_data *cb_data = cb->data;
1122        const struct inet_diag_handler *handler;
1123        u32 prev_min_dump_alloc;
1124        int protocol, err = 0;
1125
1126        protocol = inet_diag_get_protocol(r, cb_data);
1127
1128again:
1129        prev_min_dump_alloc = cb->min_dump_alloc;
1130        handler = inet_diag_lock_handler(protocol);
1131        if (!IS_ERR(handler))
1132                handler->dump(skb, cb, r);
1133        else
1134                err = PTR_ERR(handler);
1135        inet_diag_unlock_handler(handler);
1136
1137        /* The skb is not large enough to fit one sk info and
1138         * inet_sk_diag_fill() has requested for a larger skb.
1139         */
1140        if (!skb->len && cb->min_dump_alloc > prev_min_dump_alloc) {
1141                err = pskb_expand_head(skb, 0, cb->min_dump_alloc, GFP_KERNEL);
1142                if (!err)
1143                        goto again;
1144        }
1145
1146        return err ? : skb->len;
1147}
1148
1149static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
1150{
1151        return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh));
1152}
1153
1154static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen)
1155{
1156        const struct nlmsghdr *nlh = cb->nlh;
1157        struct inet_diag_dump_data *cb_data;
1158        struct sk_buff *skb = cb->skb;
1159        struct nlattr *nla;
1160        int err;
1161
1162        cb_data = kzalloc(sizeof(*cb_data), GFP_KERNEL);
1163        if (!cb_data)
1164                return -ENOMEM;
1165
1166        err = inet_diag_parse_attrs(nlh, hdrlen, cb_data->req_nlas);
1167        if (err) {
1168                kfree(cb_data);
1169                return err;
1170        }
1171        nla = cb_data->inet_diag_nla_bc;
1172        if (nla) {
1173                err = inet_diag_bc_audit(nla, skb);
1174                if (err) {
1175                        kfree(cb_data);
1176                        return err;
1177                }
1178        }
1179
1180        nla = cb_data->inet_diag_nla_bpf_stgs;
1181        if (nla) {
1182                struct bpf_sk_storage_diag *bpf_stg_diag;
1183
1184                bpf_stg_diag = bpf_sk_storage_diag_alloc(nla);
1185                if (IS_ERR(bpf_stg_diag)) {
1186                        kfree(cb_data);
1187                        return PTR_ERR(bpf_stg_diag);
1188                }
1189                cb_data->bpf_stg_diag = bpf_stg_diag;
1190        }
1191
1192        cb->data = cb_data;
1193        return 0;
1194}
1195
1196static int inet_diag_dump_start(struct netlink_callback *cb)
1197{
1198        return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req_v2));
1199}
1200
1201static int inet_diag_dump_start_compat(struct netlink_callback *cb)
1202{
1203        return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req));
1204}
1205
1206static int inet_diag_dump_done(struct netlink_callback *cb)
1207{
1208        struct inet_diag_dump_data *cb_data = cb->data;
1209
1210        bpf_sk_storage_diag_free(cb_data->bpf_stg_diag);
1211        kfree(cb->data);
1212
1213        return 0;
1214}
1215
1216static int inet_diag_type2proto(int type)
1217{
1218        switch (type) {
1219        case TCPDIAG_GETSOCK:
1220                return IPPROTO_TCP;
1221        case DCCPDIAG_GETSOCK:
1222                return IPPROTO_DCCP;
1223        default:
1224                return 0;
1225        }
1226}
1227
1228static int inet_diag_dump_compat(struct sk_buff *skb,
1229                                 struct netlink_callback *cb)
1230{
1231        struct inet_diag_req *rc = nlmsg_data(cb->nlh);
1232        struct inet_diag_req_v2 req;
1233
1234        req.sdiag_family = AF_UNSPEC; /* compatibility */
1235        req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
1236        req.idiag_ext = rc->idiag_ext;
1237        req.idiag_states = rc->idiag_states;
1238        req.id = rc->id;
1239
1240        return __inet_diag_dump(skb, cb, &req);
1241}
1242
1243static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
1244                                      const struct nlmsghdr *nlh)
1245{
1246        struct inet_diag_req *rc = nlmsg_data(nlh);
1247        struct inet_diag_req_v2 req;
1248
1249        req.sdiag_family = rc->idiag_family;
1250        req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type);
1251        req.idiag_ext = rc->idiag_ext;
1252        req.idiag_states = rc->idiag_states;
1253        req.id = rc->id;
1254
1255        return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh,
1256                                   sizeof(struct inet_diag_req), &req);
1257}
1258
1259static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
1260{
1261        int hdrlen = sizeof(struct inet_diag_req);
1262        struct net *net = sock_net(skb->sk);
1263
1264        if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
1265            nlmsg_len(nlh) < hdrlen)
1266                return -EINVAL;
1267
1268        if (nlh->nlmsg_flags & NLM_F_DUMP) {
1269                struct netlink_dump_control c = {
1270                        .start = inet_diag_dump_start_compat,
1271                        .done = inet_diag_dump_done,
1272                        .dump = inet_diag_dump_compat,
1273                };
1274                return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
1275        }
1276
1277        return inet_diag_get_exact_compat(skb, nlh);
1278}
1279
1280static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
1281{
1282        int hdrlen = sizeof(struct inet_diag_req_v2);
1283        struct net *net = sock_net(skb->sk);
1284
1285        if (nlmsg_len(h) < hdrlen)
1286                return -EINVAL;
1287
1288        if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
1289            h->nlmsg_flags & NLM_F_DUMP) {
1290                struct netlink_dump_control c = {
1291                        .start = inet_diag_dump_start,
1292                        .done = inet_diag_dump_done,
1293                        .dump = inet_diag_dump,
1294                };
1295                return netlink_dump_start(net->diag_nlsk, skb, h, &c);
1296        }
1297
1298        return inet_diag_cmd_exact(h->nlmsg_type, skb, h, hdrlen,
1299                                   nlmsg_data(h));
1300}
1301
1302static
1303int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk)
1304{
1305        const struct inet_diag_handler *handler;
1306        struct nlmsghdr *nlh;
1307        struct nlattr *attr;
1308        struct inet_diag_msg *r;
1309        void *info = NULL;
1310        int err = 0;
1311
1312        nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0);
1313        if (!nlh)
1314                return -ENOMEM;
1315
1316        r = nlmsg_data(nlh);
1317        memset(r, 0, sizeof(*r));
1318        inet_diag_msg_common_fill(r, sk);
1319        if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM)
1320                r->id.idiag_sport = inet_sk(sk)->inet_sport;
1321        r->idiag_state = sk->sk_state;
1322
1323        if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) {
1324                nlmsg_cancel(skb, nlh);
1325                return err;
1326        }
1327
1328        handler = inet_diag_lock_handler(sk->sk_protocol);
1329        if (IS_ERR(handler)) {
1330                inet_diag_unlock_handler(handler);
1331                nlmsg_cancel(skb, nlh);
1332                return PTR_ERR(handler);
1333        }
1334
1335        attr = handler->idiag_info_size
1336                ? nla_reserve_64bit(skb, INET_DIAG_INFO,
1337                                    handler->idiag_info_size,
1338                                    INET_DIAG_PAD)
1339                : NULL;
1340        if (attr)
1341                info = nla_data(attr);
1342
1343        handler->idiag_get_info(sk, r, info);
1344        inet_diag_unlock_handler(handler);
1345
1346        nlmsg_end(skb, nlh);
1347        return 0;
1348}
1349
1350static const struct sock_diag_handler inet_diag_handler = {
1351        .family = AF_INET,
1352        .dump = inet_diag_handler_cmd,
1353        .get_info = inet_diag_handler_get_info,
1354        .destroy = inet_diag_handler_cmd,
1355};
1356
1357static const struct sock_diag_handler inet6_diag_handler = {
1358        .family = AF_INET6,
1359        .dump = inet_diag_handler_cmd,
1360        .get_info = inet_diag_handler_get_info,
1361        .destroy = inet_diag_handler_cmd,
1362};
1363
1364int inet_diag_register(const struct inet_diag_handler *h)
1365{
1366        const __u16 type = h->idiag_type;
1367        int err = -EINVAL;
1368
1369        if (type >= IPPROTO_MAX)
1370                goto out;
1371
1372        mutex_lock(&inet_diag_table_mutex);
1373        err = -EEXIST;
1374        if (!inet_diag_table[type]) {
1375                inet_diag_table[type] = h;
1376                err = 0;
1377        }
1378        mutex_unlock(&inet_diag_table_mutex);
1379out:
1380        return err;
1381}
1382EXPORT_SYMBOL_GPL(inet_diag_register);
1383
1384void inet_diag_unregister(const struct inet_diag_handler *h)
1385{
1386        const __u16 type = h->idiag_type;
1387
1388        if (type >= IPPROTO_MAX)
1389                return;
1390
1391        mutex_lock(&inet_diag_table_mutex);
1392        inet_diag_table[type] = NULL;
1393        mutex_unlock(&inet_diag_table_mutex);
1394}
1395EXPORT_SYMBOL_GPL(inet_diag_unregister);
1396
1397static int __init inet_diag_init(void)
1398{
1399        const int inet_diag_table_size = (IPPROTO_MAX *
1400                                          sizeof(struct inet_diag_handler *));
1401        int err = -ENOMEM;
1402
1403        inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL);
1404        if (!inet_diag_table)
1405                goto out;
1406
1407        err = sock_diag_register(&inet_diag_handler);
1408        if (err)
1409                goto out_free_nl;
1410
1411        err = sock_diag_register(&inet6_diag_handler);
1412        if (err)
1413                goto out_free_inet;
1414
1415        sock_diag_register_inet_compat(inet_diag_rcv_msg_compat);
1416out:
1417        return err;
1418
1419out_free_inet:
1420        sock_diag_unregister(&inet_diag_handler);
1421out_free_nl:
1422        kfree(inet_diag_table);
1423        goto out;
1424}
1425
1426static void __exit inet_diag_exit(void)
1427{
1428        sock_diag_unregister(&inet6_diag_handler);
1429        sock_diag_unregister(&inet_diag_handler);
1430        sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat);
1431        kfree(inet_diag_table);
1432}
1433
1434module_init(inet_diag_init);
1435module_exit(inet_diag_exit);
1436MODULE_LICENSE("GPL");
1437MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */);
1438MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */);
1439