linux/kernel/bpf/sockmap.c
<<
>>
Prefs
   1/* Copyright (c) 2017 Covalent IO, Inc. http://covalent.io
   2 *
   3 * This program is free software; you can redistribute it and/or
   4 * modify it under the terms of version 2 of the GNU General Public
   5 * License as published by the Free Software Foundation.
   6 *
   7 * This program is distributed in the hope that it will be useful, but
   8 * WITHOUT ANY WARRANTY; without even the implied warranty of
   9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  10 * General Public License for more details.
  11 */
  12
  13/* A BPF sock_map is used to store sock objects. This is primarly used
  14 * for doing socket redirect with BPF helper routines.
  15 *
  16 * A sock map may have BPF programs attached to it, currently a program
  17 * used to parse packets and a program to provide a verdict and redirect
  18 * decision on the packet are supported. Any programs attached to a sock
  19 * map are inherited by sock objects when they are added to the map. If
  20 * no BPF programs are attached the sock object may only be used for sock
  21 * redirect.
  22 *
  23 * A sock object may be in multiple maps, but can only inherit a single
  24 * parse or verdict program. If adding a sock object to a map would result
  25 * in having multiple parsing programs the update will return an EBUSY error.
  26 *
  27 * For reference this program is similar to devmap used in XDP context
  28 * reviewing these together may be useful. For an example please review
  29 * ./samples/bpf/sockmap/.
  30 */
  31#include <linux/bpf.h>
  32#include <net/sock.h>
  33#include <linux/filter.h>
  34#include <linux/errno.h>
  35#include <linux/file.h>
  36#include <linux/kernel.h>
  37#include <linux/net.h>
  38#include <linux/skbuff.h>
  39#include <linux/workqueue.h>
  40#include <linux/list.h>
  41#include <linux/mm.h>
  42#include <net/strparser.h>
  43#include <net/tcp.h>
  44#include <linux/ptr_ring.h>
  45#include <net/inet_common.h>
  46#include <linux/sched/signal.h>
  47
  48#define SOCK_CREATE_FLAG_MASK \
  49        (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
  50
  51struct bpf_sock_progs {
  52        struct bpf_prog *bpf_tx_msg;
  53        struct bpf_prog *bpf_parse;
  54        struct bpf_prog *bpf_verdict;
  55};
  56
  57struct bpf_stab {
  58        struct bpf_map map;
  59        struct sock **sock_map;
  60        struct bpf_sock_progs progs;
  61        raw_spinlock_t lock;
  62};
  63
  64struct bucket {
  65        struct hlist_head head;
  66        raw_spinlock_t lock;
  67};
  68
  69struct bpf_htab {
  70        struct bpf_map map;
  71        struct bucket *buckets;
  72        atomic_t count;
  73        u32 n_buckets;
  74        u32 elem_size;
  75        struct bpf_sock_progs progs;
  76        struct rcu_head rcu;
  77};
  78
  79struct htab_elem {
  80        struct rcu_head rcu;
  81        struct hlist_node hash_node;
  82        u32 hash;
  83        struct sock *sk;
  84        char key[0];
  85};
  86
  87enum smap_psock_state {
  88        SMAP_TX_RUNNING,
  89};
  90
  91struct smap_psock_map_entry {
  92        struct list_head list;
  93        struct bpf_map *map;
  94        struct sock **entry;
  95        struct htab_elem __rcu *hash_link;
  96};
  97
  98struct smap_psock {
  99        struct rcu_head rcu;
 100        refcount_t refcnt;
 101
 102        /* datapath variables */
 103        struct sk_buff_head rxqueue;
 104        bool strp_enabled;
 105
 106        /* datapath error path cache across tx work invocations */
 107        int save_rem;
 108        int save_off;
 109        struct sk_buff *save_skb;
 110
 111        /* datapath variables for tx_msg ULP */
 112        struct sock *sk_redir;
 113        int apply_bytes;
 114        int cork_bytes;
 115        int sg_size;
 116        int eval;
 117        struct sk_msg_buff *cork;
 118        struct list_head ingress;
 119
 120        struct strparser strp;
 121        struct bpf_prog *bpf_tx_msg;
 122        struct bpf_prog *bpf_parse;
 123        struct bpf_prog *bpf_verdict;
 124        struct list_head maps;
 125        spinlock_t maps_lock;
 126
 127        /* Back reference used when sock callback trigger sockmap operations */
 128        struct sock *sock;
 129        unsigned long state;
 130
 131        struct work_struct tx_work;
 132        struct work_struct gc_work;
 133
 134        struct proto *sk_proto;
 135        void (*save_unhash)(struct sock *sk);
 136        void (*save_close)(struct sock *sk, long timeout);
 137        void (*save_data_ready)(struct sock *sk);
 138        void (*save_write_space)(struct sock *sk);
 139};
 140
 141static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
 142static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 143                           int nonblock, int flags, int *addr_len);
 144static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
 145static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
 146                            int offset, size_t size, int flags);
 147static void bpf_tcp_unhash(struct sock *sk);
 148static void bpf_tcp_close(struct sock *sk, long timeout);
 149
 150static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
 151{
 152        return rcu_dereference_sk_user_data(sk);
 153}
 154
 155static bool bpf_tcp_stream_read(const struct sock *sk)
 156{
 157        struct smap_psock *psock;
 158        bool empty = true;
 159
 160        rcu_read_lock();
 161        psock = smap_psock_sk(sk);
 162        if (unlikely(!psock))
 163                goto out;
 164        empty = list_empty(&psock->ingress);
 165out:
 166        rcu_read_unlock();
 167        return !empty;
 168}
 169
 170enum {
 171        SOCKMAP_IPV4,
 172        SOCKMAP_IPV6,
 173        SOCKMAP_NUM_PROTS,
 174};
 175
 176enum {
 177        SOCKMAP_BASE,
 178        SOCKMAP_TX,
 179        SOCKMAP_NUM_CONFIGS,
 180};
 181
 182static struct proto *saved_tcpv6_prot __read_mostly;
 183static DEFINE_SPINLOCK(tcpv6_prot_lock);
 184static struct proto bpf_tcp_prots[SOCKMAP_NUM_PROTS][SOCKMAP_NUM_CONFIGS];
 185static void build_protos(struct proto prot[SOCKMAP_NUM_CONFIGS],
 186                         struct proto *base)
 187{
 188        prot[SOCKMAP_BASE]                      = *base;
 189        prot[SOCKMAP_BASE].unhash               = bpf_tcp_unhash;
 190        prot[SOCKMAP_BASE].close                = bpf_tcp_close;
 191        prot[SOCKMAP_BASE].recvmsg              = bpf_tcp_recvmsg;
 192        prot[SOCKMAP_BASE].stream_memory_read   = bpf_tcp_stream_read;
 193
 194        prot[SOCKMAP_TX]                        = prot[SOCKMAP_BASE];
 195        prot[SOCKMAP_TX].sendmsg                = bpf_tcp_sendmsg;
 196        prot[SOCKMAP_TX].sendpage               = bpf_tcp_sendpage;
 197}
 198
 199static void update_sk_prot(struct sock *sk, struct smap_psock *psock)
 200{
 201        int family = sk->sk_family == AF_INET6 ? SOCKMAP_IPV6 : SOCKMAP_IPV4;
 202        int conf = psock->bpf_tx_msg ? SOCKMAP_TX : SOCKMAP_BASE;
 203
 204        sk->sk_prot = &bpf_tcp_prots[family][conf];
 205}
 206
 207static int bpf_tcp_init(struct sock *sk)
 208{
 209        struct smap_psock *psock;
 210
 211        rcu_read_lock();
 212        psock = smap_psock_sk(sk);
 213        if (unlikely(!psock)) {
 214                rcu_read_unlock();
 215                return -EINVAL;
 216        }
 217
 218        if (unlikely(psock->sk_proto)) {
 219                rcu_read_unlock();
 220                return -EBUSY;
 221        }
 222
 223        psock->save_unhash = sk->sk_prot->unhash;
 224        psock->save_close = sk->sk_prot->close;
 225        psock->sk_proto = sk->sk_prot;
 226
 227        /* Build IPv6 sockmap whenever the address of tcpv6_prot changes */
 228        if (sk->sk_family == AF_INET6 &&
 229            unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
 230                spin_lock_bh(&tcpv6_prot_lock);
 231                if (likely(sk->sk_prot != saved_tcpv6_prot)) {
 232                        build_protos(bpf_tcp_prots[SOCKMAP_IPV6], sk->sk_prot);
 233                        smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
 234                }
 235                spin_unlock_bh(&tcpv6_prot_lock);
 236        }
 237        update_sk_prot(sk, psock);
 238        rcu_read_unlock();
 239        return 0;
 240}
 241
 242static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
 243static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge);
 244
 245static void bpf_tcp_release(struct sock *sk)
 246{
 247        struct smap_psock *psock;
 248
 249        rcu_read_lock();
 250        psock = smap_psock_sk(sk);
 251        if (unlikely(!psock))
 252                goto out;
 253
 254        if (psock->cork) {
 255                free_start_sg(psock->sock, psock->cork, true);
 256                kfree(psock->cork);
 257                psock->cork = NULL;
 258        }
 259
 260        if (psock->sk_proto) {
 261                sk->sk_prot = psock->sk_proto;
 262                psock->sk_proto = NULL;
 263        }
 264out:
 265        rcu_read_unlock();
 266}
 267
 268static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
 269                                         u32 hash, void *key, u32 key_size)
 270{
 271        struct htab_elem *l;
 272
 273        hlist_for_each_entry_rcu(l, head, hash_node) {
 274                if (l->hash == hash && !memcmp(&l->key, key, key_size))
 275                        return l;
 276        }
 277
 278        return NULL;
 279}
 280
 281static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
 282{
 283        return &htab->buckets[hash & (htab->n_buckets - 1)];
 284}
 285
 286static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
 287{
 288        return &__select_bucket(htab, hash)->head;
 289}
 290
 291static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l)
 292{
 293        atomic_dec(&htab->count);
 294        kfree_rcu(l, rcu);
 295}
 296
 297static struct smap_psock_map_entry *psock_map_pop(struct sock *sk,
 298                                                  struct smap_psock *psock)
 299{
 300        struct smap_psock_map_entry *e;
 301
 302        spin_lock_bh(&psock->maps_lock);
 303        e = list_first_entry_or_null(&psock->maps,
 304                                     struct smap_psock_map_entry,
 305                                     list);
 306        if (e)
 307                list_del(&e->list);
 308        spin_unlock_bh(&psock->maps_lock);
 309        return e;
 310}
 311
 312static void bpf_tcp_remove(struct sock *sk, struct smap_psock *psock)
 313{
 314        struct smap_psock_map_entry *e;
 315        struct sk_msg_buff *md, *mtmp;
 316        struct sock *osk;
 317
 318        if (psock->cork) {
 319                free_start_sg(psock->sock, psock->cork, true);
 320                kfree(psock->cork);
 321                psock->cork = NULL;
 322        }
 323
 324        list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
 325                list_del(&md->list);
 326                free_start_sg(psock->sock, md, true);
 327                kfree(md);
 328        }
 329
 330        e = psock_map_pop(sk, psock);
 331        while (e) {
 332                if (e->entry) {
 333                        struct bpf_stab *stab = container_of(e->map, struct bpf_stab, map);
 334
 335                        raw_spin_lock_bh(&stab->lock);
 336                        osk = *e->entry;
 337                        if (osk == sk) {
 338                                *e->entry = NULL;
 339                                smap_release_sock(psock, sk);
 340                        }
 341                        raw_spin_unlock_bh(&stab->lock);
 342                } else {
 343                        struct htab_elem *link = rcu_dereference(e->hash_link);
 344                        struct bpf_htab *htab = container_of(e->map, struct bpf_htab, map);
 345                        struct hlist_head *head;
 346                        struct htab_elem *l;
 347                        struct bucket *b;
 348
 349                        b = __select_bucket(htab, link->hash);
 350                        head = &b->head;
 351                        raw_spin_lock_bh(&b->lock);
 352                        l = lookup_elem_raw(head,
 353                                            link->hash, link->key,
 354                                            htab->map.key_size);
 355                        /* If another thread deleted this object skip deletion.
 356                         * The refcnt on psock may or may not be zero.
 357                         */
 358                        if (l && l == link) {
 359                                hlist_del_rcu(&link->hash_node);
 360                                smap_release_sock(psock, link->sk);
 361                                free_htab_elem(htab, link);
 362                        }
 363                        raw_spin_unlock_bh(&b->lock);
 364                }
 365                kfree(e);
 366                e = psock_map_pop(sk, psock);
 367        }
 368}
 369
 370static void bpf_tcp_unhash(struct sock *sk)
 371{
 372        void (*unhash_fun)(struct sock *sk);
 373        struct smap_psock *psock;
 374
 375        rcu_read_lock();
 376        psock = smap_psock_sk(sk);
 377        if (unlikely(!psock)) {
 378                rcu_read_unlock();
 379                if (sk->sk_prot->unhash)
 380                        sk->sk_prot->unhash(sk);
 381                return;
 382        }
 383        unhash_fun = psock->save_unhash;
 384        bpf_tcp_remove(sk, psock);
 385        rcu_read_unlock();
 386        unhash_fun(sk);
 387}
 388
 389static void bpf_tcp_close(struct sock *sk, long timeout)
 390{
 391        void (*close_fun)(struct sock *sk, long timeout);
 392        struct smap_psock *psock;
 393
 394        lock_sock(sk);
 395        rcu_read_lock();
 396        psock = smap_psock_sk(sk);
 397        if (unlikely(!psock)) {
 398                rcu_read_unlock();
 399                release_sock(sk);
 400                return sk->sk_prot->close(sk, timeout);
 401        }
 402        close_fun = psock->save_close;
 403        bpf_tcp_remove(sk, psock);
 404        rcu_read_unlock();
 405        release_sock(sk);
 406        close_fun(sk, timeout);
 407}
 408
 409enum __sk_action {
 410        __SK_DROP = 0,
 411        __SK_PASS,
 412        __SK_REDIRECT,
 413        __SK_NONE,
 414};
 415
 416static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
 417        .name           = "bpf_tcp",
 418        .uid            = TCP_ULP_BPF,
 419        .user_visible   = false,
 420        .owner          = NULL,
 421        .init           = bpf_tcp_init,
 422        .release        = bpf_tcp_release,
 423};
 424
 425static int memcopy_from_iter(struct sock *sk,
 426                             struct sk_msg_buff *md,
 427                             struct iov_iter *from, int bytes)
 428{
 429        struct scatterlist *sg = md->sg_data;
 430        int i = md->sg_curr, rc = -ENOSPC;
 431
 432        do {
 433                int copy;
 434                char *to;
 435
 436                if (md->sg_copybreak >= sg[i].length) {
 437                        md->sg_copybreak = 0;
 438
 439                        if (++i == MAX_SKB_FRAGS)
 440                                i = 0;
 441
 442                        if (i == md->sg_end)
 443                                break;
 444                }
 445
 446                copy = sg[i].length - md->sg_copybreak;
 447                to = sg_virt(&sg[i]) + md->sg_copybreak;
 448                md->sg_copybreak += copy;
 449
 450                if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
 451                        rc = copy_from_iter_nocache(to, copy, from);
 452                else
 453                        rc = copy_from_iter(to, copy, from);
 454
 455                if (rc != copy) {
 456                        rc = -EFAULT;
 457                        goto out;
 458                }
 459
 460                bytes -= copy;
 461                if (!bytes)
 462                        break;
 463
 464                md->sg_copybreak = 0;
 465                if (++i == MAX_SKB_FRAGS)
 466                        i = 0;
 467        } while (i != md->sg_end);
 468out:
 469        md->sg_curr = i;
 470        return rc;
 471}
 472
 473static int bpf_tcp_push(struct sock *sk, int apply_bytes,
 474                        struct sk_msg_buff *md,
 475                        int flags, bool uncharge)
 476{
 477        bool apply = apply_bytes;
 478        struct scatterlist *sg;
 479        int offset, ret = 0;
 480        struct page *p;
 481        size_t size;
 482
 483        while (1) {
 484                sg = md->sg_data + md->sg_start;
 485                size = (apply && apply_bytes < sg->length) ?
 486                        apply_bytes : sg->length;
 487                offset = sg->offset;
 488
 489                tcp_rate_check_app_limited(sk);
 490                p = sg_page(sg);
 491retry:
 492                ret = do_tcp_sendpages(sk, p, offset, size, flags);
 493                if (ret != size) {
 494                        if (ret > 0) {
 495                                if (apply)
 496                                        apply_bytes -= ret;
 497
 498                                sg->offset += ret;
 499                                sg->length -= ret;
 500                                size -= ret;
 501                                offset += ret;
 502                                if (uncharge)
 503                                        sk_mem_uncharge(sk, ret);
 504                                goto retry;
 505                        }
 506
 507                        return ret;
 508                }
 509
 510                if (apply)
 511                        apply_bytes -= ret;
 512                sg->offset += ret;
 513                sg->length -= ret;
 514                if (uncharge)
 515                        sk_mem_uncharge(sk, ret);
 516
 517                if (!sg->length) {
 518                        put_page(p);
 519                        md->sg_start++;
 520                        if (md->sg_start == MAX_SKB_FRAGS)
 521                                md->sg_start = 0;
 522                        sg_init_table(sg, 1);
 523
 524                        if (md->sg_start == md->sg_end)
 525                                break;
 526                }
 527
 528                if (apply && !apply_bytes)
 529                        break;
 530        }
 531        return 0;
 532}
 533
 534static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
 535{
 536        struct scatterlist *sg = md->sg_data + md->sg_start;
 537
 538        if (md->sg_copy[md->sg_start]) {
 539                md->data = md->data_end = 0;
 540        } else {
 541                md->data = sg_virt(sg);
 542                md->data_end = md->data + sg->length;
 543        }
 544}
 545
 546static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
 547{
 548        struct scatterlist *sg = md->sg_data;
 549        int i = md->sg_start;
 550
 551        do {
 552                int uncharge = (bytes < sg[i].length) ? bytes : sg[i].length;
 553
 554                sk_mem_uncharge(sk, uncharge);
 555                bytes -= uncharge;
 556                if (!bytes)
 557                        break;
 558                i++;
 559                if (i == MAX_SKB_FRAGS)
 560                        i = 0;
 561        } while (i != md->sg_end);
 562}
 563
 564static void free_bytes_sg(struct sock *sk, int bytes,
 565                          struct sk_msg_buff *md, bool charge)
 566{
 567        struct scatterlist *sg = md->sg_data;
 568        int i = md->sg_start, free;
 569
 570        while (bytes && sg[i].length) {
 571                free = sg[i].length;
 572                if (bytes < free) {
 573                        sg[i].length -= bytes;
 574                        sg[i].offset += bytes;
 575                        if (charge)
 576                                sk_mem_uncharge(sk, bytes);
 577                        break;
 578                }
 579
 580                if (charge)
 581                        sk_mem_uncharge(sk, sg[i].length);
 582                put_page(sg_page(&sg[i]));
 583                bytes -= sg[i].length;
 584                sg[i].length = 0;
 585                sg[i].page_link = 0;
 586                sg[i].offset = 0;
 587                i++;
 588
 589                if (i == MAX_SKB_FRAGS)
 590                        i = 0;
 591        }
 592        md->sg_start = i;
 593}
 594
 595static int free_sg(struct sock *sk, int start,
 596                   struct sk_msg_buff *md, bool charge)
 597{
 598        struct scatterlist *sg = md->sg_data;
 599        int i = start, free = 0;
 600
 601        while (sg[i].length) {
 602                free += sg[i].length;
 603                if (charge)
 604                        sk_mem_uncharge(sk, sg[i].length);
 605                if (!md->skb)
 606                        put_page(sg_page(&sg[i]));
 607                sg[i].length = 0;
 608                sg[i].page_link = 0;
 609                sg[i].offset = 0;
 610                i++;
 611
 612                if (i == MAX_SKB_FRAGS)
 613                        i = 0;
 614        }
 615        if (md->skb)
 616                consume_skb(md->skb);
 617
 618        return free;
 619}
 620
 621static int free_start_sg(struct sock *sk, struct sk_msg_buff *md, bool charge)
 622{
 623        int free = free_sg(sk, md->sg_start, md, charge);
 624
 625        md->sg_start = md->sg_end;
 626        return free;
 627}
 628
 629static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
 630{
 631        return free_sg(sk, md->sg_curr, md, true);
 632}
 633
 634static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
 635{
 636        return ((_rc == SK_PASS) ?
 637               (md->sk_redir ? __SK_REDIRECT : __SK_PASS) :
 638               __SK_DROP);
 639}
 640
 641static unsigned int smap_do_tx_msg(struct sock *sk,
 642                                   struct smap_psock *psock,
 643                                   struct sk_msg_buff *md)
 644{
 645        struct bpf_prog *prog;
 646        unsigned int rc, _rc;
 647
 648        preempt_disable();
 649        rcu_read_lock();
 650
 651        /* If the policy was removed mid-send then default to 'accept' */
 652        prog = READ_ONCE(psock->bpf_tx_msg);
 653        if (unlikely(!prog)) {
 654                _rc = SK_PASS;
 655                goto verdict;
 656        }
 657
 658        bpf_compute_data_pointers_sg(md);
 659        md->sk = sk;
 660        rc = (*prog->bpf_func)(md, prog->insnsi);
 661        psock->apply_bytes = md->apply_bytes;
 662
 663        /* Moving return codes from UAPI namespace into internal namespace */
 664        _rc = bpf_map_msg_verdict(rc, md);
 665
 666        /* The psock has a refcount on the sock but not on the map and because
 667         * we need to drop rcu read lock here its possible the map could be
 668         * removed between here and when we need it to execute the sock
 669         * redirect. So do the map lookup now for future use.
 670         */
 671        if (_rc == __SK_REDIRECT) {
 672                if (psock->sk_redir)
 673                        sock_put(psock->sk_redir);
 674                psock->sk_redir = do_msg_redirect_map(md);
 675                if (!psock->sk_redir) {
 676                        _rc = __SK_DROP;
 677                        goto verdict;
 678                }
 679                sock_hold(psock->sk_redir);
 680        }
 681verdict:
 682        rcu_read_unlock();
 683        preempt_enable();
 684
 685        return _rc;
 686}
 687
 688static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
 689                           struct smap_psock *psock,
 690                           struct sk_msg_buff *md, int flags)
 691{
 692        bool apply = apply_bytes;
 693        size_t size, copied = 0;
 694        struct sk_msg_buff *r;
 695        int err = 0, i;
 696
 697        r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
 698        if (unlikely(!r))
 699                return -ENOMEM;
 700
 701        lock_sock(sk);
 702        r->sg_start = md->sg_start;
 703        i = md->sg_start;
 704
 705        do {
 706                size = (apply && apply_bytes < md->sg_data[i].length) ?
 707                        apply_bytes : md->sg_data[i].length;
 708
 709                if (!sk_wmem_schedule(sk, size)) {
 710                        if (!copied)
 711                                err = -ENOMEM;
 712                        break;
 713                }
 714
 715                sk_mem_charge(sk, size);
 716                r->sg_data[i] = md->sg_data[i];
 717                r->sg_data[i].length = size;
 718                md->sg_data[i].length -= size;
 719                md->sg_data[i].offset += size;
 720                copied += size;
 721
 722                if (md->sg_data[i].length) {
 723                        get_page(sg_page(&r->sg_data[i]));
 724                        r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
 725                } else {
 726                        i++;
 727                        if (i == MAX_SKB_FRAGS)
 728                                i = 0;
 729                        r->sg_end = i;
 730                }
 731
 732                if (apply) {
 733                        apply_bytes -= size;
 734                        if (!apply_bytes)
 735                                break;
 736                }
 737        } while (i != md->sg_end);
 738
 739        md->sg_start = i;
 740
 741        if (!err) {
 742                list_add_tail(&r->list, &psock->ingress);
 743                sk->sk_data_ready(sk);
 744        } else {
 745                free_start_sg(sk, r, true);
 746                kfree(r);
 747        }
 748
 749        release_sock(sk);
 750        return err;
 751}
 752
 753static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
 754                                       struct sk_msg_buff *md,
 755                                       int flags)
 756{
 757        bool ingress = !!(md->flags & BPF_F_INGRESS);
 758        struct smap_psock *psock;
 759        int err = 0;
 760
 761        rcu_read_lock();
 762        psock = smap_psock_sk(sk);
 763        if (unlikely(!psock))
 764                goto out_rcu;
 765
 766        if (!refcount_inc_not_zero(&psock->refcnt))
 767                goto out_rcu;
 768
 769        rcu_read_unlock();
 770
 771        if (ingress) {
 772                err = bpf_tcp_ingress(sk, send, psock, md, flags);
 773        } else {
 774                lock_sock(sk);
 775                err = bpf_tcp_push(sk, send, md, flags, false);
 776                release_sock(sk);
 777        }
 778        smap_release_sock(psock, sk);
 779        return err;
 780out_rcu:
 781        rcu_read_unlock();
 782        return 0;
 783}
 784
 785static inline void bpf_md_init(struct smap_psock *psock)
 786{
 787        if (!psock->apply_bytes) {
 788                psock->eval =  __SK_NONE;
 789                if (psock->sk_redir) {
 790                        sock_put(psock->sk_redir);
 791                        psock->sk_redir = NULL;
 792                }
 793        }
 794}
 795
 796static void apply_bytes_dec(struct smap_psock *psock, int i)
 797{
 798        if (psock->apply_bytes) {
 799                if (psock->apply_bytes < i)
 800                        psock->apply_bytes = 0;
 801                else
 802                        psock->apply_bytes -= i;
 803        }
 804}
 805
 806static int bpf_exec_tx_verdict(struct smap_psock *psock,
 807                               struct sk_msg_buff *m,
 808                               struct sock *sk,
 809                               int *copied, int flags)
 810{
 811        bool cork = false, enospc = (m->sg_start == m->sg_end);
 812        struct sock *redir;
 813        int err = 0;
 814        int send;
 815
 816more_data:
 817        if (psock->eval == __SK_NONE)
 818                psock->eval = smap_do_tx_msg(sk, psock, m);
 819
 820        if (m->cork_bytes &&
 821            m->cork_bytes > psock->sg_size && !enospc) {
 822                psock->cork_bytes = m->cork_bytes - psock->sg_size;
 823                if (!psock->cork) {
 824                        psock->cork = kcalloc(1,
 825                                        sizeof(struct sk_msg_buff),
 826                                        GFP_ATOMIC | __GFP_NOWARN);
 827
 828                        if (!psock->cork) {
 829                                err = -ENOMEM;
 830                                goto out_err;
 831                        }
 832                }
 833                memcpy(psock->cork, m, sizeof(*m));
 834                goto out_err;
 835        }
 836
 837        send = psock->sg_size;
 838        if (psock->apply_bytes && psock->apply_bytes < send)
 839                send = psock->apply_bytes;
 840
 841        switch (psock->eval) {
 842        case __SK_PASS:
 843                err = bpf_tcp_push(sk, send, m, flags, true);
 844                if (unlikely(err)) {
 845                        *copied -= free_start_sg(sk, m, true);
 846                        break;
 847                }
 848
 849                apply_bytes_dec(psock, send);
 850                psock->sg_size -= send;
 851                break;
 852        case __SK_REDIRECT:
 853                redir = psock->sk_redir;
 854                apply_bytes_dec(psock, send);
 855
 856                if (psock->cork) {
 857                        cork = true;
 858                        psock->cork = NULL;
 859                }
 860
 861                return_mem_sg(sk, send, m);
 862                release_sock(sk);
 863
 864                err = bpf_tcp_sendmsg_do_redirect(redir, send, m, flags);
 865                lock_sock(sk);
 866
 867                if (unlikely(err < 0)) {
 868                        int free = free_start_sg(sk, m, false);
 869
 870                        psock->sg_size = 0;
 871                        if (!cork)
 872                                *copied -= free;
 873                } else {
 874                        psock->sg_size -= send;
 875                }
 876
 877                if (cork) {
 878                        free_start_sg(sk, m, true);
 879                        psock->sg_size = 0;
 880                        kfree(m);
 881                        m = NULL;
 882                        err = 0;
 883                }
 884                break;
 885        case __SK_DROP:
 886        default:
 887                free_bytes_sg(sk, send, m, true);
 888                apply_bytes_dec(psock, send);
 889                *copied -= send;
 890                psock->sg_size -= send;
 891                err = -EACCES;
 892                break;
 893        }
 894
 895        if (likely(!err)) {
 896                bpf_md_init(psock);
 897                if (m &&
 898                    m->sg_data[m->sg_start].page_link &&
 899                    m->sg_data[m->sg_start].length)
 900                        goto more_data;
 901        }
 902
 903out_err:
 904        return err;
 905}
 906
 907static int bpf_wait_data(struct sock *sk,
 908                         struct smap_psock *psk, int flags,
 909                         long timeo, int *err)
 910{
 911        int rc;
 912
 913        DEFINE_WAIT_FUNC(wait, woken_wake_function);
 914
 915        add_wait_queue(sk_sleep(sk), &wait);
 916        sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 917        rc = sk_wait_event(sk, &timeo,
 918                           !list_empty(&psk->ingress) ||
 919                           !skb_queue_empty(&sk->sk_receive_queue),
 920                           &wait);
 921        sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 922        remove_wait_queue(sk_sleep(sk), &wait);
 923
 924        return rc;
 925}
 926
 927static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 928                           int nonblock, int flags, int *addr_len)
 929{
 930        struct iov_iter *iter = &msg->msg_iter;
 931        struct smap_psock *psock;
 932        int copied = 0;
 933
 934        if (unlikely(flags & MSG_ERRQUEUE))
 935                return inet_recv_error(sk, msg, len, addr_len);
 936        if (!skb_queue_empty(&sk->sk_receive_queue))
 937                return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 938
 939        rcu_read_lock();
 940        psock = smap_psock_sk(sk);
 941        if (unlikely(!psock))
 942                goto out;
 943
 944        if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
 945                goto out;
 946        rcu_read_unlock();
 947
 948        lock_sock(sk);
 949bytes_ready:
 950        while (copied != len) {
 951                struct scatterlist *sg;
 952                struct sk_msg_buff *md;
 953                int i;
 954
 955                md = list_first_entry_or_null(&psock->ingress,
 956                                              struct sk_msg_buff, list);
 957                if (unlikely(!md))
 958                        break;
 959                i = md->sg_start;
 960                do {
 961                        struct page *page;
 962                        int n, copy;
 963
 964                        sg = &md->sg_data[i];
 965                        copy = sg->length;
 966                        page = sg_page(sg);
 967
 968                        if (copied + copy > len)
 969                                copy = len - copied;
 970
 971                        n = copy_page_to_iter(page, sg->offset, copy, iter);
 972                        if (n != copy) {
 973                                md->sg_start = i;
 974                                release_sock(sk);
 975                                smap_release_sock(psock, sk);
 976                                return -EFAULT;
 977                        }
 978
 979                        copied += copy;
 980                        sg->offset += copy;
 981                        sg->length -= copy;
 982                        sk_mem_uncharge(sk, copy);
 983
 984                        if (!sg->length) {
 985                                i++;
 986                                if (i == MAX_SKB_FRAGS)
 987                                        i = 0;
 988                                if (!md->skb)
 989                                        put_page(page);
 990                        }
 991                        if (copied == len)
 992                                break;
 993                } while (i != md->sg_end);
 994                md->sg_start = i;
 995
 996                if (!sg->length && md->sg_start == md->sg_end) {
 997                        list_del(&md->list);
 998                        if (md->skb)
 999                                consume_skb(md->skb);
1000                        kfree(md);
1001                }
1002        }
1003
1004        if (!copied) {
1005                long timeo;
1006                int data;
1007                int err = 0;
1008
1009                timeo = sock_rcvtimeo(sk, nonblock);
1010                data = bpf_wait_data(sk, psock, flags, timeo, &err);
1011
1012                if (data) {
1013                        if (!skb_queue_empty(&sk->sk_receive_queue)) {
1014                                release_sock(sk);
1015                                smap_release_sock(psock, sk);
1016                                copied = tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1017                                return copied;
1018                        }
1019                        goto bytes_ready;
1020                }
1021
1022                if (err)
1023                        copied = err;
1024        }
1025
1026        release_sock(sk);
1027        smap_release_sock(psock, sk);
1028        return copied;
1029out:
1030        rcu_read_unlock();
1031        return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
1032}
1033
1034
1035static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
1036{
1037        int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
1038        struct sk_msg_buff md = {0};
1039        unsigned int sg_copy = 0;
1040        struct smap_psock *psock;
1041        int copied = 0, err = 0;
1042        struct scatterlist *sg;
1043        long timeo;
1044
1045        /* Its possible a sock event or user removed the psock _but_ the ops
1046         * have not been reprogrammed yet so we get here. In this case fallback
1047         * to tcp_sendmsg. Note this only works because we _only_ ever allow
1048         * a single ULP there is no hierarchy here.
1049         */
1050        rcu_read_lock();
1051        psock = smap_psock_sk(sk);
1052        if (unlikely(!psock)) {
1053                rcu_read_unlock();
1054                return tcp_sendmsg(sk, msg, size);
1055        }
1056
1057        /* Increment the psock refcnt to ensure its not released while sending a
1058         * message. Required because sk lookup and bpf programs are used in
1059         * separate rcu critical sections. Its OK if we lose the map entry
1060         * but we can't lose the sock reference.
1061         */
1062        if (!refcount_inc_not_zero(&psock->refcnt)) {
1063                rcu_read_unlock();
1064                return tcp_sendmsg(sk, msg, size);
1065        }
1066
1067        sg = md.sg_data;
1068        sg_init_marker(sg, MAX_SKB_FRAGS);
1069        rcu_read_unlock();
1070
1071        lock_sock(sk);
1072        timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
1073
1074        while (msg_data_left(msg)) {
1075                struct sk_msg_buff *m = NULL;
1076                bool enospc = false;
1077                int copy;
1078
1079                if (sk->sk_err) {
1080                        err = -sk->sk_err;
1081                        goto out_err;
1082                }
1083
1084                copy = msg_data_left(msg);
1085                if (!sk_stream_memory_free(sk))
1086                        goto wait_for_sndbuf;
1087
1088                m = psock->cork_bytes ? psock->cork : &md;
1089                m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
1090                err = sk_alloc_sg(sk, copy, m->sg_data,
1091                                  m->sg_start, &m->sg_end, &sg_copy,
1092                                  m->sg_end - 1);
1093                if (err) {
1094                        if (err != -ENOSPC)
1095                                goto wait_for_memory;
1096                        enospc = true;
1097                        copy = sg_copy;
1098                }
1099
1100                err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
1101                if (err < 0) {
1102                        free_curr_sg(sk, m);
1103                        goto out_err;
1104                }
1105
1106                psock->sg_size += copy;
1107                copied += copy;
1108                sg_copy = 0;
1109
1110                /* When bytes are being corked skip running BPF program and
1111                 * applying verdict unless there is no more buffer space. In
1112                 * the ENOSPC case simply run BPF prorgram with currently
1113                 * accumulated data. We don't have much choice at this point
1114                 * we could try extending the page frags or chaining complex
1115                 * frags but even in these cases _eventually_ we will hit an
1116                 * OOM scenario. More complex recovery schemes may be
1117                 * implemented in the future, but BPF programs must handle
1118                 * the case where apply_cork requests are not honored. The
1119                 * canonical method to verify this is to check data length.
1120                 */
1121                if (psock->cork_bytes) {
1122                        if (copy > psock->cork_bytes)
1123                                psock->cork_bytes = 0;
1124                        else
1125                                psock->cork_bytes -= copy;
1126
1127                        if (psock->cork_bytes && !enospc)
1128                                goto out_cork;
1129
1130                        /* All cork bytes accounted for re-run filter */
1131                        psock->eval = __SK_NONE;
1132                        psock->cork_bytes = 0;
1133                }
1134
1135                err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1136                if (unlikely(err < 0))
1137                        goto out_err;
1138                continue;
1139wait_for_sndbuf:
1140                set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1141wait_for_memory:
1142                err = sk_stream_wait_memory(sk, &timeo);
1143                if (err) {
1144                        if (m && m != psock->cork)
1145                                free_start_sg(sk, m, true);
1146                        goto out_err;
1147                }
1148        }
1149out_err:
1150        if (err < 0)
1151                err = sk_stream_error(sk, msg->msg_flags, err);
1152out_cork:
1153        release_sock(sk);
1154        smap_release_sock(psock, sk);
1155        return copied ? copied : err;
1156}
1157
1158static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
1159                            int offset, size_t size, int flags)
1160{
1161        struct sk_msg_buff md = {0}, *m = NULL;
1162        int err = 0, copied = 0;
1163        struct smap_psock *psock;
1164        struct scatterlist *sg;
1165        bool enospc = false;
1166
1167        rcu_read_lock();
1168        psock = smap_psock_sk(sk);
1169        if (unlikely(!psock))
1170                goto accept;
1171
1172        if (!refcount_inc_not_zero(&psock->refcnt))
1173                goto accept;
1174        rcu_read_unlock();
1175
1176        lock_sock(sk);
1177
1178        if (psock->cork_bytes) {
1179                m = psock->cork;
1180                sg = &m->sg_data[m->sg_end];
1181        } else {
1182                m = &md;
1183                sg = m->sg_data;
1184                sg_init_marker(sg, MAX_SKB_FRAGS);
1185        }
1186
1187        /* Catch case where ring is full and sendpage is stalled. */
1188        if (unlikely(m->sg_end == m->sg_start &&
1189            m->sg_data[m->sg_end].length))
1190                goto out_err;
1191
1192        psock->sg_size += size;
1193        sg_set_page(sg, page, size, offset);
1194        get_page(page);
1195        m->sg_copy[m->sg_end] = true;
1196        sk_mem_charge(sk, size);
1197        m->sg_end++;
1198        copied = size;
1199
1200        if (m->sg_end == MAX_SKB_FRAGS)
1201                m->sg_end = 0;
1202
1203        if (m->sg_end == m->sg_start)
1204                enospc = true;
1205
1206        if (psock->cork_bytes) {
1207                if (size > psock->cork_bytes)
1208                        psock->cork_bytes = 0;
1209                else
1210                        psock->cork_bytes -= size;
1211
1212                if (psock->cork_bytes && !enospc)
1213                        goto out_err;
1214
1215                /* All cork bytes accounted for re-run filter */
1216                psock->eval = __SK_NONE;
1217                psock->cork_bytes = 0;
1218        }
1219
1220        err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
1221out_err:
1222        release_sock(sk);
1223        smap_release_sock(psock, sk);
1224        return copied ? copied : err;
1225accept:
1226        rcu_read_unlock();
1227        return tcp_sendpage(sk, page, offset, size, flags);
1228}
1229
1230static void bpf_tcp_msg_add(struct smap_psock *psock,
1231                            struct sock *sk,
1232                            struct bpf_prog *tx_msg)
1233{
1234        struct bpf_prog *orig_tx_msg;
1235
1236        orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
1237        if (orig_tx_msg)
1238                bpf_prog_put(orig_tx_msg);
1239}
1240
1241static int bpf_tcp_ulp_register(void)
1242{
1243        build_protos(bpf_tcp_prots[SOCKMAP_IPV4], &tcp_prot);
1244        /* Once BPF TX ULP is registered it is never unregistered. It
1245         * will be in the ULP list for the lifetime of the system. Doing
1246         * duplicate registers is not a problem.
1247         */
1248        return tcp_register_ulp(&bpf_tcp_ulp_ops);
1249}
1250
1251static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
1252{
1253        struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
1254        int rc;
1255
1256        if (unlikely(!prog))
1257                return __SK_DROP;
1258
1259        skb_orphan(skb);
1260        /* We need to ensure that BPF metadata for maps is also cleared
1261         * when we orphan the skb so that we don't have the possibility
1262         * to reference a stale map.
1263         */
1264        TCP_SKB_CB(skb)->bpf.sk_redir = NULL;
1265        skb->sk = psock->sock;
1266        bpf_compute_data_end_sk_skb(skb);
1267        preempt_disable();
1268        rc = (*prog->bpf_func)(skb, prog->insnsi);
1269        preempt_enable();
1270        skb->sk = NULL;
1271
1272        /* Moving return codes from UAPI namespace into internal namespace */
1273        return rc == SK_PASS ?
1274                (TCP_SKB_CB(skb)->bpf.sk_redir ? __SK_REDIRECT : __SK_PASS) :
1275                __SK_DROP;
1276}
1277
1278static int smap_do_ingress(struct smap_psock *psock, struct sk_buff *skb)
1279{
1280        struct sock *sk = psock->sock;
1281        int copied = 0, num_sg;
1282        struct sk_msg_buff *r;
1283
1284        r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_ATOMIC);
1285        if (unlikely(!r))
1286                return -EAGAIN;
1287
1288        if (!sk_rmem_schedule(sk, skb, skb->len)) {
1289                kfree(r);
1290                return -EAGAIN;
1291        }
1292
1293        sg_init_table(r->sg_data, MAX_SKB_FRAGS);
1294        num_sg = skb_to_sgvec(skb, r->sg_data, 0, skb->len);
1295        if (unlikely(num_sg < 0)) {
1296                kfree(r);
1297                return num_sg;
1298        }
1299        sk_mem_charge(sk, skb->len);
1300        copied = skb->len;
1301        r->sg_start = 0;
1302        r->sg_end = num_sg == MAX_SKB_FRAGS ? 0 : num_sg;
1303        r->skb = skb;
1304        list_add_tail(&r->list, &psock->ingress);
1305        sk->sk_data_ready(sk);
1306        return copied;
1307}
1308
1309static void smap_do_verdict(struct smap_psock *psock, struct sk_buff *skb)
1310{
1311        struct smap_psock *peer;
1312        struct sock *sk;
1313        __u32 in;
1314        int rc;
1315
1316        rc = smap_verdict_func(psock, skb);
1317        switch (rc) {
1318        case __SK_REDIRECT:
1319                sk = do_sk_redirect_map(skb);
1320                if (!sk) {
1321                        kfree_skb(skb);
1322                        break;
1323                }
1324
1325                peer = smap_psock_sk(sk);
1326                in = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1327
1328                if (unlikely(!peer || sock_flag(sk, SOCK_DEAD) ||
1329                             !test_bit(SMAP_TX_RUNNING, &peer->state))) {
1330                        kfree_skb(skb);
1331                        break;
1332                }
1333
1334                if (!in && sock_writeable(sk)) {
1335                        skb_set_owner_w(skb, sk);
1336                        skb_queue_tail(&peer->rxqueue, skb);
1337                        schedule_work(&peer->tx_work);
1338                        break;
1339                } else if (in &&
1340                           atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) {
1341                        skb_queue_tail(&peer->rxqueue, skb);
1342                        schedule_work(&peer->tx_work);
1343                        break;
1344                }
1345        /* Fall through and free skb otherwise */
1346        case __SK_DROP:
1347        default:
1348                kfree_skb(skb);
1349        }
1350}
1351
1352static void smap_report_sk_error(struct smap_psock *psock, int err)
1353{
1354        struct sock *sk = psock->sock;
1355
1356        sk->sk_err = err;
1357        sk->sk_error_report(sk);
1358}
1359
1360static void smap_read_sock_strparser(struct strparser *strp,
1361                                     struct sk_buff *skb)
1362{
1363        struct smap_psock *psock;
1364
1365        rcu_read_lock();
1366        psock = container_of(strp, struct smap_psock, strp);
1367        smap_do_verdict(psock, skb);
1368        rcu_read_unlock();
1369}
1370
1371/* Called with lock held on socket */
1372static void smap_data_ready(struct sock *sk)
1373{
1374        struct smap_psock *psock;
1375
1376        rcu_read_lock();
1377        psock = smap_psock_sk(sk);
1378        if (likely(psock)) {
1379                write_lock_bh(&sk->sk_callback_lock);
1380                strp_data_ready(&psock->strp);
1381                write_unlock_bh(&sk->sk_callback_lock);
1382        }
1383        rcu_read_unlock();
1384}
1385
1386static void smap_tx_work(struct work_struct *w)
1387{
1388        struct smap_psock *psock;
1389        struct sk_buff *skb;
1390        int rem, off, n;
1391
1392        psock = container_of(w, struct smap_psock, tx_work);
1393
1394        /* lock sock to avoid losing sk_socket at some point during loop */
1395        lock_sock(psock->sock);
1396        if (psock->save_skb) {
1397                skb = psock->save_skb;
1398                rem = psock->save_rem;
1399                off = psock->save_off;
1400                psock->save_skb = NULL;
1401                goto start;
1402        }
1403
1404        while ((skb = skb_dequeue(&psock->rxqueue))) {
1405                __u32 flags;
1406
1407                rem = skb->len;
1408                off = 0;
1409start:
1410                flags = (TCP_SKB_CB(skb)->bpf.flags) & BPF_F_INGRESS;
1411                do {
1412                        if (likely(psock->sock->sk_socket)) {
1413                                if (flags)
1414                                        n = smap_do_ingress(psock, skb);
1415                                else
1416                                        n = skb_send_sock_locked(psock->sock,
1417                                                                 skb, off, rem);
1418                        } else {
1419                                n = -EINVAL;
1420                        }
1421
1422                        if (n <= 0) {
1423                                if (n == -EAGAIN) {
1424                                        /* Retry when space is available */
1425                                        psock->save_skb = skb;
1426                                        psock->save_rem = rem;
1427                                        psock->save_off = off;
1428                                        goto out;
1429                                }
1430                                /* Hard errors break pipe and stop xmit */
1431                                smap_report_sk_error(psock, n ? -n : EPIPE);
1432                                clear_bit(SMAP_TX_RUNNING, &psock->state);
1433                                kfree_skb(skb);
1434                                goto out;
1435                        }
1436                        rem -= n;
1437                        off += n;
1438                } while (rem);
1439
1440                if (!flags)
1441                        kfree_skb(skb);
1442        }
1443out:
1444        release_sock(psock->sock);
1445}
1446
1447static void smap_write_space(struct sock *sk)
1448{
1449        struct smap_psock *psock;
1450        void (*write_space)(struct sock *sk);
1451
1452        rcu_read_lock();
1453        psock = smap_psock_sk(sk);
1454        if (likely(psock && test_bit(SMAP_TX_RUNNING, &psock->state)))
1455                schedule_work(&psock->tx_work);
1456        write_space = psock->save_write_space;
1457        rcu_read_unlock();
1458        write_space(sk);
1459}
1460
1461static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
1462{
1463        if (!psock->strp_enabled)
1464                return;
1465        sk->sk_data_ready = psock->save_data_ready;
1466        sk->sk_write_space = psock->save_write_space;
1467        psock->save_data_ready = NULL;
1468        psock->save_write_space = NULL;
1469        strp_stop(&psock->strp);
1470        psock->strp_enabled = false;
1471}
1472
1473static void smap_destroy_psock(struct rcu_head *rcu)
1474{
1475        struct smap_psock *psock = container_of(rcu,
1476                                                  struct smap_psock, rcu);
1477
1478        /* Now that a grace period has passed there is no longer
1479         * any reference to this sock in the sockmap so we can
1480         * destroy the psock, strparser, and bpf programs. But,
1481         * because we use workqueue sync operations we can not
1482         * do it in rcu context
1483         */
1484        schedule_work(&psock->gc_work);
1485}
1486
1487static bool psock_is_smap_sk(struct sock *sk)
1488{
1489        return inet_csk(sk)->icsk_ulp_ops == &bpf_tcp_ulp_ops;
1490}
1491
1492static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
1493{
1494        if (refcount_dec_and_test(&psock->refcnt)) {
1495                if (psock_is_smap_sk(sock))
1496                        tcp_cleanup_ulp(sock);
1497                write_lock_bh(&sock->sk_callback_lock);
1498                smap_stop_sock(psock, sock);
1499                write_unlock_bh(&sock->sk_callback_lock);
1500                clear_bit(SMAP_TX_RUNNING, &psock->state);
1501                rcu_assign_sk_user_data(sock, NULL);
1502                call_rcu_sched(&psock->rcu, smap_destroy_psock);
1503        }
1504}
1505
1506static int smap_parse_func_strparser(struct strparser *strp,
1507                                       struct sk_buff *skb)
1508{
1509        struct smap_psock *psock;
1510        struct bpf_prog *prog;
1511        int rc;
1512
1513        rcu_read_lock();
1514        psock = container_of(strp, struct smap_psock, strp);
1515        prog = READ_ONCE(psock->bpf_parse);
1516
1517        if (unlikely(!prog)) {
1518                rcu_read_unlock();
1519                return skb->len;
1520        }
1521
1522        /* Attach socket for bpf program to use if needed we can do this
1523         * because strparser clones the skb before handing it to a upper
1524         * layer, meaning skb_orphan has been called. We NULL sk on the
1525         * way out to ensure we don't trigger a BUG_ON in skb/sk operations
1526         * later and because we are not charging the memory of this skb to
1527         * any socket yet.
1528         */
1529        skb->sk = psock->sock;
1530        bpf_compute_data_end_sk_skb(skb);
1531        rc = (*prog->bpf_func)(skb, prog->insnsi);
1532        skb->sk = NULL;
1533        rcu_read_unlock();
1534        return rc;
1535}
1536
1537static int smap_read_sock_done(struct strparser *strp, int err)
1538{
1539        return err;
1540}
1541
1542static int smap_init_sock(struct smap_psock *psock,
1543                          struct sock *sk)
1544{
1545        static const struct strp_callbacks cb = {
1546                .rcv_msg = smap_read_sock_strparser,
1547                .parse_msg = smap_parse_func_strparser,
1548                .read_sock_done = smap_read_sock_done,
1549        };
1550
1551        return strp_init(&psock->strp, sk, &cb);
1552}
1553
1554static void smap_init_progs(struct smap_psock *psock,
1555                            struct bpf_prog *verdict,
1556                            struct bpf_prog *parse)
1557{
1558        struct bpf_prog *orig_parse, *orig_verdict;
1559
1560        orig_parse = xchg(&psock->bpf_parse, parse);
1561        orig_verdict = xchg(&psock->bpf_verdict, verdict);
1562
1563        if (orig_verdict)
1564                bpf_prog_put(orig_verdict);
1565        if (orig_parse)
1566                bpf_prog_put(orig_parse);
1567}
1568
1569static void smap_start_sock(struct smap_psock *psock, struct sock *sk)
1570{
1571        if (sk->sk_data_ready == smap_data_ready)
1572                return;
1573        psock->save_data_ready = sk->sk_data_ready;
1574        psock->save_write_space = sk->sk_write_space;
1575        sk->sk_data_ready = smap_data_ready;
1576        sk->sk_write_space = smap_write_space;
1577        psock->strp_enabled = true;
1578}
1579
1580static void sock_map_remove_complete(struct bpf_stab *stab)
1581{
1582        bpf_map_area_free(stab->sock_map);
1583        kfree(stab);
1584}
1585
1586static void smap_gc_work(struct work_struct *w)
1587{
1588        struct smap_psock_map_entry *e, *tmp;
1589        struct sk_msg_buff *md, *mtmp;
1590        struct smap_psock *psock;
1591
1592        psock = container_of(w, struct smap_psock, gc_work);
1593
1594        /* no callback lock needed because we already detached sockmap ops */
1595        if (psock->strp_enabled)
1596                strp_done(&psock->strp);
1597
1598        cancel_work_sync(&psock->tx_work);
1599        __skb_queue_purge(&psock->rxqueue);
1600
1601        /* At this point all strparser and xmit work must be complete */
1602        if (psock->bpf_parse)
1603                bpf_prog_put(psock->bpf_parse);
1604        if (psock->bpf_verdict)
1605                bpf_prog_put(psock->bpf_verdict);
1606        if (psock->bpf_tx_msg)
1607                bpf_prog_put(psock->bpf_tx_msg);
1608
1609        if (psock->cork) {
1610                free_start_sg(psock->sock, psock->cork, true);
1611                kfree(psock->cork);
1612        }
1613
1614        list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
1615                list_del(&md->list);
1616                free_start_sg(psock->sock, md, true);
1617                kfree(md);
1618        }
1619
1620        list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1621                list_del(&e->list);
1622                kfree(e);
1623        }
1624
1625        if (psock->sk_redir)
1626                sock_put(psock->sk_redir);
1627
1628        sock_put(psock->sock);
1629        kfree(psock);
1630}
1631
1632static struct smap_psock *smap_init_psock(struct sock *sock, int node)
1633{
1634        struct smap_psock *psock;
1635
1636        psock = kzalloc_node(sizeof(struct smap_psock),
1637                             GFP_ATOMIC | __GFP_NOWARN,
1638                             node);
1639        if (!psock)
1640                return ERR_PTR(-ENOMEM);
1641
1642        psock->eval =  __SK_NONE;
1643        psock->sock = sock;
1644        skb_queue_head_init(&psock->rxqueue);
1645        INIT_WORK(&psock->tx_work, smap_tx_work);
1646        INIT_WORK(&psock->gc_work, smap_gc_work);
1647        INIT_LIST_HEAD(&psock->maps);
1648        INIT_LIST_HEAD(&psock->ingress);
1649        refcount_set(&psock->refcnt, 1);
1650        spin_lock_init(&psock->maps_lock);
1651
1652        rcu_assign_sk_user_data(sock, psock);
1653        sock_hold(sock);
1654        return psock;
1655}
1656
1657static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
1658{
1659        struct bpf_stab *stab;
1660        u64 cost;
1661        int err;
1662
1663        if (!capable(CAP_NET_ADMIN))
1664                return ERR_PTR(-EPERM);
1665
1666        /* check sanity of attributes */
1667        if (attr->max_entries == 0 || attr->key_size != 4 ||
1668            attr->value_size != 4 || attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1669                return ERR_PTR(-EINVAL);
1670
1671        err = bpf_tcp_ulp_register();
1672        if (err && err != -EEXIST)
1673                return ERR_PTR(err);
1674
1675        stab = kzalloc(sizeof(*stab), GFP_USER);
1676        if (!stab)
1677                return ERR_PTR(-ENOMEM);
1678
1679        bpf_map_init_from_attr(&stab->map, attr);
1680        raw_spin_lock_init(&stab->lock);
1681
1682        /* make sure page count doesn't overflow */
1683        cost = (u64) stab->map.max_entries * sizeof(struct sock *);
1684        err = -EINVAL;
1685        if (cost >= U32_MAX - PAGE_SIZE)
1686                goto free_stab;
1687
1688        stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
1689
1690        /* if map size is larger than memlock limit, reject it early */
1691        err = bpf_map_precharge_memlock(stab->map.pages);
1692        if (err)
1693                goto free_stab;
1694
1695        err = -ENOMEM;
1696        stab->sock_map = bpf_map_area_alloc(stab->map.max_entries *
1697                                            sizeof(struct sock *),
1698                                            stab->map.numa_node);
1699        if (!stab->sock_map)
1700                goto free_stab;
1701
1702        return &stab->map;
1703free_stab:
1704        kfree(stab);
1705        return ERR_PTR(err);
1706}
1707
1708static void smap_list_map_remove(struct smap_psock *psock,
1709                                 struct sock **entry)
1710{
1711        struct smap_psock_map_entry *e, *tmp;
1712
1713        spin_lock_bh(&psock->maps_lock);
1714        list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1715                if (e->entry == entry) {
1716                        list_del(&e->list);
1717                        kfree(e);
1718                }
1719        }
1720        spin_unlock_bh(&psock->maps_lock);
1721}
1722
1723static void smap_list_hash_remove(struct smap_psock *psock,
1724                                  struct htab_elem *hash_link)
1725{
1726        struct smap_psock_map_entry *e, *tmp;
1727
1728        spin_lock_bh(&psock->maps_lock);
1729        list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1730                struct htab_elem *c = rcu_dereference(e->hash_link);
1731
1732                if (c == hash_link) {
1733                        list_del(&e->list);
1734                        kfree(e);
1735                }
1736        }
1737        spin_unlock_bh(&psock->maps_lock);
1738}
1739
1740static void sock_map_free(struct bpf_map *map)
1741{
1742        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1743        int i;
1744
1745        synchronize_rcu();
1746
1747        /* At this point no update, lookup or delete operations can happen.
1748         * However, be aware we can still get a socket state event updates,
1749         * and data ready callabacks that reference the psock from sk_user_data
1750         * Also psock worker threads are still in-flight. So smap_release_sock
1751         * will only free the psock after cancel_sync on the worker threads
1752         * and a grace period expire to ensure psock is really safe to remove.
1753         */
1754        rcu_read_lock();
1755        raw_spin_lock_bh(&stab->lock);
1756        for (i = 0; i < stab->map.max_entries; i++) {
1757                struct smap_psock *psock;
1758                struct sock *sock;
1759
1760                sock = stab->sock_map[i];
1761                if (!sock)
1762                        continue;
1763                stab->sock_map[i] = NULL;
1764                psock = smap_psock_sk(sock);
1765                /* This check handles a racing sock event that can get the
1766                 * sk_callback_lock before this case but after xchg happens
1767                 * causing the refcnt to hit zero and sock user data (psock)
1768                 * to be null and queued for garbage collection.
1769                 */
1770                if (likely(psock)) {
1771                        smap_list_map_remove(psock, &stab->sock_map[i]);
1772                        smap_release_sock(psock, sock);
1773                }
1774        }
1775        raw_spin_unlock_bh(&stab->lock);
1776        rcu_read_unlock();
1777
1778        sock_map_remove_complete(stab);
1779}
1780
1781static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
1782{
1783        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1784        u32 i = key ? *(u32 *)key : U32_MAX;
1785        u32 *next = (u32 *)next_key;
1786
1787        if (i >= stab->map.max_entries) {
1788                *next = 0;
1789                return 0;
1790        }
1791
1792        if (i == stab->map.max_entries - 1)
1793                return -ENOENT;
1794
1795        *next = i + 1;
1796        return 0;
1797}
1798
1799struct sock  *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
1800{
1801        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1802
1803        if (key >= map->max_entries)
1804                return NULL;
1805
1806        return READ_ONCE(stab->sock_map[key]);
1807}
1808
1809static int sock_map_delete_elem(struct bpf_map *map, void *key)
1810{
1811        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1812        struct smap_psock *psock;
1813        int k = *(u32 *)key;
1814        struct sock *sock;
1815
1816        if (k >= map->max_entries)
1817                return -EINVAL;
1818
1819        raw_spin_lock_bh(&stab->lock);
1820        sock = stab->sock_map[k];
1821        stab->sock_map[k] = NULL;
1822        raw_spin_unlock_bh(&stab->lock);
1823        if (!sock)
1824                return -EINVAL;
1825
1826        psock = smap_psock_sk(sock);
1827        if (!psock)
1828                return 0;
1829        if (psock->bpf_parse) {
1830                write_lock_bh(&sock->sk_callback_lock);
1831                smap_stop_sock(psock, sock);
1832                write_unlock_bh(&sock->sk_callback_lock);
1833        }
1834        smap_list_map_remove(psock, &stab->sock_map[k]);
1835        smap_release_sock(psock, sock);
1836        return 0;
1837}
1838
1839/* Locking notes: Concurrent updates, deletes, and lookups are allowed and are
1840 * done inside rcu critical sections. This ensures on updates that the psock
1841 * will not be released via smap_release_sock() until concurrent updates/deletes
1842 * complete. All operations operate on sock_map using cmpxchg and xchg
1843 * operations to ensure we do not get stale references. Any reads into the
1844 * map must be done with READ_ONCE() because of this.
1845 *
1846 * A psock is destroyed via call_rcu and after any worker threads are cancelled
1847 * and syncd so we are certain all references from the update/lookup/delete
1848 * operations as well as references in the data path are no longer in use.
1849 *
1850 * Psocks may exist in multiple maps, but only a single set of parse/verdict
1851 * programs may be inherited from the maps it belongs to. A reference count
1852 * is kept with the total number of references to the psock from all maps. The
1853 * psock will not be released until this reaches zero. The psock and sock
1854 * user data data use the sk_callback_lock to protect critical data structures
1855 * from concurrent access. This allows us to avoid two updates from modifying
1856 * the user data in sock and the lock is required anyways for modifying
1857 * callbacks, we simply increase its scope slightly.
1858 *
1859 * Rules to follow,
1860 *  - psock must always be read inside RCU critical section
1861 *  - sk_user_data must only be modified inside sk_callback_lock and read
1862 *    inside RCU critical section.
1863 *  - psock->maps list must only be read & modified inside sk_callback_lock
1864 *  - sock_map must use READ_ONCE and (cmp)xchg operations
1865 *  - BPF verdict/parse programs must use READ_ONCE and xchg operations
1866 */
1867
1868static int __sock_map_ctx_update_elem(struct bpf_map *map,
1869                                      struct bpf_sock_progs *progs,
1870                                      struct sock *sock,
1871                                      void *key)
1872{
1873        struct bpf_prog *verdict, *parse, *tx_msg;
1874        struct smap_psock *psock;
1875        bool new = false;
1876        int err = 0;
1877
1878        /* 1. If sock map has BPF programs those will be inherited by the
1879         * sock being added. If the sock is already attached to BPF programs
1880         * this results in an error.
1881         */
1882        verdict = READ_ONCE(progs->bpf_verdict);
1883        parse = READ_ONCE(progs->bpf_parse);
1884        tx_msg = READ_ONCE(progs->bpf_tx_msg);
1885
1886        if (parse && verdict) {
1887                /* bpf prog refcnt may be zero if a concurrent attach operation
1888                 * removes the program after the above READ_ONCE() but before
1889                 * we increment the refcnt. If this is the case abort with an
1890                 * error.
1891                 */
1892                verdict = bpf_prog_inc_not_zero(verdict);
1893                if (IS_ERR(verdict))
1894                        return PTR_ERR(verdict);
1895
1896                parse = bpf_prog_inc_not_zero(parse);
1897                if (IS_ERR(parse)) {
1898                        bpf_prog_put(verdict);
1899                        return PTR_ERR(parse);
1900                }
1901        }
1902
1903        if (tx_msg) {
1904                tx_msg = bpf_prog_inc_not_zero(tx_msg);
1905                if (IS_ERR(tx_msg)) {
1906                        if (parse && verdict) {
1907                                bpf_prog_put(parse);
1908                                bpf_prog_put(verdict);
1909                        }
1910                        return PTR_ERR(tx_msg);
1911                }
1912        }
1913
1914        psock = smap_psock_sk(sock);
1915
1916        /* 2. Do not allow inheriting programs if psock exists and has
1917         * already inherited programs. This would create confusion on
1918         * which parser/verdict program is running. If no psock exists
1919         * create one. Inside sk_callback_lock to ensure concurrent create
1920         * doesn't update user data.
1921         */
1922        if (psock) {
1923                if (!psock_is_smap_sk(sock)) {
1924                        err = -EBUSY;
1925                        goto out_progs;
1926                }
1927                if (READ_ONCE(psock->bpf_parse) && parse) {
1928                        err = -EBUSY;
1929                        goto out_progs;
1930                }
1931                if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
1932                        err = -EBUSY;
1933                        goto out_progs;
1934                }
1935                if (!refcount_inc_not_zero(&psock->refcnt)) {
1936                        err = -EAGAIN;
1937                        goto out_progs;
1938                }
1939        } else {
1940                psock = smap_init_psock(sock, map->numa_node);
1941                if (IS_ERR(psock)) {
1942                        err = PTR_ERR(psock);
1943                        goto out_progs;
1944                }
1945
1946                set_bit(SMAP_TX_RUNNING, &psock->state);
1947                new = true;
1948        }
1949
1950        /* 3. At this point we have a reference to a valid psock that is
1951         * running. Attach any BPF programs needed.
1952         */
1953        if (tx_msg)
1954                bpf_tcp_msg_add(psock, sock, tx_msg);
1955        if (new) {
1956                err = tcp_set_ulp_id(sock, TCP_ULP_BPF);
1957                if (err)
1958                        goto out_free;
1959        }
1960
1961        if (parse && verdict && !psock->strp_enabled) {
1962                err = smap_init_sock(psock, sock);
1963                if (err)
1964                        goto out_free;
1965                smap_init_progs(psock, verdict, parse);
1966                write_lock_bh(&sock->sk_callback_lock);
1967                smap_start_sock(psock, sock);
1968                write_unlock_bh(&sock->sk_callback_lock);
1969        }
1970
1971        return err;
1972out_free:
1973        smap_release_sock(psock, sock);
1974out_progs:
1975        if (parse && verdict) {
1976                bpf_prog_put(parse);
1977                bpf_prog_put(verdict);
1978        }
1979        if (tx_msg)
1980                bpf_prog_put(tx_msg);
1981        return err;
1982}
1983
1984static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
1985                                    struct bpf_map *map,
1986                                    void *key, u64 flags)
1987{
1988        struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1989        struct bpf_sock_progs *progs = &stab->progs;
1990        struct sock *osock, *sock = skops->sk;
1991        struct smap_psock_map_entry *e;
1992        struct smap_psock *psock;
1993        u32 i = *(u32 *)key;
1994        int err;
1995
1996        if (unlikely(flags > BPF_EXIST))
1997                return -EINVAL;
1998        if (unlikely(i >= stab->map.max_entries))
1999                return -E2BIG;
2000
2001        e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
2002        if (!e)
2003                return -ENOMEM;
2004
2005        err = __sock_map_ctx_update_elem(map, progs, sock, key);
2006        if (err)
2007                goto out;
2008
2009        /* psock guaranteed to be present. */
2010        psock = smap_psock_sk(sock);
2011        raw_spin_lock_bh(&stab->lock);
2012        osock = stab->sock_map[i];
2013        if (osock && flags == BPF_NOEXIST) {
2014                err = -EEXIST;
2015                goto out_unlock;
2016        }
2017        if (!osock && flags == BPF_EXIST) {
2018                err = -ENOENT;
2019                goto out_unlock;
2020        }
2021
2022        e->entry = &stab->sock_map[i];
2023        e->map = map;
2024        spin_lock_bh(&psock->maps_lock);
2025        list_add_tail(&e->list, &psock->maps);
2026        spin_unlock_bh(&psock->maps_lock);
2027
2028        stab->sock_map[i] = sock;
2029        if (osock) {
2030                psock = smap_psock_sk(osock);
2031                smap_list_map_remove(psock, &stab->sock_map[i]);
2032                smap_release_sock(psock, osock);
2033        }
2034        raw_spin_unlock_bh(&stab->lock);
2035        return 0;
2036out_unlock:
2037        smap_release_sock(psock, sock);
2038        raw_spin_unlock_bh(&stab->lock);
2039out:
2040        kfree(e);
2041        return err;
2042}
2043
2044int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
2045{
2046        struct bpf_sock_progs *progs;
2047        struct bpf_prog *orig;
2048
2049        if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2050                struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2051
2052                progs = &stab->progs;
2053        } else if (map->map_type == BPF_MAP_TYPE_SOCKHASH) {
2054                struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2055
2056                progs = &htab->progs;
2057        } else {
2058                return -EINVAL;
2059        }
2060
2061        switch (type) {
2062        case BPF_SK_MSG_VERDICT:
2063                orig = xchg(&progs->bpf_tx_msg, prog);
2064                break;
2065        case BPF_SK_SKB_STREAM_PARSER:
2066                orig = xchg(&progs->bpf_parse, prog);
2067                break;
2068        case BPF_SK_SKB_STREAM_VERDICT:
2069                orig = xchg(&progs->bpf_verdict, prog);
2070                break;
2071        default:
2072                return -EOPNOTSUPP;
2073        }
2074
2075        if (orig)
2076                bpf_prog_put(orig);
2077
2078        return 0;
2079}
2080
2081int sockmap_get_from_fd(const union bpf_attr *attr, int type,
2082                        struct bpf_prog *prog)
2083{
2084        int ufd = attr->target_fd;
2085        struct bpf_map *map;
2086        struct fd f;
2087        int err;
2088
2089        f = fdget(ufd);
2090        map = __bpf_map_get(f);
2091        if (IS_ERR(map))
2092                return PTR_ERR(map);
2093
2094        err = sock_map_prog(map, prog, attr->attach_type);
2095        fdput(f);
2096        return err;
2097}
2098
2099static void *sock_map_lookup(struct bpf_map *map, void *key)
2100{
2101        return NULL;
2102}
2103
2104static int sock_map_update_elem(struct bpf_map *map,
2105                                void *key, void *value, u64 flags)
2106{
2107        struct bpf_sock_ops_kern skops;
2108        u32 fd = *(u32 *)value;
2109        struct socket *socket;
2110        int err;
2111
2112        socket = sockfd_lookup(fd, &err);
2113        if (!socket)
2114                return err;
2115
2116        skops.sk = socket->sk;
2117        if (!skops.sk) {
2118                fput(socket->file);
2119                return -EINVAL;
2120        }
2121
2122        /* ULPs are currently supported only for TCP sockets in ESTABLISHED
2123         * state.
2124         */
2125        if (skops.sk->sk_type != SOCK_STREAM ||
2126            skops.sk->sk_protocol != IPPROTO_TCP ||
2127            skops.sk->sk_state != TCP_ESTABLISHED) {
2128                fput(socket->file);
2129                return -EOPNOTSUPP;
2130        }
2131
2132        lock_sock(skops.sk);
2133        preempt_disable();
2134        rcu_read_lock();
2135        err = sock_map_ctx_update_elem(&skops, map, key, flags);
2136        rcu_read_unlock();
2137        preempt_enable();
2138        release_sock(skops.sk);
2139        fput(socket->file);
2140        return err;
2141}
2142
2143static void sock_map_release(struct bpf_map *map)
2144{
2145        struct bpf_sock_progs *progs;
2146        struct bpf_prog *orig;
2147
2148        if (map->map_type == BPF_MAP_TYPE_SOCKMAP) {
2149                struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
2150
2151                progs = &stab->progs;
2152        } else {
2153                struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2154
2155                progs = &htab->progs;
2156        }
2157
2158        orig = xchg(&progs->bpf_parse, NULL);
2159        if (orig)
2160                bpf_prog_put(orig);
2161        orig = xchg(&progs->bpf_verdict, NULL);
2162        if (orig)
2163                bpf_prog_put(orig);
2164
2165        orig = xchg(&progs->bpf_tx_msg, NULL);
2166        if (orig)
2167                bpf_prog_put(orig);
2168}
2169
2170static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
2171{
2172        struct bpf_htab *htab;
2173        int i, err;
2174        u64 cost;
2175
2176        if (!capable(CAP_NET_ADMIN))
2177                return ERR_PTR(-EPERM);
2178
2179        /* check sanity of attributes */
2180        if (attr->max_entries == 0 ||
2181            attr->key_size == 0 ||
2182            attr->value_size != 4 ||
2183            attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
2184                return ERR_PTR(-EINVAL);
2185
2186        if (attr->key_size > MAX_BPF_STACK)
2187                /* eBPF programs initialize keys on stack, so they cannot be
2188                 * larger than max stack size
2189                 */
2190                return ERR_PTR(-E2BIG);
2191
2192        err = bpf_tcp_ulp_register();
2193        if (err && err != -EEXIST)
2194                return ERR_PTR(err);
2195
2196        htab = kzalloc(sizeof(*htab), GFP_USER);
2197        if (!htab)
2198                return ERR_PTR(-ENOMEM);
2199
2200        bpf_map_init_from_attr(&htab->map, attr);
2201
2202        htab->n_buckets = roundup_pow_of_two(htab->map.max_entries);
2203        htab->elem_size = sizeof(struct htab_elem) +
2204                          round_up(htab->map.key_size, 8);
2205        err = -EINVAL;
2206        if (htab->n_buckets == 0 ||
2207            htab->n_buckets > U32_MAX / sizeof(struct bucket))
2208                goto free_htab;
2209
2210        cost = (u64) htab->n_buckets * sizeof(struct bucket) +
2211               (u64) htab->elem_size * htab->map.max_entries;
2212
2213        if (cost >= U32_MAX - PAGE_SIZE)
2214                goto free_htab;
2215
2216        htab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
2217        err = bpf_map_precharge_memlock(htab->map.pages);
2218        if (err)
2219                goto free_htab;
2220
2221        err = -ENOMEM;
2222        htab->buckets = bpf_map_area_alloc(
2223                                htab->n_buckets * sizeof(struct bucket),
2224                                htab->map.numa_node);
2225        if (!htab->buckets)
2226                goto free_htab;
2227
2228        for (i = 0; i < htab->n_buckets; i++) {
2229                INIT_HLIST_HEAD(&htab->buckets[i].head);
2230                raw_spin_lock_init(&htab->buckets[i].lock);
2231        }
2232
2233        return &htab->map;
2234free_htab:
2235        kfree(htab);
2236        return ERR_PTR(err);
2237}
2238
2239static void __bpf_htab_free(struct rcu_head *rcu)
2240{
2241        struct bpf_htab *htab;
2242
2243        htab = container_of(rcu, struct bpf_htab, rcu);
2244        bpf_map_area_free(htab->buckets);
2245        kfree(htab);
2246}
2247
2248static void sock_hash_free(struct bpf_map *map)
2249{
2250        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2251        int i;
2252
2253        synchronize_rcu();
2254
2255        /* At this point no update, lookup or delete operations can happen.
2256         * However, be aware we can still get a socket state event updates,
2257         * and data ready callabacks that reference the psock from sk_user_data
2258         * Also psock worker threads are still in-flight. So smap_release_sock
2259         * will only free the psock after cancel_sync on the worker threads
2260         * and a grace period expire to ensure psock is really safe to remove.
2261         */
2262        rcu_read_lock();
2263        for (i = 0; i < htab->n_buckets; i++) {
2264                struct bucket *b = __select_bucket(htab, i);
2265                struct hlist_head *head;
2266                struct hlist_node *n;
2267                struct htab_elem *l;
2268
2269                raw_spin_lock_bh(&b->lock);
2270                head = &b->head;
2271                hlist_for_each_entry_safe(l, n, head, hash_node) {
2272                        struct sock *sock = l->sk;
2273                        struct smap_psock *psock;
2274
2275                        hlist_del_rcu(&l->hash_node);
2276                        psock = smap_psock_sk(sock);
2277                        /* This check handles a racing sock event that can get
2278                         * the sk_callback_lock before this case but after xchg
2279                         * causing the refcnt to hit zero and sock user data
2280                         * (psock) to be null and queued for garbage collection.
2281                         */
2282                        if (likely(psock)) {
2283                                smap_list_hash_remove(psock, l);
2284                                smap_release_sock(psock, sock);
2285                        }
2286                        free_htab_elem(htab, l);
2287                }
2288                raw_spin_unlock_bh(&b->lock);
2289        }
2290        rcu_read_unlock();
2291        call_rcu(&htab->rcu, __bpf_htab_free);
2292}
2293
2294static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
2295                                              void *key, u32 key_size, u32 hash,
2296                                              struct sock *sk,
2297                                              struct htab_elem *old_elem)
2298{
2299        struct htab_elem *l_new;
2300
2301        if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
2302                if (!old_elem) {
2303                        atomic_dec(&htab->count);
2304                        return ERR_PTR(-E2BIG);
2305                }
2306        }
2307        l_new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
2308                             htab->map.numa_node);
2309        if (!l_new) {
2310                atomic_dec(&htab->count);
2311                return ERR_PTR(-ENOMEM);
2312        }
2313
2314        memcpy(l_new->key, key, key_size);
2315        l_new->sk = sk;
2316        l_new->hash = hash;
2317        return l_new;
2318}
2319
2320static inline u32 htab_map_hash(const void *key, u32 key_len)
2321{
2322        return jhash(key, key_len, 0);
2323}
2324
2325static int sock_hash_get_next_key(struct bpf_map *map,
2326                                  void *key, void *next_key)
2327{
2328        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2329        struct htab_elem *l, *next_l;
2330        struct hlist_head *h;
2331        u32 hash, key_size;
2332        int i = 0;
2333
2334        WARN_ON_ONCE(!rcu_read_lock_held());
2335
2336        key_size = map->key_size;
2337        if (!key)
2338                goto find_first_elem;
2339        hash = htab_map_hash(key, key_size);
2340        h = select_bucket(htab, hash);
2341
2342        l = lookup_elem_raw(h, hash, key, key_size);
2343        if (!l)
2344                goto find_first_elem;
2345        next_l = hlist_entry_safe(
2346                     rcu_dereference_raw(hlist_next_rcu(&l->hash_node)),
2347                     struct htab_elem, hash_node);
2348        if (next_l) {
2349                memcpy(next_key, next_l->key, key_size);
2350                return 0;
2351        }
2352
2353        /* no more elements in this hash list, go to the next bucket */
2354        i = hash & (htab->n_buckets - 1);
2355        i++;
2356
2357find_first_elem:
2358        /* iterate over buckets */
2359        for (; i < htab->n_buckets; i++) {
2360                h = select_bucket(htab, i);
2361
2362                /* pick first element in the bucket */
2363                next_l = hlist_entry_safe(
2364                                rcu_dereference_raw(hlist_first_rcu(h)),
2365                                struct htab_elem, hash_node);
2366                if (next_l) {
2367                        /* if it's not empty, just return it */
2368                        memcpy(next_key, next_l->key, key_size);
2369                        return 0;
2370                }
2371        }
2372
2373        /* iterated over all buckets and all elements */
2374        return -ENOENT;
2375}
2376
2377static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
2378                                     struct bpf_map *map,
2379                                     void *key, u64 map_flags)
2380{
2381        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2382        struct bpf_sock_progs *progs = &htab->progs;
2383        struct htab_elem *l_new = NULL, *l_old;
2384        struct smap_psock_map_entry *e = NULL;
2385        struct hlist_head *head;
2386        struct smap_psock *psock;
2387        u32 key_size, hash;
2388        struct sock *sock;
2389        struct bucket *b;
2390        int err;
2391
2392        sock = skops->sk;
2393
2394        if (sock->sk_type != SOCK_STREAM ||
2395            sock->sk_protocol != IPPROTO_TCP)
2396                return -EOPNOTSUPP;
2397
2398        if (unlikely(map_flags > BPF_EXIST))
2399                return -EINVAL;
2400
2401        e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
2402        if (!e)
2403                return -ENOMEM;
2404
2405        WARN_ON_ONCE(!rcu_read_lock_held());
2406        key_size = map->key_size;
2407        hash = htab_map_hash(key, key_size);
2408        b = __select_bucket(htab, hash);
2409        head = &b->head;
2410
2411        err = __sock_map_ctx_update_elem(map, progs, sock, key);
2412        if (err)
2413                goto err;
2414
2415        /* psock is valid here because otherwise above *ctx_update_elem would
2416         * have thrown an error. It is safe to skip error check.
2417         */
2418        psock = smap_psock_sk(sock);
2419        raw_spin_lock_bh(&b->lock);
2420        l_old = lookup_elem_raw(head, hash, key, key_size);
2421        if (l_old && map_flags == BPF_NOEXIST) {
2422                err = -EEXIST;
2423                goto bucket_err;
2424        }
2425        if (!l_old && map_flags == BPF_EXIST) {
2426                err = -ENOENT;
2427                goto bucket_err;
2428        }
2429
2430        l_new = alloc_sock_hash_elem(htab, key, key_size, hash, sock, l_old);
2431        if (IS_ERR(l_new)) {
2432                err = PTR_ERR(l_new);
2433                goto bucket_err;
2434        }
2435
2436        rcu_assign_pointer(e->hash_link, l_new);
2437        e->map = map;
2438        spin_lock_bh(&psock->maps_lock);
2439        list_add_tail(&e->list, &psock->maps);
2440        spin_unlock_bh(&psock->maps_lock);
2441
2442        /* add new element to the head of the list, so that
2443         * concurrent search will find it before old elem
2444         */
2445        hlist_add_head_rcu(&l_new->hash_node, head);
2446        if (l_old) {
2447                psock = smap_psock_sk(l_old->sk);
2448
2449                hlist_del_rcu(&l_old->hash_node);
2450                smap_list_hash_remove(psock, l_old);
2451                smap_release_sock(psock, l_old->sk);
2452                free_htab_elem(htab, l_old);
2453        }
2454        raw_spin_unlock_bh(&b->lock);
2455        return 0;
2456bucket_err:
2457        smap_release_sock(psock, sock);
2458        raw_spin_unlock_bh(&b->lock);
2459err:
2460        kfree(e);
2461        return err;
2462}
2463
2464static int sock_hash_update_elem(struct bpf_map *map,
2465                                void *key, void *value, u64 flags)
2466{
2467        struct bpf_sock_ops_kern skops;
2468        u32 fd = *(u32 *)value;
2469        struct socket *socket;
2470        int err;
2471
2472        socket = sockfd_lookup(fd, &err);
2473        if (!socket)
2474                return err;
2475
2476        skops.sk = socket->sk;
2477        if (!skops.sk) {
2478                fput(socket->file);
2479                return -EINVAL;
2480        }
2481
2482        /* ULPs are currently supported only for TCP sockets in ESTABLISHED
2483         * state.
2484         */
2485        if (skops.sk->sk_type != SOCK_STREAM ||
2486            skops.sk->sk_protocol != IPPROTO_TCP ||
2487            skops.sk->sk_state != TCP_ESTABLISHED) {
2488                fput(socket->file);
2489                return -EOPNOTSUPP;
2490        }
2491
2492        lock_sock(skops.sk);
2493        preempt_disable();
2494        rcu_read_lock();
2495        err = sock_hash_ctx_update_elem(&skops, map, key, flags);
2496        rcu_read_unlock();
2497        preempt_enable();
2498        release_sock(skops.sk);
2499        fput(socket->file);
2500        return err;
2501}
2502
2503static int sock_hash_delete_elem(struct bpf_map *map, void *key)
2504{
2505        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2506        struct hlist_head *head;
2507        struct bucket *b;
2508        struct htab_elem *l;
2509        u32 hash, key_size;
2510        int ret = -ENOENT;
2511
2512        key_size = map->key_size;
2513        hash = htab_map_hash(key, key_size);
2514        b = __select_bucket(htab, hash);
2515        head = &b->head;
2516
2517        raw_spin_lock_bh(&b->lock);
2518        l = lookup_elem_raw(head, hash, key, key_size);
2519        if (l) {
2520                struct sock *sock = l->sk;
2521                struct smap_psock *psock;
2522
2523                hlist_del_rcu(&l->hash_node);
2524                psock = smap_psock_sk(sock);
2525                /* This check handles a racing sock event that can get the
2526                 * sk_callback_lock before this case but after xchg happens
2527                 * causing the refcnt to hit zero and sock user data (psock)
2528                 * to be null and queued for garbage collection.
2529                 */
2530                if (likely(psock)) {
2531                        smap_list_hash_remove(psock, l);
2532                        smap_release_sock(psock, sock);
2533                }
2534                free_htab_elem(htab, l);
2535                ret = 0;
2536        }
2537        raw_spin_unlock_bh(&b->lock);
2538        return ret;
2539}
2540
2541struct sock  *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
2542{
2543        struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
2544        struct hlist_head *head;
2545        struct htab_elem *l;
2546        u32 key_size, hash;
2547        struct bucket *b;
2548        struct sock *sk;
2549
2550        key_size = map->key_size;
2551        hash = htab_map_hash(key, key_size);
2552        b = __select_bucket(htab, hash);
2553        head = &b->head;
2554
2555        l = lookup_elem_raw(head, hash, key, key_size);
2556        sk = l ? l->sk : NULL;
2557        return sk;
2558}
2559
2560const struct bpf_map_ops sock_map_ops = {
2561        .map_alloc = sock_map_alloc,
2562        .map_free = sock_map_free,
2563        .map_lookup_elem = sock_map_lookup,
2564        .map_get_next_key = sock_map_get_next_key,
2565        .map_update_elem = sock_map_update_elem,
2566        .map_delete_elem = sock_map_delete_elem,
2567        .map_release_uref = sock_map_release,
2568        .map_check_btf = map_check_no_btf,
2569};
2570
2571const struct bpf_map_ops sock_hash_ops = {
2572        .map_alloc = sock_hash_alloc,
2573        .map_free = sock_hash_free,
2574        .map_lookup_elem = sock_map_lookup,
2575        .map_get_next_key = sock_hash_get_next_key,
2576        .map_update_elem = sock_hash_update_elem,
2577        .map_delete_elem = sock_hash_delete_elem,
2578        .map_release_uref = sock_map_release,
2579        .map_check_btf = map_check_no_btf,
2580};
2581
2582static bool bpf_is_valid_sock_op(struct bpf_sock_ops_kern *ops)
2583{
2584        return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
2585               ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
2586}
2587BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, bpf_sock,
2588           struct bpf_map *, map, void *, key, u64, flags)
2589{
2590        WARN_ON_ONCE(!rcu_read_lock_held());
2591
2592        /* ULPs are currently supported only for TCP sockets in ESTABLISHED
2593         * state. This checks that the sock ops triggering the update is
2594         * one indicating we are (or will be soon) in an ESTABLISHED state.
2595         */
2596        if (!bpf_is_valid_sock_op(bpf_sock))
2597                return -EOPNOTSUPP;
2598        return sock_map_ctx_update_elem(bpf_sock, map, key, flags);
2599}
2600
2601const struct bpf_func_proto bpf_sock_map_update_proto = {
2602        .func           = bpf_sock_map_update,
2603        .gpl_only       = false,
2604        .pkt_access     = true,
2605        .ret_type       = RET_INTEGER,
2606        .arg1_type      = ARG_PTR_TO_CTX,
2607        .arg2_type      = ARG_CONST_MAP_PTR,
2608        .arg3_type      = ARG_PTR_TO_MAP_KEY,
2609        .arg4_type      = ARG_ANYTHING,
2610};
2611
2612BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, bpf_sock,
2613           struct bpf_map *, map, void *, key, u64, flags)
2614{
2615        WARN_ON_ONCE(!rcu_read_lock_held());
2616
2617        if (!bpf_is_valid_sock_op(bpf_sock))
2618                return -EOPNOTSUPP;
2619        return sock_hash_ctx_update_elem(bpf_sock, map, key, flags);
2620}
2621
2622const struct bpf_func_proto bpf_sock_hash_update_proto = {
2623        .func           = bpf_sock_hash_update,
2624        .gpl_only       = false,
2625        .pkt_access     = true,
2626        .ret_type       = RET_INTEGER,
2627        .arg1_type      = ARG_PTR_TO_CTX,
2628        .arg2_type      = ARG_CONST_MAP_PTR,
2629        .arg3_type      = ARG_PTR_TO_MAP_KEY,
2630        .arg4_type      = ARG_ANYTHING,
2631};
2632