linux/net/core/sock_map.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
   3
   4#include <linux/bpf.h>
   5#include <linux/filter.h>
   6#include <linux/errno.h>
   7#include <linux/file.h>
   8#include <linux/net.h>
   9#include <linux/workqueue.h>
  10#include <linux/skmsg.h>
  11#include <linux/list.h>
  12#include <linux/jhash.h>
  13#include <linux/sock_diag.h>
  14#include <net/udp.h>
  15
  16struct bpf_stab {
  17        struct bpf_map map;
  18        struct sock **sks;
  19        struct sk_psock_progs progs;
  20        raw_spinlock_t lock;
  21};
  22
  23#define SOCK_CREATE_FLAG_MASK                           \
  24        (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
  25
  26static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
  27{
  28        struct bpf_stab *stab;
  29        u64 cost;
  30        int err;
  31
  32        if (!capable(CAP_NET_ADMIN))
  33                return ERR_PTR(-EPERM);
  34        if (attr->max_entries == 0 ||
  35            attr->key_size    != 4 ||
  36            (attr->value_size != sizeof(u32) &&
  37             attr->value_size != sizeof(u64)) ||
  38            attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
  39                return ERR_PTR(-EINVAL);
  40
  41        stab = kzalloc(sizeof(*stab), GFP_USER);
  42        if (!stab)
  43                return ERR_PTR(-ENOMEM);
  44
  45        bpf_map_init_from_attr(&stab->map, attr);
  46        raw_spin_lock_init(&stab->lock);
  47
  48        /* Make sure page count doesn't overflow. */
  49        cost = (u64) stab->map.max_entries * sizeof(struct sock *);
  50        err = bpf_map_charge_init(&stab->map.memory, cost);
  51        if (err)
  52                goto free_stab;
  53
  54        stab->sks = bpf_map_area_alloc(stab->map.max_entries *
  55                                       sizeof(struct sock *),
  56                                       stab->map.numa_node);
  57        if (stab->sks)
  58                return &stab->map;
  59        err = -ENOMEM;
  60        bpf_map_charge_finish(&stab->map.memory);
  61free_stab:
  62        kfree(stab);
  63        return ERR_PTR(err);
  64}
  65
  66int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
  67{
  68        u32 ufd = attr->target_fd;
  69        struct bpf_map *map;
  70        struct fd f;
  71        int ret;
  72
  73        if (attr->attach_flags || attr->replace_bpf_fd)
  74                return -EINVAL;
  75
  76        f = fdget(ufd);
  77        map = __bpf_map_get(f);
  78        if (IS_ERR(map))
  79                return PTR_ERR(map);
  80        ret = sock_map_prog_update(map, prog, NULL, attr->attach_type);
  81        fdput(f);
  82        return ret;
  83}
  84
  85int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype)
  86{
  87        u32 ufd = attr->target_fd;
  88        struct bpf_prog *prog;
  89        struct bpf_map *map;
  90        struct fd f;
  91        int ret;
  92
  93        if (attr->attach_flags || attr->replace_bpf_fd)
  94                return -EINVAL;
  95
  96        f = fdget(ufd);
  97        map = __bpf_map_get(f);
  98        if (IS_ERR(map))
  99                return PTR_ERR(map);
 100
 101        prog = bpf_prog_get(attr->attach_bpf_fd);
 102        if (IS_ERR(prog)) {
 103                ret = PTR_ERR(prog);
 104                goto put_map;
 105        }
 106
 107        if (prog->type != ptype) {
 108                ret = -EINVAL;
 109                goto put_prog;
 110        }
 111
 112        ret = sock_map_prog_update(map, NULL, prog, attr->attach_type);
 113put_prog:
 114        bpf_prog_put(prog);
 115put_map:
 116        fdput(f);
 117        return ret;
 118}
 119
 120static void sock_map_sk_acquire(struct sock *sk)
 121        __acquires(&sk->sk_lock.slock)
 122{
 123        lock_sock(sk);
 124        preempt_disable();
 125        rcu_read_lock();
 126}
 127
 128static void sock_map_sk_release(struct sock *sk)
 129        __releases(&sk->sk_lock.slock)
 130{
 131        rcu_read_unlock();
 132        preempt_enable();
 133        release_sock(sk);
 134}
 135
 136static void sock_map_add_link(struct sk_psock *psock,
 137                              struct sk_psock_link *link,
 138                              struct bpf_map *map, void *link_raw)
 139{
 140        link->link_raw = link_raw;
 141        link->map = map;
 142        spin_lock_bh(&psock->link_lock);
 143        list_add_tail(&link->list, &psock->link);
 144        spin_unlock_bh(&psock->link_lock);
 145}
 146
 147static void sock_map_del_link(struct sock *sk,
 148                              struct sk_psock *psock, void *link_raw)
 149{
 150        struct sk_psock_link *link, *tmp;
 151        bool strp_stop = false;
 152
 153        spin_lock_bh(&psock->link_lock);
 154        list_for_each_entry_safe(link, tmp, &psock->link, list) {
 155                if (link->link_raw == link_raw) {
 156                        struct bpf_map *map = link->map;
 157                        struct bpf_stab *stab = container_of(map, struct bpf_stab,
 158                                                             map);
 159                        if (psock->parser.enabled && stab->progs.skb_parser)
 160                                strp_stop = true;
 161                        list_del(&link->list);
 162                        sk_psock_free_link(link);
 163                }
 164        }
 165        spin_unlock_bh(&psock->link_lock);
 166        if (strp_stop) {
 167                write_lock_bh(&sk->sk_callback_lock);
 168                sk_psock_stop_strp(sk, psock);
 169                write_unlock_bh(&sk->sk_callback_lock);
 170        }
 171}
 172
 173static void sock_map_unref(struct sock *sk, void *link_raw)
 174{
 175        struct sk_psock *psock = sk_psock(sk);
 176
 177        if (likely(psock)) {
 178                sock_map_del_link(sk, psock, link_raw);
 179                sk_psock_put(sk, psock);
 180        }
 181}
 182
 183static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
 184{
 185        struct proto *prot;
 186
 187        sock_owned_by_me(sk);
 188
 189        switch (sk->sk_type) {
 190        case SOCK_STREAM:
 191                prot = tcp_bpf_get_proto(sk, psock);
 192                break;
 193
 194        case SOCK_DGRAM:
 195                prot = udp_bpf_get_proto(sk, psock);
 196                break;
 197
 198        default:
 199                return -EINVAL;
 200        }
 201
 202        if (IS_ERR(prot))
 203                return PTR_ERR(prot);
 204
 205        sk_psock_update_proto(sk, psock, prot);
 206        return 0;
 207}
 208
 209static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
 210{
 211        struct sk_psock *psock;
 212
 213        rcu_read_lock();
 214        psock = sk_psock(sk);
 215        if (psock) {
 216                if (sk->sk_prot->close != sock_map_close) {
 217                        psock = ERR_PTR(-EBUSY);
 218                        goto out;
 219                }
 220
 221                if (!refcount_inc_not_zero(&psock->refcnt))
 222                        psock = ERR_PTR(-EBUSY);
 223        }
 224out:
 225        rcu_read_unlock();
 226        return psock;
 227}
 228
 229static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
 230                         struct sock *sk)
 231{
 232        struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
 233        struct sk_psock *psock;
 234        bool skb_progs;
 235        int ret;
 236
 237        skb_verdict = READ_ONCE(progs->skb_verdict);
 238        skb_parser = READ_ONCE(progs->skb_parser);
 239        skb_progs = skb_parser && skb_verdict;
 240        if (skb_progs) {
 241                skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
 242                if (IS_ERR(skb_verdict))
 243                        return PTR_ERR(skb_verdict);
 244                skb_parser = bpf_prog_inc_not_zero(skb_parser);
 245                if (IS_ERR(skb_parser)) {
 246                        bpf_prog_put(skb_verdict);
 247                        return PTR_ERR(skb_parser);
 248                }
 249        }
 250
 251        msg_parser = READ_ONCE(progs->msg_parser);
 252        if (msg_parser) {
 253                msg_parser = bpf_prog_inc_not_zero(msg_parser);
 254                if (IS_ERR(msg_parser)) {
 255                        ret = PTR_ERR(msg_parser);
 256                        goto out;
 257                }
 258        }
 259
 260        psock = sock_map_psock_get_checked(sk);
 261        if (IS_ERR(psock)) {
 262                ret = PTR_ERR(psock);
 263                goto out_progs;
 264        }
 265
 266        if (psock) {
 267                if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
 268                    (skb_progs  && READ_ONCE(psock->progs.skb_parser))) {
 269                        sk_psock_put(sk, psock);
 270                        ret = -EBUSY;
 271                        goto out_progs;
 272                }
 273        } else {
 274                psock = sk_psock_init(sk, map->numa_node);
 275                if (!psock) {
 276                        ret = -ENOMEM;
 277                        goto out_progs;
 278                }
 279        }
 280
 281        if (msg_parser)
 282                psock_set_prog(&psock->progs.msg_parser, msg_parser);
 283
 284        ret = sock_map_init_proto(sk, psock);
 285        if (ret < 0)
 286                goto out_drop;
 287
 288        write_lock_bh(&sk->sk_callback_lock);
 289        if (skb_progs && !psock->parser.enabled) {
 290                ret = sk_psock_init_strp(sk, psock);
 291                if (ret) {
 292                        write_unlock_bh(&sk->sk_callback_lock);
 293                        goto out_drop;
 294                }
 295                psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
 296                psock_set_prog(&psock->progs.skb_parser, skb_parser);
 297                sk_psock_start_strp(sk, psock);
 298        }
 299        write_unlock_bh(&sk->sk_callback_lock);
 300        return 0;
 301out_drop:
 302        sk_psock_put(sk, psock);
 303out_progs:
 304        if (msg_parser)
 305                bpf_prog_put(msg_parser);
 306out:
 307        if (skb_progs) {
 308                bpf_prog_put(skb_verdict);
 309                bpf_prog_put(skb_parser);
 310        }
 311        return ret;
 312}
 313
 314static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
 315{
 316        struct sk_psock *psock;
 317        int ret;
 318
 319        psock = sock_map_psock_get_checked(sk);
 320        if (IS_ERR(psock))
 321                return PTR_ERR(psock);
 322
 323        if (!psock) {
 324                psock = sk_psock_init(sk, map->numa_node);
 325                if (!psock)
 326                        return -ENOMEM;
 327        }
 328
 329        ret = sock_map_init_proto(sk, psock);
 330        if (ret < 0)
 331                sk_psock_put(sk, psock);
 332        return ret;
 333}
 334
 335static void sock_map_free(struct bpf_map *map)
 336{
 337        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 338        int i;
 339
 340        /* After the sync no updates or deletes will be in-flight so it
 341         * is safe to walk map and remove entries without risking a race
 342         * in EEXIST update case.
 343         */
 344        synchronize_rcu();
 345        for (i = 0; i < stab->map.max_entries; i++) {
 346                struct sock **psk = &stab->sks[i];
 347                struct sock *sk;
 348
 349                sk = xchg(psk, NULL);
 350                if (sk) {
 351                        lock_sock(sk);
 352                        rcu_read_lock();
 353                        sock_map_unref(sk, psk);
 354                        rcu_read_unlock();
 355                        release_sock(sk);
 356                }
 357        }
 358
 359        /* wait for psock readers accessing its map link */
 360        synchronize_rcu();
 361
 362        bpf_map_area_free(stab->sks);
 363        kfree(stab);
 364}
 365
 366static void sock_map_release_progs(struct bpf_map *map)
 367{
 368        psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
 369}
 370
 371static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
 372{
 373        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 374
 375        WARN_ON_ONCE(!rcu_read_lock_held());
 376
 377        if (unlikely(key >= map->max_entries))
 378                return NULL;
 379        return READ_ONCE(stab->sks[key]);
 380}
 381
 382static void *sock_map_lookup(struct bpf_map *map, void *key)
 383{
 384        struct sock *sk;
 385
 386        sk = __sock_map_lookup_elem(map, *(u32 *)key);
 387        if (!sk || !sk_fullsock(sk))
 388                return NULL;
 389        if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt))
 390                return NULL;
 391        return sk;
 392}
 393
 394static void *sock_map_lookup_sys(struct bpf_map *map, void *key)
 395{
 396        struct sock *sk;
 397
 398        if (map->value_size != sizeof(u64))
 399                return ERR_PTR(-ENOSPC);
 400
 401        sk = __sock_map_lookup_elem(map, *(u32 *)key);
 402        if (!sk)
 403                return ERR_PTR(-ENOENT);
 404
 405        sock_gen_cookie(sk);
 406        return &sk->sk_cookie;
 407}
 408
 409static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
 410                             struct sock **psk)
 411{
 412        struct sock *sk;
 413        int err = 0;
 414
 415        raw_spin_lock_bh(&stab->lock);
 416        sk = *psk;
 417        if (!sk_test || sk_test == sk)
 418                sk = xchg(psk, NULL);
 419
 420        if (likely(sk))
 421                sock_map_unref(sk, psk);
 422        else
 423                err = -EINVAL;
 424
 425        raw_spin_unlock_bh(&stab->lock);
 426        return err;
 427}
 428
 429static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
 430                                      void *link_raw)
 431{
 432        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 433
 434        __sock_map_delete(stab, sk, link_raw);
 435}
 436
 437static int sock_map_delete_elem(struct bpf_map *map, void *key)
 438{
 439        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 440        u32 i = *(u32 *)key;
 441        struct sock **psk;
 442
 443        if (unlikely(i >= map->max_entries))
 444                return -EINVAL;
 445
 446        psk = &stab->sks[i];
 447        return __sock_map_delete(stab, NULL, psk);
 448}
 449
 450static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
 451{
 452        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 453        u32 i = key ? *(u32 *)key : U32_MAX;
 454        u32 *key_next = next;
 455
 456        if (i == stab->map.max_entries - 1)
 457                return -ENOENT;
 458        if (i >= stab->map.max_entries)
 459                *key_next = 0;
 460        else
 461                *key_next = i + 1;
 462        return 0;
 463}
 464
 465static bool sock_map_redirect_allowed(const struct sock *sk);
 466
 467static int sock_map_update_common(struct bpf_map *map, u32 idx,
 468                                  struct sock *sk, u64 flags)
 469{
 470        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 471        struct sk_psock_link *link;
 472        struct sk_psock *psock;
 473        struct sock *osk;
 474        int ret;
 475
 476        WARN_ON_ONCE(!rcu_read_lock_held());
 477        if (unlikely(flags > BPF_EXIST))
 478                return -EINVAL;
 479        if (unlikely(idx >= map->max_entries))
 480                return -E2BIG;
 481        if (inet_csk_has_ulp(sk))
 482                return -EINVAL;
 483
 484        link = sk_psock_init_link();
 485        if (!link)
 486                return -ENOMEM;
 487
 488        /* Only sockets we can redirect into/from in BPF need to hold
 489         * refs to parser/verdict progs and have their sk_data_ready
 490         * and sk_write_space callbacks overridden.
 491         */
 492        if (sock_map_redirect_allowed(sk))
 493                ret = sock_map_link(map, &stab->progs, sk);
 494        else
 495                ret = sock_map_link_no_progs(map, sk);
 496        if (ret < 0)
 497                goto out_free;
 498
 499        psock = sk_psock(sk);
 500        WARN_ON_ONCE(!psock);
 501
 502        raw_spin_lock_bh(&stab->lock);
 503        osk = stab->sks[idx];
 504        if (osk && flags == BPF_NOEXIST) {
 505                ret = -EEXIST;
 506                goto out_unlock;
 507        } else if (!osk && flags == BPF_EXIST) {
 508                ret = -ENOENT;
 509                goto out_unlock;
 510        }
 511
 512        sock_map_add_link(psock, link, map, &stab->sks[idx]);
 513        stab->sks[idx] = sk;
 514        if (osk)
 515                sock_map_unref(osk, &stab->sks[idx]);
 516        raw_spin_unlock_bh(&stab->lock);
 517        return 0;
 518out_unlock:
 519        raw_spin_unlock_bh(&stab->lock);
 520        if (psock)
 521                sk_psock_put(sk, psock);
 522out_free:
 523        sk_psock_free_link(link);
 524        return ret;
 525}
 526
 527static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
 528{
 529        return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
 530               ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB ||
 531               ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB;
 532}
 533
 534static bool sk_is_tcp(const struct sock *sk)
 535{
 536        return sk->sk_type == SOCK_STREAM &&
 537               sk->sk_protocol == IPPROTO_TCP;
 538}
 539
 540static bool sk_is_udp(const struct sock *sk)
 541{
 542        return sk->sk_type == SOCK_DGRAM &&
 543               sk->sk_protocol == IPPROTO_UDP;
 544}
 545
 546static bool sock_map_redirect_allowed(const struct sock *sk)
 547{
 548        return sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN;
 549}
 550
 551static bool sock_map_sk_is_suitable(const struct sock *sk)
 552{
 553        return sk_is_tcp(sk) || sk_is_udp(sk);
 554}
 555
 556static bool sock_map_sk_state_allowed(const struct sock *sk)
 557{
 558        if (sk_is_tcp(sk))
 559                return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN);
 560        else if (sk_is_udp(sk))
 561                return sk_hashed(sk);
 562
 563        return false;
 564}
 565
 566static int sock_map_update_elem(struct bpf_map *map, void *key,
 567                                void *value, u64 flags)
 568{
 569        u32 idx = *(u32 *)key;
 570        struct socket *sock;
 571        struct sock *sk;
 572        int ret;
 573        u64 ufd;
 574
 575        if (map->value_size == sizeof(u64))
 576                ufd = *(u64 *)value;
 577        else
 578                ufd = *(u32 *)value;
 579        if (ufd > S32_MAX)
 580                return -EINVAL;
 581
 582        sock = sockfd_lookup(ufd, &ret);
 583        if (!sock)
 584                return ret;
 585        sk = sock->sk;
 586        if (!sk) {
 587                ret = -EINVAL;
 588                goto out;
 589        }
 590        if (!sock_map_sk_is_suitable(sk)) {
 591                ret = -EOPNOTSUPP;
 592                goto out;
 593        }
 594
 595        sock_map_sk_acquire(sk);
 596        if (!sock_map_sk_state_allowed(sk))
 597                ret = -EOPNOTSUPP;
 598        else
 599                ret = sock_map_update_common(map, idx, sk, flags);
 600        sock_map_sk_release(sk);
 601out:
 602        fput(sock->file);
 603        return ret;
 604}
 605
 606BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
 607           struct bpf_map *, map, void *, key, u64, flags)
 608{
 609        WARN_ON_ONCE(!rcu_read_lock_held());
 610
 611        if (likely(sock_map_sk_is_suitable(sops->sk) &&
 612                   sock_map_op_okay(sops)))
 613                return sock_map_update_common(map, *(u32 *)key, sops->sk,
 614                                              flags);
 615        return -EOPNOTSUPP;
 616}
 617
 618const struct bpf_func_proto bpf_sock_map_update_proto = {
 619        .func           = bpf_sock_map_update,
 620        .gpl_only       = false,
 621        .pkt_access     = true,
 622        .ret_type       = RET_INTEGER,
 623        .arg1_type      = ARG_PTR_TO_CTX,
 624        .arg2_type      = ARG_CONST_MAP_PTR,
 625        .arg3_type      = ARG_PTR_TO_MAP_KEY,
 626        .arg4_type      = ARG_ANYTHING,
 627};
 628
 629BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
 630           struct bpf_map *, map, u32, key, u64, flags)
 631{
 632        struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
 633        struct sock *sk;
 634
 635        if (unlikely(flags & ~(BPF_F_INGRESS)))
 636                return SK_DROP;
 637
 638        sk = __sock_map_lookup_elem(map, key);
 639        if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
 640                return SK_DROP;
 641
 642        tcb->bpf.flags = flags;
 643        tcb->bpf.sk_redir = sk;
 644        return SK_PASS;
 645}
 646
 647const struct bpf_func_proto bpf_sk_redirect_map_proto = {
 648        .func           = bpf_sk_redirect_map,
 649        .gpl_only       = false,
 650        .ret_type       = RET_INTEGER,
 651        .arg1_type      = ARG_PTR_TO_CTX,
 652        .arg2_type      = ARG_CONST_MAP_PTR,
 653        .arg3_type      = ARG_ANYTHING,
 654        .arg4_type      = ARG_ANYTHING,
 655};
 656
 657BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
 658           struct bpf_map *, map, u32, key, u64, flags)
 659{
 660        struct sock *sk;
 661
 662        if (unlikely(flags & ~(BPF_F_INGRESS)))
 663                return SK_DROP;
 664
 665        sk = __sock_map_lookup_elem(map, key);
 666        if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
 667                return SK_DROP;
 668
 669        msg->flags = flags;
 670        msg->sk_redir = sk;
 671        return SK_PASS;
 672}
 673
 674const struct bpf_func_proto bpf_msg_redirect_map_proto = {
 675        .func           = bpf_msg_redirect_map,
 676        .gpl_only       = false,
 677        .ret_type       = RET_INTEGER,
 678        .arg1_type      = ARG_PTR_TO_CTX,
 679        .arg2_type      = ARG_CONST_MAP_PTR,
 680        .arg3_type      = ARG_ANYTHING,
 681        .arg4_type      = ARG_ANYTHING,
 682};
 683
 684const struct bpf_map_ops sock_map_ops = {
 685        .map_alloc              = sock_map_alloc,
 686        .map_free               = sock_map_free,
 687        .map_get_next_key       = sock_map_get_next_key,
 688        .map_lookup_elem_sys_only = sock_map_lookup_sys,
 689        .map_update_elem        = sock_map_update_elem,
 690        .map_delete_elem        = sock_map_delete_elem,
 691        .map_lookup_elem        = sock_map_lookup,
 692        .map_release_uref       = sock_map_release_progs,
 693        .map_check_btf          = map_check_no_btf,
 694};
 695
 696struct bpf_htab_elem {
 697        struct rcu_head rcu;
 698        u32 hash;
 699        struct sock *sk;
 700        struct hlist_node node;
 701        u8 key[];
 702};
 703
 704struct bpf_htab_bucket {
 705        struct hlist_head head;
 706        raw_spinlock_t lock;
 707};
 708
 709struct bpf_htab {
 710        struct bpf_map map;
 711        struct bpf_htab_bucket *buckets;
 712        u32 buckets_num;
 713        u32 elem_size;
 714        struct sk_psock_progs progs;
 715        atomic_t count;
 716};
 717
 718static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
 719{
 720        return jhash(key, len, 0);
 721}
 722
 723static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab,
 724                                                       u32 hash)
 725{
 726        return &htab->buckets[hash & (htab->buckets_num - 1)];
 727}
 728
 729static struct bpf_htab_elem *
 730sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
 731                          u32 key_size)
 732{
 733        struct bpf_htab_elem *elem;
 734
 735        hlist_for_each_entry_rcu(elem, head, node) {
 736                if (elem->hash == hash &&
 737                    !memcmp(&elem->key, key, key_size))
 738                        return elem;
 739        }
 740
 741        return NULL;
 742}
 743
 744static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
 745{
 746        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 747        u32 key_size = map->key_size, hash;
 748        struct bpf_htab_bucket *bucket;
 749        struct bpf_htab_elem *elem;
 750
 751        WARN_ON_ONCE(!rcu_read_lock_held());
 752
 753        hash = sock_hash_bucket_hash(key, key_size);
 754        bucket = sock_hash_select_bucket(htab, hash);
 755        elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
 756
 757        return elem ? elem->sk : NULL;
 758}
 759
 760static void sock_hash_free_elem(struct bpf_htab *htab,
 761                                struct bpf_htab_elem *elem)
 762{
 763        atomic_dec(&htab->count);
 764        kfree_rcu(elem, rcu);
 765}
 766
 767static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
 768                                       void *link_raw)
 769{
 770        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 771        struct bpf_htab_elem *elem_probe, *elem = link_raw;
 772        struct bpf_htab_bucket *bucket;
 773
 774        WARN_ON_ONCE(!rcu_read_lock_held());
 775        bucket = sock_hash_select_bucket(htab, elem->hash);
 776
 777        /* elem may be deleted in parallel from the map, but access here
 778         * is okay since it's going away only after RCU grace period.
 779         * However, we need to check whether it's still present.
 780         */
 781        raw_spin_lock_bh(&bucket->lock);
 782        elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
 783                                               elem->key, map->key_size);
 784        if (elem_probe && elem_probe == elem) {
 785                hlist_del_rcu(&elem->node);
 786                sock_map_unref(elem->sk, elem);
 787                sock_hash_free_elem(htab, elem);
 788        }
 789        raw_spin_unlock_bh(&bucket->lock);
 790}
 791
 792static int sock_hash_delete_elem(struct bpf_map *map, void *key)
 793{
 794        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 795        u32 hash, key_size = map->key_size;
 796        struct bpf_htab_bucket *bucket;
 797        struct bpf_htab_elem *elem;
 798        int ret = -ENOENT;
 799
 800        hash = sock_hash_bucket_hash(key, key_size);
 801        bucket = sock_hash_select_bucket(htab, hash);
 802
 803        raw_spin_lock_bh(&bucket->lock);
 804        elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
 805        if (elem) {
 806                hlist_del_rcu(&elem->node);
 807                sock_map_unref(elem->sk, elem);
 808                sock_hash_free_elem(htab, elem);
 809                ret = 0;
 810        }
 811        raw_spin_unlock_bh(&bucket->lock);
 812        return ret;
 813}
 814
 815static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab,
 816                                                  void *key, u32 key_size,
 817                                                  u32 hash, struct sock *sk,
 818                                                  struct bpf_htab_elem *old)
 819{
 820        struct bpf_htab_elem *new;
 821
 822        if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
 823                if (!old) {
 824                        atomic_dec(&htab->count);
 825                        return ERR_PTR(-E2BIG);
 826                }
 827        }
 828
 829        new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
 830                           htab->map.numa_node);
 831        if (!new) {
 832                atomic_dec(&htab->count);
 833                return ERR_PTR(-ENOMEM);
 834        }
 835        memcpy(new->key, key, key_size);
 836        new->sk = sk;
 837        new->hash = hash;
 838        return new;
 839}
 840
 841static int sock_hash_update_common(struct bpf_map *map, void *key,
 842                                   struct sock *sk, u64 flags)
 843{
 844        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 845        u32 key_size = map->key_size, hash;
 846        struct bpf_htab_elem *elem, *elem_new;
 847        struct bpf_htab_bucket *bucket;
 848        struct sk_psock_link *link;
 849        struct sk_psock *psock;
 850        int ret;
 851
 852        WARN_ON_ONCE(!rcu_read_lock_held());
 853        if (unlikely(flags > BPF_EXIST))
 854                return -EINVAL;
 855        if (inet_csk_has_ulp(sk))
 856                return -EINVAL;
 857
 858        link = sk_psock_init_link();
 859        if (!link)
 860                return -ENOMEM;
 861
 862        /* Only sockets we can redirect into/from in BPF need to hold
 863         * refs to parser/verdict progs and have their sk_data_ready
 864         * and sk_write_space callbacks overridden.
 865         */
 866        if (sock_map_redirect_allowed(sk))
 867                ret = sock_map_link(map, &htab->progs, sk);
 868        else
 869                ret = sock_map_link_no_progs(map, sk);
 870        if (ret < 0)
 871                goto out_free;
 872
 873        psock = sk_psock(sk);
 874        WARN_ON_ONCE(!psock);
 875
 876        hash = sock_hash_bucket_hash(key, key_size);
 877        bucket = sock_hash_select_bucket(htab, hash);
 878
 879        raw_spin_lock_bh(&bucket->lock);
 880        elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
 881        if (elem && flags == BPF_NOEXIST) {
 882                ret = -EEXIST;
 883                goto out_unlock;
 884        } else if (!elem && flags == BPF_EXIST) {
 885                ret = -ENOENT;
 886                goto out_unlock;
 887        }
 888
 889        elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
 890        if (IS_ERR(elem_new)) {
 891                ret = PTR_ERR(elem_new);
 892                goto out_unlock;
 893        }
 894
 895        sock_map_add_link(psock, link, map, elem_new);
 896        /* Add new element to the head of the list, so that
 897         * concurrent search will find it before old elem.
 898         */
 899        hlist_add_head_rcu(&elem_new->node, &bucket->head);
 900        if (elem) {
 901                hlist_del_rcu(&elem->node);
 902                sock_map_unref(elem->sk, elem);
 903                sock_hash_free_elem(htab, elem);
 904        }
 905        raw_spin_unlock_bh(&bucket->lock);
 906        return 0;
 907out_unlock:
 908        raw_spin_unlock_bh(&bucket->lock);
 909        sk_psock_put(sk, psock);
 910out_free:
 911        sk_psock_free_link(link);
 912        return ret;
 913}
 914
 915static int sock_hash_update_elem(struct bpf_map *map, void *key,
 916                                 void *value, u64 flags)
 917{
 918        struct socket *sock;
 919        struct sock *sk;
 920        int ret;
 921        u64 ufd;
 922
 923        if (map->value_size == sizeof(u64))
 924                ufd = *(u64 *)value;
 925        else
 926                ufd = *(u32 *)value;
 927        if (ufd > S32_MAX)
 928                return -EINVAL;
 929
 930        sock = sockfd_lookup(ufd, &ret);
 931        if (!sock)
 932                return ret;
 933        sk = sock->sk;
 934        if (!sk) {
 935                ret = -EINVAL;
 936                goto out;
 937        }
 938        if (!sock_map_sk_is_suitable(sk)) {
 939                ret = -EOPNOTSUPP;
 940                goto out;
 941        }
 942
 943        sock_map_sk_acquire(sk);
 944        if (!sock_map_sk_state_allowed(sk))
 945                ret = -EOPNOTSUPP;
 946        else
 947                ret = sock_hash_update_common(map, key, sk, flags);
 948        sock_map_sk_release(sk);
 949out:
 950        fput(sock->file);
 951        return ret;
 952}
 953
 954static int sock_hash_get_next_key(struct bpf_map *map, void *key,
 955                                  void *key_next)
 956{
 957        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 958        struct bpf_htab_elem *elem, *elem_next;
 959        u32 hash, key_size = map->key_size;
 960        struct hlist_head *head;
 961        int i = 0;
 962
 963        if (!key)
 964                goto find_first_elem;
 965        hash = sock_hash_bucket_hash(key, key_size);
 966        head = &sock_hash_select_bucket(htab, hash)->head;
 967        elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
 968        if (!elem)
 969                goto find_first_elem;
 970
 971        elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)),
 972                                     struct bpf_htab_elem, node);
 973        if (elem_next) {
 974                memcpy(key_next, elem_next->key, key_size);
 975                return 0;
 976        }
 977
 978        i = hash & (htab->buckets_num - 1);
 979        i++;
 980find_first_elem:
 981        for (; i < htab->buckets_num; i++) {
 982                head = &sock_hash_select_bucket(htab, i)->head;
 983                elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)),
 984                                             struct bpf_htab_elem, node);
 985                if (elem_next) {
 986                        memcpy(key_next, elem_next->key, key_size);
 987                        return 0;
 988                }
 989        }
 990
 991        return -ENOENT;
 992}
 993
 994static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
 995{
 996        struct bpf_htab *htab;
 997        int i, err;
 998        u64 cost;
 999
1000        if (!capable(CAP_NET_ADMIN))
1001                return ERR_PTR(-EPERM);
1002        if (attr->max_entries == 0 ||
1003            attr->key_size    == 0 ||
1004            (attr->value_size != sizeof(u32) &&
1005             attr->value_size != sizeof(u64)) ||
1006            attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1007                return ERR_PTR(-EINVAL);
1008        if (attr->key_size > MAX_BPF_STACK)
1009                return ERR_PTR(-E2BIG);
1010
1011        htab = kzalloc(sizeof(*htab), GFP_USER);
1012        if (!htab)
1013                return ERR_PTR(-ENOMEM);
1014
1015        bpf_map_init_from_attr(&htab->map, attr);
1016
1017        htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
1018        htab->elem_size = sizeof(struct bpf_htab_elem) +
1019                          round_up(htab->map.key_size, 8);
1020        if (htab->buckets_num == 0 ||
1021            htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) {
1022                err = -EINVAL;
1023                goto free_htab;
1024        }
1025
1026        cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) +
1027               (u64) htab->elem_size * htab->map.max_entries;
1028        if (cost >= U32_MAX - PAGE_SIZE) {
1029                err = -EINVAL;
1030                goto free_htab;
1031        }
1032        err = bpf_map_charge_init(&htab->map.memory, cost);
1033        if (err)
1034                goto free_htab;
1035
1036        htab->buckets = bpf_map_area_alloc(htab->buckets_num *
1037                                           sizeof(struct bpf_htab_bucket),
1038                                           htab->map.numa_node);
1039        if (!htab->buckets) {
1040                bpf_map_charge_finish(&htab->map.memory);
1041                err = -ENOMEM;
1042                goto free_htab;
1043        }
1044
1045        for (i = 0; i < htab->buckets_num; i++) {
1046                INIT_HLIST_HEAD(&htab->buckets[i].head);
1047                raw_spin_lock_init(&htab->buckets[i].lock);
1048        }
1049
1050        return &htab->map;
1051free_htab:
1052        kfree(htab);
1053        return ERR_PTR(err);
1054}
1055
1056static void sock_hash_free(struct bpf_map *map)
1057{
1058        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
1059        struct bpf_htab_bucket *bucket;
1060        struct hlist_head unlink_list;
1061        struct bpf_htab_elem *elem;
1062        struct hlist_node *node;
1063        int i;
1064
1065        /* After the sync no updates or deletes will be in-flight so it
1066         * is safe to walk map and remove entries without risking a race
1067         * in EEXIST update case.
1068         */
1069        synchronize_rcu();
1070        for (i = 0; i < htab->buckets_num; i++) {
1071                bucket = sock_hash_select_bucket(htab, i);
1072
1073                /* We are racing with sock_hash_delete_from_link to
1074                 * enter the spin-lock critical section. Every socket on
1075                 * the list is still linked to sockhash. Since link
1076                 * exists, psock exists and holds a ref to socket. That
1077                 * lets us to grab a socket ref too.
1078                 */
1079                raw_spin_lock_bh(&bucket->lock);
1080                hlist_for_each_entry(elem, &bucket->head, node)
1081                        sock_hold(elem->sk);
1082                hlist_move_list(&bucket->head, &unlink_list);
1083                raw_spin_unlock_bh(&bucket->lock);
1084
1085                /* Process removed entries out of atomic context to
1086                 * block for socket lock before deleting the psock's
1087                 * link to sockhash.
1088                 */
1089                hlist_for_each_entry_safe(elem, node, &unlink_list, node) {
1090                        hlist_del(&elem->node);
1091                        lock_sock(elem->sk);
1092                        rcu_read_lock();
1093                        sock_map_unref(elem->sk, elem);
1094                        rcu_read_unlock();
1095                        release_sock(elem->sk);
1096                        sock_put(elem->sk);
1097                        sock_hash_free_elem(htab, elem);
1098                }
1099        }
1100
1101        /* wait for psock readers accessing its map link */
1102        synchronize_rcu();
1103
1104        bpf_map_area_free(htab->buckets);
1105        kfree(htab);
1106}
1107
1108static void *sock_hash_lookup_sys(struct bpf_map *map, void *key)
1109{
1110        struct sock *sk;
1111
1112        if (map->value_size != sizeof(u64))
1113                return ERR_PTR(-ENOSPC);
1114
1115        sk = __sock_hash_lookup_elem(map, key);
1116        if (!sk)
1117                return ERR_PTR(-ENOENT);
1118
1119        sock_gen_cookie(sk);
1120        return &sk->sk_cookie;
1121}
1122
1123static void *sock_hash_lookup(struct bpf_map *map, void *key)
1124{
1125        struct sock *sk;
1126
1127        sk = __sock_hash_lookup_elem(map, key);
1128        if (!sk || !sk_fullsock(sk))
1129                return NULL;
1130        if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt))
1131                return NULL;
1132        return sk;
1133}
1134
1135static void sock_hash_release_progs(struct bpf_map *map)
1136{
1137        psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs);
1138}
1139
1140BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
1141           struct bpf_map *, map, void *, key, u64, flags)
1142{
1143        WARN_ON_ONCE(!rcu_read_lock_held());
1144
1145        if (likely(sock_map_sk_is_suitable(sops->sk) &&
1146                   sock_map_op_okay(sops)))
1147                return sock_hash_update_common(map, key, sops->sk, flags);
1148        return -EOPNOTSUPP;
1149}
1150
1151const struct bpf_func_proto bpf_sock_hash_update_proto = {
1152        .func           = bpf_sock_hash_update,
1153        .gpl_only       = false,
1154        .pkt_access     = true,
1155        .ret_type       = RET_INTEGER,
1156        .arg1_type      = ARG_PTR_TO_CTX,
1157        .arg2_type      = ARG_CONST_MAP_PTR,
1158        .arg3_type      = ARG_PTR_TO_MAP_KEY,
1159        .arg4_type      = ARG_ANYTHING,
1160};
1161
1162BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
1163           struct bpf_map *, map, void *, key, u64, flags)
1164{
1165        struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
1166        struct sock *sk;
1167
1168        if (unlikely(flags & ~(BPF_F_INGRESS)))
1169                return SK_DROP;
1170
1171        sk = __sock_hash_lookup_elem(map, key);
1172        if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
1173                return SK_DROP;
1174
1175        tcb->bpf.flags = flags;
1176        tcb->bpf.sk_redir = sk;
1177        return SK_PASS;
1178}
1179
1180const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
1181        .func           = bpf_sk_redirect_hash,
1182        .gpl_only       = false,
1183        .ret_type       = RET_INTEGER,
1184        .arg1_type      = ARG_PTR_TO_CTX,
1185        .arg2_type      = ARG_CONST_MAP_PTR,
1186        .arg3_type      = ARG_PTR_TO_MAP_KEY,
1187        .arg4_type      = ARG_ANYTHING,
1188};
1189
1190BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
1191           struct bpf_map *, map, void *, key, u64, flags)
1192{
1193        struct sock *sk;
1194
1195        if (unlikely(flags & ~(BPF_F_INGRESS)))
1196                return SK_DROP;
1197
1198        sk = __sock_hash_lookup_elem(map, key);
1199        if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
1200                return SK_DROP;
1201
1202        msg->flags = flags;
1203        msg->sk_redir = sk;
1204        return SK_PASS;
1205}
1206
1207const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
1208        .func           = bpf_msg_redirect_hash,
1209        .gpl_only       = false,
1210        .ret_type       = RET_INTEGER,
1211        .arg1_type      = ARG_PTR_TO_CTX,
1212        .arg2_type      = ARG_CONST_MAP_PTR,
1213        .arg3_type      = ARG_PTR_TO_MAP_KEY,
1214        .arg4_type      = ARG_ANYTHING,
1215};
1216
1217const struct bpf_map_ops sock_hash_ops = {
1218        .map_alloc              = sock_hash_alloc,
1219        .map_free               = sock_hash_free,
1220        .map_get_next_key       = sock_hash_get_next_key,
1221        .map_update_elem        = sock_hash_update_elem,
1222        .map_delete_elem        = sock_hash_delete_elem,
1223        .map_lookup_elem        = sock_hash_lookup,
1224        .map_lookup_elem_sys_only = sock_hash_lookup_sys,
1225        .map_release_uref       = sock_hash_release_progs,
1226        .map_check_btf          = map_check_no_btf,
1227};
1228
1229static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
1230{
1231        switch (map->map_type) {
1232        case BPF_MAP_TYPE_SOCKMAP:
1233                return &container_of(map, struct bpf_stab, map)->progs;
1234        case BPF_MAP_TYPE_SOCKHASH:
1235                return &container_of(map, struct bpf_htab, map)->progs;
1236        default:
1237                break;
1238        }
1239
1240        return NULL;
1241}
1242
1243int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
1244                         struct bpf_prog *old, u32 which)
1245{
1246        struct sk_psock_progs *progs = sock_map_progs(map);
1247        struct bpf_prog **pprog;
1248
1249        if (!progs)
1250                return -EOPNOTSUPP;
1251
1252        switch (which) {
1253        case BPF_SK_MSG_VERDICT:
1254                pprog = &progs->msg_parser;
1255                break;
1256        case BPF_SK_SKB_STREAM_PARSER:
1257                pprog = &progs->skb_parser;
1258                break;
1259        case BPF_SK_SKB_STREAM_VERDICT:
1260                pprog = &progs->skb_verdict;
1261                break;
1262        default:
1263                return -EOPNOTSUPP;
1264        }
1265
1266        if (old)
1267                return psock_replace_prog(pprog, prog, old);
1268
1269        psock_set_prog(pprog, prog);
1270        return 0;
1271}
1272
1273static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
1274{
1275        switch (link->map->map_type) {
1276        case BPF_MAP_TYPE_SOCKMAP:
1277                return sock_map_delete_from_link(link->map, sk,
1278                                                 link->link_raw);
1279        case BPF_MAP_TYPE_SOCKHASH:
1280                return sock_hash_delete_from_link(link->map, sk,
1281                                                  link->link_raw);
1282        default:
1283                break;
1284        }
1285}
1286
1287static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
1288{
1289        struct sk_psock_link *link;
1290
1291        while ((link = sk_psock_link_pop(psock))) {
1292                sock_map_unlink(sk, link);
1293                sk_psock_free_link(link);
1294        }
1295}
1296
1297void sock_map_unhash(struct sock *sk)
1298{
1299        void (*saved_unhash)(struct sock *sk);
1300        struct sk_psock *psock;
1301
1302        rcu_read_lock();
1303        psock = sk_psock(sk);
1304        if (unlikely(!psock)) {
1305                rcu_read_unlock();
1306                if (sk->sk_prot->unhash)
1307                        sk->sk_prot->unhash(sk);
1308                return;
1309        }
1310
1311        saved_unhash = psock->saved_unhash;
1312        sock_map_remove_links(sk, psock);
1313        rcu_read_unlock();
1314        saved_unhash(sk);
1315}
1316
1317void sock_map_close(struct sock *sk, long timeout)
1318{
1319        void (*saved_close)(struct sock *sk, long timeout);
1320        struct sk_psock *psock;
1321
1322        lock_sock(sk);
1323        rcu_read_lock();
1324        psock = sk_psock(sk);
1325        if (unlikely(!psock)) {
1326                rcu_read_unlock();
1327                release_sock(sk);
1328                return sk->sk_prot->close(sk, timeout);
1329        }
1330
1331        saved_close = psock->saved_close;
1332        sock_map_remove_links(sk, psock);
1333        rcu_read_unlock();
1334        release_sock(sk);
1335        saved_close(sk, timeout);
1336}
1337