linux/net/kcm/kcmsock.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Kernel Connection Multiplexor
   4 *
   5 * Copyright (c) 2016 Tom Herbert <tom@herbertland.com>
   6 */
   7
   8#include <linux/bpf.h>
   9#include <linux/errno.h>
  10#include <linux/errqueue.h>
  11#include <linux/file.h>
  12#include <linux/in.h>
  13#include <linux/kernel.h>
  14#include <linux/module.h>
  15#include <linux/net.h>
  16#include <linux/netdevice.h>
  17#include <linux/poll.h>
  18#include <linux/rculist.h>
  19#include <linux/skbuff.h>
  20#include <linux/socket.h>
  21#include <linux/uaccess.h>
  22#include <linux/workqueue.h>
  23#include <linux/syscalls.h>
  24#include <linux/sched/signal.h>
  25
  26#include <net/kcm.h>
  27#include <net/netns/generic.h>
  28#include <net/sock.h>
  29#include <uapi/linux/kcm.h>
  30
  31unsigned int kcm_net_id;
  32
  33static struct kmem_cache *kcm_psockp __read_mostly;
  34static struct kmem_cache *kcm_muxp __read_mostly;
  35static struct workqueue_struct *kcm_wq;
  36
  37static inline struct kcm_sock *kcm_sk(const struct sock *sk)
  38{
  39        return (struct kcm_sock *)sk;
  40}
  41
  42static inline struct kcm_tx_msg *kcm_tx_msg(struct sk_buff *skb)
  43{
  44        return (struct kcm_tx_msg *)skb->cb;
  45}
  46
  47static void report_csk_error(struct sock *csk, int err)
  48{
  49        csk->sk_err = EPIPE;
  50        sk_error_report(csk);
  51}
  52
  53static void kcm_abort_tx_psock(struct kcm_psock *psock, int err,
  54                               bool wakeup_kcm)
  55{
  56        struct sock *csk = psock->sk;
  57        struct kcm_mux *mux = psock->mux;
  58
  59        /* Unrecoverable error in transmit */
  60
  61        spin_lock_bh(&mux->lock);
  62
  63        if (psock->tx_stopped) {
  64                spin_unlock_bh(&mux->lock);
  65                return;
  66        }
  67
  68        psock->tx_stopped = 1;
  69        KCM_STATS_INCR(psock->stats.tx_aborts);
  70
  71        if (!psock->tx_kcm) {
  72                /* Take off psocks_avail list */
  73                list_del(&psock->psock_avail_list);
  74        } else if (wakeup_kcm) {
  75                /* In this case psock is being aborted while outside of
  76                 * write_msgs and psock is reserved. Schedule tx_work
  77                 * to handle the failure there. Need to commit tx_stopped
  78                 * before queuing work.
  79                 */
  80                smp_mb();
  81
  82                queue_work(kcm_wq, &psock->tx_kcm->tx_work);
  83        }
  84
  85        spin_unlock_bh(&mux->lock);
  86
  87        /* Report error on lower socket */
  88        report_csk_error(csk, err);
  89}
  90
  91/* RX mux lock held. */
  92static void kcm_update_rx_mux_stats(struct kcm_mux *mux,
  93                                    struct kcm_psock *psock)
  94{
  95        STRP_STATS_ADD(mux->stats.rx_bytes,
  96                       psock->strp.stats.bytes -
  97                       psock->saved_rx_bytes);
  98        mux->stats.rx_msgs +=
  99                psock->strp.stats.msgs - psock->saved_rx_msgs;
 100        psock->saved_rx_msgs = psock->strp.stats.msgs;
 101        psock->saved_rx_bytes = psock->strp.stats.bytes;
 102}
 103
 104static void kcm_update_tx_mux_stats(struct kcm_mux *mux,
 105                                    struct kcm_psock *psock)
 106{
 107        KCM_STATS_ADD(mux->stats.tx_bytes,
 108                      psock->stats.tx_bytes - psock->saved_tx_bytes);
 109        mux->stats.tx_msgs +=
 110                psock->stats.tx_msgs - psock->saved_tx_msgs;
 111        psock->saved_tx_msgs = psock->stats.tx_msgs;
 112        psock->saved_tx_bytes = psock->stats.tx_bytes;
 113}
 114
 115static int kcm_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
 116
 117/* KCM is ready to receive messages on its queue-- either the KCM is new or
 118 * has become unblocked after being blocked on full socket buffer. Queue any
 119 * pending ready messages on a psock. RX mux lock held.
 120 */
 121static void kcm_rcv_ready(struct kcm_sock *kcm)
 122{
 123        struct kcm_mux *mux = kcm->mux;
 124        struct kcm_psock *psock;
 125        struct sk_buff *skb;
 126
 127        if (unlikely(kcm->rx_wait || kcm->rx_psock || kcm->rx_disabled))
 128                return;
 129
 130        while (unlikely((skb = __skb_dequeue(&mux->rx_hold_queue)))) {
 131                if (kcm_queue_rcv_skb(&kcm->sk, skb)) {
 132                        /* Assuming buffer limit has been reached */
 133                        skb_queue_head(&mux->rx_hold_queue, skb);
 134                        WARN_ON(!sk_rmem_alloc_get(&kcm->sk));
 135                        return;
 136                }
 137        }
 138
 139        while (!list_empty(&mux->psocks_ready)) {
 140                psock = list_first_entry(&mux->psocks_ready, struct kcm_psock,
 141                                         psock_ready_list);
 142
 143                if (kcm_queue_rcv_skb(&kcm->sk, psock->ready_rx_msg)) {
 144                        /* Assuming buffer limit has been reached */
 145                        WARN_ON(!sk_rmem_alloc_get(&kcm->sk));
 146                        return;
 147                }
 148
 149                /* Consumed the ready message on the psock. Schedule rx_work to
 150                 * get more messages.
 151                 */
 152                list_del(&psock->psock_ready_list);
 153                psock->ready_rx_msg = NULL;
 154                /* Commit clearing of ready_rx_msg for queuing work */
 155                smp_mb();
 156
 157                strp_unpause(&psock->strp);
 158                strp_check_rcv(&psock->strp);
 159        }
 160
 161        /* Buffer limit is okay now, add to ready list */
 162        list_add_tail(&kcm->wait_rx_list,
 163                      &kcm->mux->kcm_rx_waiters);
 164        kcm->rx_wait = true;
 165}
 166
 167static void kcm_rfree(struct sk_buff *skb)
 168{
 169        struct sock *sk = skb->sk;
 170        struct kcm_sock *kcm = kcm_sk(sk);
 171        struct kcm_mux *mux = kcm->mux;
 172        unsigned int len = skb->truesize;
 173
 174        sk_mem_uncharge(sk, len);
 175        atomic_sub(len, &sk->sk_rmem_alloc);
 176
 177        /* For reading rx_wait and rx_psock without holding lock */
 178        smp_mb__after_atomic();
 179
 180        if (!kcm->rx_wait && !kcm->rx_psock &&
 181            sk_rmem_alloc_get(sk) < sk->sk_rcvlowat) {
 182                spin_lock_bh(&mux->rx_lock);
 183                kcm_rcv_ready(kcm);
 184                spin_unlock_bh(&mux->rx_lock);
 185        }
 186}
 187
 188static int kcm_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
 189{
 190        struct sk_buff_head *list = &sk->sk_receive_queue;
 191
 192        if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf)
 193                return -ENOMEM;
 194
 195        if (!sk_rmem_schedule(sk, skb, skb->truesize))
 196                return -ENOBUFS;
 197
 198        skb->dev = NULL;
 199
 200        skb_orphan(skb);
 201        skb->sk = sk;
 202        skb->destructor = kcm_rfree;
 203        atomic_add(skb->truesize, &sk->sk_rmem_alloc);
 204        sk_mem_charge(sk, skb->truesize);
 205
 206        skb_queue_tail(list, skb);
 207
 208        if (!sock_flag(sk, SOCK_DEAD))
 209                sk->sk_data_ready(sk);
 210
 211        return 0;
 212}
 213
 214/* Requeue received messages for a kcm socket to other kcm sockets. This is
 215 * called with a kcm socket is receive disabled.
 216 * RX mux lock held.
 217 */
 218static void requeue_rx_msgs(struct kcm_mux *mux, struct sk_buff_head *head)
 219{
 220        struct sk_buff *skb;
 221        struct kcm_sock *kcm;
 222
 223        while ((skb = __skb_dequeue(head))) {
 224                /* Reset destructor to avoid calling kcm_rcv_ready */
 225                skb->destructor = sock_rfree;
 226                skb_orphan(skb);
 227try_again:
 228                if (list_empty(&mux->kcm_rx_waiters)) {
 229                        skb_queue_tail(&mux->rx_hold_queue, skb);
 230                        continue;
 231                }
 232
 233                kcm = list_first_entry(&mux->kcm_rx_waiters,
 234                                       struct kcm_sock, wait_rx_list);
 235
 236                if (kcm_queue_rcv_skb(&kcm->sk, skb)) {
 237                        /* Should mean socket buffer full */
 238                        list_del(&kcm->wait_rx_list);
 239                        kcm->rx_wait = false;
 240
 241                        /* Commit rx_wait to read in kcm_free */
 242                        smp_wmb();
 243
 244                        goto try_again;
 245                }
 246        }
 247}
 248
 249/* Lower sock lock held */
 250static struct kcm_sock *reserve_rx_kcm(struct kcm_psock *psock,
 251                                       struct sk_buff *head)
 252{
 253        struct kcm_mux *mux = psock->mux;
 254        struct kcm_sock *kcm;
 255
 256        WARN_ON(psock->ready_rx_msg);
 257
 258        if (psock->rx_kcm)
 259                return psock->rx_kcm;
 260
 261        spin_lock_bh(&mux->rx_lock);
 262
 263        if (psock->rx_kcm) {
 264                spin_unlock_bh(&mux->rx_lock);
 265                return psock->rx_kcm;
 266        }
 267
 268        kcm_update_rx_mux_stats(mux, psock);
 269
 270        if (list_empty(&mux->kcm_rx_waiters)) {
 271                psock->ready_rx_msg = head;
 272                strp_pause(&psock->strp);
 273                list_add_tail(&psock->psock_ready_list,
 274                              &mux->psocks_ready);
 275                spin_unlock_bh(&mux->rx_lock);
 276                return NULL;
 277        }
 278
 279        kcm = list_first_entry(&mux->kcm_rx_waiters,
 280                               struct kcm_sock, wait_rx_list);
 281        list_del(&kcm->wait_rx_list);
 282        kcm->rx_wait = false;
 283
 284        psock->rx_kcm = kcm;
 285        kcm->rx_psock = psock;
 286
 287        spin_unlock_bh(&mux->rx_lock);
 288
 289        return kcm;
 290}
 291
 292static void kcm_done(struct kcm_sock *kcm);
 293
 294static void kcm_done_work(struct work_struct *w)
 295{
 296        kcm_done(container_of(w, struct kcm_sock, done_work));
 297}
 298
 299/* Lower sock held */
 300static void unreserve_rx_kcm(struct kcm_psock *psock,
 301                             bool rcv_ready)
 302{
 303        struct kcm_sock *kcm = psock->rx_kcm;
 304        struct kcm_mux *mux = psock->mux;
 305
 306        if (!kcm)
 307                return;
 308
 309        spin_lock_bh(&mux->rx_lock);
 310
 311        psock->rx_kcm = NULL;
 312        kcm->rx_psock = NULL;
 313
 314        /* Commit kcm->rx_psock before sk_rmem_alloc_get to sync with
 315         * kcm_rfree
 316         */
 317        smp_mb();
 318
 319        if (unlikely(kcm->done)) {
 320                spin_unlock_bh(&mux->rx_lock);
 321
 322                /* Need to run kcm_done in a task since we need to qcquire
 323                 * callback locks which may already be held here.
 324                 */
 325                INIT_WORK(&kcm->done_work, kcm_done_work);
 326                schedule_work(&kcm->done_work);
 327                return;
 328        }
 329
 330        if (unlikely(kcm->rx_disabled)) {
 331                requeue_rx_msgs(mux, &kcm->sk.sk_receive_queue);
 332        } else if (rcv_ready || unlikely(!sk_rmem_alloc_get(&kcm->sk))) {
 333                /* Check for degenerative race with rx_wait that all
 334                 * data was dequeued (accounted for in kcm_rfree).
 335                 */
 336                kcm_rcv_ready(kcm);
 337        }
 338        spin_unlock_bh(&mux->rx_lock);
 339}
 340
 341/* Lower sock lock held */
 342static void psock_data_ready(struct sock *sk)
 343{
 344        struct kcm_psock *psock;
 345
 346        read_lock_bh(&sk->sk_callback_lock);
 347
 348        psock = (struct kcm_psock *)sk->sk_user_data;
 349        if (likely(psock))
 350                strp_data_ready(&psock->strp);
 351
 352        read_unlock_bh(&sk->sk_callback_lock);
 353}
 354
 355/* Called with lower sock held */
 356static void kcm_rcv_strparser(struct strparser *strp, struct sk_buff *skb)
 357{
 358        struct kcm_psock *psock = container_of(strp, struct kcm_psock, strp);
 359        struct kcm_sock *kcm;
 360
 361try_queue:
 362        kcm = reserve_rx_kcm(psock, skb);
 363        if (!kcm) {
 364                 /* Unable to reserve a KCM, message is held in psock and strp
 365                  * is paused.
 366                  */
 367                return;
 368        }
 369
 370        if (kcm_queue_rcv_skb(&kcm->sk, skb)) {
 371                /* Should mean socket buffer full */
 372                unreserve_rx_kcm(psock, false);
 373                goto try_queue;
 374        }
 375}
 376
 377static int kcm_parse_func_strparser(struct strparser *strp, struct sk_buff *skb)
 378{
 379        struct kcm_psock *psock = container_of(strp, struct kcm_psock, strp);
 380        struct bpf_prog *prog = psock->bpf_prog;
 381        int res;
 382
 383        res = bpf_prog_run_pin_on_cpu(prog, skb);
 384        return res;
 385}
 386
 387static int kcm_read_sock_done(struct strparser *strp, int err)
 388{
 389        struct kcm_psock *psock = container_of(strp, struct kcm_psock, strp);
 390
 391        unreserve_rx_kcm(psock, true);
 392
 393        return err;
 394}
 395
 396static void psock_state_change(struct sock *sk)
 397{
 398        /* TCP only does a EPOLLIN for a half close. Do a EPOLLHUP here
 399         * since application will normally not poll with EPOLLIN
 400         * on the TCP sockets.
 401         */
 402
 403        report_csk_error(sk, EPIPE);
 404}
 405
 406static void psock_write_space(struct sock *sk)
 407{
 408        struct kcm_psock *psock;
 409        struct kcm_mux *mux;
 410        struct kcm_sock *kcm;
 411
 412        read_lock_bh(&sk->sk_callback_lock);
 413
 414        psock = (struct kcm_psock *)sk->sk_user_data;
 415        if (unlikely(!psock))
 416                goto out;
 417        mux = psock->mux;
 418
 419        spin_lock_bh(&mux->lock);
 420
 421        /* Check if the socket is reserved so someone is waiting for sending. */
 422        kcm = psock->tx_kcm;
 423        if (kcm && !unlikely(kcm->tx_stopped))
 424                queue_work(kcm_wq, &kcm->tx_work);
 425
 426        spin_unlock_bh(&mux->lock);
 427out:
 428        read_unlock_bh(&sk->sk_callback_lock);
 429}
 430
 431static void unreserve_psock(struct kcm_sock *kcm);
 432
 433/* kcm sock is locked. */
 434static struct kcm_psock *reserve_psock(struct kcm_sock *kcm)
 435{
 436        struct kcm_mux *mux = kcm->mux;
 437        struct kcm_psock *psock;
 438
 439        psock = kcm->tx_psock;
 440
 441        smp_rmb(); /* Must read tx_psock before tx_wait */
 442
 443        if (psock) {
 444                WARN_ON(kcm->tx_wait);
 445                if (unlikely(psock->tx_stopped))
 446                        unreserve_psock(kcm);
 447                else
 448                        return kcm->tx_psock;
 449        }
 450
 451        spin_lock_bh(&mux->lock);
 452
 453        /* Check again under lock to see if psock was reserved for this
 454         * psock via psock_unreserve.
 455         */
 456        psock = kcm->tx_psock;
 457        if (unlikely(psock)) {
 458                WARN_ON(kcm->tx_wait);
 459                spin_unlock_bh(&mux->lock);
 460                return kcm->tx_psock;
 461        }
 462
 463        if (!list_empty(&mux->psocks_avail)) {
 464                psock = list_first_entry(&mux->psocks_avail,
 465                                         struct kcm_psock,
 466                                         psock_avail_list);
 467                list_del(&psock->psock_avail_list);
 468                if (kcm->tx_wait) {
 469                        list_del(&kcm->wait_psock_list);
 470                        kcm->tx_wait = false;
 471                }
 472                kcm->tx_psock = psock;
 473                psock->tx_kcm = kcm;
 474                KCM_STATS_INCR(psock->stats.reserved);
 475        } else if (!kcm->tx_wait) {
 476                list_add_tail(&kcm->wait_psock_list,
 477                              &mux->kcm_tx_waiters);
 478                kcm->tx_wait = true;
 479        }
 480
 481        spin_unlock_bh(&mux->lock);
 482
 483        return psock;
 484}
 485
 486/* mux lock held */
 487static void psock_now_avail(struct kcm_psock *psock)
 488{
 489        struct kcm_mux *mux = psock->mux;
 490        struct kcm_sock *kcm;
 491
 492        if (list_empty(&mux->kcm_tx_waiters)) {
 493                list_add_tail(&psock->psock_avail_list,
 494                              &mux->psocks_avail);
 495        } else {
 496                kcm = list_first_entry(&mux->kcm_tx_waiters,
 497                                       struct kcm_sock,
 498                                       wait_psock_list);
 499                list_del(&kcm->wait_psock_list);
 500                kcm->tx_wait = false;
 501                psock->tx_kcm = kcm;
 502
 503                /* Commit before changing tx_psock since that is read in
 504                 * reserve_psock before queuing work.
 505                 */
 506                smp_mb();
 507
 508                kcm->tx_psock = psock;
 509                KCM_STATS_INCR(psock->stats.reserved);
 510                queue_work(kcm_wq, &kcm->tx_work);
 511        }
 512}
 513
 514/* kcm sock is locked. */
 515static void unreserve_psock(struct kcm_sock *kcm)
 516{
 517        struct kcm_psock *psock;
 518        struct kcm_mux *mux = kcm->mux;
 519
 520        spin_lock_bh(&mux->lock);
 521
 522        psock = kcm->tx_psock;
 523
 524        if (WARN_ON(!psock)) {
 525                spin_unlock_bh(&mux->lock);
 526                return;
 527        }
 528
 529        smp_rmb(); /* Read tx_psock before tx_wait */
 530
 531        kcm_update_tx_mux_stats(mux, psock);
 532
 533        WARN_ON(kcm->tx_wait);
 534
 535        kcm->tx_psock = NULL;
 536        psock->tx_kcm = NULL;
 537        KCM_STATS_INCR(psock->stats.unreserved);
 538
 539        if (unlikely(psock->tx_stopped)) {
 540                if (psock->done) {
 541                        /* Deferred free */
 542                        list_del(&psock->psock_list);
 543                        mux->psocks_cnt--;
 544                        sock_put(psock->sk);
 545                        fput(psock->sk->sk_socket->file);
 546                        kmem_cache_free(kcm_psockp, psock);
 547                }
 548
 549                /* Don't put back on available list */
 550
 551                spin_unlock_bh(&mux->lock);
 552
 553                return;
 554        }
 555
 556        psock_now_avail(psock);
 557
 558        spin_unlock_bh(&mux->lock);
 559}
 560
 561static void kcm_report_tx_retry(struct kcm_sock *kcm)
 562{
 563        struct kcm_mux *mux = kcm->mux;
 564
 565        spin_lock_bh(&mux->lock);
 566        KCM_STATS_INCR(mux->stats.tx_retries);
 567        spin_unlock_bh(&mux->lock);
 568}
 569
 570/* Write any messages ready on the kcm socket.  Called with kcm sock lock
 571 * held.  Return bytes actually sent or error.
 572 */
 573static int kcm_write_msgs(struct kcm_sock *kcm)
 574{
 575        struct sock *sk = &kcm->sk;
 576        struct kcm_psock *psock;
 577        struct sk_buff *skb, *head;
 578        struct kcm_tx_msg *txm;
 579        unsigned short fragidx, frag_offset;
 580        unsigned int sent, total_sent = 0;
 581        int ret = 0;
 582
 583        kcm->tx_wait_more = false;
 584        psock = kcm->tx_psock;
 585        if (unlikely(psock && psock->tx_stopped)) {
 586                /* A reserved psock was aborted asynchronously. Unreserve
 587                 * it and we'll retry the message.
 588                 */
 589                unreserve_psock(kcm);
 590                kcm_report_tx_retry(kcm);
 591                if (skb_queue_empty(&sk->sk_write_queue))
 592                        return 0;
 593
 594                kcm_tx_msg(skb_peek(&sk->sk_write_queue))->sent = 0;
 595
 596        } else if (skb_queue_empty(&sk->sk_write_queue)) {
 597                return 0;
 598        }
 599
 600        head = skb_peek(&sk->sk_write_queue);
 601        txm = kcm_tx_msg(head);
 602
 603        if (txm->sent) {
 604                /* Send of first skbuff in queue already in progress */
 605                if (WARN_ON(!psock)) {
 606                        ret = -EINVAL;
 607                        goto out;
 608                }
 609                sent = txm->sent;
 610                frag_offset = txm->frag_offset;
 611                fragidx = txm->fragidx;
 612                skb = txm->frag_skb;
 613
 614                goto do_frag;
 615        }
 616
 617try_again:
 618        psock = reserve_psock(kcm);
 619        if (!psock)
 620                goto out;
 621
 622        do {
 623                skb = head;
 624                txm = kcm_tx_msg(head);
 625                sent = 0;
 626
 627do_frag_list:
 628                if (WARN_ON(!skb_shinfo(skb)->nr_frags)) {
 629                        ret = -EINVAL;
 630                        goto out;
 631                }
 632
 633                for (fragidx = 0; fragidx < skb_shinfo(skb)->nr_frags;
 634                     fragidx++) {
 635                        skb_frag_t *frag;
 636
 637                        frag_offset = 0;
 638do_frag:
 639                        frag = &skb_shinfo(skb)->frags[fragidx];
 640                        if (WARN_ON(!skb_frag_size(frag))) {
 641                                ret = -EINVAL;
 642                                goto out;
 643                        }
 644
 645                        ret = kernel_sendpage(psock->sk->sk_socket,
 646                                              skb_frag_page(frag),
 647                                              skb_frag_off(frag) + frag_offset,
 648                                              skb_frag_size(frag) - frag_offset,
 649                                              MSG_DONTWAIT);
 650                        if (ret <= 0) {
 651                                if (ret == -EAGAIN) {
 652                                        /* Save state to try again when there's
 653                                         * write space on the socket
 654                                         */
 655                                        txm->sent = sent;
 656                                        txm->frag_offset = frag_offset;
 657                                        txm->fragidx = fragidx;
 658                                        txm->frag_skb = skb;
 659
 660                                        ret = 0;
 661                                        goto out;
 662                                }
 663
 664                                /* Hard failure in sending message, abort this
 665                                 * psock since it has lost framing
 666                                 * synchronization and retry sending the
 667                                 * message from the beginning.
 668                                 */
 669                                kcm_abort_tx_psock(psock, ret ? -ret : EPIPE,
 670                                                   true);
 671                                unreserve_psock(kcm);
 672
 673                                txm->sent = 0;
 674                                kcm_report_tx_retry(kcm);
 675                                ret = 0;
 676
 677                                goto try_again;
 678                        }
 679
 680                        sent += ret;
 681                        frag_offset += ret;
 682                        KCM_STATS_ADD(psock->stats.tx_bytes, ret);
 683                        if (frag_offset < skb_frag_size(frag)) {
 684                                /* Not finished with this frag */
 685                                goto do_frag;
 686                        }
 687                }
 688
 689                if (skb == head) {
 690                        if (skb_has_frag_list(skb)) {
 691                                skb = skb_shinfo(skb)->frag_list;
 692                                goto do_frag_list;
 693                        }
 694                } else if (skb->next) {
 695                        skb = skb->next;
 696                        goto do_frag_list;
 697                }
 698
 699                /* Successfully sent the whole packet, account for it. */
 700                skb_dequeue(&sk->sk_write_queue);
 701                kfree_skb(head);
 702                sk->sk_wmem_queued -= sent;
 703                total_sent += sent;
 704                KCM_STATS_INCR(psock->stats.tx_msgs);
 705        } while ((head = skb_peek(&sk->sk_write_queue)));
 706out:
 707        if (!head) {
 708                /* Done with all queued messages. */
 709                WARN_ON(!skb_queue_empty(&sk->sk_write_queue));
 710                unreserve_psock(kcm);
 711        }
 712
 713        /* Check if write space is available */
 714        sk->sk_write_space(sk);
 715
 716        return total_sent ? : ret;
 717}
 718
 719static void kcm_tx_work(struct work_struct *w)
 720{
 721        struct kcm_sock *kcm = container_of(w, struct kcm_sock, tx_work);
 722        struct sock *sk = &kcm->sk;
 723        int err;
 724
 725        lock_sock(sk);
 726
 727        /* Primarily for SOCK_DGRAM sockets, also handle asynchronous tx
 728         * aborts
 729         */
 730        err = kcm_write_msgs(kcm);
 731        if (err < 0) {
 732                /* Hard failure in write, report error on KCM socket */
 733                pr_warn("KCM: Hard failure on kcm_write_msgs %d\n", err);
 734                report_csk_error(&kcm->sk, -err);
 735                goto out;
 736        }
 737
 738        /* Primarily for SOCK_SEQPACKET sockets */
 739        if (likely(sk->sk_socket) &&
 740            test_bit(SOCK_NOSPACE, &sk->sk_socket->flags)) {
 741                clear_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 742                sk->sk_write_space(sk);
 743        }
 744
 745out:
 746        release_sock(sk);
 747}
 748
 749static void kcm_push(struct kcm_sock *kcm)
 750{
 751        if (kcm->tx_wait_more)
 752                kcm_write_msgs(kcm);
 753}
 754
 755static ssize_t kcm_sendpage(struct socket *sock, struct page *page,
 756                            int offset, size_t size, int flags)
 757
 758{
 759        struct sock *sk = sock->sk;
 760        struct kcm_sock *kcm = kcm_sk(sk);
 761        struct sk_buff *skb = NULL, *head = NULL;
 762        long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
 763        bool eor;
 764        int err = 0;
 765        int i;
 766
 767        if (flags & MSG_SENDPAGE_NOTLAST)
 768                flags |= MSG_MORE;
 769
 770        /* No MSG_EOR from splice, only look at MSG_MORE */
 771        eor = !(flags & MSG_MORE);
 772
 773        lock_sock(sk);
 774
 775        sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
 776
 777        err = -EPIPE;
 778        if (sk->sk_err)
 779                goto out_error;
 780
 781        if (kcm->seq_skb) {
 782                /* Previously opened message */
 783                head = kcm->seq_skb;
 784                skb = kcm_tx_msg(head)->last_skb;
 785                i = skb_shinfo(skb)->nr_frags;
 786
 787                if (skb_can_coalesce(skb, i, page, offset)) {
 788                        skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], size);
 789                        skb_shinfo(skb)->flags |= SKBFL_SHARED_FRAG;
 790                        goto coalesced;
 791                }
 792
 793                if (i >= MAX_SKB_FRAGS) {
 794                        struct sk_buff *tskb;
 795
 796                        tskb = alloc_skb(0, sk->sk_allocation);
 797                        while (!tskb) {
 798                                kcm_push(kcm);
 799                                err = sk_stream_wait_memory(sk, &timeo);
 800                                if (err)
 801                                        goto out_error;
 802                        }
 803
 804                        if (head == skb)
 805                                skb_shinfo(head)->frag_list = tskb;
 806                        else
 807                                skb->next = tskb;
 808
 809                        skb = tskb;
 810                        skb->ip_summed = CHECKSUM_UNNECESSARY;
 811                        i = 0;
 812                }
 813        } else {
 814                /* Call the sk_stream functions to manage the sndbuf mem. */
 815                if (!sk_stream_memory_free(sk)) {
 816                        kcm_push(kcm);
 817                        set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 818                        err = sk_stream_wait_memory(sk, &timeo);
 819                        if (err)
 820                                goto out_error;
 821                }
 822
 823                head = alloc_skb(0, sk->sk_allocation);
 824                while (!head) {
 825                        kcm_push(kcm);
 826                        err = sk_stream_wait_memory(sk, &timeo);
 827                        if (err)
 828                                goto out_error;
 829                }
 830
 831                skb = head;
 832                i = 0;
 833        }
 834
 835        get_page(page);
 836        skb_fill_page_desc(skb, i, page, offset, size);
 837        skb_shinfo(skb)->flags |= SKBFL_SHARED_FRAG;
 838
 839coalesced:
 840        skb->len += size;
 841        skb->data_len += size;
 842        skb->truesize += size;
 843        sk->sk_wmem_queued += size;
 844        sk_mem_charge(sk, size);
 845
 846        if (head != skb) {
 847                head->len += size;
 848                head->data_len += size;
 849                head->truesize += size;
 850        }
 851
 852        if (eor) {
 853                bool not_busy = skb_queue_empty(&sk->sk_write_queue);
 854
 855                /* Message complete, queue it on send buffer */
 856                __skb_queue_tail(&sk->sk_write_queue, head);
 857                kcm->seq_skb = NULL;
 858                KCM_STATS_INCR(kcm->stats.tx_msgs);
 859
 860                if (flags & MSG_BATCH) {
 861                        kcm->tx_wait_more = true;
 862                } else if (kcm->tx_wait_more || not_busy) {
 863                        err = kcm_write_msgs(kcm);
 864                        if (err < 0) {
 865                                /* We got a hard error in write_msgs but have
 866                                 * already queued this message. Report an error
 867                                 * in the socket, but don't affect return value
 868                                 * from sendmsg
 869                                 */
 870                                pr_warn("KCM: Hard failure on kcm_write_msgs\n");
 871                                report_csk_error(&kcm->sk, -err);
 872                        }
 873                }
 874        } else {
 875                /* Message not complete, save state */
 876                kcm->seq_skb = head;
 877                kcm_tx_msg(head)->last_skb = skb;
 878        }
 879
 880        KCM_STATS_ADD(kcm->stats.tx_bytes, size);
 881
 882        release_sock(sk);
 883        return size;
 884
 885out_error:
 886        kcm_push(kcm);
 887
 888        err = sk_stream_error(sk, flags, err);
 889
 890        /* make sure we wake any epoll edge trigger waiter */
 891        if (unlikely(skb_queue_len(&sk->sk_write_queue) == 0 && err == -EAGAIN))
 892                sk->sk_write_space(sk);
 893
 894        release_sock(sk);
 895        return err;
 896}
 897
 898static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 899{
 900        struct sock *sk = sock->sk;
 901        struct kcm_sock *kcm = kcm_sk(sk);
 902        struct sk_buff *skb = NULL, *head = NULL;
 903        size_t copy, copied = 0;
 904        long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 905        int eor = (sock->type == SOCK_DGRAM) ?
 906                  !(msg->msg_flags & MSG_MORE) : !!(msg->msg_flags & MSG_EOR);
 907        int err = -EPIPE;
 908
 909        lock_sock(sk);
 910
 911        /* Per tcp_sendmsg this should be in poll */
 912        sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
 913
 914        if (sk->sk_err)
 915                goto out_error;
 916
 917        if (kcm->seq_skb) {
 918                /* Previously opened message */
 919                head = kcm->seq_skb;
 920                skb = kcm_tx_msg(head)->last_skb;
 921                goto start;
 922        }
 923
 924        /* Call the sk_stream functions to manage the sndbuf mem. */
 925        if (!sk_stream_memory_free(sk)) {
 926                kcm_push(kcm);
 927                set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 928                err = sk_stream_wait_memory(sk, &timeo);
 929                if (err)
 930                        goto out_error;
 931        }
 932
 933        if (msg_data_left(msg)) {
 934                /* New message, alloc head skb */
 935                head = alloc_skb(0, sk->sk_allocation);
 936                while (!head) {
 937                        kcm_push(kcm);
 938                        err = sk_stream_wait_memory(sk, &timeo);
 939                        if (err)
 940                                goto out_error;
 941
 942                        head = alloc_skb(0, sk->sk_allocation);
 943                }
 944
 945                skb = head;
 946
 947                /* Set ip_summed to CHECKSUM_UNNECESSARY to avoid calling
 948                 * csum_and_copy_from_iter from skb_do_copy_data_nocache.
 949                 */
 950                skb->ip_summed = CHECKSUM_UNNECESSARY;
 951        }
 952
 953start:
 954        while (msg_data_left(msg)) {
 955                bool merge = true;
 956                int i = skb_shinfo(skb)->nr_frags;
 957                struct page_frag *pfrag = sk_page_frag(sk);
 958
 959                if (!sk_page_frag_refill(sk, pfrag))
 960                        goto wait_for_memory;
 961
 962                if (!skb_can_coalesce(skb, i, pfrag->page,
 963                                      pfrag->offset)) {
 964                        if (i == MAX_SKB_FRAGS) {
 965                                struct sk_buff *tskb;
 966
 967                                tskb = alloc_skb(0, sk->sk_allocation);
 968                                if (!tskb)
 969                                        goto wait_for_memory;
 970
 971                                if (head == skb)
 972                                        skb_shinfo(head)->frag_list = tskb;
 973                                else
 974                                        skb->next = tskb;
 975
 976                                skb = tskb;
 977                                skb->ip_summed = CHECKSUM_UNNECESSARY;
 978                                continue;
 979                        }
 980                        merge = false;
 981                }
 982
 983                copy = min_t(int, msg_data_left(msg),
 984                             pfrag->size - pfrag->offset);
 985
 986                if (!sk_wmem_schedule(sk, copy))
 987                        goto wait_for_memory;
 988
 989                err = skb_copy_to_page_nocache(sk, &msg->msg_iter, skb,
 990                                               pfrag->page,
 991                                               pfrag->offset,
 992                                               copy);
 993                if (err)
 994                        goto out_error;
 995
 996                /* Update the skb. */
 997                if (merge) {
 998                        skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy);
 999                } else {
1000                        skb_fill_page_desc(skb, i, pfrag->page,
1001                                           pfrag->offset, copy);
1002                        get_page(pfrag->page);
1003                }
1004
1005                pfrag->offset += copy;
1006                copied += copy;
1007                if (head != skb) {
1008                        head->len += copy;
1009                        head->data_len += copy;
1010                }
1011
1012                continue;
1013
1014wait_for_memory:
1015                kcm_push(kcm);
1016                err = sk_stream_wait_memory(sk, &timeo);
1017                if (err)
1018                        goto out_error;
1019        }
1020
1021        if (eor) {
1022                bool not_busy = skb_queue_empty(&sk->sk_write_queue);
1023
1024                if (head) {
1025                        /* Message complete, queue it on send buffer */
1026                        __skb_queue_tail(&sk->sk_write_queue, head);
1027                        kcm->seq_skb = NULL;
1028                        KCM_STATS_INCR(kcm->stats.tx_msgs);
1029                }
1030
1031                if (msg->msg_flags & MSG_BATCH) {
1032                        kcm->tx_wait_more = true;
1033                } else if (kcm->tx_wait_more || not_busy) {
1034                        err = kcm_write_msgs(kcm);
1035                        if (err < 0) {
1036                                /* We got a hard error in write_msgs but have
1037                                 * already queued this message. Report an error
1038                                 * in the socket, but don't affect return value
1039                                 * from sendmsg
1040                                 */
1041                                pr_warn("KCM: Hard failure on kcm_write_msgs\n");
1042                                report_csk_error(&kcm->sk, -err);
1043                        }
1044                }
1045        } else {
1046                /* Message not complete, save state */
1047partial_message:
1048                if (head) {
1049                        kcm->seq_skb = head;
1050                        kcm_tx_msg(head)->last_skb = skb;
1051                }
1052        }
1053
1054        KCM_STATS_ADD(kcm->stats.tx_bytes, copied);
1055
1056        release_sock(sk);
1057        return copied;
1058
1059out_error:
1060        kcm_push(kcm);
1061
1062        if (copied && sock->type == SOCK_SEQPACKET) {
1063                /* Wrote some bytes before encountering an
1064                 * error, return partial success.
1065                 */
1066                goto partial_message;
1067        }
1068
1069        if (head != kcm->seq_skb)
1070                kfree_skb(head);
1071
1072        err = sk_stream_error(sk, msg->msg_flags, err);
1073
1074        /* make sure we wake any epoll edge trigger waiter */
1075        if (unlikely(skb_queue_len(&sk->sk_write_queue) == 0 && err == -EAGAIN))
1076                sk->sk_write_space(sk);
1077
1078        release_sock(sk);
1079        return err;
1080}
1081
1082static struct sk_buff *kcm_wait_data(struct sock *sk, int flags,
1083                                     long timeo, int *err)
1084{
1085        struct sk_buff *skb;
1086
1087        while (!(skb = skb_peek(&sk->sk_receive_queue))) {
1088                if (sk->sk_err) {
1089                        *err = sock_error(sk);
1090                        return NULL;
1091                }
1092
1093                if (sock_flag(sk, SOCK_DONE))
1094                        return NULL;
1095
1096                if ((flags & MSG_DONTWAIT) || !timeo) {
1097                        *err = -EAGAIN;
1098                        return NULL;
1099                }
1100
1101                sk_wait_data(sk, &timeo, NULL);
1102
1103                /* Handle signals */
1104                if (signal_pending(current)) {
1105                        *err = sock_intr_errno(timeo);
1106                        return NULL;
1107                }
1108        }
1109
1110        return skb;
1111}
1112
1113static int kcm_recvmsg(struct socket *sock, struct msghdr *msg,
1114                       size_t len, int flags)
1115{
1116        struct sock *sk = sock->sk;
1117        struct kcm_sock *kcm = kcm_sk(sk);
1118        int err = 0;
1119        long timeo;
1120        struct strp_msg *stm;
1121        int copied = 0;
1122        struct sk_buff *skb;
1123
1124        timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1125
1126        lock_sock(sk);
1127
1128        skb = kcm_wait_data(sk, flags, timeo, &err);
1129        if (!skb)
1130                goto out;
1131
1132        /* Okay, have a message on the receive queue */
1133
1134        stm = strp_msg(skb);
1135
1136        if (len > stm->full_len)
1137                len = stm->full_len;
1138
1139        err = skb_copy_datagram_msg(skb, stm->offset, msg, len);
1140        if (err < 0)
1141                goto out;
1142
1143        copied = len;
1144        if (likely(!(flags & MSG_PEEK))) {
1145                KCM_STATS_ADD(kcm->stats.rx_bytes, copied);
1146                if (copied < stm->full_len) {
1147                        if (sock->type == SOCK_DGRAM) {
1148                                /* Truncated message */
1149                                msg->msg_flags |= MSG_TRUNC;
1150                                goto msg_finished;
1151                        }
1152                        stm->offset += copied;
1153                        stm->full_len -= copied;
1154                } else {
1155msg_finished:
1156                        /* Finished with message */
1157                        msg->msg_flags |= MSG_EOR;
1158                        KCM_STATS_INCR(kcm->stats.rx_msgs);
1159                        skb_unlink(skb, &sk->sk_receive_queue);
1160                        kfree_skb(skb);
1161                }
1162        }
1163
1164out:
1165        release_sock(sk);
1166
1167        return copied ? : err;
1168}
1169
1170static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos,
1171                               struct pipe_inode_info *pipe, size_t len,
1172                               unsigned int flags)
1173{
1174        struct sock *sk = sock->sk;
1175        struct kcm_sock *kcm = kcm_sk(sk);
1176        long timeo;
1177        struct strp_msg *stm;
1178        int err = 0;
1179        ssize_t copied;
1180        struct sk_buff *skb;
1181
1182        /* Only support splice for SOCKSEQPACKET */
1183
1184        timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1185
1186        lock_sock(sk);
1187
1188        skb = kcm_wait_data(sk, flags, timeo, &err);
1189        if (!skb)
1190                goto err_out;
1191
1192        /* Okay, have a message on the receive queue */
1193
1194        stm = strp_msg(skb);
1195
1196        if (len > stm->full_len)
1197                len = stm->full_len;
1198
1199        copied = skb_splice_bits(skb, sk, stm->offset, pipe, len, flags);
1200        if (copied < 0) {
1201                err = copied;
1202                goto err_out;
1203        }
1204
1205        KCM_STATS_ADD(kcm->stats.rx_bytes, copied);
1206
1207        stm->offset += copied;
1208        stm->full_len -= copied;
1209
1210        /* We have no way to return MSG_EOR. If all the bytes have been
1211         * read we still leave the message in the receive socket buffer.
1212         * A subsequent recvmsg needs to be done to return MSG_EOR and
1213         * finish reading the message.
1214         */
1215
1216        release_sock(sk);
1217
1218        return copied;
1219
1220err_out:
1221        release_sock(sk);
1222
1223        return err;
1224}
1225
1226/* kcm sock lock held */
1227static void kcm_recv_disable(struct kcm_sock *kcm)
1228{
1229        struct kcm_mux *mux = kcm->mux;
1230
1231        if (kcm->rx_disabled)
1232                return;
1233
1234        spin_lock_bh(&mux->rx_lock);
1235
1236        kcm->rx_disabled = 1;
1237
1238        /* If a psock is reserved we'll do cleanup in unreserve */
1239        if (!kcm->rx_psock) {
1240                if (kcm->rx_wait) {
1241                        list_del(&kcm->wait_rx_list);
1242                        kcm->rx_wait = false;
1243                }
1244
1245                requeue_rx_msgs(mux, &kcm->sk.sk_receive_queue);
1246        }
1247
1248        spin_unlock_bh(&mux->rx_lock);
1249}
1250
1251/* kcm sock lock held */
1252static void kcm_recv_enable(struct kcm_sock *kcm)
1253{
1254        struct kcm_mux *mux = kcm->mux;
1255
1256        if (!kcm->rx_disabled)
1257                return;
1258
1259        spin_lock_bh(&mux->rx_lock);
1260
1261        kcm->rx_disabled = 0;
1262        kcm_rcv_ready(kcm);
1263
1264        spin_unlock_bh(&mux->rx_lock);
1265}
1266
1267static int kcm_setsockopt(struct socket *sock, int level, int optname,
1268                          sockptr_t optval, unsigned int optlen)
1269{
1270        struct kcm_sock *kcm = kcm_sk(sock->sk);
1271        int val, valbool;
1272        int err = 0;
1273
1274        if (level != SOL_KCM)
1275                return -ENOPROTOOPT;
1276
1277        if (optlen < sizeof(int))
1278                return -EINVAL;
1279
1280        if (copy_from_sockptr(&val, optval, sizeof(int)))
1281                return -EFAULT;
1282
1283        valbool = val ? 1 : 0;
1284
1285        switch (optname) {
1286        case KCM_RECV_DISABLE:
1287                lock_sock(&kcm->sk);
1288                if (valbool)
1289                        kcm_recv_disable(kcm);
1290                else
1291                        kcm_recv_enable(kcm);
1292                release_sock(&kcm->sk);
1293                break;
1294        default:
1295                err = -ENOPROTOOPT;
1296        }
1297
1298        return err;
1299}
1300
1301static int kcm_getsockopt(struct socket *sock, int level, int optname,
1302                          char __user *optval, int __user *optlen)
1303{
1304        struct kcm_sock *kcm = kcm_sk(sock->sk);
1305        int val, len;
1306
1307        if (level != SOL_KCM)
1308                return -ENOPROTOOPT;
1309
1310        if (get_user(len, optlen))
1311                return -EFAULT;
1312
1313        len = min_t(unsigned int, len, sizeof(int));
1314        if (len < 0)
1315                return -EINVAL;
1316
1317        switch (optname) {
1318        case KCM_RECV_DISABLE:
1319                val = kcm->rx_disabled;
1320                break;
1321        default:
1322                return -ENOPROTOOPT;
1323        }
1324
1325        if (put_user(len, optlen))
1326                return -EFAULT;
1327        if (copy_to_user(optval, &val, len))
1328                return -EFAULT;
1329        return 0;
1330}
1331
1332static void init_kcm_sock(struct kcm_sock *kcm, struct kcm_mux *mux)
1333{
1334        struct kcm_sock *tkcm;
1335        struct list_head *head;
1336        int index = 0;
1337
1338        /* For SOCK_SEQPACKET sock type, datagram_poll checks the sk_state, so
1339         * we set sk_state, otherwise epoll_wait always returns right away with
1340         * EPOLLHUP
1341         */
1342        kcm->sk.sk_state = TCP_ESTABLISHED;
1343
1344        /* Add to mux's kcm sockets list */
1345        kcm->mux = mux;
1346        spin_lock_bh(&mux->lock);
1347
1348        head = &mux->kcm_socks;
1349        list_for_each_entry(tkcm, &mux->kcm_socks, kcm_sock_list) {
1350                if (tkcm->index != index)
1351                        break;
1352                head = &tkcm->kcm_sock_list;
1353                index++;
1354        }
1355
1356        list_add(&kcm->kcm_sock_list, head);
1357        kcm->index = index;
1358
1359        mux->kcm_socks_cnt++;
1360        spin_unlock_bh(&mux->lock);
1361
1362        INIT_WORK(&kcm->tx_work, kcm_tx_work);
1363
1364        spin_lock_bh(&mux->rx_lock);
1365        kcm_rcv_ready(kcm);
1366        spin_unlock_bh(&mux->rx_lock);
1367}
1368
1369static int kcm_attach(struct socket *sock, struct socket *csock,
1370                      struct bpf_prog *prog)
1371{
1372        struct kcm_sock *kcm = kcm_sk(sock->sk);
1373        struct kcm_mux *mux = kcm->mux;
1374        struct sock *csk;
1375        struct kcm_psock *psock = NULL, *tpsock;
1376        struct list_head *head;
1377        int index = 0;
1378        static const struct strp_callbacks cb = {
1379                .rcv_msg = kcm_rcv_strparser,
1380                .parse_msg = kcm_parse_func_strparser,
1381                .read_sock_done = kcm_read_sock_done,
1382        };
1383        int err = 0;
1384
1385        csk = csock->sk;
1386        if (!csk)
1387                return -EINVAL;
1388
1389        lock_sock(csk);
1390
1391        /* Only allow TCP sockets to be attached for now */
1392        if ((csk->sk_family != AF_INET && csk->sk_family != AF_INET6) ||
1393            csk->sk_protocol != IPPROTO_TCP) {
1394                err = -EOPNOTSUPP;
1395                goto out;
1396        }
1397
1398        /* Don't allow listeners or closed sockets */
1399        if (csk->sk_state == TCP_LISTEN || csk->sk_state == TCP_CLOSE) {
1400                err = -EOPNOTSUPP;
1401                goto out;
1402        }
1403
1404        psock = kmem_cache_zalloc(kcm_psockp, GFP_KERNEL);
1405        if (!psock) {
1406                err = -ENOMEM;
1407                goto out;
1408        }
1409
1410        psock->mux = mux;
1411        psock->sk = csk;
1412        psock->bpf_prog = prog;
1413
1414        err = strp_init(&psock->strp, csk, &cb);
1415        if (err) {
1416                kmem_cache_free(kcm_psockp, psock);
1417                goto out;
1418        }
1419
1420        write_lock_bh(&csk->sk_callback_lock);
1421
1422        /* Check if sk_user_data is already by KCM or someone else.
1423         * Must be done under lock to prevent race conditions.
1424         */
1425        if (csk->sk_user_data) {
1426                write_unlock_bh(&csk->sk_callback_lock);
1427                strp_stop(&psock->strp);
1428                strp_done(&psock->strp);
1429                kmem_cache_free(kcm_psockp, psock);
1430                err = -EALREADY;
1431                goto out;
1432        }
1433
1434        psock->save_data_ready = csk->sk_data_ready;
1435        psock->save_write_space = csk->sk_write_space;
1436        psock->save_state_change = csk->sk_state_change;
1437        csk->sk_user_data = psock;
1438        csk->sk_data_ready = psock_data_ready;
1439        csk->sk_write_space = psock_write_space;
1440        csk->sk_state_change = psock_state_change;
1441
1442        write_unlock_bh(&csk->sk_callback_lock);
1443
1444        sock_hold(csk);
1445
1446        /* Finished initialization, now add the psock to the MUX. */
1447        spin_lock_bh(&mux->lock);
1448        head = &mux->psocks;
1449        list_for_each_entry(tpsock, &mux->psocks, psock_list) {
1450                if (tpsock->index != index)
1451                        break;
1452                head = &tpsock->psock_list;
1453                index++;
1454        }
1455
1456        list_add(&psock->psock_list, head);
1457        psock->index = index;
1458
1459        KCM_STATS_INCR(mux->stats.psock_attach);
1460        mux->psocks_cnt++;
1461        psock_now_avail(psock);
1462        spin_unlock_bh(&mux->lock);
1463
1464        /* Schedule RX work in case there are already bytes queued */
1465        strp_check_rcv(&psock->strp);
1466
1467out:
1468        release_sock(csk);
1469
1470        return err;
1471}
1472
1473static int kcm_attach_ioctl(struct socket *sock, struct kcm_attach *info)
1474{
1475        struct socket *csock;
1476        struct bpf_prog *prog;
1477        int err;
1478
1479        csock = sockfd_lookup(info->fd, &err);
1480        if (!csock)
1481                return -ENOENT;
1482
1483        prog = bpf_prog_get_type(info->bpf_fd, BPF_PROG_TYPE_SOCKET_FILTER);
1484        if (IS_ERR(prog)) {
1485                err = PTR_ERR(prog);
1486                goto out;
1487        }
1488
1489        err = kcm_attach(sock, csock, prog);
1490        if (err) {
1491                bpf_prog_put(prog);
1492                goto out;
1493        }
1494
1495        /* Keep reference on file also */
1496
1497        return 0;
1498out:
1499        sockfd_put(csock);
1500        return err;
1501}
1502
1503static void kcm_unattach(struct kcm_psock *psock)
1504{
1505        struct sock *csk = psock->sk;
1506        struct kcm_mux *mux = psock->mux;
1507
1508        lock_sock(csk);
1509
1510        /* Stop getting callbacks from TCP socket. After this there should
1511         * be no way to reserve a kcm for this psock.
1512         */
1513        write_lock_bh(&csk->sk_callback_lock);
1514        csk->sk_user_data = NULL;
1515        csk->sk_data_ready = psock->save_data_ready;
1516        csk->sk_write_space = psock->save_write_space;
1517        csk->sk_state_change = psock->save_state_change;
1518        strp_stop(&psock->strp);
1519
1520        if (WARN_ON(psock->rx_kcm)) {
1521                write_unlock_bh(&csk->sk_callback_lock);
1522                release_sock(csk);
1523                return;
1524        }
1525
1526        spin_lock_bh(&mux->rx_lock);
1527
1528        /* Stop receiver activities. After this point psock should not be
1529         * able to get onto ready list either through callbacks or work.
1530         */
1531        if (psock->ready_rx_msg) {
1532                list_del(&psock->psock_ready_list);
1533                kfree_skb(psock->ready_rx_msg);
1534                psock->ready_rx_msg = NULL;
1535                KCM_STATS_INCR(mux->stats.rx_ready_drops);
1536        }
1537
1538        spin_unlock_bh(&mux->rx_lock);
1539
1540        write_unlock_bh(&csk->sk_callback_lock);
1541
1542        /* Call strp_done without sock lock */
1543        release_sock(csk);
1544        strp_done(&psock->strp);
1545        lock_sock(csk);
1546
1547        bpf_prog_put(psock->bpf_prog);
1548
1549        spin_lock_bh(&mux->lock);
1550
1551        aggregate_psock_stats(&psock->stats, &mux->aggregate_psock_stats);
1552        save_strp_stats(&psock->strp, &mux->aggregate_strp_stats);
1553
1554        KCM_STATS_INCR(mux->stats.psock_unattach);
1555
1556        if (psock->tx_kcm) {
1557                /* psock was reserved.  Just mark it finished and we will clean
1558                 * up in the kcm paths, we need kcm lock which can not be
1559                 * acquired here.
1560                 */
1561                KCM_STATS_INCR(mux->stats.psock_unattach_rsvd);
1562                spin_unlock_bh(&mux->lock);
1563
1564                /* We are unattaching a socket that is reserved. Abort the
1565                 * socket since we may be out of sync in sending on it. We need
1566                 * to do this without the mux lock.
1567                 */
1568                kcm_abort_tx_psock(psock, EPIPE, false);
1569
1570                spin_lock_bh(&mux->lock);
1571                if (!psock->tx_kcm) {
1572                        /* psock now unreserved in window mux was unlocked */
1573                        goto no_reserved;
1574                }
1575                psock->done = 1;
1576
1577                /* Commit done before queuing work to process it */
1578                smp_mb();
1579
1580                /* Queue tx work to make sure psock->done is handled */
1581                queue_work(kcm_wq, &psock->tx_kcm->tx_work);
1582                spin_unlock_bh(&mux->lock);
1583        } else {
1584no_reserved:
1585                if (!psock->tx_stopped)
1586                        list_del(&psock->psock_avail_list);
1587                list_del(&psock->psock_list);
1588                mux->psocks_cnt--;
1589                spin_unlock_bh(&mux->lock);
1590
1591                sock_put(csk);
1592                fput(csk->sk_socket->file);
1593                kmem_cache_free(kcm_psockp, psock);
1594        }
1595
1596        release_sock(csk);
1597}
1598
1599static int kcm_unattach_ioctl(struct socket *sock, struct kcm_unattach *info)
1600{
1601        struct kcm_sock *kcm = kcm_sk(sock->sk);
1602        struct kcm_mux *mux = kcm->mux;
1603        struct kcm_psock *psock;
1604        struct socket *csock;
1605        struct sock *csk;
1606        int err;
1607
1608        csock = sockfd_lookup(info->fd, &err);
1609        if (!csock)
1610                return -ENOENT;
1611
1612        csk = csock->sk;
1613        if (!csk) {
1614                err = -EINVAL;
1615                goto out;
1616        }
1617
1618        err = -ENOENT;
1619
1620        spin_lock_bh(&mux->lock);
1621
1622        list_for_each_entry(psock, &mux->psocks, psock_list) {
1623                if (psock->sk != csk)
1624                        continue;
1625
1626                /* Found the matching psock */
1627
1628                if (psock->unattaching || WARN_ON(psock->done)) {
1629                        err = -EALREADY;
1630                        break;
1631                }
1632
1633                psock->unattaching = 1;
1634
1635                spin_unlock_bh(&mux->lock);
1636
1637                /* Lower socket lock should already be held */
1638                kcm_unattach(psock);
1639
1640                err = 0;
1641                goto out;
1642        }
1643
1644        spin_unlock_bh(&mux->lock);
1645
1646out:
1647        sockfd_put(csock);
1648        return err;
1649}
1650
1651static struct proto kcm_proto = {
1652        .name   = "KCM",
1653        .owner  = THIS_MODULE,
1654        .obj_size = sizeof(struct kcm_sock),
1655};
1656
1657/* Clone a kcm socket. */
1658static struct file *kcm_clone(struct socket *osock)
1659{
1660        struct socket *newsock;
1661        struct sock *newsk;
1662
1663        newsock = sock_alloc();
1664        if (!newsock)
1665                return ERR_PTR(-ENFILE);
1666
1667        newsock->type = osock->type;
1668        newsock->ops = osock->ops;
1669
1670        __module_get(newsock->ops->owner);
1671
1672        newsk = sk_alloc(sock_net(osock->sk), PF_KCM, GFP_KERNEL,
1673                         &kcm_proto, false);
1674        if (!newsk) {
1675                sock_release(newsock);
1676                return ERR_PTR(-ENOMEM);
1677        }
1678        sock_init_data(newsock, newsk);
1679        init_kcm_sock(kcm_sk(newsk), kcm_sk(osock->sk)->mux);
1680
1681        return sock_alloc_file(newsock, 0, osock->sk->sk_prot_creator->name);
1682}
1683
1684static int kcm_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
1685{
1686        int err;
1687
1688        switch (cmd) {
1689        case SIOCKCMATTACH: {
1690                struct kcm_attach info;
1691
1692                if (copy_from_user(&info, (void __user *)arg, sizeof(info)))
1693                        return -EFAULT;
1694
1695                err = kcm_attach_ioctl(sock, &info);
1696
1697                break;
1698        }
1699        case SIOCKCMUNATTACH: {
1700                struct kcm_unattach info;
1701
1702                if (copy_from_user(&info, (void __user *)arg, sizeof(info)))
1703                        return -EFAULT;
1704
1705                err = kcm_unattach_ioctl(sock, &info);
1706
1707                break;
1708        }
1709        case SIOCKCMCLONE: {
1710                struct kcm_clone info;
1711                struct file *file;
1712
1713                info.fd = get_unused_fd_flags(0);
1714                if (unlikely(info.fd < 0))
1715                        return info.fd;
1716
1717                file = kcm_clone(sock);
1718                if (IS_ERR(file)) {
1719                        put_unused_fd(info.fd);
1720                        return PTR_ERR(file);
1721                }
1722                if (copy_to_user((void __user *)arg, &info,
1723                                 sizeof(info))) {
1724                        put_unused_fd(info.fd);
1725                        fput(file);
1726                        return -EFAULT;
1727                }
1728                fd_install(info.fd, file);
1729                err = 0;
1730                break;
1731        }
1732        default:
1733                err = -ENOIOCTLCMD;
1734                break;
1735        }
1736
1737        return err;
1738}
1739
1740static void free_mux(struct rcu_head *rcu)
1741{
1742        struct kcm_mux *mux = container_of(rcu,
1743            struct kcm_mux, rcu);
1744
1745        kmem_cache_free(kcm_muxp, mux);
1746}
1747
1748static void release_mux(struct kcm_mux *mux)
1749{
1750        struct kcm_net *knet = mux->knet;
1751        struct kcm_psock *psock, *tmp_psock;
1752
1753        /* Release psocks */
1754        list_for_each_entry_safe(psock, tmp_psock,
1755                                 &mux->psocks, psock_list) {
1756                if (!WARN_ON(psock->unattaching))
1757                        kcm_unattach(psock);
1758        }
1759
1760        if (WARN_ON(mux->psocks_cnt))
1761                return;
1762
1763        __skb_queue_purge(&mux->rx_hold_queue);
1764
1765        mutex_lock(&knet->mutex);
1766        aggregate_mux_stats(&mux->stats, &knet->aggregate_mux_stats);
1767        aggregate_psock_stats(&mux->aggregate_psock_stats,
1768                              &knet->aggregate_psock_stats);
1769        aggregate_strp_stats(&mux->aggregate_strp_stats,
1770                             &knet->aggregate_strp_stats);
1771        list_del_rcu(&mux->kcm_mux_list);
1772        knet->count--;
1773        mutex_unlock(&knet->mutex);
1774
1775        call_rcu(&mux->rcu, free_mux);
1776}
1777
1778static void kcm_done(struct kcm_sock *kcm)
1779{
1780        struct kcm_mux *mux = kcm->mux;
1781        struct sock *sk = &kcm->sk;
1782        int socks_cnt;
1783
1784        spin_lock_bh(&mux->rx_lock);
1785        if (kcm->rx_psock) {
1786                /* Cleanup in unreserve_rx_kcm */
1787                WARN_ON(kcm->done);
1788                kcm->rx_disabled = 1;
1789                kcm->done = 1;
1790                spin_unlock_bh(&mux->rx_lock);
1791                return;
1792        }
1793
1794        if (kcm->rx_wait) {
1795                list_del(&kcm->wait_rx_list);
1796                kcm->rx_wait = false;
1797        }
1798        /* Move any pending receive messages to other kcm sockets */
1799        requeue_rx_msgs(mux, &sk->sk_receive_queue);
1800
1801        spin_unlock_bh(&mux->rx_lock);
1802
1803        if (WARN_ON(sk_rmem_alloc_get(sk)))
1804                return;
1805
1806        /* Detach from MUX */
1807        spin_lock_bh(&mux->lock);
1808
1809        list_del(&kcm->kcm_sock_list);
1810        mux->kcm_socks_cnt--;
1811        socks_cnt = mux->kcm_socks_cnt;
1812
1813        spin_unlock_bh(&mux->lock);
1814
1815        if (!socks_cnt) {
1816                /* We are done with the mux now. */
1817                release_mux(mux);
1818        }
1819
1820        WARN_ON(kcm->rx_wait);
1821
1822        sock_put(&kcm->sk);
1823}
1824
1825/* Called by kcm_release to close a KCM socket.
1826 * If this is the last KCM socket on the MUX, destroy the MUX.
1827 */
1828static int kcm_release(struct socket *sock)
1829{
1830        struct sock *sk = sock->sk;
1831        struct kcm_sock *kcm;
1832        struct kcm_mux *mux;
1833        struct kcm_psock *psock;
1834
1835        if (!sk)
1836                return 0;
1837
1838        kcm = kcm_sk(sk);
1839        mux = kcm->mux;
1840
1841        sock_orphan(sk);
1842        kfree_skb(kcm->seq_skb);
1843
1844        lock_sock(sk);
1845        /* Purge queue under lock to avoid race condition with tx_work trying
1846         * to act when queue is nonempty. If tx_work runs after this point
1847         * it will just return.
1848         */
1849        __skb_queue_purge(&sk->sk_write_queue);
1850
1851        /* Set tx_stopped. This is checked when psock is bound to a kcm and we
1852         * get a writespace callback. This prevents further work being queued
1853         * from the callback (unbinding the psock occurs after canceling work.
1854         */
1855        kcm->tx_stopped = 1;
1856
1857        release_sock(sk);
1858
1859        spin_lock_bh(&mux->lock);
1860        if (kcm->tx_wait) {
1861                /* Take of tx_wait list, after this point there should be no way
1862                 * that a psock will be assigned to this kcm.
1863                 */
1864                list_del(&kcm->wait_psock_list);
1865                kcm->tx_wait = false;
1866        }
1867        spin_unlock_bh(&mux->lock);
1868
1869        /* Cancel work. After this point there should be no outside references
1870         * to the kcm socket.
1871         */
1872        cancel_work_sync(&kcm->tx_work);
1873
1874        lock_sock(sk);
1875        psock = kcm->tx_psock;
1876        if (psock) {
1877                /* A psock was reserved, so we need to kill it since it
1878                 * may already have some bytes queued from a message. We
1879                 * need to do this after removing kcm from tx_wait list.
1880                 */
1881                kcm_abort_tx_psock(psock, EPIPE, false);
1882                unreserve_psock(kcm);
1883        }
1884        release_sock(sk);
1885
1886        WARN_ON(kcm->tx_wait);
1887        WARN_ON(kcm->tx_psock);
1888
1889        sock->sk = NULL;
1890
1891        kcm_done(kcm);
1892
1893        return 0;
1894}
1895
1896static const struct proto_ops kcm_dgram_ops = {
1897        .family =       PF_KCM,
1898        .owner =        THIS_MODULE,
1899        .release =      kcm_release,
1900        .bind =         sock_no_bind,
1901        .connect =      sock_no_connect,
1902        .socketpair =   sock_no_socketpair,
1903        .accept =       sock_no_accept,
1904        .getname =      sock_no_getname,
1905        .poll =         datagram_poll,
1906        .ioctl =        kcm_ioctl,
1907        .listen =       sock_no_listen,
1908        .shutdown =     sock_no_shutdown,
1909        .setsockopt =   kcm_setsockopt,
1910        .getsockopt =   kcm_getsockopt,
1911        .sendmsg =      kcm_sendmsg,
1912        .recvmsg =      kcm_recvmsg,
1913        .mmap =         sock_no_mmap,
1914        .sendpage =     kcm_sendpage,
1915};
1916
1917static const struct proto_ops kcm_seqpacket_ops = {
1918        .family =       PF_KCM,
1919        .owner =        THIS_MODULE,
1920        .release =      kcm_release,
1921        .bind =         sock_no_bind,
1922        .connect =      sock_no_connect,
1923        .socketpair =   sock_no_socketpair,
1924        .accept =       sock_no_accept,
1925        .getname =      sock_no_getname,
1926        .poll =         datagram_poll,
1927        .ioctl =        kcm_ioctl,
1928        .listen =       sock_no_listen,
1929        .shutdown =     sock_no_shutdown,
1930        .setsockopt =   kcm_setsockopt,
1931        .getsockopt =   kcm_getsockopt,
1932        .sendmsg =      kcm_sendmsg,
1933        .recvmsg =      kcm_recvmsg,
1934        .mmap =         sock_no_mmap,
1935        .sendpage =     kcm_sendpage,
1936        .splice_read =  kcm_splice_read,
1937};
1938
1939/* Create proto operation for kcm sockets */
1940static int kcm_create(struct net *net, struct socket *sock,
1941                      int protocol, int kern)
1942{
1943        struct kcm_net *knet = net_generic(net, kcm_net_id);
1944        struct sock *sk;
1945        struct kcm_mux *mux;
1946
1947        switch (sock->type) {
1948        case SOCK_DGRAM:
1949                sock->ops = &kcm_dgram_ops;
1950                break;
1951        case SOCK_SEQPACKET:
1952                sock->ops = &kcm_seqpacket_ops;
1953                break;
1954        default:
1955                return -ESOCKTNOSUPPORT;
1956        }
1957
1958        if (protocol != KCMPROTO_CONNECTED)
1959                return -EPROTONOSUPPORT;
1960
1961        sk = sk_alloc(net, PF_KCM, GFP_KERNEL, &kcm_proto, kern);
1962        if (!sk)
1963                return -ENOMEM;
1964
1965        /* Allocate a kcm mux, shared between KCM sockets */
1966        mux = kmem_cache_zalloc(kcm_muxp, GFP_KERNEL);
1967        if (!mux) {
1968                sk_free(sk);
1969                return -ENOMEM;
1970        }
1971
1972        spin_lock_init(&mux->lock);
1973        spin_lock_init(&mux->rx_lock);
1974        INIT_LIST_HEAD(&mux->kcm_socks);
1975        INIT_LIST_HEAD(&mux->kcm_rx_waiters);
1976        INIT_LIST_HEAD(&mux->kcm_tx_waiters);
1977
1978        INIT_LIST_HEAD(&mux->psocks);
1979        INIT_LIST_HEAD(&mux->psocks_ready);
1980        INIT_LIST_HEAD(&mux->psocks_avail);
1981
1982        mux->knet = knet;
1983
1984        /* Add new MUX to list */
1985        mutex_lock(&knet->mutex);
1986        list_add_rcu(&mux->kcm_mux_list, &knet->mux_list);
1987        knet->count++;
1988        mutex_unlock(&knet->mutex);
1989
1990        skb_queue_head_init(&mux->rx_hold_queue);
1991
1992        /* Init KCM socket */
1993        sock_init_data(sock, sk);
1994        init_kcm_sock(kcm_sk(sk), mux);
1995
1996        return 0;
1997}
1998
1999static const struct net_proto_family kcm_family_ops = {
2000        .family = PF_KCM,
2001        .create = kcm_create,
2002        .owner  = THIS_MODULE,
2003};
2004
2005static __net_init int kcm_init_net(struct net *net)
2006{
2007        struct kcm_net *knet = net_generic(net, kcm_net_id);
2008
2009        INIT_LIST_HEAD_RCU(&knet->mux_list);
2010        mutex_init(&knet->mutex);
2011
2012        return 0;
2013}
2014
2015static __net_exit void kcm_exit_net(struct net *net)
2016{
2017        struct kcm_net *knet = net_generic(net, kcm_net_id);
2018
2019        /* All KCM sockets should be closed at this point, which should mean
2020         * that all multiplexors and psocks have been destroyed.
2021         */
2022        WARN_ON(!list_empty(&knet->mux_list));
2023}
2024
2025static struct pernet_operations kcm_net_ops = {
2026        .init = kcm_init_net,
2027        .exit = kcm_exit_net,
2028        .id   = &kcm_net_id,
2029        .size = sizeof(struct kcm_net),
2030};
2031
2032static int __init kcm_init(void)
2033{
2034        int err = -ENOMEM;
2035
2036        kcm_muxp = kmem_cache_create("kcm_mux_cache",
2037                                     sizeof(struct kcm_mux), 0,
2038                                     SLAB_HWCACHE_ALIGN, NULL);
2039        if (!kcm_muxp)
2040                goto fail;
2041
2042        kcm_psockp = kmem_cache_create("kcm_psock_cache",
2043                                       sizeof(struct kcm_psock), 0,
2044                                        SLAB_HWCACHE_ALIGN, NULL);
2045        if (!kcm_psockp)
2046                goto fail;
2047
2048        kcm_wq = create_singlethread_workqueue("kkcmd");
2049        if (!kcm_wq)
2050                goto fail;
2051
2052        err = proto_register(&kcm_proto, 1);
2053        if (err)
2054                goto fail;
2055
2056        err = register_pernet_device(&kcm_net_ops);
2057        if (err)
2058                goto net_ops_fail;
2059
2060        err = sock_register(&kcm_family_ops);
2061        if (err)
2062                goto sock_register_fail;
2063
2064        err = kcm_proc_init();
2065        if (err)
2066                goto proc_init_fail;
2067
2068        return 0;
2069
2070proc_init_fail:
2071        sock_unregister(PF_KCM);
2072
2073sock_register_fail:
2074        unregister_pernet_device(&kcm_net_ops);
2075
2076net_ops_fail:
2077        proto_unregister(&kcm_proto);
2078
2079fail:
2080        kmem_cache_destroy(kcm_muxp);
2081        kmem_cache_destroy(kcm_psockp);
2082
2083        if (kcm_wq)
2084                destroy_workqueue(kcm_wq);
2085
2086        return err;
2087}
2088
2089static void __exit kcm_exit(void)
2090{
2091        kcm_proc_exit();
2092        sock_unregister(PF_KCM);
2093        unregister_pernet_device(&kcm_net_ops);
2094        proto_unregister(&kcm_proto);
2095        destroy_workqueue(kcm_wq);
2096
2097        kmem_cache_destroy(kcm_muxp);
2098        kmem_cache_destroy(kcm_psockp);
2099}
2100
2101module_init(kcm_init);
2102module_exit(kcm_exit);
2103
2104MODULE_LICENSE("GPL");
2105MODULE_ALIAS_NETPROTO(PF_KCM);
2106