linux/net/core/skmsg.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
   3
   4#include <linux/skmsg.h>
   5#include <linux/skbuff.h>
   6#include <linux/scatterlist.h>
   7
   8#include <net/sock.h>
   9#include <net/tcp.h>
  10
  11static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
  12{
  13        if (msg->sg.end > msg->sg.start &&
  14            elem_first_coalesce < msg->sg.end)
  15                return true;
  16
  17        if (msg->sg.end < msg->sg.start &&
  18            (elem_first_coalesce > msg->sg.start ||
  19             elem_first_coalesce < msg->sg.end))
  20                return true;
  21
  22        return false;
  23}
  24
  25int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
  26                 int elem_first_coalesce)
  27{
  28        struct page_frag *pfrag = sk_page_frag(sk);
  29        int ret = 0;
  30
  31        len -= msg->sg.size;
  32        while (len > 0) {
  33                struct scatterlist *sge;
  34                u32 orig_offset;
  35                int use, i;
  36
  37                if (!sk_page_frag_refill(sk, pfrag))
  38                        return -ENOMEM;
  39
  40                orig_offset = pfrag->offset;
  41                use = min_t(int, len, pfrag->size - orig_offset);
  42                if (!sk_wmem_schedule(sk, use))
  43                        return -ENOMEM;
  44
  45                i = msg->sg.end;
  46                sk_msg_iter_var_prev(i);
  47                sge = &msg->sg.data[i];
  48
  49                if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
  50                    sg_page(sge) == pfrag->page &&
  51                    sge->offset + sge->length == orig_offset) {
  52                        sge->length += use;
  53                } else {
  54                        if (sk_msg_full(msg)) {
  55                                ret = -ENOSPC;
  56                                break;
  57                        }
  58
  59                        sge = &msg->sg.data[msg->sg.end];
  60                        sg_unmark_end(sge);
  61                        sg_set_page(sge, pfrag->page, use, orig_offset);
  62                        get_page(pfrag->page);
  63                        sk_msg_iter_next(msg, end);
  64                }
  65
  66                sk_mem_charge(sk, use);
  67                msg->sg.size += use;
  68                pfrag->offset += use;
  69                len -= use;
  70        }
  71
  72        return ret;
  73}
  74EXPORT_SYMBOL_GPL(sk_msg_alloc);
  75
  76int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
  77                 u32 off, u32 len)
  78{
  79        int i = src->sg.start;
  80        struct scatterlist *sge = sk_msg_elem(src, i);
  81        struct scatterlist *sgd = NULL;
  82        u32 sge_len, sge_off;
  83
  84        while (off) {
  85                if (sge->length > off)
  86                        break;
  87                off -= sge->length;
  88                sk_msg_iter_var_next(i);
  89                if (i == src->sg.end && off)
  90                        return -ENOSPC;
  91                sge = sk_msg_elem(src, i);
  92        }
  93
  94        while (len) {
  95                sge_len = sge->length - off;
  96                if (sge_len > len)
  97                        sge_len = len;
  98
  99                if (dst->sg.end)
 100                        sgd = sk_msg_elem(dst, dst->sg.end - 1);
 101
 102                if (sgd &&
 103                    (sg_page(sge) == sg_page(sgd)) &&
 104                    (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) {
 105                        sgd->length += sge_len;
 106                        dst->sg.size += sge_len;
 107                } else if (!sk_msg_full(dst)) {
 108                        sge_off = sge->offset + off;
 109                        sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off);
 110                } else {
 111                        return -ENOSPC;
 112                }
 113
 114                off = 0;
 115                len -= sge_len;
 116                sk_mem_charge(sk, sge_len);
 117                sk_msg_iter_var_next(i);
 118                if (i == src->sg.end && len)
 119                        return -ENOSPC;
 120                sge = sk_msg_elem(src, i);
 121        }
 122
 123        return 0;
 124}
 125EXPORT_SYMBOL_GPL(sk_msg_clone);
 126
 127void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
 128{
 129        int i = msg->sg.start;
 130
 131        do {
 132                struct scatterlist *sge = sk_msg_elem(msg, i);
 133
 134                if (bytes < sge->length) {
 135                        sge->length -= bytes;
 136                        sge->offset += bytes;
 137                        sk_mem_uncharge(sk, bytes);
 138                        break;
 139                }
 140
 141                sk_mem_uncharge(sk, sge->length);
 142                bytes -= sge->length;
 143                sge->length = 0;
 144                sge->offset = 0;
 145                sk_msg_iter_var_next(i);
 146        } while (bytes && i != msg->sg.end);
 147        msg->sg.start = i;
 148}
 149EXPORT_SYMBOL_GPL(sk_msg_return_zero);
 150
 151void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
 152{
 153        int i = msg->sg.start;
 154
 155        do {
 156                struct scatterlist *sge = &msg->sg.data[i];
 157                int uncharge = (bytes < sge->length) ? bytes : sge->length;
 158
 159                sk_mem_uncharge(sk, uncharge);
 160                bytes -= uncharge;
 161                sk_msg_iter_var_next(i);
 162        } while (i != msg->sg.end);
 163}
 164EXPORT_SYMBOL_GPL(sk_msg_return);
 165
 166static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
 167                            bool charge)
 168{
 169        struct scatterlist *sge = sk_msg_elem(msg, i);
 170        u32 len = sge->length;
 171
 172        if (charge)
 173                sk_mem_uncharge(sk, len);
 174        if (!msg->skb)
 175                put_page(sg_page(sge));
 176        memset(sge, 0, sizeof(*sge));
 177        return len;
 178}
 179
 180static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
 181                         bool charge)
 182{
 183        struct scatterlist *sge = sk_msg_elem(msg, i);
 184        int freed = 0;
 185
 186        while (msg->sg.size) {
 187                msg->sg.size -= sge->length;
 188                freed += sk_msg_free_elem(sk, msg, i, charge);
 189                sk_msg_iter_var_next(i);
 190                sk_msg_check_to_free(msg, i, msg->sg.size);
 191                sge = sk_msg_elem(msg, i);
 192        }
 193        if (msg->skb)
 194                consume_skb(msg->skb);
 195        sk_msg_init(msg);
 196        return freed;
 197}
 198
 199int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
 200{
 201        return __sk_msg_free(sk, msg, msg->sg.start, false);
 202}
 203EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
 204
 205int sk_msg_free(struct sock *sk, struct sk_msg *msg)
 206{
 207        return __sk_msg_free(sk, msg, msg->sg.start, true);
 208}
 209EXPORT_SYMBOL_GPL(sk_msg_free);
 210
 211static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
 212                                  u32 bytes, bool charge)
 213{
 214        struct scatterlist *sge;
 215        u32 i = msg->sg.start;
 216
 217        while (bytes) {
 218                sge = sk_msg_elem(msg, i);
 219                if (!sge->length)
 220                        break;
 221                if (bytes < sge->length) {
 222                        if (charge)
 223                                sk_mem_uncharge(sk, bytes);
 224                        sge->length -= bytes;
 225                        sge->offset += bytes;
 226                        msg->sg.size -= bytes;
 227                        break;
 228                }
 229
 230                msg->sg.size -= sge->length;
 231                bytes -= sge->length;
 232                sk_msg_free_elem(sk, msg, i, charge);
 233                sk_msg_iter_var_next(i);
 234                sk_msg_check_to_free(msg, i, bytes);
 235        }
 236        msg->sg.start = i;
 237}
 238
 239void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
 240{
 241        __sk_msg_free_partial(sk, msg, bytes, true);
 242}
 243EXPORT_SYMBOL_GPL(sk_msg_free_partial);
 244
 245void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
 246                                  u32 bytes)
 247{
 248        __sk_msg_free_partial(sk, msg, bytes, false);
 249}
 250
 251void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
 252{
 253        int trim = msg->sg.size - len;
 254        u32 i = msg->sg.end;
 255
 256        if (trim <= 0) {
 257                WARN_ON(trim < 0);
 258                return;
 259        }
 260
 261        sk_msg_iter_var_prev(i);
 262        msg->sg.size = len;
 263        while (msg->sg.data[i].length &&
 264               trim >= msg->sg.data[i].length) {
 265                trim -= msg->sg.data[i].length;
 266                sk_msg_free_elem(sk, msg, i, true);
 267                sk_msg_iter_var_prev(i);
 268                if (!trim)
 269                        goto out;
 270        }
 271
 272        msg->sg.data[i].length -= trim;
 273        sk_mem_uncharge(sk, trim);
 274out:
 275        /* If we trim data before curr pointer update copybreak and current
 276         * so that any future copy operations start at new copy location.
 277         * However trimed data that has not yet been used in a copy op
 278         * does not require an update.
 279         */
 280        if (msg->sg.curr >= i) {
 281                msg->sg.curr = i;
 282                msg->sg.copybreak = msg->sg.data[i].length;
 283        }
 284        sk_msg_iter_var_next(i);
 285        msg->sg.end = i;
 286}
 287EXPORT_SYMBOL_GPL(sk_msg_trim);
 288
 289int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
 290                              struct sk_msg *msg, u32 bytes)
 291{
 292        int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
 293        const int to_max_pages = MAX_MSG_FRAGS;
 294        struct page *pages[MAX_MSG_FRAGS];
 295        ssize_t orig, copied, use, offset;
 296
 297        orig = msg->sg.size;
 298        while (bytes > 0) {
 299                i = 0;
 300                maxpages = to_max_pages - num_elems;
 301                if (maxpages == 0) {
 302                        ret = -EFAULT;
 303                        goto out;
 304                }
 305
 306                copied = iov_iter_get_pages(from, pages, bytes, maxpages,
 307                                            &offset);
 308                if (copied <= 0) {
 309                        ret = -EFAULT;
 310                        goto out;
 311                }
 312
 313                iov_iter_advance(from, copied);
 314                bytes -= copied;
 315                msg->sg.size += copied;
 316
 317                while (copied) {
 318                        use = min_t(int, copied, PAGE_SIZE - offset);
 319                        sg_set_page(&msg->sg.data[msg->sg.end],
 320                                    pages[i], use, offset);
 321                        sg_unmark_end(&msg->sg.data[msg->sg.end]);
 322                        sk_mem_charge(sk, use);
 323
 324                        offset = 0;
 325                        copied -= use;
 326                        sk_msg_iter_next(msg, end);
 327                        num_elems++;
 328                        i++;
 329                }
 330                /* When zerocopy is mixed with sk_msg_*copy* operations we
 331                 * may have a copybreak set in this case clear and prefer
 332                 * zerocopy remainder when possible.
 333                 */
 334                msg->sg.copybreak = 0;
 335                msg->sg.curr = msg->sg.end;
 336        }
 337out:
 338        /* Revert iov_iter updates, msg will need to use 'trim' later if it
 339         * also needs to be cleared.
 340         */
 341        if (ret)
 342                iov_iter_revert(from, msg->sg.size - orig);
 343        return ret;
 344}
 345EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
 346
 347int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
 348                             struct sk_msg *msg, u32 bytes)
 349{
 350        int ret = -ENOSPC, i = msg->sg.curr;
 351        struct scatterlist *sge;
 352        u32 copy, buf_size;
 353        void *to;
 354
 355        do {
 356                sge = sk_msg_elem(msg, i);
 357                /* This is possible if a trim operation shrunk the buffer */
 358                if (msg->sg.copybreak >= sge->length) {
 359                        msg->sg.copybreak = 0;
 360                        sk_msg_iter_var_next(i);
 361                        if (i == msg->sg.end)
 362                                break;
 363                        sge = sk_msg_elem(msg, i);
 364                }
 365
 366                buf_size = sge->length - msg->sg.copybreak;
 367                copy = (buf_size > bytes) ? bytes : buf_size;
 368                to = sg_virt(sge) + msg->sg.copybreak;
 369                msg->sg.copybreak += copy;
 370                if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
 371                        ret = copy_from_iter_nocache(to, copy, from);
 372                else
 373                        ret = copy_from_iter(to, copy, from);
 374                if (ret != copy) {
 375                        ret = -EFAULT;
 376                        goto out;
 377                }
 378                bytes -= copy;
 379                if (!bytes)
 380                        break;
 381                msg->sg.copybreak = 0;
 382                sk_msg_iter_var_next(i);
 383        } while (i != msg->sg.end);
 384out:
 385        msg->sg.curr = i;
 386        return ret;
 387}
 388EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
 389
 390static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
 391{
 392        struct sock *sk = psock->sk;
 393        int copied = 0, num_sge;
 394        struct sk_msg *msg;
 395
 396        msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
 397        if (unlikely(!msg))
 398                return -EAGAIN;
 399        if (!sk_rmem_schedule(sk, skb, skb->len)) {
 400                kfree(msg);
 401                return -EAGAIN;
 402        }
 403
 404        sk_msg_init(msg);
 405        num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len);
 406        if (unlikely(num_sge < 0)) {
 407                kfree(msg);
 408                return num_sge;
 409        }
 410
 411        sk_mem_charge(sk, skb->len);
 412        copied = skb->len;
 413        msg->sg.start = 0;
 414        msg->sg.size = copied;
 415        msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
 416        msg->skb = skb;
 417
 418        sk_psock_queue_msg(psock, msg);
 419        sk_psock_data_ready(sk, psock);
 420        return copied;
 421}
 422
 423static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
 424                               u32 off, u32 len, bool ingress)
 425{
 426        if (ingress)
 427                return sk_psock_skb_ingress(psock, skb);
 428        else
 429                return skb_send_sock_locked(psock->sk, skb, off, len);
 430}
 431
 432static void sk_psock_backlog(struct work_struct *work)
 433{
 434        struct sk_psock *psock = container_of(work, struct sk_psock, work);
 435        struct sk_psock_work_state *state = &psock->work_state;
 436        struct sk_buff *skb;
 437        bool ingress;
 438        u32 len, off;
 439        int ret;
 440
 441        /* Lock sock to avoid losing sk_socket during loop. */
 442        lock_sock(psock->sk);
 443        if (state->skb) {
 444                skb = state->skb;
 445                len = state->len;
 446                off = state->off;
 447                state->skb = NULL;
 448                goto start;
 449        }
 450
 451        while ((skb = skb_dequeue(&psock->ingress_skb))) {
 452                len = skb->len;
 453                off = 0;
 454start:
 455                ingress = tcp_skb_bpf_ingress(skb);
 456                do {
 457                        ret = -EIO;
 458                        if (likely(psock->sk->sk_socket))
 459                                ret = sk_psock_handle_skb(psock, skb, off,
 460                                                          len, ingress);
 461                        if (ret <= 0) {
 462                                if (ret == -EAGAIN) {
 463                                        state->skb = skb;
 464                                        state->len = len;
 465                                        state->off = off;
 466                                        goto end;
 467                                }
 468                                /* Hard errors break pipe and stop xmit. */
 469                                sk_psock_report_error(psock, ret ? -ret : EPIPE);
 470                                sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
 471                                kfree_skb(skb);
 472                                goto end;
 473                        }
 474                        off += ret;
 475                        len -= ret;
 476                } while (len);
 477
 478                if (!ingress)
 479                        kfree_skb(skb);
 480        }
 481end:
 482        release_sock(psock->sk);
 483}
 484
 485struct sk_psock *sk_psock_init(struct sock *sk, int node)
 486{
 487        struct sk_psock *psock = kzalloc_node(sizeof(*psock),
 488                                              GFP_ATOMIC | __GFP_NOWARN,
 489                                              node);
 490        if (!psock)
 491                return NULL;
 492
 493        psock->sk = sk;
 494        psock->eval =  __SK_NONE;
 495
 496        INIT_LIST_HEAD(&psock->link);
 497        spin_lock_init(&psock->link_lock);
 498
 499        INIT_WORK(&psock->work, sk_psock_backlog);
 500        INIT_LIST_HEAD(&psock->ingress_msg);
 501        skb_queue_head_init(&psock->ingress_skb);
 502
 503        sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
 504        refcount_set(&psock->refcnt, 1);
 505
 506        rcu_assign_sk_user_data(sk, psock);
 507        sock_hold(sk);
 508
 509        return psock;
 510}
 511EXPORT_SYMBOL_GPL(sk_psock_init);
 512
 513struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
 514{
 515        struct sk_psock_link *link;
 516
 517        spin_lock_bh(&psock->link_lock);
 518        link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
 519                                        list);
 520        if (link)
 521                list_del(&link->list);
 522        spin_unlock_bh(&psock->link_lock);
 523        return link;
 524}
 525
 526void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
 527{
 528        struct sk_msg *msg, *tmp;
 529
 530        list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
 531                list_del(&msg->list);
 532                sk_msg_free(psock->sk, msg);
 533                kfree(msg);
 534        }
 535}
 536
 537static void sk_psock_zap_ingress(struct sk_psock *psock)
 538{
 539        __skb_queue_purge(&psock->ingress_skb);
 540        __sk_psock_purge_ingress_msg(psock);
 541}
 542
 543static void sk_psock_link_destroy(struct sk_psock *psock)
 544{
 545        struct sk_psock_link *link, *tmp;
 546
 547        list_for_each_entry_safe(link, tmp, &psock->link, list) {
 548                list_del(&link->list);
 549                sk_psock_free_link(link);
 550        }
 551}
 552
 553static void sk_psock_destroy_deferred(struct work_struct *gc)
 554{
 555        struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
 556
 557        /* No sk_callback_lock since already detached. */
 558
 559        /* Parser has been stopped */
 560        if (psock->progs.skb_parser)
 561                strp_done(&psock->parser.strp);
 562
 563        cancel_work_sync(&psock->work);
 564
 565        psock_progs_drop(&psock->progs);
 566
 567        sk_psock_link_destroy(psock);
 568        sk_psock_cork_free(psock);
 569        sk_psock_zap_ingress(psock);
 570
 571        if (psock->sk_redir)
 572                sock_put(psock->sk_redir);
 573        sock_put(psock->sk);
 574        kfree(psock);
 575}
 576
 577void sk_psock_destroy(struct rcu_head *rcu)
 578{
 579        struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
 580
 581        INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
 582        schedule_work(&psock->gc);
 583}
 584EXPORT_SYMBOL_GPL(sk_psock_destroy);
 585
 586void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
 587{
 588        rcu_assign_sk_user_data(sk, NULL);
 589        sk_psock_cork_free(psock);
 590        sk_psock_zap_ingress(psock);
 591        sk_psock_restore_proto(sk, psock);
 592
 593        write_lock_bh(&sk->sk_callback_lock);
 594        if (psock->progs.skb_parser)
 595                sk_psock_stop_strp(sk, psock);
 596        write_unlock_bh(&sk->sk_callback_lock);
 597        sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
 598
 599        call_rcu(&psock->rcu, sk_psock_destroy);
 600}
 601EXPORT_SYMBOL_GPL(sk_psock_drop);
 602
 603static int sk_psock_map_verd(int verdict, bool redir)
 604{
 605        switch (verdict) {
 606        case SK_PASS:
 607                return redir ? __SK_REDIRECT : __SK_PASS;
 608        case SK_DROP:
 609        default:
 610                break;
 611        }
 612
 613        return __SK_DROP;
 614}
 615
 616int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
 617                         struct sk_msg *msg)
 618{
 619        struct bpf_prog *prog;
 620        int ret;
 621
 622        preempt_disable();
 623        rcu_read_lock();
 624        prog = READ_ONCE(psock->progs.msg_parser);
 625        if (unlikely(!prog)) {
 626                ret = __SK_PASS;
 627                goto out;
 628        }
 629
 630        sk_msg_compute_data_pointers(msg);
 631        msg->sk = sk;
 632        ret = BPF_PROG_RUN(prog, msg);
 633        ret = sk_psock_map_verd(ret, msg->sk_redir);
 634        psock->apply_bytes = msg->apply_bytes;
 635        if (ret == __SK_REDIRECT) {
 636                if (psock->sk_redir)
 637                        sock_put(psock->sk_redir);
 638                psock->sk_redir = msg->sk_redir;
 639                if (!psock->sk_redir) {
 640                        ret = __SK_DROP;
 641                        goto out;
 642                }
 643                sock_hold(psock->sk_redir);
 644        }
 645out:
 646        rcu_read_unlock();
 647        preempt_enable();
 648        return ret;
 649}
 650EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
 651
 652static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
 653                            struct sk_buff *skb)
 654{
 655        int ret;
 656
 657        skb->sk = psock->sk;
 658        bpf_compute_data_end_sk_skb(skb);
 659        preempt_disable();
 660        ret = BPF_PROG_RUN(prog, skb);
 661        preempt_enable();
 662        /* strparser clones the skb before handing it to a upper layer,
 663         * meaning skb_orphan has been called. We NULL sk on the way out
 664         * to ensure we don't trigger a BUG_ON() in skb/sk operations
 665         * later and because we are not charging the memory of this skb
 666         * to any socket yet.
 667         */
 668        skb->sk = NULL;
 669        return ret;
 670}
 671
 672static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
 673{
 674        struct sk_psock_parser *parser;
 675
 676        parser = container_of(strp, struct sk_psock_parser, strp);
 677        return container_of(parser, struct sk_psock, parser);
 678}
 679
 680static void sk_psock_verdict_apply(struct sk_psock *psock,
 681                                   struct sk_buff *skb, int verdict)
 682{
 683        struct sk_psock *psock_other;
 684        struct sock *sk_other;
 685        bool ingress;
 686
 687        switch (verdict) {
 688        case __SK_PASS:
 689                sk_other = psock->sk;
 690                if (sock_flag(sk_other, SOCK_DEAD) ||
 691                    !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
 692                        goto out_free;
 693                }
 694                if (atomic_read(&sk_other->sk_rmem_alloc) <=
 695                    sk_other->sk_rcvbuf) {
 696                        struct tcp_skb_cb *tcp = TCP_SKB_CB(skb);
 697
 698                        tcp->bpf.flags |= BPF_F_INGRESS;
 699                        skb_queue_tail(&psock->ingress_skb, skb);
 700                        schedule_work(&psock->work);
 701                        break;
 702                }
 703                goto out_free;
 704        case __SK_REDIRECT:
 705                sk_other = tcp_skb_bpf_redirect_fetch(skb);
 706                if (unlikely(!sk_other))
 707                        goto out_free;
 708                psock_other = sk_psock(sk_other);
 709                if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
 710                    !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED))
 711                        goto out_free;
 712                ingress = tcp_skb_bpf_ingress(skb);
 713                if ((!ingress && sock_writeable(sk_other)) ||
 714                    (ingress &&
 715                     atomic_read(&sk_other->sk_rmem_alloc) <=
 716                     sk_other->sk_rcvbuf)) {
 717                        if (!ingress)
 718                                skb_set_owner_w(skb, sk_other);
 719                        skb_queue_tail(&psock_other->ingress_skb, skb);
 720                        schedule_work(&psock_other->work);
 721                        break;
 722                }
 723                /* fall-through */
 724        case __SK_DROP:
 725                /* fall-through */
 726        default:
 727out_free:
 728                kfree_skb(skb);
 729        }
 730}
 731
 732static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
 733{
 734        struct sk_psock *psock = sk_psock_from_strp(strp);
 735        struct bpf_prog *prog;
 736        int ret = __SK_DROP;
 737
 738        rcu_read_lock();
 739        prog = READ_ONCE(psock->progs.skb_verdict);
 740        if (likely(prog)) {
 741                skb_orphan(skb);
 742                tcp_skb_bpf_redirect_clear(skb);
 743                ret = sk_psock_bpf_run(psock, prog, skb);
 744                ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
 745        }
 746        rcu_read_unlock();
 747        sk_psock_verdict_apply(psock, skb, ret);
 748}
 749
 750static int sk_psock_strp_read_done(struct strparser *strp, int err)
 751{
 752        return err;
 753}
 754
 755static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
 756{
 757        struct sk_psock *psock = sk_psock_from_strp(strp);
 758        struct bpf_prog *prog;
 759        int ret = skb->len;
 760
 761        rcu_read_lock();
 762        prog = READ_ONCE(psock->progs.skb_parser);
 763        if (likely(prog))
 764                ret = sk_psock_bpf_run(psock, prog, skb);
 765        rcu_read_unlock();
 766        return ret;
 767}
 768
 769/* Called with socket lock held. */
 770static void sk_psock_strp_data_ready(struct sock *sk)
 771{
 772        struct sk_psock *psock;
 773
 774        rcu_read_lock();
 775        psock = sk_psock(sk);
 776        if (likely(psock)) {
 777                write_lock_bh(&sk->sk_callback_lock);
 778                strp_data_ready(&psock->parser.strp);
 779                write_unlock_bh(&sk->sk_callback_lock);
 780        }
 781        rcu_read_unlock();
 782}
 783
 784static void sk_psock_write_space(struct sock *sk)
 785{
 786        struct sk_psock *psock;
 787        void (*write_space)(struct sock *sk);
 788
 789        rcu_read_lock();
 790        psock = sk_psock(sk);
 791        if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)))
 792                schedule_work(&psock->work);
 793        write_space = psock->saved_write_space;
 794        rcu_read_unlock();
 795        write_space(sk);
 796}
 797
 798int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
 799{
 800        static const struct strp_callbacks cb = {
 801                .rcv_msg        = sk_psock_strp_read,
 802                .read_sock_done = sk_psock_strp_read_done,
 803                .parse_msg      = sk_psock_strp_parse,
 804        };
 805
 806        psock->parser.enabled = false;
 807        return strp_init(&psock->parser.strp, sk, &cb);
 808}
 809
 810void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
 811{
 812        struct sk_psock_parser *parser = &psock->parser;
 813
 814        if (parser->enabled)
 815                return;
 816
 817        parser->saved_data_ready = sk->sk_data_ready;
 818        sk->sk_data_ready = sk_psock_strp_data_ready;
 819        sk->sk_write_space = sk_psock_write_space;
 820        parser->enabled = true;
 821}
 822
 823void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
 824{
 825        struct sk_psock_parser *parser = &psock->parser;
 826
 827        if (!parser->enabled)
 828                return;
 829
 830        sk->sk_data_ready = parser->saved_data_ready;
 831        parser->saved_data_ready = NULL;
 832        strp_stop(&parser->strp);
 833        parser->enabled = false;
 834}
 835