linux/net/ipv4/tcp_bpf.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
   3
   4#include <linux/skmsg.h>
   5#include <linux/filter.h>
   6#include <linux/bpf.h>
   7#include <linux/init.h>
   8#include <linux/wait.h>
   9
  10#include <net/inet_common.h>
  11#include <net/tls.h>
  12
  13static bool tcp_bpf_stream_read(const struct sock *sk)
  14{
  15        struct sk_psock *psock;
  16        bool empty = true;
  17
  18        rcu_read_lock();
  19        psock = sk_psock(sk);
  20        if (likely(psock))
  21                empty = list_empty(&psock->ingress_msg);
  22        rcu_read_unlock();
  23        return !empty;
  24}
  25
  26static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
  27                             int flags, long timeo, int *err)
  28{
  29        DEFINE_WAIT_FUNC(wait, woken_wake_function);
  30        int ret = 0;
  31
  32        if (!timeo)
  33                return ret;
  34
  35        add_wait_queue(sk_sleep(sk), &wait);
  36        sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  37        ret = sk_wait_event(sk, &timeo,
  38                            !list_empty(&psock->ingress_msg) ||
  39                            !skb_queue_empty(&sk->sk_receive_queue), &wait);
  40        sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  41        remove_wait_queue(sk_sleep(sk), &wait);
  42        return ret;
  43}
  44
  45int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
  46                      struct msghdr *msg, int len, int flags)
  47{
  48        struct iov_iter *iter = &msg->msg_iter;
  49        int peek = flags & MSG_PEEK;
  50        int i, ret, copied = 0;
  51        struct sk_msg *msg_rx;
  52
  53        msg_rx = list_first_entry_or_null(&psock->ingress_msg,
  54                                          struct sk_msg, list);
  55
  56        while (copied != len) {
  57                struct scatterlist *sge;
  58
  59                if (unlikely(!msg_rx))
  60                        break;
  61
  62                i = msg_rx->sg.start;
  63                do {
  64                        struct page *page;
  65                        int copy;
  66
  67                        sge = sk_msg_elem(msg_rx, i);
  68                        copy = sge->length;
  69                        page = sg_page(sge);
  70                        if (copied + copy > len)
  71                                copy = len - copied;
  72                        ret = copy_page_to_iter(page, sge->offset, copy, iter);
  73                        if (ret != copy) {
  74                                msg_rx->sg.start = i;
  75                                return -EFAULT;
  76                        }
  77
  78                        copied += copy;
  79                        if (likely(!peek)) {
  80                                sge->offset += copy;
  81                                sge->length -= copy;
  82                                sk_mem_uncharge(sk, copy);
  83                                msg_rx->sg.size -= copy;
  84
  85                                if (!sge->length) {
  86                                        sk_msg_iter_var_next(i);
  87                                        if (!msg_rx->skb)
  88                                                put_page(page);
  89                                }
  90                        } else {
  91                                sk_msg_iter_var_next(i);
  92                        }
  93
  94                        if (copied == len)
  95                                break;
  96                } while (i != msg_rx->sg.end);
  97
  98                if (unlikely(peek)) {
  99                        msg_rx = list_next_entry(msg_rx, list);
 100                        continue;
 101                }
 102
 103                msg_rx->sg.start = i;
 104                if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
 105                        list_del(&msg_rx->list);
 106                        if (msg_rx->skb)
 107                                consume_skb(msg_rx->skb);
 108                        kfree(msg_rx);
 109                }
 110                msg_rx = list_first_entry_or_null(&psock->ingress_msg,
 111                                                  struct sk_msg, list);
 112        }
 113
 114        return copied;
 115}
 116EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
 117
 118int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 119                    int nonblock, int flags, int *addr_len)
 120{
 121        struct sk_psock *psock;
 122        int copied, ret;
 123
 124        if (unlikely(flags & MSG_ERRQUEUE))
 125                return inet_recv_error(sk, msg, len, addr_len);
 126        if (!skb_queue_empty(&sk->sk_receive_queue))
 127                return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 128
 129        psock = sk_psock_get(sk);
 130        if (unlikely(!psock))
 131                return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 132        lock_sock(sk);
 133msg_bytes_ready:
 134        copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
 135        if (!copied) {
 136                int data, err = 0;
 137                long timeo;
 138
 139                timeo = sock_rcvtimeo(sk, nonblock);
 140                data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
 141                if (data) {
 142                        if (skb_queue_empty(&sk->sk_receive_queue))
 143                                goto msg_bytes_ready;
 144                        release_sock(sk);
 145                        sk_psock_put(sk, psock);
 146                        return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 147                }
 148                if (err) {
 149                        ret = err;
 150                        goto out;
 151                }
 152                copied = -EAGAIN;
 153        }
 154        ret = copied;
 155out:
 156        release_sock(sk);
 157        sk_psock_put(sk, psock);
 158        return ret;
 159}
 160
 161static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
 162                           struct sk_msg *msg, u32 apply_bytes, int flags)
 163{
 164        bool apply = apply_bytes;
 165        struct scatterlist *sge;
 166        u32 size, copied = 0;
 167        struct sk_msg *tmp;
 168        int i, ret = 0;
 169
 170        tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
 171        if (unlikely(!tmp))
 172                return -ENOMEM;
 173
 174        lock_sock(sk);
 175        tmp->sg.start = msg->sg.start;
 176        i = msg->sg.start;
 177        do {
 178                sge = sk_msg_elem(msg, i);
 179                size = (apply && apply_bytes < sge->length) ?
 180                        apply_bytes : sge->length;
 181                if (!sk_wmem_schedule(sk, size)) {
 182                        if (!copied)
 183                                ret = -ENOMEM;
 184                        break;
 185                }
 186
 187                sk_mem_charge(sk, size);
 188                sk_msg_xfer(tmp, msg, i, size);
 189                copied += size;
 190                if (sge->length)
 191                        get_page(sk_msg_page(tmp, i));
 192                sk_msg_iter_var_next(i);
 193                tmp->sg.end = i;
 194                if (apply) {
 195                        apply_bytes -= size;
 196                        if (!apply_bytes)
 197                                break;
 198                }
 199        } while (i != msg->sg.end);
 200
 201        if (!ret) {
 202                msg->sg.start = i;
 203                msg->sg.size -= apply_bytes;
 204                sk_psock_queue_msg(psock, tmp);
 205                sk_psock_data_ready(sk, psock);
 206        } else {
 207                sk_msg_free(sk, tmp);
 208                kfree(tmp);
 209        }
 210
 211        release_sock(sk);
 212        return ret;
 213}
 214
 215static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
 216                        int flags, bool uncharge)
 217{
 218        bool apply = apply_bytes;
 219        struct scatterlist *sge;
 220        struct page *page;
 221        int size, ret = 0;
 222        u32 off;
 223
 224        while (1) {
 225                bool has_tx_ulp;
 226
 227                sge = sk_msg_elem(msg, msg->sg.start);
 228                size = (apply && apply_bytes < sge->length) ?
 229                        apply_bytes : sge->length;
 230                off  = sge->offset;
 231                page = sg_page(sge);
 232
 233                tcp_rate_check_app_limited(sk);
 234retry:
 235                has_tx_ulp = tls_sw_has_ctx_tx(sk);
 236                if (has_tx_ulp) {
 237                        flags |= MSG_SENDPAGE_NOPOLICY;
 238                        ret = kernel_sendpage_locked(sk,
 239                                                     page, off, size, flags);
 240                } else {
 241                        ret = do_tcp_sendpages(sk, page, off, size, flags);
 242                }
 243
 244                if (ret <= 0)
 245                        return ret;
 246                if (apply)
 247                        apply_bytes -= ret;
 248                msg->sg.size -= ret;
 249                sge->offset += ret;
 250                sge->length -= ret;
 251                if (uncharge)
 252                        sk_mem_uncharge(sk, ret);
 253                if (ret != size) {
 254                        size -= ret;
 255                        off  += ret;
 256                        goto retry;
 257                }
 258                if (!sge->length) {
 259                        put_page(page);
 260                        sk_msg_iter_next(msg, start);
 261                        sg_init_table(sge, 1);
 262                        if (msg->sg.start == msg->sg.end)
 263                                break;
 264                }
 265                if (apply && !apply_bytes)
 266                        break;
 267        }
 268
 269        return 0;
 270}
 271
 272static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
 273                               u32 apply_bytes, int flags, bool uncharge)
 274{
 275        int ret;
 276
 277        lock_sock(sk);
 278        ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
 279        release_sock(sk);
 280        return ret;
 281}
 282
 283int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
 284                          u32 bytes, int flags)
 285{
 286        bool ingress = sk_msg_to_ingress(msg);
 287        struct sk_psock *psock = sk_psock_get(sk);
 288        int ret;
 289
 290        if (unlikely(!psock)) {
 291                sk_msg_free(sk, msg);
 292                return 0;
 293        }
 294        ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
 295                        tcp_bpf_push_locked(sk, msg, bytes, flags, false);
 296        sk_psock_put(sk, psock);
 297        return ret;
 298}
 299EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
 300
 301static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
 302                                struct sk_msg *msg, int *copied, int flags)
 303{
 304        bool cork = false, enospc = msg->sg.start == msg->sg.end;
 305        struct sock *sk_redir;
 306        u32 tosend, delta = 0;
 307        int ret;
 308
 309more_data:
 310        if (psock->eval == __SK_NONE) {
 311                /* Track delta in msg size to add/subtract it on SK_DROP from
 312                 * returned to user copied size. This ensures user doesn't
 313                 * get a positive return code with msg_cut_data and SK_DROP
 314                 * verdict.
 315                 */
 316                delta = msg->sg.size;
 317                psock->eval = sk_psock_msg_verdict(sk, psock, msg);
 318                if (msg->sg.size < delta)
 319                        delta -= msg->sg.size;
 320                else
 321                        delta = 0;
 322        }
 323
 324        if (msg->cork_bytes &&
 325            msg->cork_bytes > msg->sg.size && !enospc) {
 326                psock->cork_bytes = msg->cork_bytes - msg->sg.size;
 327                if (!psock->cork) {
 328                        psock->cork = kzalloc(sizeof(*psock->cork),
 329                                              GFP_ATOMIC | __GFP_NOWARN);
 330                        if (!psock->cork)
 331                                return -ENOMEM;
 332                }
 333                memcpy(psock->cork, msg, sizeof(*msg));
 334                return 0;
 335        }
 336
 337        tosend = msg->sg.size;
 338        if (psock->apply_bytes && psock->apply_bytes < tosend)
 339                tosend = psock->apply_bytes;
 340
 341        switch (psock->eval) {
 342        case __SK_PASS:
 343                ret = tcp_bpf_push(sk, msg, tosend, flags, true);
 344                if (unlikely(ret)) {
 345                        *copied -= sk_msg_free(sk, msg);
 346                        break;
 347                }
 348                sk_msg_apply_bytes(psock, tosend);
 349                break;
 350        case __SK_REDIRECT:
 351                sk_redir = psock->sk_redir;
 352                sk_msg_apply_bytes(psock, tosend);
 353                if (psock->cork) {
 354                        cork = true;
 355                        psock->cork = NULL;
 356                }
 357                sk_msg_return(sk, msg, tosend);
 358                release_sock(sk);
 359                ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
 360                lock_sock(sk);
 361                if (unlikely(ret < 0)) {
 362                        int free = sk_msg_free_nocharge(sk, msg);
 363
 364                        if (!cork)
 365                                *copied -= free;
 366                }
 367                if (cork) {
 368                        sk_msg_free(sk, msg);
 369                        kfree(msg);
 370                        msg = NULL;
 371                        ret = 0;
 372                }
 373                break;
 374        case __SK_DROP:
 375        default:
 376                sk_msg_free_partial(sk, msg, tosend);
 377                sk_msg_apply_bytes(psock, tosend);
 378                *copied -= (tosend + delta);
 379                return -EACCES;
 380        }
 381
 382        if (likely(!ret)) {
 383                if (!psock->apply_bytes) {
 384                        psock->eval =  __SK_NONE;
 385                        if (psock->sk_redir) {
 386                                sock_put(psock->sk_redir);
 387                                psock->sk_redir = NULL;
 388                        }
 389                }
 390                if (msg &&
 391                    msg->sg.data[msg->sg.start].page_link &&
 392                    msg->sg.data[msg->sg.start].length)
 393                        goto more_data;
 394        }
 395        return ret;
 396}
 397
 398static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 399{
 400        struct sk_msg tmp, *msg_tx = NULL;
 401        int copied = 0, err = 0;
 402        struct sk_psock *psock;
 403        long timeo;
 404        int flags;
 405
 406        /* Don't let internal do_tcp_sendpages() flags through */
 407        flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED);
 408        flags |= MSG_NO_SHARED_FRAGS;
 409
 410        psock = sk_psock_get(sk);
 411        if (unlikely(!psock))
 412                return tcp_sendmsg(sk, msg, size);
 413
 414        lock_sock(sk);
 415        timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 416        while (msg_data_left(msg)) {
 417                bool enospc = false;
 418                u32 copy, osize;
 419
 420                if (sk->sk_err) {
 421                        err = -sk->sk_err;
 422                        goto out_err;
 423                }
 424
 425                copy = msg_data_left(msg);
 426                if (!sk_stream_memory_free(sk))
 427                        goto wait_for_sndbuf;
 428                if (psock->cork) {
 429                        msg_tx = psock->cork;
 430                } else {
 431                        msg_tx = &tmp;
 432                        sk_msg_init(msg_tx);
 433                }
 434
 435                osize = msg_tx->sg.size;
 436                err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
 437                if (err) {
 438                        if (err != -ENOSPC)
 439                                goto wait_for_memory;
 440                        enospc = true;
 441                        copy = msg_tx->sg.size - osize;
 442                }
 443
 444                err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
 445                                               copy);
 446                if (err < 0) {
 447                        sk_msg_trim(sk, msg_tx, osize);
 448                        goto out_err;
 449                }
 450
 451                copied += copy;
 452                if (psock->cork_bytes) {
 453                        if (size > psock->cork_bytes)
 454                                psock->cork_bytes = 0;
 455                        else
 456                                psock->cork_bytes -= size;
 457                        if (psock->cork_bytes && !enospc)
 458                                goto out_err;
 459                        /* All cork bytes are accounted, rerun the prog. */
 460                        psock->eval = __SK_NONE;
 461                        psock->cork_bytes = 0;
 462                }
 463
 464                err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
 465                if (unlikely(err < 0))
 466                        goto out_err;
 467                continue;
 468wait_for_sndbuf:
 469                set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 470wait_for_memory:
 471                err = sk_stream_wait_memory(sk, &timeo);
 472                if (err) {
 473                        if (msg_tx && msg_tx != psock->cork)
 474                                sk_msg_free(sk, msg_tx);
 475                        goto out_err;
 476                }
 477        }
 478out_err:
 479        if (err < 0)
 480                err = sk_stream_error(sk, msg->msg_flags, err);
 481        release_sock(sk);
 482        sk_psock_put(sk, psock);
 483        return copied ? copied : err;
 484}
 485
 486static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
 487                            size_t size, int flags)
 488{
 489        struct sk_msg tmp, *msg = NULL;
 490        int err = 0, copied = 0;
 491        struct sk_psock *psock;
 492        bool enospc = false;
 493
 494        psock = sk_psock_get(sk);
 495        if (unlikely(!psock))
 496                return tcp_sendpage(sk, page, offset, size, flags);
 497
 498        lock_sock(sk);
 499        if (psock->cork) {
 500                msg = psock->cork;
 501        } else {
 502                msg = &tmp;
 503                sk_msg_init(msg);
 504        }
 505
 506        /* Catch case where ring is full and sendpage is stalled. */
 507        if (unlikely(sk_msg_full(msg)))
 508                goto out_err;
 509
 510        sk_msg_page_add(msg, page, size, offset);
 511        sk_mem_charge(sk, size);
 512        copied = size;
 513        if (sk_msg_full(msg))
 514                enospc = true;
 515        if (psock->cork_bytes) {
 516                if (size > psock->cork_bytes)
 517                        psock->cork_bytes = 0;
 518                else
 519                        psock->cork_bytes -= size;
 520                if (psock->cork_bytes && !enospc)
 521                        goto out_err;
 522                /* All cork bytes are accounted, rerun the prog. */
 523                psock->eval = __SK_NONE;
 524                psock->cork_bytes = 0;
 525        }
 526
 527        err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
 528out_err:
 529        release_sock(sk);
 530        sk_psock_put(sk, psock);
 531        return copied ? copied : err;
 532}
 533
 534static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
 535{
 536        struct sk_psock_link *link;
 537
 538        while ((link = sk_psock_link_pop(psock))) {
 539                sk_psock_unlink(sk, link);
 540                sk_psock_free_link(link);
 541        }
 542}
 543
 544static void tcp_bpf_unhash(struct sock *sk)
 545{
 546        void (*saved_unhash)(struct sock *sk);
 547        struct sk_psock *psock;
 548
 549        rcu_read_lock();
 550        psock = sk_psock(sk);
 551        if (unlikely(!psock)) {
 552                rcu_read_unlock();
 553                if (sk->sk_prot->unhash)
 554                        sk->sk_prot->unhash(sk);
 555                return;
 556        }
 557
 558        saved_unhash = psock->saved_unhash;
 559        tcp_bpf_remove(sk, psock);
 560        rcu_read_unlock();
 561        saved_unhash(sk);
 562}
 563
 564static void tcp_bpf_close(struct sock *sk, long timeout)
 565{
 566        void (*saved_close)(struct sock *sk, long timeout);
 567        struct sk_psock *psock;
 568
 569        lock_sock(sk);
 570        rcu_read_lock();
 571        psock = sk_psock(sk);
 572        if (unlikely(!psock)) {
 573                rcu_read_unlock();
 574                release_sock(sk);
 575                return sk->sk_prot->close(sk, timeout);
 576        }
 577
 578        saved_close = psock->saved_close;
 579        tcp_bpf_remove(sk, psock);
 580        rcu_read_unlock();
 581        release_sock(sk);
 582        saved_close(sk, timeout);
 583}
 584
 585enum {
 586        TCP_BPF_IPV4,
 587        TCP_BPF_IPV6,
 588        TCP_BPF_NUM_PROTS,
 589};
 590
 591enum {
 592        TCP_BPF_BASE,
 593        TCP_BPF_TX,
 594        TCP_BPF_NUM_CFGS,
 595};
 596
 597static struct proto *tcpv6_prot_saved __read_mostly;
 598static DEFINE_SPINLOCK(tcpv6_prot_lock);
 599static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
 600
 601static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
 602                                   struct proto *base)
 603{
 604        prot[TCP_BPF_BASE]                      = *base;
 605        prot[TCP_BPF_BASE].unhash               = tcp_bpf_unhash;
 606        prot[TCP_BPF_BASE].close                = tcp_bpf_close;
 607        prot[TCP_BPF_BASE].recvmsg              = tcp_bpf_recvmsg;
 608        prot[TCP_BPF_BASE].stream_memory_read   = tcp_bpf_stream_read;
 609
 610        prot[TCP_BPF_TX]                        = prot[TCP_BPF_BASE];
 611        prot[TCP_BPF_TX].sendmsg                = tcp_bpf_sendmsg;
 612        prot[TCP_BPF_TX].sendpage               = tcp_bpf_sendpage;
 613}
 614
 615static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
 616{
 617        if (sk->sk_family == AF_INET6 &&
 618            unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
 619                spin_lock_bh(&tcpv6_prot_lock);
 620                if (likely(ops != tcpv6_prot_saved)) {
 621                        tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
 622                        smp_store_release(&tcpv6_prot_saved, ops);
 623                }
 624                spin_unlock_bh(&tcpv6_prot_lock);
 625        }
 626}
 627
 628static int __init tcp_bpf_v4_build_proto(void)
 629{
 630        tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
 631        return 0;
 632}
 633core_initcall(tcp_bpf_v4_build_proto);
 634
 635static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
 636{
 637        int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
 638        int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 639
 640        sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
 641}
 642
 643static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
 644{
 645        int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
 646        int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 647
 648        /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
 649         * or added requiring sk_prot hook updates. We keep original saved
 650         * hooks in this case.
 651         */
 652        sk->sk_prot = &tcp_bpf_prots[family][config];
 653}
 654
 655static int tcp_bpf_assert_proto_ops(struct proto *ops)
 656{
 657        /* In order to avoid retpoline, we make assumptions when we call
 658         * into ops if e.g. a psock is not present. Make sure they are
 659         * indeed valid assumptions.
 660         */
 661        return ops->recvmsg  == tcp_recvmsg &&
 662               ops->sendmsg  == tcp_sendmsg &&
 663               ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 664}
 665
 666void tcp_bpf_reinit(struct sock *sk)
 667{
 668        struct sk_psock *psock;
 669
 670        sock_owned_by_me(sk);
 671
 672        rcu_read_lock();
 673        psock = sk_psock(sk);
 674        tcp_bpf_reinit_sk_prot(sk, psock);
 675        rcu_read_unlock();
 676}
 677
 678int tcp_bpf_init(struct sock *sk)
 679{
 680        struct proto *ops = READ_ONCE(sk->sk_prot);
 681        struct sk_psock *psock;
 682
 683        sock_owned_by_me(sk);
 684
 685        rcu_read_lock();
 686        psock = sk_psock(sk);
 687        if (unlikely(!psock || psock->sk_proto ||
 688                     tcp_bpf_assert_proto_ops(ops))) {
 689                rcu_read_unlock();
 690                return -EINVAL;
 691        }
 692        tcp_bpf_check_v6_needs_rebuild(sk, ops);
 693        tcp_bpf_update_sk_prot(sk, psock);
 694        rcu_read_unlock();
 695        return 0;
 696}
 697