linux/net/kcm/kcmsock.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Kernel Connection Multiplexor
   4 *
   5 * Copyright (c) 2016 Tom Herbert <tom@herbertland.com>
   6 */
   7
   8#include <linux/bpf.h>
   9#include <linux/errno.h>
  10#include <linux/errqueue.h>
  11#include <linux/file.h>
  12#include <linux/in.h>
  13#include <linux/kernel.h>
  14#include <linux/module.h>
  15#include <linux/net.h>
  16#include <linux/netdevice.h>
  17#include <linux/poll.h>
  18#include <linux/rculist.h>
  19#include <linux/skbuff.h>
  20#include <linux/socket.h>
  21#include <linux/uaccess.h>
  22#include <linux/workqueue.h>
  23#include <linux/syscalls.h>
  24#include <linux/sched/signal.h>
  25
  26#include <net/kcm.h>
  27#include <net/netns/generic.h>
  28#include <net/sock.h>
  29#include <uapi/linux/kcm.h>
  30
  31unsigned int kcm_net_id;
  32
  33static struct kmem_cache *kcm_psockp __read_mostly;
  34static struct kmem_cache *kcm_muxp __read_mostly;
  35static struct workqueue_struct *kcm_wq;
  36
  37static inline struct kcm_sock *kcm_sk(const struct sock *sk)
  38{
  39        return (struct kcm_sock *)sk;
  40}
  41
  42static inline struct kcm_tx_msg *kcm_tx_msg(struct sk_buff *skb)
  43{
  44        return (struct kcm_tx_msg *)skb->cb;
  45}
  46
  47static void report_csk_error(struct sock *csk, int err)
  48{
  49        csk->sk_err = EPIPE;
  50        csk->sk_error_report(csk);
  51}
  52
  53static void kcm_abort_tx_psock(struct kcm_psock *psock, int err,
  54                               bool wakeup_kcm)
  55{
  56        struct sock *csk = psock->sk;
  57        struct kcm_mux *mux = psock->mux;
  58
  59        /* Unrecoverable error in transmit */
  60
  61        spin_lock_bh(&mux->lock);
  62
  63        if (psock->tx_stopped) {
  64                spin_unlock_bh(&mux->lock);
  65                return;
  66        }
  67
  68        psock->tx_stopped = 1;
  69        KCM_STATS_INCR(psock->stats.tx_aborts);
  70
  71        if (!psock->tx_kcm) {
  72                /* Take off psocks_avail list */
  73                list_del(&psock->psock_avail_list);
  74        } else if (wakeup_kcm) {
  75                /* In this case psock is being aborted while outside of
  76                 * write_msgs and psock is reserved. Schedule tx_work
  77                 * to handle the failure there. Need to commit tx_stopped
  78                 * before queuing work.
  79                 */
  80                smp_mb();
  81
  82                queue_work(kcm_wq, &psock->tx_kcm->tx_work);
  83        }
  84
  85        spin_unlock_bh(&mux->lock);
  86
  87        /* Report error on lower socket */
  88        report_csk_error(csk, err);
  89}
  90
  91/* RX mux lock held. */
  92static void kcm_update_rx_mux_stats(struct kcm_mux *mux,
  93                                    struct kcm_psock *psock)
  94{
  95        STRP_STATS_ADD(mux->stats.rx_bytes,
  96                       psock->strp.stats.bytes -
  97                       psock->saved_rx_bytes);
  98        mux->stats.rx_msgs +=
  99                psock->strp.stats.msgs - psock->saved_rx_msgs;
 100        psock->saved_rx_msgs = psock->strp.stats.msgs;
 101        psock->saved_rx_bytes = psock->strp.stats.bytes;
 102}
 103
 104static void kcm_update_tx_mux_stats(struct kcm_mux *mux,
 105                                    struct kcm_psock *psock)
 106{
 107        KCM_STATS_ADD(mux->stats.tx_bytes,
 108                      psock->stats.tx_bytes - psock->saved_tx_bytes);
 109        mux->stats.tx_msgs +=
 110                psock->stats.tx_msgs - psock->saved_tx_msgs;
 111        psock->saved_tx_msgs = psock->stats.tx_msgs;
 112        psock->saved_tx_bytes = psock->stats.tx_bytes;
 113}
 114
 115static int kcm_queue_rcv_skb(struct sock *sk, struct sk_buff *skb);
 116
 117/* KCM is ready to receive messages on its queue-- either the KCM is new or
 118 * has become unblocked after being blocked on full socket buffer. Queue any
 119 * pending ready messages on a psock. RX mux lock held.
 120 */
 121static void kcm_rcv_ready(struct kcm_sock *kcm)
 122{
 123        struct kcm_mux *mux = kcm->mux;
 124        struct kcm_psock *psock;
 125        struct sk_buff *skb;
 126
 127        if (unlikely(kcm->rx_wait || kcm->rx_psock || kcm->rx_disabled))
 128                return;
 129
 130        while (unlikely((skb = __skb_dequeue(&mux->rx_hold_queue)))) {
 131                if (kcm_queue_rcv_skb(&kcm->sk, skb)) {
 132                        /* Assuming buffer limit has been reached */
 133                        skb_queue_head(&mux->rx_hold_queue, skb);
 134                        WARN_ON(!sk_rmem_alloc_get(&kcm->sk));
 135                        return;
 136                }
 137        }
 138
 139        while (!list_empty(&mux->psocks_ready)) {
 140                psock = list_first_entry(&mux->psocks_ready, struct kcm_psock,
 141                                         psock_ready_list);
 142
 143                if (kcm_queue_rcv_skb(&kcm->sk, psock->ready_rx_msg)) {
 144                        /* Assuming buffer limit has been reached */
 145                        WARN_ON(!sk_rmem_alloc_get(&kcm->sk));
 146                        return;
 147                }
 148
 149                /* Consumed the ready message on the psock. Schedule rx_work to
 150                 * get more messages.
 151                 */
 152                list_del(&psock->psock_ready_list);
 153                psock->ready_rx_msg = NULL;
 154                /* Commit clearing of ready_rx_msg for queuing work */
 155                smp_mb();
 156
 157                strp_unpause(&psock->strp);
 158                strp_check_rcv(&psock->strp);
 159        }
 160
 161        /* Buffer limit is okay now, add to ready list */
 162        list_add_tail(&kcm->wait_rx_list,
 163                      &kcm->mux->kcm_rx_waiters);
 164        kcm->rx_wait = true;
 165}
 166
 167static void kcm_rfree(struct sk_buff *skb)
 168{
 169        struct sock *sk = skb->sk;
 170        struct kcm_sock *kcm = kcm_sk(sk);
 171        struct kcm_mux *mux = kcm->mux;
 172        unsigned int len = skb->truesize;
 173
 174        sk_mem_uncharge(sk, len);
 175        atomic_sub(len, &sk->sk_rmem_alloc);
 176
 177        /* For reading rx_wait and rx_psock without holding lock */
 178        smp_mb__after_atomic();
 179
 180        if (!kcm->rx_wait && !kcm->rx_psock &&
 181            sk_rmem_alloc_get(sk) < sk->sk_rcvlowat) {
 182                spin_lock_bh(&mux->rx_lock);
 183                kcm_rcv_ready(kcm);
 184                spin_unlock_bh(&mux->rx_lock);
 185        }
 186}
 187
 188static int kcm_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
 189{
 190        struct sk_buff_head *list = &sk->sk_receive_queue;
 191
 192        if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf)
 193                return -ENOMEM;
 194
 195        if (!sk_rmem_schedule(sk, skb, skb->truesize))
 196                return -ENOBUFS;
 197
 198        skb->dev = NULL;
 199
 200        skb_orphan(skb);
 201        skb->sk = sk;
 202        skb->destructor = kcm_rfree;
 203        atomic_add(skb->truesize, &sk->sk_rmem_alloc);
 204        sk_mem_charge(sk, skb->truesize);
 205
 206        skb_queue_tail(list, skb);
 207
 208        if (!sock_flag(sk, SOCK_DEAD))
 209                sk->sk_data_ready(sk);
 210
 211        return 0;
 212}
 213
 214/* Requeue received messages for a kcm socket to other kcm sockets. This is
 215 * called with a kcm socket is receive disabled.
 216 * RX mux lock held.
 217 */
 218static void requeue_rx_msgs(struct kcm_mux *mux, struct sk_buff_head *head)
 219{
 220        struct sk_buff *skb;
 221        struct kcm_sock *kcm;
 222
 223        while ((skb = __skb_dequeue(head))) {
 224                /* Reset destructor to avoid calling kcm_rcv_ready */
 225                skb->destructor = sock_rfree;
 226                skb_orphan(skb);
 227try_again:
 228                if (list_empty(&mux->kcm_rx_waiters)) {
 229                        skb_queue_tail(&mux->rx_hold_queue, skb);
 230                        continue;
 231                }
 232
 233                kcm = list_first_entry(&mux->kcm_rx_waiters,
 234                                       struct kcm_sock, wait_rx_list);
 235
 236                if (kcm_queue_rcv_skb(&kcm->sk, skb)) {
 237                        /* Should mean socket buffer full */
 238                        list_del(&kcm->wait_rx_list);
 239                        kcm->rx_wait = false;
 240
 241                        /* Commit rx_wait to read in kcm_free */
 242                        smp_wmb();
 243
 244                        goto try_again;
 245                }
 246        }
 247}
 248
 249/* Lower sock lock held */
 250static struct kcm_sock *reserve_rx_kcm(struct kcm_psock *psock,
 251                                       struct sk_buff *head)
 252{
 253        struct kcm_mux *mux = psock->mux;
 254        struct kcm_sock *kcm;
 255
 256        WARN_ON(psock->ready_rx_msg);
 257
 258        if (psock->rx_kcm)
 259                return psock->rx_kcm;
 260
 261        spin_lock_bh(&mux->rx_lock);
 262
 263        if (psock->rx_kcm) {
 264                spin_unlock_bh(&mux->rx_lock);
 265                return psock->rx_kcm;
 266        }
 267
 268        kcm_update_rx_mux_stats(mux, psock);
 269
 270        if (list_empty(&mux->kcm_rx_waiters)) {
 271                psock->ready_rx_msg = head;
 272                strp_pause(&psock->strp);
 273                list_add_tail(&psock->psock_ready_list,
 274                              &mux->psocks_ready);
 275                spin_unlock_bh(&mux->rx_lock);
 276                return NULL;
 277        }
 278
 279        kcm = list_first_entry(&mux->kcm_rx_waiters,
 280                               struct kcm_sock, wait_rx_list);
 281        list_del(&kcm->wait_rx_list);
 282        kcm->rx_wait = false;
 283
 284        psock->rx_kcm = kcm;
 285        kcm->rx_psock = psock;
 286
 287        spin_unlock_bh(&mux->rx_lock);
 288
 289        return kcm;
 290}
 291
 292static void kcm_done(struct kcm_sock *kcm);
 293
 294static void kcm_done_work(struct work_struct *w)
 295{
 296        kcm_done(container_of(w, struct kcm_sock, done_work));
 297}
 298
 299/* Lower sock held */
 300static void unreserve_rx_kcm(struct kcm_psock *psock,
 301                             bool rcv_ready)
 302{
 303        struct kcm_sock *kcm = psock->rx_kcm;
 304        struct kcm_mux *mux = psock->mux;
 305
 306        if (!kcm)
 307                return;
 308
 309        spin_lock_bh(&mux->rx_lock);
 310
 311        psock->rx_kcm = NULL;
 312        kcm->rx_psock = NULL;
 313
 314        /* Commit kcm->rx_psock before sk_rmem_alloc_get to sync with
 315         * kcm_rfree
 316         */
 317        smp_mb();
 318
 319        if (unlikely(kcm->done)) {
 320                spin_unlock_bh(&mux->rx_lock);
 321
 322                /* Need to run kcm_done in a task since we need to qcquire
 323                 * callback locks which may already be held here.
 324                 */
 325                INIT_WORK(&kcm->done_work, kcm_done_work);
 326                schedule_work(&kcm->done_work);
 327                return;
 328        }
 329
 330        if (unlikely(kcm->rx_disabled)) {
 331                requeue_rx_msgs(mux, &kcm->sk.sk_receive_queue);
 332        } else if (rcv_ready || unlikely(!sk_rmem_alloc_get(&kcm->sk))) {
 333                /* Check for degenerative race with rx_wait that all
 334                 * data was dequeued (accounted for in kcm_rfree).
 335                 */
 336                kcm_rcv_ready(kcm);
 337        }
 338        spin_unlock_bh(&mux->rx_lock);
 339}
 340
 341/* Lower sock lock held */
 342static void psock_data_ready(struct sock *sk)
 343{
 344        struct kcm_psock *psock;
 345
 346        read_lock_bh(&sk->sk_callback_lock);
 347
 348        psock = (struct kcm_psock *)sk->sk_user_data;
 349        if (likely(psock))
 350                strp_data_ready(&psock->strp);
 351
 352        read_unlock_bh(&sk->sk_callback_lock);
 353}
 354
 355/* Called with lower sock held */
 356static void kcm_rcv_strparser(struct strparser *strp, struct sk_buff *skb)
 357{
 358        struct kcm_psock *psock = container_of(strp, struct kcm_psock, strp);
 359        struct kcm_sock *kcm;
 360
 361try_queue:
 362        kcm = reserve_rx_kcm(psock, skb);
 363        if (!kcm) {
 364                 /* Unable to reserve a KCM, message is held in psock and strp
 365                  * is paused.
 366                  */
 367                return;
 368        }
 369
 370        if (kcm_queue_rcv_skb(&kcm->sk, skb)) {
 371                /* Should mean socket buffer full */
 372                unreserve_rx_kcm(psock, false);
 373                goto try_queue;
 374        }
 375}
 376
 377static int kcm_parse_func_strparser(struct strparser *strp, struct sk_buff *skb)
 378{
 379        struct kcm_psock *psock = container_of(strp, struct kcm_psock, strp);
 380        struct bpf_prog *prog = psock->bpf_prog;
 381        int res;
 382
 383        preempt_disable();
 384        res = BPF_PROG_RUN(prog, skb);
 385        preempt_enable();
 386        return res;
 387}
 388
 389static int kcm_read_sock_done(struct strparser *strp, int err)
 390{
 391        struct kcm_psock *psock = container_of(strp, struct kcm_psock, strp);
 392
 393        unreserve_rx_kcm(psock, true);
 394
 395        return err;
 396}
 397
 398static void psock_state_change(struct sock *sk)
 399{
 400        /* TCP only does a EPOLLIN for a half close. Do a EPOLLHUP here
 401         * since application will normally not poll with EPOLLIN
 402         * on the TCP sockets.
 403         */
 404
 405        report_csk_error(sk, EPIPE);
 406}
 407
 408static void psock_write_space(struct sock *sk)
 409{
 410        struct kcm_psock *psock;
 411        struct kcm_mux *mux;
 412        struct kcm_sock *kcm;
 413
 414        read_lock_bh(&sk->sk_callback_lock);
 415
 416        psock = (struct kcm_psock *)sk->sk_user_data;
 417        if (unlikely(!psock))
 418                goto out;
 419        mux = psock->mux;
 420
 421        spin_lock_bh(&mux->lock);
 422
 423        /* Check if the socket is reserved so someone is waiting for sending. */
 424        kcm = psock->tx_kcm;
 425        if (kcm && !unlikely(kcm->tx_stopped))
 426                queue_work(kcm_wq, &kcm->tx_work);
 427
 428        spin_unlock_bh(&mux->lock);
 429out:
 430        read_unlock_bh(&sk->sk_callback_lock);
 431}
 432
 433static void unreserve_psock(struct kcm_sock *kcm);
 434
 435/* kcm sock is locked. */
 436static struct kcm_psock *reserve_psock(struct kcm_sock *kcm)
 437{
 438        struct kcm_mux *mux = kcm->mux;
 439        struct kcm_psock *psock;
 440
 441        psock = kcm->tx_psock;
 442
 443        smp_rmb(); /* Must read tx_psock before tx_wait */
 444
 445        if (psock) {
 446                WARN_ON(kcm->tx_wait);
 447                if (unlikely(psock->tx_stopped))
 448                        unreserve_psock(kcm);
 449                else
 450                        return kcm->tx_psock;
 451        }
 452
 453        spin_lock_bh(&mux->lock);
 454
 455        /* Check again under lock to see if psock was reserved for this
 456         * psock via psock_unreserve.
 457         */
 458        psock = kcm->tx_psock;
 459        if (unlikely(psock)) {
 460                WARN_ON(kcm->tx_wait);
 461                spin_unlock_bh(&mux->lock);
 462                return kcm->tx_psock;
 463        }
 464
 465        if (!list_empty(&mux->psocks_avail)) {
 466                psock = list_first_entry(&mux->psocks_avail,
 467                                         struct kcm_psock,
 468                                         psock_avail_list);
 469                list_del(&psock->psock_avail_list);
 470                if (kcm->tx_wait) {
 471                        list_del(&kcm->wait_psock_list);
 472                        kcm->tx_wait = false;
 473                }
 474                kcm->tx_psock = psock;
 475                psock->tx_kcm = kcm;
 476                KCM_STATS_INCR(psock->stats.reserved);
 477        } else if (!kcm->tx_wait) {
 478                list_add_tail(&kcm->wait_psock_list,
 479                              &mux->kcm_tx_waiters);
 480                kcm->tx_wait = true;
 481        }
 482
 483        spin_unlock_bh(&mux->lock);
 484
 485        return psock;
 486}
 487
 488/* mux lock held */
 489static void psock_now_avail(struct kcm_psock *psock)
 490{
 491        struct kcm_mux *mux = psock->mux;
 492        struct kcm_sock *kcm;
 493
 494        if (list_empty(&mux->kcm_tx_waiters)) {
 495                list_add_tail(&psock->psock_avail_list,
 496                              &mux->psocks_avail);
 497        } else {
 498                kcm = list_first_entry(&mux->kcm_tx_waiters,
 499                                       struct kcm_sock,
 500                                       wait_psock_list);
 501                list_del(&kcm->wait_psock_list);
 502                kcm->tx_wait = false;
 503                psock->tx_kcm = kcm;
 504
 505                /* Commit before changing tx_psock since that is read in
 506                 * reserve_psock before queuing work.
 507                 */
 508                smp_mb();
 509
 510                kcm->tx_psock = psock;
 511                KCM_STATS_INCR(psock->stats.reserved);
 512                queue_work(kcm_wq, &kcm->tx_work);
 513        }
 514}
 515
 516/* kcm sock is locked. */
 517static void unreserve_psock(struct kcm_sock *kcm)
 518{
 519        struct kcm_psock *psock;
 520        struct kcm_mux *mux = kcm->mux;
 521
 522        spin_lock_bh(&mux->lock);
 523
 524        psock = kcm->tx_psock;
 525
 526        if (WARN_ON(!psock)) {
 527                spin_unlock_bh(&mux->lock);
 528                return;
 529        }
 530
 531        smp_rmb(); /* Read tx_psock before tx_wait */
 532
 533        kcm_update_tx_mux_stats(mux, psock);
 534
 535        WARN_ON(kcm->tx_wait);
 536
 537        kcm->tx_psock = NULL;
 538        psock->tx_kcm = NULL;
 539        KCM_STATS_INCR(psock->stats.unreserved);
 540
 541        if (unlikely(psock->tx_stopped)) {
 542                if (psock->done) {
 543                        /* Deferred free */
 544                        list_del(&psock->psock_list);
 545                        mux->psocks_cnt--;
 546                        sock_put(psock->sk);
 547                        fput(psock->sk->sk_socket->file);
 548                        kmem_cache_free(kcm_psockp, psock);
 549                }
 550
 551                /* Don't put back on available list */
 552
 553                spin_unlock_bh(&mux->lock);
 554
 555                return;
 556        }
 557
 558        psock_now_avail(psock);
 559
 560        spin_unlock_bh(&mux->lock);
 561}
 562
 563static void kcm_report_tx_retry(struct kcm_sock *kcm)
 564{
 565        struct kcm_mux *mux = kcm->mux;
 566
 567        spin_lock_bh(&mux->lock);
 568        KCM_STATS_INCR(mux->stats.tx_retries);
 569        spin_unlock_bh(&mux->lock);
 570}
 571
 572/* Write any messages ready on the kcm socket.  Called with kcm sock lock
 573 * held.  Return bytes actually sent or error.
 574 */
 575static int kcm_write_msgs(struct kcm_sock *kcm)
 576{
 577        struct sock *sk = &kcm->sk;
 578        struct kcm_psock *psock;
 579        struct sk_buff *skb, *head;
 580        struct kcm_tx_msg *txm;
 581        unsigned short fragidx, frag_offset;
 582        unsigned int sent, total_sent = 0;
 583        int ret = 0;
 584
 585        kcm->tx_wait_more = false;
 586        psock = kcm->tx_psock;
 587        if (unlikely(psock && psock->tx_stopped)) {
 588                /* A reserved psock was aborted asynchronously. Unreserve
 589                 * it and we'll retry the message.
 590                 */
 591                unreserve_psock(kcm);
 592                kcm_report_tx_retry(kcm);
 593                if (skb_queue_empty(&sk->sk_write_queue))
 594                        return 0;
 595
 596                kcm_tx_msg(skb_peek(&sk->sk_write_queue))->sent = 0;
 597
 598        } else if (skb_queue_empty(&sk->sk_write_queue)) {
 599                return 0;
 600        }
 601
 602        head = skb_peek(&sk->sk_write_queue);
 603        txm = kcm_tx_msg(head);
 604
 605        if (txm->sent) {
 606                /* Send of first skbuff in queue already in progress */
 607                if (WARN_ON(!psock)) {
 608                        ret = -EINVAL;
 609                        goto out;
 610                }
 611                sent = txm->sent;
 612                frag_offset = txm->frag_offset;
 613                fragidx = txm->fragidx;
 614                skb = txm->frag_skb;
 615
 616                goto do_frag;
 617        }
 618
 619try_again:
 620        psock = reserve_psock(kcm);
 621        if (!psock)
 622                goto out;
 623
 624        do {
 625                skb = head;
 626                txm = kcm_tx_msg(head);
 627                sent = 0;
 628
 629do_frag_list:
 630                if (WARN_ON(!skb_shinfo(skb)->nr_frags)) {
 631                        ret = -EINVAL;
 632                        goto out;
 633                }
 634
 635                for (fragidx = 0; fragidx < skb_shinfo(skb)->nr_frags;
 636                     fragidx++) {
 637                        skb_frag_t *frag;
 638
 639                        frag_offset = 0;
 640do_frag:
 641                        frag = &skb_shinfo(skb)->frags[fragidx];
 642                        if (WARN_ON(!skb_frag_size(frag))) {
 643                                ret = -EINVAL;
 644                                goto out;
 645                        }
 646
 647                        ret = kernel_sendpage(psock->sk->sk_socket,
 648                                              skb_frag_page(frag),
 649                                              skb_frag_off(frag) + frag_offset,
 650                                              skb_frag_size(frag) - frag_offset,
 651                                              MSG_DONTWAIT);
 652                        if (ret <= 0) {
 653                                if (ret == -EAGAIN) {
 654                                        /* Save state to try again when there's
 655                                         * write space on the socket
 656                                         */
 657                                        txm->sent = sent;
 658                                        txm->frag_offset = frag_offset;
 659                                        txm->fragidx = fragidx;
 660                                        txm->frag_skb = skb;
 661
 662                                        ret = 0;
 663                                        goto out;
 664                                }
 665
 666                                /* Hard failure in sending message, abort this
 667                                 * psock since it has lost framing
 668                                 * synchonization and retry sending the
 669                                 * message from the beginning.
 670                                 */
 671                                kcm_abort_tx_psock(psock, ret ? -ret : EPIPE,
 672                                                   true);
 673                                unreserve_psock(kcm);
 674
 675                                txm->sent = 0;
 676                                kcm_report_tx_retry(kcm);
 677                                ret = 0;
 678
 679                                goto try_again;
 680                        }
 681
 682                        sent += ret;
 683                        frag_offset += ret;
 684                        KCM_STATS_ADD(psock->stats.tx_bytes, ret);
 685                        if (frag_offset < skb_frag_size(frag)) {
 686                                /* Not finished with this frag */
 687                                goto do_frag;
 688                        }
 689                }
 690
 691                if (skb == head) {
 692                        if (skb_has_frag_list(skb)) {
 693                                skb = skb_shinfo(skb)->frag_list;
 694                                goto do_frag_list;
 695                        }
 696                } else if (skb->next) {
 697                        skb = skb->next;
 698                        goto do_frag_list;
 699                }
 700
 701                /* Successfully sent the whole packet, account for it. */
 702                skb_dequeue(&sk->sk_write_queue);
 703                kfree_skb(head);
 704                sk->sk_wmem_queued -= sent;
 705                total_sent += sent;
 706                KCM_STATS_INCR(psock->stats.tx_msgs);
 707        } while ((head = skb_peek(&sk->sk_write_queue)));
 708out:
 709        if (!head) {
 710                /* Done with all queued messages. */
 711                WARN_ON(!skb_queue_empty(&sk->sk_write_queue));
 712                unreserve_psock(kcm);
 713        }
 714
 715        /* Check if write space is available */
 716        sk->sk_write_space(sk);
 717
 718        return total_sent ? : ret;
 719}
 720
 721static void kcm_tx_work(struct work_struct *w)
 722{
 723        struct kcm_sock *kcm = container_of(w, struct kcm_sock, tx_work);
 724        struct sock *sk = &kcm->sk;
 725        int err;
 726
 727        lock_sock(sk);
 728
 729        /* Primarily for SOCK_DGRAM sockets, also handle asynchronous tx
 730         * aborts
 731         */
 732        err = kcm_write_msgs(kcm);
 733        if (err < 0) {
 734                /* Hard failure in write, report error on KCM socket */
 735                pr_warn("KCM: Hard failure on kcm_write_msgs %d\n", err);
 736                report_csk_error(&kcm->sk, -err);
 737                goto out;
 738        }
 739
 740        /* Primarily for SOCK_SEQPACKET sockets */
 741        if (likely(sk->sk_socket) &&
 742            test_bit(SOCK_NOSPACE, &sk->sk_socket->flags)) {
 743                clear_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 744                sk->sk_write_space(sk);
 745        }
 746
 747out:
 748        release_sock(sk);
 749}
 750
 751static void kcm_push(struct kcm_sock *kcm)
 752{
 753        if (kcm->tx_wait_more)
 754                kcm_write_msgs(kcm);
 755}
 756
 757static ssize_t kcm_sendpage(struct socket *sock, struct page *page,
 758                            int offset, size_t size, int flags)
 759
 760{
 761        struct sock *sk = sock->sk;
 762        struct kcm_sock *kcm = kcm_sk(sk);
 763        struct sk_buff *skb = NULL, *head = NULL;
 764        long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
 765        bool eor;
 766        int err = 0;
 767        int i;
 768
 769        if (flags & MSG_SENDPAGE_NOTLAST)
 770                flags |= MSG_MORE;
 771
 772        /* No MSG_EOR from splice, only look at MSG_MORE */
 773        eor = !(flags & MSG_MORE);
 774
 775        lock_sock(sk);
 776
 777        sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
 778
 779        err = -EPIPE;
 780        if (sk->sk_err)
 781                goto out_error;
 782
 783        if (kcm->seq_skb) {
 784                /* Previously opened message */
 785                head = kcm->seq_skb;
 786                skb = kcm_tx_msg(head)->last_skb;
 787                i = skb_shinfo(skb)->nr_frags;
 788
 789                if (skb_can_coalesce(skb, i, page, offset)) {
 790                        skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], size);
 791                        skb_shinfo(skb)->tx_flags |= SKBTX_SHARED_FRAG;
 792                        goto coalesced;
 793                }
 794
 795                if (i >= MAX_SKB_FRAGS) {
 796                        struct sk_buff *tskb;
 797
 798                        tskb = alloc_skb(0, sk->sk_allocation);
 799                        while (!tskb) {
 800                                kcm_push(kcm);
 801                                err = sk_stream_wait_memory(sk, &timeo);
 802                                if (err)
 803                                        goto out_error;
 804                        }
 805
 806                        if (head == skb)
 807                                skb_shinfo(head)->frag_list = tskb;
 808                        else
 809                                skb->next = tskb;
 810
 811                        skb = tskb;
 812                        skb->ip_summed = CHECKSUM_UNNECESSARY;
 813                        i = 0;
 814                }
 815        } else {
 816                /* Call the sk_stream functions to manage the sndbuf mem. */
 817                if (!sk_stream_memory_free(sk)) {
 818                        kcm_push(kcm);
 819                        set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 820                        err = sk_stream_wait_memory(sk, &timeo);
 821                        if (err)
 822                                goto out_error;
 823                }
 824
 825                head = alloc_skb(0, sk->sk_allocation);
 826                while (!head) {
 827                        kcm_push(kcm);
 828                        err = sk_stream_wait_memory(sk, &timeo);
 829                        if (err)
 830                                goto out_error;
 831                }
 832
 833                skb = head;
 834                i = 0;
 835        }
 836
 837        get_page(page);
 838        skb_fill_page_desc(skb, i, page, offset, size);
 839        skb_shinfo(skb)->tx_flags |= SKBTX_SHARED_FRAG;
 840
 841coalesced:
 842        skb->len += size;
 843        skb->data_len += size;
 844        skb->truesize += size;
 845        sk->sk_wmem_queued += size;
 846        sk_mem_charge(sk, size);
 847
 848        if (head != skb) {
 849                head->len += size;
 850                head->data_len += size;
 851                head->truesize += size;
 852        }
 853
 854        if (eor) {
 855                bool not_busy = skb_queue_empty(&sk->sk_write_queue);
 856
 857                /* Message complete, queue it on send buffer */
 858                __skb_queue_tail(&sk->sk_write_queue, head);
 859                kcm->seq_skb = NULL;
 860                KCM_STATS_INCR(kcm->stats.tx_msgs);
 861
 862                if (flags & MSG_BATCH) {
 863                        kcm->tx_wait_more = true;
 864                } else if (kcm->tx_wait_more || not_busy) {
 865                        err = kcm_write_msgs(kcm);
 866                        if (err < 0) {
 867                                /* We got a hard error in write_msgs but have
 868                                 * already queued this message. Report an error
 869                                 * in the socket, but don't affect return value
 870                                 * from sendmsg
 871                                 */
 872                                pr_warn("KCM: Hard failure on kcm_write_msgs\n");
 873                                report_csk_error(&kcm->sk, -err);
 874                        }
 875                }
 876        } else {
 877                /* Message not complete, save state */
 878                kcm->seq_skb = head;
 879                kcm_tx_msg(head)->last_skb = skb;
 880        }
 881
 882        KCM_STATS_ADD(kcm->stats.tx_bytes, size);
 883
 884        release_sock(sk);
 885        return size;
 886
 887out_error:
 888        kcm_push(kcm);
 889
 890        err = sk_stream_error(sk, flags, err);
 891
 892        /* make sure we wake any epoll edge trigger waiter */
 893        if (unlikely(skb_queue_len(&sk->sk_write_queue) == 0 && err == -EAGAIN))
 894                sk->sk_write_space(sk);
 895
 896        release_sock(sk);
 897        return err;
 898}
 899
 900static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 901{
 902        struct sock *sk = sock->sk;
 903        struct kcm_sock *kcm = kcm_sk(sk);
 904        struct sk_buff *skb = NULL, *head = NULL;
 905        size_t copy, copied = 0;
 906        long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 907        int eor = (sock->type == SOCK_DGRAM) ?
 908                  !(msg->msg_flags & MSG_MORE) : !!(msg->msg_flags & MSG_EOR);
 909        int err = -EPIPE;
 910
 911        lock_sock(sk);
 912
 913        /* Per tcp_sendmsg this should be in poll */
 914        sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
 915
 916        if (sk->sk_err)
 917                goto out_error;
 918
 919        if (kcm->seq_skb) {
 920                /* Previously opened message */
 921                head = kcm->seq_skb;
 922                skb = kcm_tx_msg(head)->last_skb;
 923                goto start;
 924        }
 925
 926        /* Call the sk_stream functions to manage the sndbuf mem. */
 927        if (!sk_stream_memory_free(sk)) {
 928                kcm_push(kcm);
 929                set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 930                err = sk_stream_wait_memory(sk, &timeo);
 931                if (err)
 932                        goto out_error;
 933        }
 934
 935        if (msg_data_left(msg)) {
 936                /* New message, alloc head skb */
 937                head = alloc_skb(0, sk->sk_allocation);
 938                while (!head) {
 939                        kcm_push(kcm);
 940                        err = sk_stream_wait_memory(sk, &timeo);
 941                        if (err)
 942                                goto out_error;
 943
 944                        head = alloc_skb(0, sk->sk_allocation);
 945                }
 946
 947                skb = head;
 948
 949                /* Set ip_summed to CHECKSUM_UNNECESSARY to avoid calling
 950                 * csum_and_copy_from_iter from skb_do_copy_data_nocache.
 951                 */
 952                skb->ip_summed = CHECKSUM_UNNECESSARY;
 953        }
 954
 955start:
 956        while (msg_data_left(msg)) {
 957                bool merge = true;
 958                int i = skb_shinfo(skb)->nr_frags;
 959                struct page_frag *pfrag = sk_page_frag(sk);
 960
 961                if (!sk_page_frag_refill(sk, pfrag))
 962                        goto wait_for_memory;
 963
 964                if (!skb_can_coalesce(skb, i, pfrag->page,
 965                                      pfrag->offset)) {
 966                        if (i == MAX_SKB_FRAGS) {
 967                                struct sk_buff *tskb;
 968
 969                                tskb = alloc_skb(0, sk->sk_allocation);
 970                                if (!tskb)
 971                                        goto wait_for_memory;
 972
 973                                if (head == skb)
 974                                        skb_shinfo(head)->frag_list = tskb;
 975                                else
 976                                        skb->next = tskb;
 977
 978                                skb = tskb;
 979                                skb->ip_summed = CHECKSUM_UNNECESSARY;
 980                                continue;
 981                        }
 982                        merge = false;
 983                }
 984
 985                copy = min_t(int, msg_data_left(msg),
 986                             pfrag->size - pfrag->offset);
 987
 988                if (!sk_wmem_schedule(sk, copy))
 989                        goto wait_for_memory;
 990
 991                err = skb_copy_to_page_nocache(sk, &msg->msg_iter, skb,
 992                                               pfrag->page,
 993                                               pfrag->offset,
 994                                               copy);
 995                if (err)
 996                        goto out_error;
 997
 998                /* Update the skb. */
 999                if (merge) {
1000                        skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy);
1001                } else {
1002                        skb_fill_page_desc(skb, i, pfrag->page,
1003                                           pfrag->offset, copy);
1004                        get_page(pfrag->page);
1005                }
1006
1007                pfrag->offset += copy;
1008                copied += copy;
1009                if (head != skb) {
1010                        head->len += copy;
1011                        head->data_len += copy;
1012                }
1013
1014                continue;
1015
1016wait_for_memory:
1017                kcm_push(kcm);
1018                err = sk_stream_wait_memory(sk, &timeo);
1019                if (err)
1020                        goto out_error;
1021        }
1022
1023        if (eor) {
1024                bool not_busy = skb_queue_empty(&sk->sk_write_queue);
1025
1026                if (head) {
1027                        /* Message complete, queue it on send buffer */
1028                        __skb_queue_tail(&sk->sk_write_queue, head);
1029                        kcm->seq_skb = NULL;
1030                        KCM_STATS_INCR(kcm->stats.tx_msgs);
1031                }
1032
1033                if (msg->msg_flags & MSG_BATCH) {
1034                        kcm->tx_wait_more = true;
1035                } else if (kcm->tx_wait_more || not_busy) {
1036                        err = kcm_write_msgs(kcm);
1037                        if (err < 0) {
1038                                /* We got a hard error in write_msgs but have
1039                                 * already queued this message. Report an error
1040                                 * in the socket, but don't affect return value
1041                                 * from sendmsg
1042                                 */
1043                                pr_warn("KCM: Hard failure on kcm_write_msgs\n");
1044                                report_csk_error(&kcm->sk, -err);
1045                        }
1046                }
1047        } else {
1048                /* Message not complete, save state */
1049partial_message:
1050                if (head) {
1051                        kcm->seq_skb = head;
1052                        kcm_tx_msg(head)->last_skb = skb;
1053                }
1054        }
1055
1056        KCM_STATS_ADD(kcm->stats.tx_bytes, copied);
1057
1058        release_sock(sk);
1059        return copied;
1060
1061out_error:
1062        kcm_push(kcm);
1063
1064        if (copied && sock->type == SOCK_SEQPACKET) {
1065                /* Wrote some bytes before encountering an
1066                 * error, return partial success.
1067                 */
1068                goto partial_message;
1069        }
1070
1071        if (head != kcm->seq_skb)
1072                kfree_skb(head);
1073
1074        err = sk_stream_error(sk, msg->msg_flags, err);
1075
1076        /* make sure we wake any epoll edge trigger waiter */
1077        if (unlikely(skb_queue_len(&sk->sk_write_queue) == 0 && err == -EAGAIN))
1078                sk->sk_write_space(sk);
1079
1080        release_sock(sk);
1081        return err;
1082}
1083
1084static struct sk_buff *kcm_wait_data(struct sock *sk, int flags,
1085                                     long timeo, int *err)
1086{
1087        struct sk_buff *skb;
1088
1089        while (!(skb = skb_peek(&sk->sk_receive_queue))) {
1090                if (sk->sk_err) {
1091                        *err = sock_error(sk);
1092                        return NULL;
1093                }
1094
1095                if (sock_flag(sk, SOCK_DONE))
1096                        return NULL;
1097
1098                if ((flags & MSG_DONTWAIT) || !timeo) {
1099                        *err = -EAGAIN;
1100                        return NULL;
1101                }
1102
1103                sk_wait_data(sk, &timeo, NULL);
1104
1105                /* Handle signals */
1106                if (signal_pending(current)) {
1107                        *err = sock_intr_errno(timeo);
1108                        return NULL;
1109                }
1110        }
1111
1112        return skb;
1113}
1114
1115static int kcm_recvmsg(struct socket *sock, struct msghdr *msg,
1116                       size_t len, int flags)
1117{
1118        struct sock *sk = sock->sk;
1119        struct kcm_sock *kcm = kcm_sk(sk);
1120        int err = 0;
1121        long timeo;
1122        struct strp_msg *stm;
1123        int copied = 0;
1124        struct sk_buff *skb;
1125
1126        timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1127
1128        lock_sock(sk);
1129
1130        skb = kcm_wait_data(sk, flags, timeo, &err);
1131        if (!skb)
1132                goto out;
1133
1134        /* Okay, have a message on the receive queue */
1135
1136        stm = strp_msg(skb);
1137
1138        if (len > stm->full_len)
1139                len = stm->full_len;
1140
1141        err = skb_copy_datagram_msg(skb, stm->offset, msg, len);
1142        if (err < 0)
1143                goto out;
1144
1145        copied = len;
1146        if (likely(!(flags & MSG_PEEK))) {
1147                KCM_STATS_ADD(kcm->stats.rx_bytes, copied);
1148                if (copied < stm->full_len) {
1149                        if (sock->type == SOCK_DGRAM) {
1150                                /* Truncated message */
1151                                msg->msg_flags |= MSG_TRUNC;
1152                                goto msg_finished;
1153                        }
1154                        stm->offset += copied;
1155                        stm->full_len -= copied;
1156                } else {
1157msg_finished:
1158                        /* Finished with message */
1159                        msg->msg_flags |= MSG_EOR;
1160                        KCM_STATS_INCR(kcm->stats.rx_msgs);
1161                        skb_unlink(skb, &sk->sk_receive_queue);
1162                        kfree_skb(skb);
1163                }
1164        }
1165
1166out:
1167        release_sock(sk);
1168
1169        return copied ? : err;
1170}
1171
1172static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos,
1173                               struct pipe_inode_info *pipe, size_t len,
1174                               unsigned int flags)
1175{
1176        struct sock *sk = sock->sk;
1177        struct kcm_sock *kcm = kcm_sk(sk);
1178        long timeo;
1179        struct strp_msg *stm;
1180        int err = 0;
1181        ssize_t copied;
1182        struct sk_buff *skb;
1183
1184        /* Only support splice for SOCKSEQPACKET */
1185
1186        timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1187
1188        lock_sock(sk);
1189
1190        skb = kcm_wait_data(sk, flags, timeo, &err);
1191        if (!skb)
1192                goto err_out;
1193
1194        /* Okay, have a message on the receive queue */
1195
1196        stm = strp_msg(skb);
1197
1198        if (len > stm->full_len)
1199                len = stm->full_len;
1200
1201        copied = skb_splice_bits(skb, sk, stm->offset, pipe, len, flags);
1202        if (copied < 0) {
1203                err = copied;
1204                goto err_out;
1205        }
1206
1207        KCM_STATS_ADD(kcm->stats.rx_bytes, copied);
1208
1209        stm->offset += copied;
1210        stm->full_len -= copied;
1211
1212        /* We have no way to return MSG_EOR. If all the bytes have been
1213         * read we still leave the message in the receive socket buffer.
1214         * A subsequent recvmsg needs to be done to return MSG_EOR and
1215         * finish reading the message.
1216         */
1217
1218        release_sock(sk);
1219
1220        return copied;
1221
1222err_out:
1223        release_sock(sk);
1224
1225        return err;
1226}
1227
1228/* kcm sock lock held */
1229static void kcm_recv_disable(struct kcm_sock *kcm)
1230{
1231        struct kcm_mux *mux = kcm->mux;
1232
1233        if (kcm->rx_disabled)
1234                return;
1235
1236        spin_lock_bh(&mux->rx_lock);
1237
1238        kcm->rx_disabled = 1;
1239
1240        /* If a psock is reserved we'll do cleanup in unreserve */
1241        if (!kcm->rx_psock) {
1242                if (kcm->rx_wait) {
1243                        list_del(&kcm->wait_rx_list);
1244                        kcm->rx_wait = false;
1245                }
1246
1247                requeue_rx_msgs(mux, &kcm->sk.sk_receive_queue);
1248        }
1249
1250        spin_unlock_bh(&mux->rx_lock);
1251}
1252
1253/* kcm sock lock held */
1254static void kcm_recv_enable(struct kcm_sock *kcm)
1255{
1256        struct kcm_mux *mux = kcm->mux;
1257
1258        if (!kcm->rx_disabled)
1259                return;
1260
1261        spin_lock_bh(&mux->rx_lock);
1262
1263        kcm->rx_disabled = 0;
1264        kcm_rcv_ready(kcm);
1265
1266        spin_unlock_bh(&mux->rx_lock);
1267}
1268
1269static int kcm_setsockopt(struct socket *sock, int level, int optname,
1270                          char __user *optval, unsigned int optlen)
1271{
1272        struct kcm_sock *kcm = kcm_sk(sock->sk);
1273        int val, valbool;
1274        int err = 0;
1275
1276        if (level != SOL_KCM)
1277                return -ENOPROTOOPT;
1278
1279        if (optlen < sizeof(int))
1280                return -EINVAL;
1281
1282        if (get_user(val, (int __user *)optval))
1283                return -EINVAL;
1284
1285        valbool = val ? 1 : 0;
1286
1287        switch (optname) {
1288        case KCM_RECV_DISABLE:
1289                lock_sock(&kcm->sk);
1290                if (valbool)
1291                        kcm_recv_disable(kcm);
1292                else
1293                        kcm_recv_enable(kcm);
1294                release_sock(&kcm->sk);
1295                break;
1296        default:
1297                err = -ENOPROTOOPT;
1298        }
1299
1300        return err;
1301}
1302
1303static int kcm_getsockopt(struct socket *sock, int level, int optname,
1304                          char __user *optval, int __user *optlen)
1305{
1306        struct kcm_sock *kcm = kcm_sk(sock->sk);
1307        int val, len;
1308
1309        if (level != SOL_KCM)
1310                return -ENOPROTOOPT;
1311
1312        if (get_user(len, optlen))
1313                return -EFAULT;
1314
1315        len = min_t(unsigned int, len, sizeof(int));
1316        if (len < 0)
1317                return -EINVAL;
1318
1319        switch (optname) {
1320        case KCM_RECV_DISABLE:
1321                val = kcm->rx_disabled;
1322                break;
1323        default:
1324                return -ENOPROTOOPT;
1325        }
1326
1327        if (put_user(len, optlen))
1328                return -EFAULT;
1329        if (copy_to_user(optval, &val, len))
1330                return -EFAULT;
1331        return 0;
1332}
1333
1334static void init_kcm_sock(struct kcm_sock *kcm, struct kcm_mux *mux)
1335{
1336        struct kcm_sock *tkcm;
1337        struct list_head *head;
1338        int index = 0;
1339
1340        /* For SOCK_SEQPACKET sock type, datagram_poll checks the sk_state, so
1341         * we set sk_state, otherwise epoll_wait always returns right away with
1342         * EPOLLHUP
1343         */
1344        kcm->sk.sk_state = TCP_ESTABLISHED;
1345
1346        /* Add to mux's kcm sockets list */
1347        kcm->mux = mux;
1348        spin_lock_bh(&mux->lock);
1349
1350        head = &mux->kcm_socks;
1351        list_for_each_entry(tkcm, &mux->kcm_socks, kcm_sock_list) {
1352                if (tkcm->index != index)
1353                        break;
1354                head = &tkcm->kcm_sock_list;
1355                index++;
1356        }
1357
1358        list_add(&kcm->kcm_sock_list, head);
1359        kcm->index = index;
1360
1361        mux->kcm_socks_cnt++;
1362        spin_unlock_bh(&mux->lock);
1363
1364        INIT_WORK(&kcm->tx_work, kcm_tx_work);
1365
1366        spin_lock_bh(&mux->rx_lock);
1367        kcm_rcv_ready(kcm);
1368        spin_unlock_bh(&mux->rx_lock);
1369}
1370
1371static int kcm_attach(struct socket *sock, struct socket *csock,
1372                      struct bpf_prog *prog)
1373{
1374        struct kcm_sock *kcm = kcm_sk(sock->sk);
1375        struct kcm_mux *mux = kcm->mux;
1376        struct sock *csk;
1377        struct kcm_psock *psock = NULL, *tpsock;
1378        struct list_head *head;
1379        int index = 0;
1380        static const struct strp_callbacks cb = {
1381                .rcv_msg = kcm_rcv_strparser,
1382                .parse_msg = kcm_parse_func_strparser,
1383                .read_sock_done = kcm_read_sock_done,
1384        };
1385        int err = 0;
1386
1387        csk = csock->sk;
1388        if (!csk)
1389                return -EINVAL;
1390
1391        lock_sock(csk);
1392
1393        /* Only allow TCP sockets to be attached for now */
1394        if ((csk->sk_family != AF_INET && csk->sk_family != AF_INET6) ||
1395            csk->sk_protocol != IPPROTO_TCP) {
1396                err = -EOPNOTSUPP;
1397                goto out;
1398        }
1399
1400        /* Don't allow listeners or closed sockets */
1401        if (csk->sk_state == TCP_LISTEN || csk->sk_state == TCP_CLOSE) {
1402                err = -EOPNOTSUPP;
1403                goto out;
1404        }
1405
1406        psock = kmem_cache_zalloc(kcm_psockp, GFP_KERNEL);
1407        if (!psock) {
1408                err = -ENOMEM;
1409                goto out;
1410        }
1411
1412        psock->mux = mux;
1413        psock->sk = csk;
1414        psock->bpf_prog = prog;
1415
1416        err = strp_init(&psock->strp, csk, &cb);
1417        if (err) {
1418                kmem_cache_free(kcm_psockp, psock);
1419                goto out;
1420        }
1421
1422        write_lock_bh(&csk->sk_callback_lock);
1423
1424        /* Check if sk_user_data is aready by KCM or someone else.
1425         * Must be done under lock to prevent race conditions.
1426         */
1427        if (csk->sk_user_data) {
1428                write_unlock_bh(&csk->sk_callback_lock);
1429                strp_stop(&psock->strp);
1430                strp_done(&psock->strp);
1431                kmem_cache_free(kcm_psockp, psock);
1432                err = -EALREADY;
1433                goto out;
1434        }
1435
1436        psock->save_data_ready = csk->sk_data_ready;
1437        psock->save_write_space = csk->sk_write_space;
1438        psock->save_state_change = csk->sk_state_change;
1439        csk->sk_user_data = psock;
1440        csk->sk_data_ready = psock_data_ready;
1441        csk->sk_write_space = psock_write_space;
1442        csk->sk_state_change = psock_state_change;
1443
1444        write_unlock_bh(&csk->sk_callback_lock);
1445
1446        sock_hold(csk);
1447
1448        /* Finished initialization, now add the psock to the MUX. */
1449        spin_lock_bh(&mux->lock);
1450        head = &mux->psocks;
1451        list_for_each_entry(tpsock, &mux->psocks, psock_list) {
1452                if (tpsock->index != index)
1453                        break;
1454                head = &tpsock->psock_list;
1455                index++;
1456        }
1457
1458        list_add(&psock->psock_list, head);
1459        psock->index = index;
1460
1461        KCM_STATS_INCR(mux->stats.psock_attach);
1462        mux->psocks_cnt++;
1463        psock_now_avail(psock);
1464        spin_unlock_bh(&mux->lock);
1465
1466        /* Schedule RX work in case there are already bytes queued */
1467        strp_check_rcv(&psock->strp);
1468
1469out:
1470        release_sock(csk);
1471
1472        return err;
1473}
1474
1475static int kcm_attach_ioctl(struct socket *sock, struct kcm_attach *info)
1476{
1477        struct socket *csock;
1478        struct bpf_prog *prog;
1479        int err;
1480
1481        csock = sockfd_lookup(info->fd, &err);
1482        if (!csock)
1483                return -ENOENT;
1484
1485        prog = bpf_prog_get_type(info->bpf_fd, BPF_PROG_TYPE_SOCKET_FILTER);
1486        if (IS_ERR(prog)) {
1487                err = PTR_ERR(prog);
1488                goto out;
1489        }
1490
1491        err = kcm_attach(sock, csock, prog);
1492        if (err) {
1493                bpf_prog_put(prog);
1494                goto out;
1495        }
1496
1497        /* Keep reference on file also */
1498
1499        return 0;
1500out:
1501        fput(csock->file);
1502        return err;
1503}
1504
1505static void kcm_unattach(struct kcm_psock *psock)
1506{
1507        struct sock *csk = psock->sk;
1508        struct kcm_mux *mux = psock->mux;
1509
1510        lock_sock(csk);
1511
1512        /* Stop getting callbacks from TCP socket. After this there should
1513         * be no way to reserve a kcm for this psock.
1514         */
1515        write_lock_bh(&csk->sk_callback_lock);
1516        csk->sk_user_data = NULL;
1517        csk->sk_data_ready = psock->save_data_ready;
1518        csk->sk_write_space = psock->save_write_space;
1519        csk->sk_state_change = psock->save_state_change;
1520        strp_stop(&psock->strp);
1521
1522        if (WARN_ON(psock->rx_kcm)) {
1523                write_unlock_bh(&csk->sk_callback_lock);
1524                release_sock(csk);
1525                return;
1526        }
1527
1528        spin_lock_bh(&mux->rx_lock);
1529
1530        /* Stop receiver activities. After this point psock should not be
1531         * able to get onto ready list either through callbacks or work.
1532         */
1533        if (psock->ready_rx_msg) {
1534                list_del(&psock->psock_ready_list);
1535                kfree_skb(psock->ready_rx_msg);
1536                psock->ready_rx_msg = NULL;
1537                KCM_STATS_INCR(mux->stats.rx_ready_drops);
1538        }
1539
1540        spin_unlock_bh(&mux->rx_lock);
1541
1542        write_unlock_bh(&csk->sk_callback_lock);
1543
1544        /* Call strp_done without sock lock */
1545        release_sock(csk);
1546        strp_done(&psock->strp);
1547        lock_sock(csk);
1548
1549        bpf_prog_put(psock->bpf_prog);
1550
1551        spin_lock_bh(&mux->lock);
1552
1553        aggregate_psock_stats(&psock->stats, &mux->aggregate_psock_stats);
1554        save_strp_stats(&psock->strp, &mux->aggregate_strp_stats);
1555
1556        KCM_STATS_INCR(mux->stats.psock_unattach);
1557
1558        if (psock->tx_kcm) {
1559                /* psock was reserved.  Just mark it finished and we will clean
1560                 * up in the kcm paths, we need kcm lock which can not be
1561                 * acquired here.
1562                 */
1563                KCM_STATS_INCR(mux->stats.psock_unattach_rsvd);
1564                spin_unlock_bh(&mux->lock);
1565
1566                /* We are unattaching a socket that is reserved. Abort the
1567                 * socket since we may be out of sync in sending on it. We need
1568                 * to do this without the mux lock.
1569                 */
1570                kcm_abort_tx_psock(psock, EPIPE, false);
1571
1572                spin_lock_bh(&mux->lock);
1573                if (!psock->tx_kcm) {
1574                        /* psock now unreserved in window mux was unlocked */
1575                        goto no_reserved;
1576                }
1577                psock->done = 1;
1578
1579                /* Commit done before queuing work to process it */
1580                smp_mb();
1581
1582                /* Queue tx work to make sure psock->done is handled */
1583                queue_work(kcm_wq, &psock->tx_kcm->tx_work);
1584                spin_unlock_bh(&mux->lock);
1585        } else {
1586no_reserved:
1587                if (!psock->tx_stopped)
1588                        list_del(&psock->psock_avail_list);
1589                list_del(&psock->psock_list);
1590                mux->psocks_cnt--;
1591                spin_unlock_bh(&mux->lock);
1592
1593                sock_put(csk);
1594                fput(csk->sk_socket->file);
1595                kmem_cache_free(kcm_psockp, psock);
1596        }
1597
1598        release_sock(csk);
1599}
1600
1601static int kcm_unattach_ioctl(struct socket *sock, struct kcm_unattach *info)
1602{
1603        struct kcm_sock *kcm = kcm_sk(sock->sk);
1604        struct kcm_mux *mux = kcm->mux;
1605        struct kcm_psock *psock;
1606        struct socket *csock;
1607        struct sock *csk;
1608        int err;
1609
1610        csock = sockfd_lookup(info->fd, &err);
1611        if (!csock)
1612                return -ENOENT;
1613
1614        csk = csock->sk;
1615        if (!csk) {
1616                err = -EINVAL;
1617                goto out;
1618        }
1619
1620        err = -ENOENT;
1621
1622        spin_lock_bh(&mux->lock);
1623
1624        list_for_each_entry(psock, &mux->psocks, psock_list) {
1625                if (psock->sk != csk)
1626                        continue;
1627
1628                /* Found the matching psock */
1629
1630                if (psock->unattaching || WARN_ON(psock->done)) {
1631                        err = -EALREADY;
1632                        break;
1633                }
1634
1635                psock->unattaching = 1;
1636
1637                spin_unlock_bh(&mux->lock);
1638
1639                /* Lower socket lock should already be held */
1640                kcm_unattach(psock);
1641
1642                err = 0;
1643                goto out;
1644        }
1645
1646        spin_unlock_bh(&mux->lock);
1647
1648out:
1649        fput(csock->file);
1650        return err;
1651}
1652
1653static struct proto kcm_proto = {
1654        .name   = "KCM",
1655        .owner  = THIS_MODULE,
1656        .obj_size = sizeof(struct kcm_sock),
1657};
1658
1659/* Clone a kcm socket. */
1660static struct file *kcm_clone(struct socket *osock)
1661{
1662        struct socket *newsock;
1663        struct sock *newsk;
1664
1665        newsock = sock_alloc();
1666        if (!newsock)
1667                return ERR_PTR(-ENFILE);
1668
1669        newsock->type = osock->type;
1670        newsock->ops = osock->ops;
1671
1672        __module_get(newsock->ops->owner);
1673
1674        newsk = sk_alloc(sock_net(osock->sk), PF_KCM, GFP_KERNEL,
1675                         &kcm_proto, false);
1676        if (!newsk) {
1677                sock_release(newsock);
1678                return ERR_PTR(-ENOMEM);
1679        }
1680        sock_init_data(newsock, newsk);
1681        init_kcm_sock(kcm_sk(newsk), kcm_sk(osock->sk)->mux);
1682
1683        return sock_alloc_file(newsock, 0, osock->sk->sk_prot_creator->name);
1684}
1685
1686static int kcm_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
1687{
1688        int err;
1689
1690        switch (cmd) {
1691        case SIOCKCMATTACH: {
1692                struct kcm_attach info;
1693
1694                if (copy_from_user(&info, (void __user *)arg, sizeof(info)))
1695                        return -EFAULT;
1696
1697                err = kcm_attach_ioctl(sock, &info);
1698
1699                break;
1700        }
1701        case SIOCKCMUNATTACH: {
1702                struct kcm_unattach info;
1703
1704                if (copy_from_user(&info, (void __user *)arg, sizeof(info)))
1705                        return -EFAULT;
1706
1707                err = kcm_unattach_ioctl(sock, &info);
1708
1709                break;
1710        }
1711        case SIOCKCMCLONE: {
1712                struct kcm_clone info;
1713                struct file *file;
1714
1715                info.fd = get_unused_fd_flags(0);
1716                if (unlikely(info.fd < 0))
1717                        return info.fd;
1718
1719                file = kcm_clone(sock);
1720                if (IS_ERR(file)) {
1721                        put_unused_fd(info.fd);
1722                        return PTR_ERR(file);
1723                }
1724                if (copy_to_user((void __user *)arg, &info,
1725                                 sizeof(info))) {
1726                        put_unused_fd(info.fd);
1727                        fput(file);
1728                        return -EFAULT;
1729                }
1730                fd_install(info.fd, file);
1731                err = 0;
1732                break;
1733        }
1734        default:
1735                err = -ENOIOCTLCMD;
1736                break;
1737        }
1738
1739        return err;
1740}
1741
1742static void free_mux(struct rcu_head *rcu)
1743{
1744        struct kcm_mux *mux = container_of(rcu,
1745            struct kcm_mux, rcu);
1746
1747        kmem_cache_free(kcm_muxp, mux);
1748}
1749
1750static void release_mux(struct kcm_mux *mux)
1751{
1752        struct kcm_net *knet = mux->knet;
1753        struct kcm_psock *psock, *tmp_psock;
1754
1755        /* Release psocks */
1756        list_for_each_entry_safe(psock, tmp_psock,
1757                                 &mux->psocks, psock_list) {
1758                if (!WARN_ON(psock->unattaching))
1759                        kcm_unattach(psock);
1760        }
1761
1762        if (WARN_ON(mux->psocks_cnt))
1763                return;
1764
1765        __skb_queue_purge(&mux->rx_hold_queue);
1766
1767        mutex_lock(&knet->mutex);
1768        aggregate_mux_stats(&mux->stats, &knet->aggregate_mux_stats);
1769        aggregate_psock_stats(&mux->aggregate_psock_stats,
1770                              &knet->aggregate_psock_stats);
1771        aggregate_strp_stats(&mux->aggregate_strp_stats,
1772                             &knet->aggregate_strp_stats);
1773        list_del_rcu(&mux->kcm_mux_list);
1774        knet->count--;
1775        mutex_unlock(&knet->mutex);
1776
1777        call_rcu(&mux->rcu, free_mux);
1778}
1779
1780static void kcm_done(struct kcm_sock *kcm)
1781{
1782        struct kcm_mux *mux = kcm->mux;
1783        struct sock *sk = &kcm->sk;
1784        int socks_cnt;
1785
1786        spin_lock_bh(&mux->rx_lock);
1787        if (kcm->rx_psock) {
1788                /* Cleanup in unreserve_rx_kcm */
1789                WARN_ON(kcm->done);
1790                kcm->rx_disabled = 1;
1791                kcm->done = 1;
1792                spin_unlock_bh(&mux->rx_lock);
1793                return;
1794        }
1795
1796        if (kcm->rx_wait) {
1797                list_del(&kcm->wait_rx_list);
1798                kcm->rx_wait = false;
1799        }
1800        /* Move any pending receive messages to other kcm sockets */
1801        requeue_rx_msgs(mux, &sk->sk_receive_queue);
1802
1803        spin_unlock_bh(&mux->rx_lock);
1804
1805        if (WARN_ON(sk_rmem_alloc_get(sk)))
1806                return;
1807
1808        /* Detach from MUX */
1809        spin_lock_bh(&mux->lock);
1810
1811        list_del(&kcm->kcm_sock_list);
1812        mux->kcm_socks_cnt--;
1813        socks_cnt = mux->kcm_socks_cnt;
1814
1815        spin_unlock_bh(&mux->lock);
1816
1817        if (!socks_cnt) {
1818                /* We are done with the mux now. */
1819                release_mux(mux);
1820        }
1821
1822        WARN_ON(kcm->rx_wait);
1823
1824        sock_put(&kcm->sk);
1825}
1826
1827/* Called by kcm_release to close a KCM socket.
1828 * If this is the last KCM socket on the MUX, destroy the MUX.
1829 */
1830static int kcm_release(struct socket *sock)
1831{
1832        struct sock *sk = sock->sk;
1833        struct kcm_sock *kcm;
1834        struct kcm_mux *mux;
1835        struct kcm_psock *psock;
1836
1837        if (!sk)
1838                return 0;
1839
1840        kcm = kcm_sk(sk);
1841        mux = kcm->mux;
1842
1843        sock_orphan(sk);
1844        kfree_skb(kcm->seq_skb);
1845
1846        lock_sock(sk);
1847        /* Purge queue under lock to avoid race condition with tx_work trying
1848         * to act when queue is nonempty. If tx_work runs after this point
1849         * it will just return.
1850         */
1851        __skb_queue_purge(&sk->sk_write_queue);
1852
1853        /* Set tx_stopped. This is checked when psock is bound to a kcm and we
1854         * get a writespace callback. This prevents further work being queued
1855         * from the callback (unbinding the psock occurs after canceling work.
1856         */
1857        kcm->tx_stopped = 1;
1858
1859        release_sock(sk);
1860
1861        spin_lock_bh(&mux->lock);
1862        if (kcm->tx_wait) {
1863                /* Take of tx_wait list, after this point there should be no way
1864                 * that a psock will be assigned to this kcm.
1865                 */
1866                list_del(&kcm->wait_psock_list);
1867                kcm->tx_wait = false;
1868        }
1869        spin_unlock_bh(&mux->lock);
1870
1871        /* Cancel work. After this point there should be no outside references
1872         * to the kcm socket.
1873         */
1874        cancel_work_sync(&kcm->tx_work);
1875
1876        lock_sock(sk);
1877        psock = kcm->tx_psock;
1878        if (psock) {
1879                /* A psock was reserved, so we need to kill it since it
1880                 * may already have some bytes queued from a message. We
1881                 * need to do this after removing kcm from tx_wait list.
1882                 */
1883                kcm_abort_tx_psock(psock, EPIPE, false);
1884                unreserve_psock(kcm);
1885        }
1886        release_sock(sk);
1887
1888        WARN_ON(kcm->tx_wait);
1889        WARN_ON(kcm->tx_psock);
1890
1891        sock->sk = NULL;
1892
1893        kcm_done(kcm);
1894
1895        return 0;
1896}
1897
1898static const struct proto_ops kcm_dgram_ops = {
1899        .family =       PF_KCM,
1900        .owner =        THIS_MODULE,
1901        .release =      kcm_release,
1902        .bind =         sock_no_bind,
1903        .connect =      sock_no_connect,
1904        .socketpair =   sock_no_socketpair,
1905        .accept =       sock_no_accept,
1906        .getname =      sock_no_getname,
1907        .poll =         datagram_poll,
1908        .ioctl =        kcm_ioctl,
1909        .listen =       sock_no_listen,
1910        .shutdown =     sock_no_shutdown,
1911        .setsockopt =   kcm_setsockopt,
1912        .getsockopt =   kcm_getsockopt,
1913        .sendmsg =      kcm_sendmsg,
1914        .recvmsg =      kcm_recvmsg,
1915        .mmap =         sock_no_mmap,
1916        .sendpage =     kcm_sendpage,
1917};
1918
1919static const struct proto_ops kcm_seqpacket_ops = {
1920        .family =       PF_KCM,
1921        .owner =        THIS_MODULE,
1922        .release =      kcm_release,
1923        .bind =         sock_no_bind,
1924        .connect =      sock_no_connect,
1925        .socketpair =   sock_no_socketpair,
1926        .accept =       sock_no_accept,
1927        .getname =      sock_no_getname,
1928        .poll =         datagram_poll,
1929        .ioctl =        kcm_ioctl,
1930        .listen =       sock_no_listen,
1931        .shutdown =     sock_no_shutdown,
1932        .setsockopt =   kcm_setsockopt,
1933        .getsockopt =   kcm_getsockopt,
1934        .sendmsg =      kcm_sendmsg,
1935        .recvmsg =      kcm_recvmsg,
1936        .mmap =         sock_no_mmap,
1937        .sendpage =     kcm_sendpage,
1938        .splice_read =  kcm_splice_read,
1939};
1940
1941/* Create proto operation for kcm sockets */
1942static int kcm_create(struct net *net, struct socket *sock,
1943                      int protocol, int kern)
1944{
1945        struct kcm_net *knet = net_generic(net, kcm_net_id);
1946        struct sock *sk;
1947        struct kcm_mux *mux;
1948
1949        switch (sock->type) {
1950        case SOCK_DGRAM:
1951                sock->ops = &kcm_dgram_ops;
1952                break;
1953        case SOCK_SEQPACKET:
1954                sock->ops = &kcm_seqpacket_ops;
1955                break;
1956        default:
1957                return -ESOCKTNOSUPPORT;
1958        }
1959
1960        if (protocol != KCMPROTO_CONNECTED)
1961                return -EPROTONOSUPPORT;
1962
1963        sk = sk_alloc(net, PF_KCM, GFP_KERNEL, &kcm_proto, kern);
1964        if (!sk)
1965                return -ENOMEM;
1966
1967        /* Allocate a kcm mux, shared between KCM sockets */
1968        mux = kmem_cache_zalloc(kcm_muxp, GFP_KERNEL);
1969        if (!mux) {
1970                sk_free(sk);
1971                return -ENOMEM;
1972        }
1973
1974        spin_lock_init(&mux->lock);
1975        spin_lock_init(&mux->rx_lock);
1976        INIT_LIST_HEAD(&mux->kcm_socks);
1977        INIT_LIST_HEAD(&mux->kcm_rx_waiters);
1978        INIT_LIST_HEAD(&mux->kcm_tx_waiters);
1979
1980        INIT_LIST_HEAD(&mux->psocks);
1981        INIT_LIST_HEAD(&mux->psocks_ready);
1982        INIT_LIST_HEAD(&mux->psocks_avail);
1983
1984        mux->knet = knet;
1985
1986        /* Add new MUX to list */
1987        mutex_lock(&knet->mutex);
1988        list_add_rcu(&mux->kcm_mux_list, &knet->mux_list);
1989        knet->count++;
1990        mutex_unlock(&knet->mutex);
1991
1992        skb_queue_head_init(&mux->rx_hold_queue);
1993
1994        /* Init KCM socket */
1995        sock_init_data(sock, sk);
1996        init_kcm_sock(kcm_sk(sk), mux);
1997
1998        return 0;
1999}
2000
2001static const struct net_proto_family kcm_family_ops = {
2002        .family = PF_KCM,
2003        .create = kcm_create,
2004        .owner  = THIS_MODULE,
2005};
2006
2007static __net_init int kcm_init_net(struct net *net)
2008{
2009        struct kcm_net *knet = net_generic(net, kcm_net_id);
2010
2011        INIT_LIST_HEAD_RCU(&knet->mux_list);
2012        mutex_init(&knet->mutex);
2013
2014        return 0;
2015}
2016
2017static __net_exit void kcm_exit_net(struct net *net)
2018{
2019        struct kcm_net *knet = net_generic(net, kcm_net_id);
2020
2021        /* All KCM sockets should be closed at this point, which should mean
2022         * that all multiplexors and psocks have been destroyed.
2023         */
2024        WARN_ON(!list_empty(&knet->mux_list));
2025}
2026
2027static struct pernet_operations kcm_net_ops = {
2028        .init = kcm_init_net,
2029        .exit = kcm_exit_net,
2030        .id   = &kcm_net_id,
2031        .size = sizeof(struct kcm_net),
2032};
2033
2034static int __init kcm_init(void)
2035{
2036        int err = -ENOMEM;
2037
2038        kcm_muxp = kmem_cache_create("kcm_mux_cache",
2039                                     sizeof(struct kcm_mux), 0,
2040                                     SLAB_HWCACHE_ALIGN, NULL);
2041        if (!kcm_muxp)
2042                goto fail;
2043
2044        kcm_psockp = kmem_cache_create("kcm_psock_cache",
2045                                       sizeof(struct kcm_psock), 0,
2046                                        SLAB_HWCACHE_ALIGN, NULL);
2047        if (!kcm_psockp)
2048                goto fail;
2049
2050        kcm_wq = create_singlethread_workqueue("kkcmd");
2051        if (!kcm_wq)
2052                goto fail;
2053
2054        err = proto_register(&kcm_proto, 1);
2055        if (err)
2056                goto fail;
2057
2058        err = register_pernet_device(&kcm_net_ops);
2059        if (err)
2060                goto net_ops_fail;
2061
2062        err = sock_register(&kcm_family_ops);
2063        if (err)
2064                goto sock_register_fail;
2065
2066        err = kcm_proc_init();
2067        if (err)
2068                goto proc_init_fail;
2069
2070        return 0;
2071
2072proc_init_fail:
2073        sock_unregister(PF_KCM);
2074
2075sock_register_fail:
2076        unregister_pernet_device(&kcm_net_ops);
2077
2078net_ops_fail:
2079        proto_unregister(&kcm_proto);
2080
2081fail:
2082        kmem_cache_destroy(kcm_muxp);
2083        kmem_cache_destroy(kcm_psockp);
2084
2085        if (kcm_wq)
2086                destroy_workqueue(kcm_wq);
2087
2088        return err;
2089}
2090
2091static void __exit kcm_exit(void)
2092{
2093        kcm_proc_exit();
2094        sock_unregister(PF_KCM);
2095        unregister_pernet_device(&kcm_net_ops);
2096        proto_unregister(&kcm_proto);
2097        destroy_workqueue(kcm_wq);
2098
2099        kmem_cache_destroy(kcm_muxp);
2100        kmem_cache_destroy(kcm_psockp);
2101}
2102
2103module_init(kcm_init);
2104module_exit(kcm_exit);
2105
2106MODULE_LICENSE("GPL");
2107MODULE_ALIAS_NETPROTO(PF_KCM);
2108