linux/net/ipv4/tcp_bpf.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
   3
   4#include <linux/skmsg.h>
   5#include <linux/filter.h>
   6#include <linux/bpf.h>
   7#include <linux/init.h>
   8#include <linux/wait.h>
   9
  10#include <net/inet_common.h>
  11#include <net/tls.h>
  12
  13static bool tcp_bpf_stream_read(const struct sock *sk)
  14{
  15        struct sk_psock *psock;
  16        bool empty = true;
  17
  18        rcu_read_lock();
  19        psock = sk_psock(sk);
  20        if (likely(psock))
  21                empty = list_empty(&psock->ingress_msg);
  22        rcu_read_unlock();
  23        return !empty;
  24}
  25
  26static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
  27                             int flags, long timeo, int *err)
  28{
  29        DEFINE_WAIT_FUNC(wait, woken_wake_function);
  30        int ret = 0;
  31
  32        if (!timeo)
  33                return ret;
  34
  35        add_wait_queue(sk_sleep(sk), &wait);
  36        sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  37        ret = sk_wait_event(sk, &timeo,
  38                            !list_empty(&psock->ingress_msg) ||
  39                            !skb_queue_empty(&sk->sk_receive_queue), &wait);
  40        sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
  41        remove_wait_queue(sk_sleep(sk), &wait);
  42        return ret;
  43}
  44
  45int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
  46                      struct msghdr *msg, int len, int flags)
  47{
  48        struct iov_iter *iter = &msg->msg_iter;
  49        int peek = flags & MSG_PEEK;
  50        int i, ret, copied = 0;
  51        struct sk_msg *msg_rx;
  52
  53        msg_rx = list_first_entry_or_null(&psock->ingress_msg,
  54                                          struct sk_msg, list);
  55
  56        while (copied != len) {
  57                struct scatterlist *sge;
  58
  59                if (unlikely(!msg_rx))
  60                        break;
  61
  62                i = msg_rx->sg.start;
  63                do {
  64                        struct page *page;
  65                        int copy;
  66
  67                        sge = sk_msg_elem(msg_rx, i);
  68                        copy = sge->length;
  69                        page = sg_page(sge);
  70                        if (copied + copy > len)
  71                                copy = len - copied;
  72                        ret = copy_page_to_iter(page, sge->offset, copy, iter);
  73                        if (ret != copy) {
  74                                msg_rx->sg.start = i;
  75                                return -EFAULT;
  76                        }
  77
  78                        copied += copy;
  79                        if (likely(!peek)) {
  80                                sge->offset += copy;
  81                                sge->length -= copy;
  82                                sk_mem_uncharge(sk, copy);
  83                                msg_rx->sg.size -= copy;
  84
  85                                if (!sge->length) {
  86                                        sk_msg_iter_var_next(i);
  87                                        if (!msg_rx->skb)
  88                                                put_page(page);
  89                                }
  90                        } else {
  91                                sk_msg_iter_var_next(i);
  92                        }
  93
  94                        if (copied == len)
  95                                break;
  96                } while (i != msg_rx->sg.end);
  97
  98                if (unlikely(peek)) {
  99                        msg_rx = list_next_entry(msg_rx, list);
 100                        continue;
 101                }
 102
 103                msg_rx->sg.start = i;
 104                if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) {
 105                        list_del(&msg_rx->list);
 106                        if (msg_rx->skb)
 107                                consume_skb(msg_rx->skb);
 108                        kfree(msg_rx);
 109                }
 110                msg_rx = list_first_entry_or_null(&psock->ingress_msg,
 111                                                  struct sk_msg, list);
 112        }
 113
 114        return copied;
 115}
 116EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
 117
 118int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 119                    int nonblock, int flags, int *addr_len)
 120{
 121        struct sk_psock *psock;
 122        int copied, ret;
 123
 124        psock = sk_psock_get(sk);
 125        if (unlikely(!psock))
 126                return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 127        if (unlikely(flags & MSG_ERRQUEUE))
 128                return inet_recv_error(sk, msg, len, addr_len);
 129        if (!skb_queue_empty(&sk->sk_receive_queue) &&
 130            sk_psock_queue_empty(psock))
 131                return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 132        lock_sock(sk);
 133msg_bytes_ready:
 134        copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
 135        if (!copied) {
 136                int data, err = 0;
 137                long timeo;
 138
 139                timeo = sock_rcvtimeo(sk, nonblock);
 140                data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
 141                if (data) {
 142                        if (!sk_psock_queue_empty(psock))
 143                                goto msg_bytes_ready;
 144                        release_sock(sk);
 145                        sk_psock_put(sk, psock);
 146                        return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
 147                }
 148                if (err) {
 149                        ret = err;
 150                        goto out;
 151                }
 152                copied = -EAGAIN;
 153        }
 154        ret = copied;
 155out:
 156        release_sock(sk);
 157        sk_psock_put(sk, psock);
 158        return ret;
 159}
 160
 161static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
 162                           struct sk_msg *msg, u32 apply_bytes, int flags)
 163{
 164        bool apply = apply_bytes;
 165        struct scatterlist *sge;
 166        u32 size, copied = 0;
 167        struct sk_msg *tmp;
 168        int i, ret = 0;
 169
 170        tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL);
 171        if (unlikely(!tmp))
 172                return -ENOMEM;
 173
 174        lock_sock(sk);
 175        tmp->sg.start = msg->sg.start;
 176        i = msg->sg.start;
 177        do {
 178                sge = sk_msg_elem(msg, i);
 179                size = (apply && apply_bytes < sge->length) ?
 180                        apply_bytes : sge->length;
 181                if (!sk_wmem_schedule(sk, size)) {
 182                        if (!copied)
 183                                ret = -ENOMEM;
 184                        break;
 185                }
 186
 187                sk_mem_charge(sk, size);
 188                sk_msg_xfer(tmp, msg, i, size);
 189                copied += size;
 190                if (sge->length)
 191                        get_page(sk_msg_page(tmp, i));
 192                sk_msg_iter_var_next(i);
 193                tmp->sg.end = i;
 194                if (apply) {
 195                        apply_bytes -= size;
 196                        if (!apply_bytes)
 197                                break;
 198                }
 199        } while (i != msg->sg.end);
 200
 201        if (!ret) {
 202                msg->sg.start = i;
 203                msg->sg.size -= apply_bytes;
 204                sk_psock_queue_msg(psock, tmp);
 205                sk_psock_data_ready(sk, psock);
 206        } else {
 207                sk_msg_free(sk, tmp);
 208                kfree(tmp);
 209        }
 210
 211        release_sock(sk);
 212        return ret;
 213}
 214
 215static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes,
 216                        int flags, bool uncharge)
 217{
 218        bool apply = apply_bytes;
 219        struct scatterlist *sge;
 220        struct page *page;
 221        int size, ret = 0;
 222        u32 off;
 223
 224        while (1) {
 225                bool has_tx_ulp;
 226
 227                sge = sk_msg_elem(msg, msg->sg.start);
 228                size = (apply && apply_bytes < sge->length) ?
 229                        apply_bytes : sge->length;
 230                off  = sge->offset;
 231                page = sg_page(sge);
 232
 233                tcp_rate_check_app_limited(sk);
 234retry:
 235                has_tx_ulp = tls_sw_has_ctx_tx(sk);
 236                if (has_tx_ulp) {
 237                        flags |= MSG_SENDPAGE_NOPOLICY;
 238                        ret = kernel_sendpage_locked(sk,
 239                                                     page, off, size, flags);
 240                } else {
 241                        ret = do_tcp_sendpages(sk, page, off, size, flags);
 242                }
 243
 244                if (ret <= 0)
 245                        return ret;
 246                if (apply)
 247                        apply_bytes -= ret;
 248                msg->sg.size -= ret;
 249                sge->offset += ret;
 250                sge->length -= ret;
 251                if (uncharge)
 252                        sk_mem_uncharge(sk, ret);
 253                if (ret != size) {
 254                        size -= ret;
 255                        off  += ret;
 256                        goto retry;
 257                }
 258                if (!sge->length) {
 259                        put_page(page);
 260                        sk_msg_iter_next(msg, start);
 261                        sg_init_table(sge, 1);
 262                        if (msg->sg.start == msg->sg.end)
 263                                break;
 264                }
 265                if (apply && !apply_bytes)
 266                        break;
 267        }
 268
 269        return 0;
 270}
 271
 272static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg,
 273                               u32 apply_bytes, int flags, bool uncharge)
 274{
 275        int ret;
 276
 277        lock_sock(sk);
 278        ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge);
 279        release_sock(sk);
 280        return ret;
 281}
 282
 283int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg,
 284                          u32 bytes, int flags)
 285{
 286        bool ingress = sk_msg_to_ingress(msg);
 287        struct sk_psock *psock = sk_psock_get(sk);
 288        int ret;
 289
 290        if (unlikely(!psock)) {
 291                sk_msg_free(sk, msg);
 292                return 0;
 293        }
 294        ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) :
 295                        tcp_bpf_push_locked(sk, msg, bytes, flags, false);
 296        sk_psock_put(sk, psock);
 297        return ret;
 298}
 299EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
 300
 301static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
 302                                struct sk_msg *msg, int *copied, int flags)
 303{
 304        bool cork = false, enospc = sk_msg_full(msg);
 305        struct sock *sk_redir;
 306        u32 tosend, delta = 0;
 307        int ret;
 308
 309more_data:
 310        if (psock->eval == __SK_NONE) {
 311                /* Track delta in msg size to add/subtract it on SK_DROP from
 312                 * returned to user copied size. This ensures user doesn't
 313                 * get a positive return code with msg_cut_data and SK_DROP
 314                 * verdict.
 315                 */
 316                delta = msg->sg.size;
 317                psock->eval = sk_psock_msg_verdict(sk, psock, msg);
 318                delta -= msg->sg.size;
 319        }
 320
 321        if (msg->cork_bytes &&
 322            msg->cork_bytes > msg->sg.size && !enospc) {
 323                psock->cork_bytes = msg->cork_bytes - msg->sg.size;
 324                if (!psock->cork) {
 325                        psock->cork = kzalloc(sizeof(*psock->cork),
 326                                              GFP_ATOMIC | __GFP_NOWARN);
 327                        if (!psock->cork)
 328                                return -ENOMEM;
 329                }
 330                memcpy(psock->cork, msg, sizeof(*msg));
 331                return 0;
 332        }
 333
 334        tosend = msg->sg.size;
 335        if (psock->apply_bytes && psock->apply_bytes < tosend)
 336                tosend = psock->apply_bytes;
 337
 338        switch (psock->eval) {
 339        case __SK_PASS:
 340                ret = tcp_bpf_push(sk, msg, tosend, flags, true);
 341                if (unlikely(ret)) {
 342                        *copied -= sk_msg_free(sk, msg);
 343                        break;
 344                }
 345                sk_msg_apply_bytes(psock, tosend);
 346                break;
 347        case __SK_REDIRECT:
 348                sk_redir = psock->sk_redir;
 349                sk_msg_apply_bytes(psock, tosend);
 350                if (psock->cork) {
 351                        cork = true;
 352                        psock->cork = NULL;
 353                }
 354                sk_msg_return(sk, msg, tosend);
 355                release_sock(sk);
 356                ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
 357                lock_sock(sk);
 358                if (unlikely(ret < 0)) {
 359                        int free = sk_msg_free_nocharge(sk, msg);
 360
 361                        if (!cork)
 362                                *copied -= free;
 363                }
 364                if (cork) {
 365                        sk_msg_free(sk, msg);
 366                        kfree(msg);
 367                        msg = NULL;
 368                        ret = 0;
 369                }
 370                break;
 371        case __SK_DROP:
 372        default:
 373                sk_msg_free_partial(sk, msg, tosend);
 374                sk_msg_apply_bytes(psock, tosend);
 375                *copied -= (tosend + delta);
 376                return -EACCES;
 377        }
 378
 379        if (likely(!ret)) {
 380                if (!psock->apply_bytes) {
 381                        psock->eval =  __SK_NONE;
 382                        if (psock->sk_redir) {
 383                                sock_put(psock->sk_redir);
 384                                psock->sk_redir = NULL;
 385                        }
 386                }
 387                if (msg &&
 388                    msg->sg.data[msg->sg.start].page_link &&
 389                    msg->sg.data[msg->sg.start].length)
 390                        goto more_data;
 391        }
 392        return ret;
 393}
 394
 395static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 396{
 397        struct sk_msg tmp, *msg_tx = NULL;
 398        int copied = 0, err = 0;
 399        struct sk_psock *psock;
 400        long timeo;
 401        int flags;
 402
 403        /* Don't let internal do_tcp_sendpages() flags through */
 404        flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED);
 405        flags |= MSG_NO_SHARED_FRAGS;
 406
 407        psock = sk_psock_get(sk);
 408        if (unlikely(!psock))
 409                return tcp_sendmsg(sk, msg, size);
 410
 411        lock_sock(sk);
 412        timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 413        while (msg_data_left(msg)) {
 414                bool enospc = false;
 415                u32 copy, osize;
 416
 417                if (sk->sk_err) {
 418                        err = -sk->sk_err;
 419                        goto out_err;
 420                }
 421
 422                copy = msg_data_left(msg);
 423                if (!sk_stream_memory_free(sk))
 424                        goto wait_for_sndbuf;
 425                if (psock->cork) {
 426                        msg_tx = psock->cork;
 427                } else {
 428                        msg_tx = &tmp;
 429                        sk_msg_init(msg_tx);
 430                }
 431
 432                osize = msg_tx->sg.size;
 433                err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1);
 434                if (err) {
 435                        if (err != -ENOSPC)
 436                                goto wait_for_memory;
 437                        enospc = true;
 438                        copy = msg_tx->sg.size - osize;
 439                }
 440
 441                err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx,
 442                                               copy);
 443                if (err < 0) {
 444                        sk_msg_trim(sk, msg_tx, osize);
 445                        goto out_err;
 446                }
 447
 448                copied += copy;
 449                if (psock->cork_bytes) {
 450                        if (size > psock->cork_bytes)
 451                                psock->cork_bytes = 0;
 452                        else
 453                                psock->cork_bytes -= size;
 454                        if (psock->cork_bytes && !enospc)
 455                                goto out_err;
 456                        /* All cork bytes are accounted, rerun the prog. */
 457                        psock->eval = __SK_NONE;
 458                        psock->cork_bytes = 0;
 459                }
 460
 461                err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags);
 462                if (unlikely(err < 0))
 463                        goto out_err;
 464                continue;
 465wait_for_sndbuf:
 466                set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 467wait_for_memory:
 468                err = sk_stream_wait_memory(sk, &timeo);
 469                if (err) {
 470                        if (msg_tx && msg_tx != psock->cork)
 471                                sk_msg_free(sk, msg_tx);
 472                        goto out_err;
 473                }
 474        }
 475out_err:
 476        if (err < 0)
 477                err = sk_stream_error(sk, msg->msg_flags, err);
 478        release_sock(sk);
 479        sk_psock_put(sk, psock);
 480        return copied ? copied : err;
 481}
 482
 483static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset,
 484                            size_t size, int flags)
 485{
 486        struct sk_msg tmp, *msg = NULL;
 487        int err = 0, copied = 0;
 488        struct sk_psock *psock;
 489        bool enospc = false;
 490
 491        psock = sk_psock_get(sk);
 492        if (unlikely(!psock))
 493                return tcp_sendpage(sk, page, offset, size, flags);
 494
 495        lock_sock(sk);
 496        if (psock->cork) {
 497                msg = psock->cork;
 498        } else {
 499                msg = &tmp;
 500                sk_msg_init(msg);
 501        }
 502
 503        /* Catch case where ring is full and sendpage is stalled. */
 504        if (unlikely(sk_msg_full(msg)))
 505                goto out_err;
 506
 507        sk_msg_page_add(msg, page, size, offset);
 508        sk_mem_charge(sk, size);
 509        copied = size;
 510        if (sk_msg_full(msg))
 511                enospc = true;
 512        if (psock->cork_bytes) {
 513                if (size > psock->cork_bytes)
 514                        psock->cork_bytes = 0;
 515                else
 516                        psock->cork_bytes -= size;
 517                if (psock->cork_bytes && !enospc)
 518                        goto out_err;
 519                /* All cork bytes are accounted, rerun the prog. */
 520                psock->eval = __SK_NONE;
 521                psock->cork_bytes = 0;
 522        }
 523
 524        err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags);
 525out_err:
 526        release_sock(sk);
 527        sk_psock_put(sk, psock);
 528        return copied ? copied : err;
 529}
 530
 531static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
 532{
 533        struct sk_psock_link *link;
 534
 535        while ((link = sk_psock_link_pop(psock))) {
 536                sk_psock_unlink(sk, link);
 537                sk_psock_free_link(link);
 538        }
 539}
 540
 541static void tcp_bpf_unhash(struct sock *sk)
 542{
 543        void (*saved_unhash)(struct sock *sk);
 544        struct sk_psock *psock;
 545
 546        rcu_read_lock();
 547        psock = sk_psock(sk);
 548        if (unlikely(!psock)) {
 549                rcu_read_unlock();
 550                if (sk->sk_prot->unhash)
 551                        sk->sk_prot->unhash(sk);
 552                return;
 553        }
 554
 555        saved_unhash = psock->saved_unhash;
 556        tcp_bpf_remove(sk, psock);
 557        rcu_read_unlock();
 558        saved_unhash(sk);
 559}
 560
 561static void tcp_bpf_close(struct sock *sk, long timeout)
 562{
 563        void (*saved_close)(struct sock *sk, long timeout);
 564        struct sk_psock *psock;
 565
 566        lock_sock(sk);
 567        rcu_read_lock();
 568        psock = sk_psock(sk);
 569        if (unlikely(!psock)) {
 570                rcu_read_unlock();
 571                release_sock(sk);
 572                return sk->sk_prot->close(sk, timeout);
 573        }
 574
 575        saved_close = psock->saved_close;
 576        tcp_bpf_remove(sk, psock);
 577        rcu_read_unlock();
 578        release_sock(sk);
 579        saved_close(sk, timeout);
 580}
 581
 582enum {
 583        TCP_BPF_IPV4,
 584        TCP_BPF_IPV6,
 585        TCP_BPF_NUM_PROTS,
 586};
 587
 588enum {
 589        TCP_BPF_BASE,
 590        TCP_BPF_TX,
 591        TCP_BPF_NUM_CFGS,
 592};
 593
 594static struct proto *tcpv6_prot_saved __read_mostly;
 595static DEFINE_SPINLOCK(tcpv6_prot_lock);
 596static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS];
 597
 598static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
 599                                   struct proto *base)
 600{
 601        prot[TCP_BPF_BASE]                      = *base;
 602        prot[TCP_BPF_BASE].unhash               = tcp_bpf_unhash;
 603        prot[TCP_BPF_BASE].close                = tcp_bpf_close;
 604        prot[TCP_BPF_BASE].recvmsg              = tcp_bpf_recvmsg;
 605        prot[TCP_BPF_BASE].stream_memory_read   = tcp_bpf_stream_read;
 606
 607        prot[TCP_BPF_TX]                        = prot[TCP_BPF_BASE];
 608        prot[TCP_BPF_TX].sendmsg                = tcp_bpf_sendmsg;
 609        prot[TCP_BPF_TX].sendpage               = tcp_bpf_sendpage;
 610}
 611
 612static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
 613{
 614        if (sk->sk_family == AF_INET6 &&
 615            unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
 616                spin_lock_bh(&tcpv6_prot_lock);
 617                if (likely(ops != tcpv6_prot_saved)) {
 618                        tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
 619                        smp_store_release(&tcpv6_prot_saved, ops);
 620                }
 621                spin_unlock_bh(&tcpv6_prot_lock);
 622        }
 623}
 624
 625static int __init tcp_bpf_v4_build_proto(void)
 626{
 627        tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot);
 628        return 0;
 629}
 630core_initcall(tcp_bpf_v4_build_proto);
 631
 632static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
 633{
 634        int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
 635        int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 636
 637        sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
 638}
 639
 640static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
 641{
 642        int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
 643        int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 644
 645        /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
 646         * or added requiring sk_prot hook updates. We keep original saved
 647         * hooks in this case.
 648         */
 649        sk->sk_prot = &tcp_bpf_prots[family][config];
 650}
 651
 652static int tcp_bpf_assert_proto_ops(struct proto *ops)
 653{
 654        /* In order to avoid retpoline, we make assumptions when we call
 655         * into ops if e.g. a psock is not present. Make sure they are
 656         * indeed valid assumptions.
 657         */
 658        return ops->recvmsg  == tcp_recvmsg &&
 659               ops->sendmsg  == tcp_sendmsg &&
 660               ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 661}
 662
 663void tcp_bpf_reinit(struct sock *sk)
 664{
 665        struct sk_psock *psock;
 666
 667        sock_owned_by_me(sk);
 668
 669        rcu_read_lock();
 670        psock = sk_psock(sk);
 671        tcp_bpf_reinit_sk_prot(sk, psock);
 672        rcu_read_unlock();
 673}
 674
 675int tcp_bpf_init(struct sock *sk)
 676{
 677        struct proto *ops = READ_ONCE(sk->sk_prot);
 678        struct sk_psock *psock;
 679
 680        sock_owned_by_me(sk);
 681
 682        rcu_read_lock();
 683        psock = sk_psock(sk);
 684        if (unlikely(!psock || psock->sk_proto ||
 685                     tcp_bpf_assert_proto_ops(ops))) {
 686                rcu_read_unlock();
 687                return -EINVAL;
 688        }
 689        tcp_bpf_check_v6_needs_rebuild(sk, ops);
 690        tcp_bpf_update_sk_prot(sk, psock);
 691        rcu_read_unlock();
 692        return 0;
 693}
 694