linux/net/core/sock_map.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
   3
   4#include <linux/bpf.h>
   5#include <linux/filter.h>
   6#include <linux/errno.h>
   7#include <linux/file.h>
   8#include <linux/net.h>
   9#include <linux/workqueue.h>
  10#include <linux/skmsg.h>
  11#include <linux/list.h>
  12#include <linux/jhash.h>
  13
  14struct bpf_stab {
  15        struct bpf_map map;
  16        struct sock **sks;
  17        struct sk_psock_progs progs;
  18        raw_spinlock_t lock;
  19};
  20
  21#define SOCK_CREATE_FLAG_MASK                           \
  22        (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
  23
  24static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
  25{
  26        struct bpf_stab *stab;
  27        u64 cost;
  28        int err;
  29
  30        if (!capable(CAP_NET_ADMIN))
  31                return ERR_PTR(-EPERM);
  32        if (attr->max_entries == 0 ||
  33            attr->key_size    != 4 ||
  34            attr->value_size  != 4 ||
  35            attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
  36                return ERR_PTR(-EINVAL);
  37
  38        stab = kzalloc(sizeof(*stab), GFP_USER);
  39        if (!stab)
  40                return ERR_PTR(-ENOMEM);
  41
  42        bpf_map_init_from_attr(&stab->map, attr);
  43        raw_spin_lock_init(&stab->lock);
  44
  45        /* Make sure page count doesn't overflow. */
  46        cost = (u64) stab->map.max_entries * sizeof(struct sock *);
  47        err = bpf_map_charge_init(&stab->map.memory, cost);
  48        if (err)
  49                goto free_stab;
  50
  51        stab->sks = bpf_map_area_alloc(stab->map.max_entries *
  52                                       sizeof(struct sock *),
  53                                       stab->map.numa_node);
  54        if (stab->sks)
  55                return &stab->map;
  56        err = -ENOMEM;
  57        bpf_map_charge_finish(&stab->map.memory);
  58free_stab:
  59        kfree(stab);
  60        return ERR_PTR(err);
  61}
  62
  63int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
  64{
  65        u32 ufd = attr->target_fd;
  66        struct bpf_map *map;
  67        struct fd f;
  68        int ret;
  69
  70        f = fdget(ufd);
  71        map = __bpf_map_get(f);
  72        if (IS_ERR(map))
  73                return PTR_ERR(map);
  74        ret = sock_map_prog_update(map, prog, attr->attach_type);
  75        fdput(f);
  76        return ret;
  77}
  78
  79static void sock_map_sk_acquire(struct sock *sk)
  80        __acquires(&sk->sk_lock.slock)
  81{
  82        lock_sock(sk);
  83        preempt_disable();
  84        rcu_read_lock();
  85}
  86
  87static void sock_map_sk_release(struct sock *sk)
  88        __releases(&sk->sk_lock.slock)
  89{
  90        rcu_read_unlock();
  91        preempt_enable();
  92        release_sock(sk);
  93}
  94
  95static void sock_map_add_link(struct sk_psock *psock,
  96                              struct sk_psock_link *link,
  97                              struct bpf_map *map, void *link_raw)
  98{
  99        link->link_raw = link_raw;
 100        link->map = map;
 101        spin_lock_bh(&psock->link_lock);
 102        list_add_tail(&link->list, &psock->link);
 103        spin_unlock_bh(&psock->link_lock);
 104}
 105
 106static void sock_map_del_link(struct sock *sk,
 107                              struct sk_psock *psock, void *link_raw)
 108{
 109        struct sk_psock_link *link, *tmp;
 110        bool strp_stop = false;
 111
 112        spin_lock_bh(&psock->link_lock);
 113        list_for_each_entry_safe(link, tmp, &psock->link, list) {
 114                if (link->link_raw == link_raw) {
 115                        struct bpf_map *map = link->map;
 116                        struct bpf_stab *stab = container_of(map, struct bpf_stab,
 117                                                             map);
 118                        if (psock->parser.enabled && stab->progs.skb_parser)
 119                                strp_stop = true;
 120                        list_del(&link->list);
 121                        sk_psock_free_link(link);
 122                }
 123        }
 124        spin_unlock_bh(&psock->link_lock);
 125        if (strp_stop) {
 126                write_lock_bh(&sk->sk_callback_lock);
 127                sk_psock_stop_strp(sk, psock);
 128                write_unlock_bh(&sk->sk_callback_lock);
 129        }
 130}
 131
 132static void sock_map_unref(struct sock *sk, void *link_raw)
 133{
 134        struct sk_psock *psock = sk_psock(sk);
 135
 136        if (likely(psock)) {
 137                sock_map_del_link(sk, psock, link_raw);
 138                sk_psock_put(sk, psock);
 139        }
 140}
 141
 142static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
 143                         struct sock *sk)
 144{
 145        struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
 146        bool skb_progs, sk_psock_is_new = false;
 147        struct sk_psock *psock;
 148        int ret;
 149
 150        skb_verdict = READ_ONCE(progs->skb_verdict);
 151        skb_parser = READ_ONCE(progs->skb_parser);
 152        skb_progs = skb_parser && skb_verdict;
 153        if (skb_progs) {
 154                skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
 155                if (IS_ERR(skb_verdict))
 156                        return PTR_ERR(skb_verdict);
 157                skb_parser = bpf_prog_inc_not_zero(skb_parser);
 158                if (IS_ERR(skb_parser)) {
 159                        bpf_prog_put(skb_verdict);
 160                        return PTR_ERR(skb_parser);
 161                }
 162        }
 163
 164        msg_parser = READ_ONCE(progs->msg_parser);
 165        if (msg_parser) {
 166                msg_parser = bpf_prog_inc_not_zero(msg_parser);
 167                if (IS_ERR(msg_parser)) {
 168                        ret = PTR_ERR(msg_parser);
 169                        goto out;
 170                }
 171        }
 172
 173        psock = sk_psock_get_checked(sk);
 174        if (IS_ERR(psock)) {
 175                ret = PTR_ERR(psock);
 176                goto out_progs;
 177        }
 178
 179        if (psock) {
 180                if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
 181                    (skb_progs  && READ_ONCE(psock->progs.skb_parser))) {
 182                        sk_psock_put(sk, psock);
 183                        ret = -EBUSY;
 184                        goto out_progs;
 185                }
 186        } else {
 187                psock = sk_psock_init(sk, map->numa_node);
 188                if (!psock) {
 189                        ret = -ENOMEM;
 190                        goto out_progs;
 191                }
 192                sk_psock_is_new = true;
 193        }
 194
 195        if (msg_parser)
 196                psock_set_prog(&psock->progs.msg_parser, msg_parser);
 197        if (sk_psock_is_new) {
 198                ret = tcp_bpf_init(sk);
 199                if (ret < 0)
 200                        goto out_drop;
 201        } else {
 202                tcp_bpf_reinit(sk);
 203        }
 204
 205        write_lock_bh(&sk->sk_callback_lock);
 206        if (skb_progs && !psock->parser.enabled) {
 207                ret = sk_psock_init_strp(sk, psock);
 208                if (ret) {
 209                        write_unlock_bh(&sk->sk_callback_lock);
 210                        goto out_drop;
 211                }
 212                psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
 213                psock_set_prog(&psock->progs.skb_parser, skb_parser);
 214                sk_psock_start_strp(sk, psock);
 215        }
 216        write_unlock_bh(&sk->sk_callback_lock);
 217        return 0;
 218out_drop:
 219        sk_psock_put(sk, psock);
 220out_progs:
 221        if (msg_parser)
 222                bpf_prog_put(msg_parser);
 223out:
 224        if (skb_progs) {
 225                bpf_prog_put(skb_verdict);
 226                bpf_prog_put(skb_parser);
 227        }
 228        return ret;
 229}
 230
 231static void sock_map_free(struct bpf_map *map)
 232{
 233        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 234        int i;
 235
 236        synchronize_rcu();
 237        rcu_read_lock();
 238        raw_spin_lock_bh(&stab->lock);
 239        for (i = 0; i < stab->map.max_entries; i++) {
 240                struct sock **psk = &stab->sks[i];
 241                struct sock *sk;
 242
 243                sk = xchg(psk, NULL);
 244                if (sk)
 245                        sock_map_unref(sk, psk);
 246        }
 247        raw_spin_unlock_bh(&stab->lock);
 248        rcu_read_unlock();
 249
 250        synchronize_rcu();
 251
 252        bpf_map_area_free(stab->sks);
 253        kfree(stab);
 254}
 255
 256static void sock_map_release_progs(struct bpf_map *map)
 257{
 258        psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
 259}
 260
 261static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
 262{
 263        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 264
 265        WARN_ON_ONCE(!rcu_read_lock_held());
 266
 267        if (unlikely(key >= map->max_entries))
 268                return NULL;
 269        return READ_ONCE(stab->sks[key]);
 270}
 271
 272static void *sock_map_lookup(struct bpf_map *map, void *key)
 273{
 274        return ERR_PTR(-EOPNOTSUPP);
 275}
 276
 277static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
 278                             struct sock **psk)
 279{
 280        struct sock *sk;
 281        int err = 0;
 282
 283        raw_spin_lock_bh(&stab->lock);
 284        sk = *psk;
 285        if (!sk_test || sk_test == sk)
 286                sk = xchg(psk, NULL);
 287
 288        if (likely(sk))
 289                sock_map_unref(sk, psk);
 290        else
 291                err = -EINVAL;
 292
 293        raw_spin_unlock_bh(&stab->lock);
 294        return err;
 295}
 296
 297static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
 298                                      void *link_raw)
 299{
 300        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 301
 302        __sock_map_delete(stab, sk, link_raw);
 303}
 304
 305static int sock_map_delete_elem(struct bpf_map *map, void *key)
 306{
 307        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 308        u32 i = *(u32 *)key;
 309        struct sock **psk;
 310
 311        if (unlikely(i >= map->max_entries))
 312                return -EINVAL;
 313
 314        psk = &stab->sks[i];
 315        return __sock_map_delete(stab, NULL, psk);
 316}
 317
 318static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
 319{
 320        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 321        u32 i = key ? *(u32 *)key : U32_MAX;
 322        u32 *key_next = next;
 323
 324        if (i == stab->map.max_entries - 1)
 325                return -ENOENT;
 326        if (i >= stab->map.max_entries)
 327                *key_next = 0;
 328        else
 329                *key_next = i + 1;
 330        return 0;
 331}
 332
 333static int sock_map_update_common(struct bpf_map *map, u32 idx,
 334                                  struct sock *sk, u64 flags)
 335{
 336        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
 337        struct inet_connection_sock *icsk = inet_csk(sk);
 338        struct sk_psock_link *link;
 339        struct sk_psock *psock;
 340        struct sock *osk;
 341        int ret;
 342
 343        WARN_ON_ONCE(!rcu_read_lock_held());
 344        if (unlikely(flags > BPF_EXIST))
 345                return -EINVAL;
 346        if (unlikely(idx >= map->max_entries))
 347                return -E2BIG;
 348        if (unlikely(icsk->icsk_ulp_data))
 349                return -EINVAL;
 350
 351        link = sk_psock_init_link();
 352        if (!link)
 353                return -ENOMEM;
 354
 355        ret = sock_map_link(map, &stab->progs, sk);
 356        if (ret < 0)
 357                goto out_free;
 358
 359        psock = sk_psock(sk);
 360        WARN_ON_ONCE(!psock);
 361
 362        raw_spin_lock_bh(&stab->lock);
 363        osk = stab->sks[idx];
 364        if (osk && flags == BPF_NOEXIST) {
 365                ret = -EEXIST;
 366                goto out_unlock;
 367        } else if (!osk && flags == BPF_EXIST) {
 368                ret = -ENOENT;
 369                goto out_unlock;
 370        }
 371
 372        sock_map_add_link(psock, link, map, &stab->sks[idx]);
 373        stab->sks[idx] = sk;
 374        if (osk)
 375                sock_map_unref(osk, &stab->sks[idx]);
 376        raw_spin_unlock_bh(&stab->lock);
 377        return 0;
 378out_unlock:
 379        raw_spin_unlock_bh(&stab->lock);
 380        if (psock)
 381                sk_psock_put(sk, psock);
 382out_free:
 383        sk_psock_free_link(link);
 384        return ret;
 385}
 386
 387static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
 388{
 389        return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
 390               ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
 391}
 392
 393static bool sock_map_sk_is_suitable(const struct sock *sk)
 394{
 395        return sk->sk_type == SOCK_STREAM &&
 396               sk->sk_protocol == IPPROTO_TCP;
 397}
 398
 399static int sock_map_update_elem(struct bpf_map *map, void *key,
 400                                void *value, u64 flags)
 401{
 402        u32 ufd = *(u32 *)value;
 403        u32 idx = *(u32 *)key;
 404        struct socket *sock;
 405        struct sock *sk;
 406        int ret;
 407
 408        sock = sockfd_lookup(ufd, &ret);
 409        if (!sock)
 410                return ret;
 411        sk = sock->sk;
 412        if (!sk) {
 413                ret = -EINVAL;
 414                goto out;
 415        }
 416        if (!sock_map_sk_is_suitable(sk) ||
 417            sk->sk_state != TCP_ESTABLISHED) {
 418                ret = -EOPNOTSUPP;
 419                goto out;
 420        }
 421
 422        sock_map_sk_acquire(sk);
 423        ret = sock_map_update_common(map, idx, sk, flags);
 424        sock_map_sk_release(sk);
 425out:
 426        fput(sock->file);
 427        return ret;
 428}
 429
 430BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
 431           struct bpf_map *, map, void *, key, u64, flags)
 432{
 433        WARN_ON_ONCE(!rcu_read_lock_held());
 434
 435        if (likely(sock_map_sk_is_suitable(sops->sk) &&
 436                   sock_map_op_okay(sops)))
 437                return sock_map_update_common(map, *(u32 *)key, sops->sk,
 438                                              flags);
 439        return -EOPNOTSUPP;
 440}
 441
 442const struct bpf_func_proto bpf_sock_map_update_proto = {
 443        .func           = bpf_sock_map_update,
 444        .gpl_only       = false,
 445        .pkt_access     = true,
 446        .ret_type       = RET_INTEGER,
 447        .arg1_type      = ARG_PTR_TO_CTX,
 448        .arg2_type      = ARG_CONST_MAP_PTR,
 449        .arg3_type      = ARG_PTR_TO_MAP_KEY,
 450        .arg4_type      = ARG_ANYTHING,
 451};
 452
 453BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
 454           struct bpf_map *, map, u32, key, u64, flags)
 455{
 456        struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
 457
 458        if (unlikely(flags & ~(BPF_F_INGRESS)))
 459                return SK_DROP;
 460        tcb->bpf.flags = flags;
 461        tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
 462        if (!tcb->bpf.sk_redir)
 463                return SK_DROP;
 464        return SK_PASS;
 465}
 466
 467const struct bpf_func_proto bpf_sk_redirect_map_proto = {
 468        .func           = bpf_sk_redirect_map,
 469        .gpl_only       = false,
 470        .ret_type       = RET_INTEGER,
 471        .arg1_type      = ARG_PTR_TO_CTX,
 472        .arg2_type      = ARG_CONST_MAP_PTR,
 473        .arg3_type      = ARG_ANYTHING,
 474        .arg4_type      = ARG_ANYTHING,
 475};
 476
 477BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
 478           struct bpf_map *, map, u32, key, u64, flags)
 479{
 480        if (unlikely(flags & ~(BPF_F_INGRESS)))
 481                return SK_DROP;
 482        msg->flags = flags;
 483        msg->sk_redir = __sock_map_lookup_elem(map, key);
 484        if (!msg->sk_redir)
 485                return SK_DROP;
 486        return SK_PASS;
 487}
 488
 489const struct bpf_func_proto bpf_msg_redirect_map_proto = {
 490        .func           = bpf_msg_redirect_map,
 491        .gpl_only       = false,
 492        .ret_type       = RET_INTEGER,
 493        .arg1_type      = ARG_PTR_TO_CTX,
 494        .arg2_type      = ARG_CONST_MAP_PTR,
 495        .arg3_type      = ARG_ANYTHING,
 496        .arg4_type      = ARG_ANYTHING,
 497};
 498
 499const struct bpf_map_ops sock_map_ops = {
 500        .map_alloc              = sock_map_alloc,
 501        .map_free               = sock_map_free,
 502        .map_get_next_key       = sock_map_get_next_key,
 503        .map_update_elem        = sock_map_update_elem,
 504        .map_delete_elem        = sock_map_delete_elem,
 505        .map_lookup_elem        = sock_map_lookup,
 506        .map_release_uref       = sock_map_release_progs,
 507        .map_check_btf          = map_check_no_btf,
 508};
 509
 510struct bpf_htab_elem {
 511        struct rcu_head rcu;
 512        u32 hash;
 513        struct sock *sk;
 514        struct hlist_node node;
 515        u8 key[0];
 516};
 517
 518struct bpf_htab_bucket {
 519        struct hlist_head head;
 520        raw_spinlock_t lock;
 521};
 522
 523struct bpf_htab {
 524        struct bpf_map map;
 525        struct bpf_htab_bucket *buckets;
 526        u32 buckets_num;
 527        u32 elem_size;
 528        struct sk_psock_progs progs;
 529        atomic_t count;
 530};
 531
 532static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
 533{
 534        return jhash(key, len, 0);
 535}
 536
 537static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab,
 538                                                       u32 hash)
 539{
 540        return &htab->buckets[hash & (htab->buckets_num - 1)];
 541}
 542
 543static struct bpf_htab_elem *
 544sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
 545                          u32 key_size)
 546{
 547        struct bpf_htab_elem *elem;
 548
 549        hlist_for_each_entry_rcu(elem, head, node) {
 550                if (elem->hash == hash &&
 551                    !memcmp(&elem->key, key, key_size))
 552                        return elem;
 553        }
 554
 555        return NULL;
 556}
 557
 558static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
 559{
 560        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 561        u32 key_size = map->key_size, hash;
 562        struct bpf_htab_bucket *bucket;
 563        struct bpf_htab_elem *elem;
 564
 565        WARN_ON_ONCE(!rcu_read_lock_held());
 566
 567        hash = sock_hash_bucket_hash(key, key_size);
 568        bucket = sock_hash_select_bucket(htab, hash);
 569        elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
 570
 571        return elem ? elem->sk : NULL;
 572}
 573
 574static void sock_hash_free_elem(struct bpf_htab *htab,
 575                                struct bpf_htab_elem *elem)
 576{
 577        atomic_dec(&htab->count);
 578        kfree_rcu(elem, rcu);
 579}
 580
 581static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
 582                                       void *link_raw)
 583{
 584        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 585        struct bpf_htab_elem *elem_probe, *elem = link_raw;
 586        struct bpf_htab_bucket *bucket;
 587
 588        WARN_ON_ONCE(!rcu_read_lock_held());
 589        bucket = sock_hash_select_bucket(htab, elem->hash);
 590
 591        /* elem may be deleted in parallel from the map, but access here
 592         * is okay since it's going away only after RCU grace period.
 593         * However, we need to check whether it's still present.
 594         */
 595        raw_spin_lock_bh(&bucket->lock);
 596        elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
 597                                               elem->key, map->key_size);
 598        if (elem_probe && elem_probe == elem) {
 599                hlist_del_rcu(&elem->node);
 600                sock_map_unref(elem->sk, elem);
 601                sock_hash_free_elem(htab, elem);
 602        }
 603        raw_spin_unlock_bh(&bucket->lock);
 604}
 605
 606static int sock_hash_delete_elem(struct bpf_map *map, void *key)
 607{
 608        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 609        u32 hash, key_size = map->key_size;
 610        struct bpf_htab_bucket *bucket;
 611        struct bpf_htab_elem *elem;
 612        int ret = -ENOENT;
 613
 614        hash = sock_hash_bucket_hash(key, key_size);
 615        bucket = sock_hash_select_bucket(htab, hash);
 616
 617        raw_spin_lock_bh(&bucket->lock);
 618        elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
 619        if (elem) {
 620                hlist_del_rcu(&elem->node);
 621                sock_map_unref(elem->sk, elem);
 622                sock_hash_free_elem(htab, elem);
 623                ret = 0;
 624        }
 625        raw_spin_unlock_bh(&bucket->lock);
 626        return ret;
 627}
 628
 629static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab,
 630                                                  void *key, u32 key_size,
 631                                                  u32 hash, struct sock *sk,
 632                                                  struct bpf_htab_elem *old)
 633{
 634        struct bpf_htab_elem *new;
 635
 636        if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
 637                if (!old) {
 638                        atomic_dec(&htab->count);
 639                        return ERR_PTR(-E2BIG);
 640                }
 641        }
 642
 643        new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
 644                           htab->map.numa_node);
 645        if (!new) {
 646                atomic_dec(&htab->count);
 647                return ERR_PTR(-ENOMEM);
 648        }
 649        memcpy(new->key, key, key_size);
 650        new->sk = sk;
 651        new->hash = hash;
 652        return new;
 653}
 654
 655static int sock_hash_update_common(struct bpf_map *map, void *key,
 656                                   struct sock *sk, u64 flags)
 657{
 658        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 659        struct inet_connection_sock *icsk = inet_csk(sk);
 660        u32 key_size = map->key_size, hash;
 661        struct bpf_htab_elem *elem, *elem_new;
 662        struct bpf_htab_bucket *bucket;
 663        struct sk_psock_link *link;
 664        struct sk_psock *psock;
 665        int ret;
 666
 667        WARN_ON_ONCE(!rcu_read_lock_held());
 668        if (unlikely(flags > BPF_EXIST))
 669                return -EINVAL;
 670        if (unlikely(icsk->icsk_ulp_data))
 671                return -EINVAL;
 672
 673        link = sk_psock_init_link();
 674        if (!link)
 675                return -ENOMEM;
 676
 677        ret = sock_map_link(map, &htab->progs, sk);
 678        if (ret < 0)
 679                goto out_free;
 680
 681        psock = sk_psock(sk);
 682        WARN_ON_ONCE(!psock);
 683
 684        hash = sock_hash_bucket_hash(key, key_size);
 685        bucket = sock_hash_select_bucket(htab, hash);
 686
 687        raw_spin_lock_bh(&bucket->lock);
 688        elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
 689        if (elem && flags == BPF_NOEXIST) {
 690                ret = -EEXIST;
 691                goto out_unlock;
 692        } else if (!elem && flags == BPF_EXIST) {
 693                ret = -ENOENT;
 694                goto out_unlock;
 695        }
 696
 697        elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
 698        if (IS_ERR(elem_new)) {
 699                ret = PTR_ERR(elem_new);
 700                goto out_unlock;
 701        }
 702
 703        sock_map_add_link(psock, link, map, elem_new);
 704        /* Add new element to the head of the list, so that
 705         * concurrent search will find it before old elem.
 706         */
 707        hlist_add_head_rcu(&elem_new->node, &bucket->head);
 708        if (elem) {
 709                hlist_del_rcu(&elem->node);
 710                sock_map_unref(elem->sk, elem);
 711                sock_hash_free_elem(htab, elem);
 712        }
 713        raw_spin_unlock_bh(&bucket->lock);
 714        return 0;
 715out_unlock:
 716        raw_spin_unlock_bh(&bucket->lock);
 717        sk_psock_put(sk, psock);
 718out_free:
 719        sk_psock_free_link(link);
 720        return ret;
 721}
 722
 723static int sock_hash_update_elem(struct bpf_map *map, void *key,
 724                                 void *value, u64 flags)
 725{
 726        u32 ufd = *(u32 *)value;
 727        struct socket *sock;
 728        struct sock *sk;
 729        int ret;
 730
 731        sock = sockfd_lookup(ufd, &ret);
 732        if (!sock)
 733                return ret;
 734        sk = sock->sk;
 735        if (!sk) {
 736                ret = -EINVAL;
 737                goto out;
 738        }
 739        if (!sock_map_sk_is_suitable(sk) ||
 740            sk->sk_state != TCP_ESTABLISHED) {
 741                ret = -EOPNOTSUPP;
 742                goto out;
 743        }
 744
 745        sock_map_sk_acquire(sk);
 746        ret = sock_hash_update_common(map, key, sk, flags);
 747        sock_map_sk_release(sk);
 748out:
 749        fput(sock->file);
 750        return ret;
 751}
 752
 753static int sock_hash_get_next_key(struct bpf_map *map, void *key,
 754                                  void *key_next)
 755{
 756        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 757        struct bpf_htab_elem *elem, *elem_next;
 758        u32 hash, key_size = map->key_size;
 759        struct hlist_head *head;
 760        int i = 0;
 761
 762        if (!key)
 763                goto find_first_elem;
 764        hash = sock_hash_bucket_hash(key, key_size);
 765        head = &sock_hash_select_bucket(htab, hash)->head;
 766        elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
 767        if (!elem)
 768                goto find_first_elem;
 769
 770        elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)),
 771                                     struct bpf_htab_elem, node);
 772        if (elem_next) {
 773                memcpy(key_next, elem_next->key, key_size);
 774                return 0;
 775        }
 776
 777        i = hash & (htab->buckets_num - 1);
 778        i++;
 779find_first_elem:
 780        for (; i < htab->buckets_num; i++) {
 781                head = &sock_hash_select_bucket(htab, i)->head;
 782                elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)),
 783                                             struct bpf_htab_elem, node);
 784                if (elem_next) {
 785                        memcpy(key_next, elem_next->key, key_size);
 786                        return 0;
 787                }
 788        }
 789
 790        return -ENOENT;
 791}
 792
 793static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
 794{
 795        struct bpf_htab *htab;
 796        int i, err;
 797        u64 cost;
 798
 799        if (!capable(CAP_NET_ADMIN))
 800                return ERR_PTR(-EPERM);
 801        if (attr->max_entries == 0 ||
 802            attr->key_size    == 0 ||
 803            attr->value_size  != 4 ||
 804            attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
 805                return ERR_PTR(-EINVAL);
 806        if (attr->key_size > MAX_BPF_STACK)
 807                return ERR_PTR(-E2BIG);
 808
 809        htab = kzalloc(sizeof(*htab), GFP_USER);
 810        if (!htab)
 811                return ERR_PTR(-ENOMEM);
 812
 813        bpf_map_init_from_attr(&htab->map, attr);
 814
 815        htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
 816        htab->elem_size = sizeof(struct bpf_htab_elem) +
 817                          round_up(htab->map.key_size, 8);
 818        if (htab->buckets_num == 0 ||
 819            htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) {
 820                err = -EINVAL;
 821                goto free_htab;
 822        }
 823
 824        cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) +
 825               (u64) htab->elem_size * htab->map.max_entries;
 826        if (cost >= U32_MAX - PAGE_SIZE) {
 827                err = -EINVAL;
 828                goto free_htab;
 829        }
 830
 831        htab->buckets = bpf_map_area_alloc(htab->buckets_num *
 832                                           sizeof(struct bpf_htab_bucket),
 833                                           htab->map.numa_node);
 834        if (!htab->buckets) {
 835                err = -ENOMEM;
 836                goto free_htab;
 837        }
 838
 839        for (i = 0; i < htab->buckets_num; i++) {
 840                INIT_HLIST_HEAD(&htab->buckets[i].head);
 841                raw_spin_lock_init(&htab->buckets[i].lock);
 842        }
 843
 844        return &htab->map;
 845free_htab:
 846        kfree(htab);
 847        return ERR_PTR(err);
 848}
 849
 850static void sock_hash_free(struct bpf_map *map)
 851{
 852        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
 853        struct bpf_htab_bucket *bucket;
 854        struct bpf_htab_elem *elem;
 855        struct hlist_node *node;
 856        int i;
 857
 858        synchronize_rcu();
 859        rcu_read_lock();
 860        for (i = 0; i < htab->buckets_num; i++) {
 861                bucket = sock_hash_select_bucket(htab, i);
 862                raw_spin_lock_bh(&bucket->lock);
 863                hlist_for_each_entry_safe(elem, node, &bucket->head, node) {
 864                        hlist_del_rcu(&elem->node);
 865                        sock_map_unref(elem->sk, elem);
 866                }
 867                raw_spin_unlock_bh(&bucket->lock);
 868        }
 869        rcu_read_unlock();
 870
 871        bpf_map_area_free(htab->buckets);
 872        kfree(htab);
 873}
 874
 875static void sock_hash_release_progs(struct bpf_map *map)
 876{
 877        psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs);
 878}
 879
 880BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
 881           struct bpf_map *, map, void *, key, u64, flags)
 882{
 883        WARN_ON_ONCE(!rcu_read_lock_held());
 884
 885        if (likely(sock_map_sk_is_suitable(sops->sk) &&
 886                   sock_map_op_okay(sops)))
 887                return sock_hash_update_common(map, key, sops->sk, flags);
 888        return -EOPNOTSUPP;
 889}
 890
 891const struct bpf_func_proto bpf_sock_hash_update_proto = {
 892        .func           = bpf_sock_hash_update,
 893        .gpl_only       = false,
 894        .pkt_access     = true,
 895        .ret_type       = RET_INTEGER,
 896        .arg1_type      = ARG_PTR_TO_CTX,
 897        .arg2_type      = ARG_CONST_MAP_PTR,
 898        .arg3_type      = ARG_PTR_TO_MAP_KEY,
 899        .arg4_type      = ARG_ANYTHING,
 900};
 901
 902BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
 903           struct bpf_map *, map, void *, key, u64, flags)
 904{
 905        struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
 906
 907        if (unlikely(flags & ~(BPF_F_INGRESS)))
 908                return SK_DROP;
 909        tcb->bpf.flags = flags;
 910        tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
 911        if (!tcb->bpf.sk_redir)
 912                return SK_DROP;
 913        return SK_PASS;
 914}
 915
 916const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
 917        .func           = bpf_sk_redirect_hash,
 918        .gpl_only       = false,
 919        .ret_type       = RET_INTEGER,
 920        .arg1_type      = ARG_PTR_TO_CTX,
 921        .arg2_type      = ARG_CONST_MAP_PTR,
 922        .arg3_type      = ARG_PTR_TO_MAP_KEY,
 923        .arg4_type      = ARG_ANYTHING,
 924};
 925
 926BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
 927           struct bpf_map *, map, void *, key, u64, flags)
 928{
 929        if (unlikely(flags & ~(BPF_F_INGRESS)))
 930                return SK_DROP;
 931        msg->flags = flags;
 932        msg->sk_redir = __sock_hash_lookup_elem(map, key);
 933        if (!msg->sk_redir)
 934                return SK_DROP;
 935        return SK_PASS;
 936}
 937
 938const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
 939        .func           = bpf_msg_redirect_hash,
 940        .gpl_only       = false,
 941        .ret_type       = RET_INTEGER,
 942        .arg1_type      = ARG_PTR_TO_CTX,
 943        .arg2_type      = ARG_CONST_MAP_PTR,
 944        .arg3_type      = ARG_PTR_TO_MAP_KEY,
 945        .arg4_type      = ARG_ANYTHING,
 946};
 947
 948const struct bpf_map_ops sock_hash_ops = {
 949        .map_alloc              = sock_hash_alloc,
 950        .map_free               = sock_hash_free,
 951        .map_get_next_key       = sock_hash_get_next_key,
 952        .map_update_elem        = sock_hash_update_elem,
 953        .map_delete_elem        = sock_hash_delete_elem,
 954        .map_lookup_elem        = sock_map_lookup,
 955        .map_release_uref       = sock_hash_release_progs,
 956        .map_check_btf          = map_check_no_btf,
 957};
 958
 959static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
 960{
 961        switch (map->map_type) {
 962        case BPF_MAP_TYPE_SOCKMAP:
 963                return &container_of(map, struct bpf_stab, map)->progs;
 964        case BPF_MAP_TYPE_SOCKHASH:
 965                return &container_of(map, struct bpf_htab, map)->progs;
 966        default:
 967                break;
 968        }
 969
 970        return NULL;
 971}
 972
 973int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
 974                         u32 which)
 975{
 976        struct sk_psock_progs *progs = sock_map_progs(map);
 977
 978        if (!progs)
 979                return -EOPNOTSUPP;
 980
 981        switch (which) {
 982        case BPF_SK_MSG_VERDICT:
 983                psock_set_prog(&progs->msg_parser, prog);
 984                break;
 985        case BPF_SK_SKB_STREAM_PARSER:
 986                psock_set_prog(&progs->skb_parser, prog);
 987                break;
 988        case BPF_SK_SKB_STREAM_VERDICT:
 989                psock_set_prog(&progs->skb_verdict, prog);
 990                break;
 991        default:
 992                return -EOPNOTSUPP;
 993        }
 994
 995        return 0;
 996}
 997
 998void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
 999{
1000        switch (link->map->map_type) {
1001        case BPF_MAP_TYPE_SOCKMAP:
1002                return sock_map_delete_from_link(link->map, sk,
1003                                                 link->link_raw);
1004        case BPF_MAP_TYPE_SOCKHASH:
1005                return sock_hash_delete_from_link(link->map, sk,
1006                                                  link->link_raw);
1007        default:
1008                break;
1009        }
1010}
1011