linux/net/vmw_vsock/virtio_transport_common.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * common code for virtio vsock
   4 *
   5 * Copyright (C) 2013-2015 Red Hat, Inc.
   6 * Author: Asias He <asias@redhat.com>
   7 *         Stefan Hajnoczi <stefanha@redhat.com>
   8 */
   9#include <linux/spinlock.h>
  10#include <linux/module.h>
  11#include <linux/sched/signal.h>
  12#include <linux/ctype.h>
  13#include <linux/list.h>
  14#include <linux/virtio_vsock.h>
  15#include <uapi/linux/vsockmon.h>
  16
  17#include <net/sock.h>
  18#include <net/af_vsock.h>
  19
  20#define CREATE_TRACE_POINTS
  21#include <trace/events/vsock_virtio_transport_common.h>
  22
  23/* How long to wait for graceful shutdown of a connection */
  24#define VSOCK_CLOSE_TIMEOUT (8 * HZ)
  25
  26/* Threshold for detecting small packets to copy */
  27#define GOOD_COPY_LEN  128
  28
  29static const struct virtio_transport *
  30virtio_transport_get_ops(struct vsock_sock *vsk)
  31{
  32        const struct vsock_transport *t = vsock_core_get_transport(vsk);
  33
  34        if (WARN_ON(!t))
  35                return NULL;
  36
  37        return container_of(t, struct virtio_transport, transport);
  38}
  39
  40static struct virtio_vsock_pkt *
  41virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
  42                           size_t len,
  43                           u32 src_cid,
  44                           u32 src_port,
  45                           u32 dst_cid,
  46                           u32 dst_port)
  47{
  48        struct virtio_vsock_pkt *pkt;
  49        int err;
  50
  51        pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
  52        if (!pkt)
  53                return NULL;
  54
  55        pkt->hdr.type           = cpu_to_le16(info->type);
  56        pkt->hdr.op             = cpu_to_le16(info->op);
  57        pkt->hdr.src_cid        = cpu_to_le64(src_cid);
  58        pkt->hdr.dst_cid        = cpu_to_le64(dst_cid);
  59        pkt->hdr.src_port       = cpu_to_le32(src_port);
  60        pkt->hdr.dst_port       = cpu_to_le32(dst_port);
  61        pkt->hdr.flags          = cpu_to_le32(info->flags);
  62        pkt->len                = len;
  63        pkt->hdr.len            = cpu_to_le32(len);
  64        pkt->reply              = info->reply;
  65        pkt->vsk                = info->vsk;
  66
  67        if (info->msg && len > 0) {
  68                pkt->buf = kmalloc(len, GFP_KERNEL);
  69                if (!pkt->buf)
  70                        goto out_pkt;
  71
  72                pkt->buf_len = len;
  73
  74                err = memcpy_from_msg(pkt->buf, info->msg, len);
  75                if (err)
  76                        goto out;
  77
  78                if (msg_data_left(info->msg) == 0 &&
  79                    info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) {
  80                        pkt->hdr.flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
  81
  82                        if (info->msg->msg_flags & MSG_EOR)
  83                                pkt->hdr.flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
  84                }
  85        }
  86
  87        trace_virtio_transport_alloc_pkt(src_cid, src_port,
  88                                         dst_cid, dst_port,
  89                                         len,
  90                                         info->type,
  91                                         info->op,
  92                                         info->flags);
  93
  94        return pkt;
  95
  96out:
  97        kfree(pkt->buf);
  98out_pkt:
  99        kfree(pkt);
 100        return NULL;
 101}
 102
 103/* Packet capture */
 104static struct sk_buff *virtio_transport_build_skb(void *opaque)
 105{
 106        struct virtio_vsock_pkt *pkt = opaque;
 107        struct af_vsockmon_hdr *hdr;
 108        struct sk_buff *skb;
 109        size_t payload_len;
 110        void *payload_buf;
 111
 112        /* A packet could be split to fit the RX buffer, so we can retrieve
 113         * the payload length from the header and the buffer pointer taking
 114         * care of the offset in the original packet.
 115         */
 116        payload_len = le32_to_cpu(pkt->hdr.len);
 117        payload_buf = pkt->buf + pkt->off;
 118
 119        skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + payload_len,
 120                        GFP_ATOMIC);
 121        if (!skb)
 122                return NULL;
 123
 124        hdr = skb_put(skb, sizeof(*hdr));
 125
 126        /* pkt->hdr is little-endian so no need to byteswap here */
 127        hdr->src_cid = pkt->hdr.src_cid;
 128        hdr->src_port = pkt->hdr.src_port;
 129        hdr->dst_cid = pkt->hdr.dst_cid;
 130        hdr->dst_port = pkt->hdr.dst_port;
 131
 132        hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
 133        hdr->len = cpu_to_le16(sizeof(pkt->hdr));
 134        memset(hdr->reserved, 0, sizeof(hdr->reserved));
 135
 136        switch (le16_to_cpu(pkt->hdr.op)) {
 137        case VIRTIO_VSOCK_OP_REQUEST:
 138        case VIRTIO_VSOCK_OP_RESPONSE:
 139                hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
 140                break;
 141        case VIRTIO_VSOCK_OP_RST:
 142        case VIRTIO_VSOCK_OP_SHUTDOWN:
 143                hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
 144                break;
 145        case VIRTIO_VSOCK_OP_RW:
 146                hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
 147                break;
 148        case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
 149        case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
 150                hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
 151                break;
 152        default:
 153                hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
 154                break;
 155        }
 156
 157        skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
 158
 159        if (payload_len) {
 160                skb_put_data(skb, payload_buf, payload_len);
 161        }
 162
 163        return skb;
 164}
 165
 166void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
 167{
 168        if (pkt->tap_delivered)
 169                return;
 170
 171        vsock_deliver_tap(virtio_transport_build_skb, pkt);
 172        pkt->tap_delivered = true;
 173}
 174EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
 175
 176static u16 virtio_transport_get_type(struct sock *sk)
 177{
 178        if (sk->sk_type == SOCK_STREAM)
 179                return VIRTIO_VSOCK_TYPE_STREAM;
 180        else
 181                return VIRTIO_VSOCK_TYPE_SEQPACKET;
 182}
 183
 184/* This function can only be used on connecting/connected sockets,
 185 * since a socket assigned to a transport is required.
 186 *
 187 * Do not use on listener sockets!
 188 */
 189static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
 190                                          struct virtio_vsock_pkt_info *info)
 191{
 192        u32 src_cid, src_port, dst_cid, dst_port;
 193        const struct virtio_transport *t_ops;
 194        struct virtio_vsock_sock *vvs;
 195        struct virtio_vsock_pkt *pkt;
 196        u32 pkt_len = info->pkt_len;
 197
 198        info->type = virtio_transport_get_type(sk_vsock(vsk));
 199
 200        t_ops = virtio_transport_get_ops(vsk);
 201        if (unlikely(!t_ops))
 202                return -EFAULT;
 203
 204        src_cid = t_ops->transport.get_local_cid();
 205        src_port = vsk->local_addr.svm_port;
 206        if (!info->remote_cid) {
 207                dst_cid = vsk->remote_addr.svm_cid;
 208                dst_port = vsk->remote_addr.svm_port;
 209        } else {
 210                dst_cid = info->remote_cid;
 211                dst_port = info->remote_port;
 212        }
 213
 214        vvs = vsk->trans;
 215
 216        /* we can send less than pkt_len bytes */
 217        if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
 218                pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
 219
 220        /* virtio_transport_get_credit might return less than pkt_len credit */
 221        pkt_len = virtio_transport_get_credit(vvs, pkt_len);
 222
 223        /* Do not send zero length OP_RW pkt */
 224        if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
 225                return pkt_len;
 226
 227        pkt = virtio_transport_alloc_pkt(info, pkt_len,
 228                                         src_cid, src_port,
 229                                         dst_cid, dst_port);
 230        if (!pkt) {
 231                virtio_transport_put_credit(vvs, pkt_len);
 232                return -ENOMEM;
 233        }
 234
 235        virtio_transport_inc_tx_pkt(vvs, pkt);
 236
 237        return t_ops->send_pkt(pkt);
 238}
 239
 240static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
 241                                        struct virtio_vsock_pkt *pkt)
 242{
 243        if (vvs->rx_bytes + pkt->len > vvs->buf_alloc)
 244                return false;
 245
 246        vvs->rx_bytes += pkt->len;
 247        return true;
 248}
 249
 250static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
 251                                        struct virtio_vsock_pkt *pkt)
 252{
 253        vvs->rx_bytes -= pkt->len;
 254        vvs->fwd_cnt += pkt->len;
 255}
 256
 257void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
 258{
 259        spin_lock_bh(&vvs->rx_lock);
 260        vvs->last_fwd_cnt = vvs->fwd_cnt;
 261        pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
 262        pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
 263        spin_unlock_bh(&vvs->rx_lock);
 264}
 265EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
 266
 267u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
 268{
 269        u32 ret;
 270
 271        spin_lock_bh(&vvs->tx_lock);
 272        ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
 273        if (ret > credit)
 274                ret = credit;
 275        vvs->tx_cnt += ret;
 276        spin_unlock_bh(&vvs->tx_lock);
 277
 278        return ret;
 279}
 280EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
 281
 282void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
 283{
 284        spin_lock_bh(&vvs->tx_lock);
 285        vvs->tx_cnt -= credit;
 286        spin_unlock_bh(&vvs->tx_lock);
 287}
 288EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
 289
 290static int virtio_transport_send_credit_update(struct vsock_sock *vsk)
 291{
 292        struct virtio_vsock_pkt_info info = {
 293                .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
 294                .vsk = vsk,
 295        };
 296
 297        return virtio_transport_send_pkt_info(vsk, &info);
 298}
 299
 300static ssize_t
 301virtio_transport_stream_do_peek(struct vsock_sock *vsk,
 302                                struct msghdr *msg,
 303                                size_t len)
 304{
 305        struct virtio_vsock_sock *vvs = vsk->trans;
 306        struct virtio_vsock_pkt *pkt;
 307        size_t bytes, total = 0, off;
 308        int err = -EFAULT;
 309
 310        spin_lock_bh(&vvs->rx_lock);
 311
 312        list_for_each_entry(pkt, &vvs->rx_queue, list) {
 313                off = pkt->off;
 314
 315                if (total == len)
 316                        break;
 317
 318                while (total < len && off < pkt->len) {
 319                        bytes = len - total;
 320                        if (bytes > pkt->len - off)
 321                                bytes = pkt->len - off;
 322
 323                        /* sk_lock is held by caller so no one else can dequeue.
 324                         * Unlock rx_lock since memcpy_to_msg() may sleep.
 325                         */
 326                        spin_unlock_bh(&vvs->rx_lock);
 327
 328                        err = memcpy_to_msg(msg, pkt->buf + off, bytes);
 329                        if (err)
 330                                goto out;
 331
 332                        spin_lock_bh(&vvs->rx_lock);
 333
 334                        total += bytes;
 335                        off += bytes;
 336                }
 337        }
 338
 339        spin_unlock_bh(&vvs->rx_lock);
 340
 341        return total;
 342
 343out:
 344        if (total)
 345                err = total;
 346        return err;
 347}
 348
 349static ssize_t
 350virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
 351                                   struct msghdr *msg,
 352                                   size_t len)
 353{
 354        struct virtio_vsock_sock *vvs = vsk->trans;
 355        struct virtio_vsock_pkt *pkt;
 356        size_t bytes, total = 0;
 357        u32 free_space;
 358        int err = -EFAULT;
 359
 360        spin_lock_bh(&vvs->rx_lock);
 361        while (total < len && !list_empty(&vvs->rx_queue)) {
 362                pkt = list_first_entry(&vvs->rx_queue,
 363                                       struct virtio_vsock_pkt, list);
 364
 365                bytes = len - total;
 366                if (bytes > pkt->len - pkt->off)
 367                        bytes = pkt->len - pkt->off;
 368
 369                /* sk_lock is held by caller so no one else can dequeue.
 370                 * Unlock rx_lock since memcpy_to_msg() may sleep.
 371                 */
 372                spin_unlock_bh(&vvs->rx_lock);
 373
 374                err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
 375                if (err)
 376                        goto out;
 377
 378                spin_lock_bh(&vvs->rx_lock);
 379
 380                total += bytes;
 381                pkt->off += bytes;
 382                if (pkt->off == pkt->len) {
 383                        virtio_transport_dec_rx_pkt(vvs, pkt);
 384                        list_del(&pkt->list);
 385                        virtio_transport_free_pkt(pkt);
 386                }
 387        }
 388
 389        free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt);
 390
 391        spin_unlock_bh(&vvs->rx_lock);
 392
 393        /* To reduce the number of credit update messages,
 394         * don't update credits as long as lots of space is available.
 395         * Note: the limit chosen here is arbitrary. Setting the limit
 396         * too high causes extra messages. Too low causes transmitter
 397         * stalls. As stalls are in theory more expensive than extra
 398         * messages, we set the limit to a high value. TODO: experiment
 399         * with different values.
 400         */
 401        if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
 402                virtio_transport_send_credit_update(vsk);
 403
 404        return total;
 405
 406out:
 407        if (total)
 408                err = total;
 409        return err;
 410}
 411
 412static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
 413                                                 struct msghdr *msg,
 414                                                 int flags)
 415{
 416        struct virtio_vsock_sock *vvs = vsk->trans;
 417        struct virtio_vsock_pkt *pkt;
 418        int dequeued_len = 0;
 419        size_t user_buf_len = msg_data_left(msg);
 420        bool msg_ready = false;
 421
 422        spin_lock_bh(&vvs->rx_lock);
 423
 424        if (vvs->msg_count == 0) {
 425                spin_unlock_bh(&vvs->rx_lock);
 426                return 0;
 427        }
 428
 429        while (!msg_ready) {
 430                pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, list);
 431
 432                if (dequeued_len >= 0) {
 433                        size_t pkt_len;
 434                        size_t bytes_to_copy;
 435
 436                        pkt_len = (size_t)le32_to_cpu(pkt->hdr.len);
 437                        bytes_to_copy = min(user_buf_len, pkt_len);
 438
 439                        if (bytes_to_copy) {
 440                                int err;
 441
 442                                /* sk_lock is held by caller so no one else can dequeue.
 443                                 * Unlock rx_lock since memcpy_to_msg() may sleep.
 444                                 */
 445                                spin_unlock_bh(&vvs->rx_lock);
 446
 447                                err = memcpy_to_msg(msg, pkt->buf, bytes_to_copy);
 448                                if (err) {
 449                                        /* Copy of message failed. Rest of
 450                                         * fragments will be freed without copy.
 451                                         */
 452                                        dequeued_len = err;
 453                                } else {
 454                                        user_buf_len -= bytes_to_copy;
 455                                }
 456
 457                                spin_lock_bh(&vvs->rx_lock);
 458                        }
 459
 460                        if (dequeued_len >= 0)
 461                                dequeued_len += pkt_len;
 462                }
 463
 464                if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOM) {
 465                        msg_ready = true;
 466                        vvs->msg_count--;
 467
 468                        if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR)
 469                                msg->msg_flags |= MSG_EOR;
 470                }
 471
 472                virtio_transport_dec_rx_pkt(vvs, pkt);
 473                list_del(&pkt->list);
 474                virtio_transport_free_pkt(pkt);
 475        }
 476
 477        spin_unlock_bh(&vvs->rx_lock);
 478
 479        virtio_transport_send_credit_update(vsk);
 480
 481        return dequeued_len;
 482}
 483
 484ssize_t
 485virtio_transport_stream_dequeue(struct vsock_sock *vsk,
 486                                struct msghdr *msg,
 487                                size_t len, int flags)
 488{
 489        if (flags & MSG_PEEK)
 490                return virtio_transport_stream_do_peek(vsk, msg, len);
 491        else
 492                return virtio_transport_stream_do_dequeue(vsk, msg, len);
 493}
 494EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
 495
 496ssize_t
 497virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
 498                                   struct msghdr *msg,
 499                                   int flags)
 500{
 501        if (flags & MSG_PEEK)
 502                return -EOPNOTSUPP;
 503
 504        return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
 505}
 506EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
 507
 508int
 509virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
 510                                   struct msghdr *msg,
 511                                   size_t len)
 512{
 513        struct virtio_vsock_sock *vvs = vsk->trans;
 514
 515        spin_lock_bh(&vvs->tx_lock);
 516
 517        if (len > vvs->peer_buf_alloc) {
 518                spin_unlock_bh(&vvs->tx_lock);
 519                return -EMSGSIZE;
 520        }
 521
 522        spin_unlock_bh(&vvs->tx_lock);
 523
 524        return virtio_transport_stream_enqueue(vsk, msg, len);
 525}
 526EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue);
 527
 528int
 529virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
 530                               struct msghdr *msg,
 531                               size_t len, int flags)
 532{
 533        return -EOPNOTSUPP;
 534}
 535EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
 536
 537s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
 538{
 539        struct virtio_vsock_sock *vvs = vsk->trans;
 540        s64 bytes;
 541
 542        spin_lock_bh(&vvs->rx_lock);
 543        bytes = vvs->rx_bytes;
 544        spin_unlock_bh(&vvs->rx_lock);
 545
 546        return bytes;
 547}
 548EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
 549
 550u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk)
 551{
 552        struct virtio_vsock_sock *vvs = vsk->trans;
 553        u32 msg_count;
 554
 555        spin_lock_bh(&vvs->rx_lock);
 556        msg_count = vvs->msg_count;
 557        spin_unlock_bh(&vvs->rx_lock);
 558
 559        return msg_count;
 560}
 561EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data);
 562
 563static s64 virtio_transport_has_space(struct vsock_sock *vsk)
 564{
 565        struct virtio_vsock_sock *vvs = vsk->trans;
 566        s64 bytes;
 567
 568        bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
 569        if (bytes < 0)
 570                bytes = 0;
 571
 572        return bytes;
 573}
 574
 575s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
 576{
 577        struct virtio_vsock_sock *vvs = vsk->trans;
 578        s64 bytes;
 579
 580        spin_lock_bh(&vvs->tx_lock);
 581        bytes = virtio_transport_has_space(vsk);
 582        spin_unlock_bh(&vvs->tx_lock);
 583
 584        return bytes;
 585}
 586EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
 587
 588int virtio_transport_do_socket_init(struct vsock_sock *vsk,
 589                                    struct vsock_sock *psk)
 590{
 591        struct virtio_vsock_sock *vvs;
 592
 593        vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
 594        if (!vvs)
 595                return -ENOMEM;
 596
 597        vsk->trans = vvs;
 598        vvs->vsk = vsk;
 599        if (psk && psk->trans) {
 600                struct virtio_vsock_sock *ptrans = psk->trans;
 601
 602                vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
 603        }
 604
 605        if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE)
 606                vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE;
 607
 608        vvs->buf_alloc = vsk->buffer_size;
 609
 610        spin_lock_init(&vvs->rx_lock);
 611        spin_lock_init(&vvs->tx_lock);
 612        INIT_LIST_HEAD(&vvs->rx_queue);
 613
 614        return 0;
 615}
 616EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
 617
 618/* sk_lock held by the caller */
 619void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
 620{
 621        struct virtio_vsock_sock *vvs = vsk->trans;
 622
 623        if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
 624                *val = VIRTIO_VSOCK_MAX_BUF_SIZE;
 625
 626        vvs->buf_alloc = *val;
 627
 628        virtio_transport_send_credit_update(vsk);
 629}
 630EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
 631
 632int
 633virtio_transport_notify_poll_in(struct vsock_sock *vsk,
 634                                size_t target,
 635                                bool *data_ready_now)
 636{
 637        if (vsock_stream_has_data(vsk))
 638                *data_ready_now = true;
 639        else
 640                *data_ready_now = false;
 641
 642        return 0;
 643}
 644EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
 645
 646int
 647virtio_transport_notify_poll_out(struct vsock_sock *vsk,
 648                                 size_t target,
 649                                 bool *space_avail_now)
 650{
 651        s64 free_space;
 652
 653        free_space = vsock_stream_has_space(vsk);
 654        if (free_space > 0)
 655                *space_avail_now = true;
 656        else if (free_space == 0)
 657                *space_avail_now = false;
 658
 659        return 0;
 660}
 661EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
 662
 663int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
 664        size_t target, struct vsock_transport_recv_notify_data *data)
 665{
 666        return 0;
 667}
 668EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
 669
 670int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
 671        size_t target, struct vsock_transport_recv_notify_data *data)
 672{
 673        return 0;
 674}
 675EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
 676
 677int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
 678        size_t target, struct vsock_transport_recv_notify_data *data)
 679{
 680        return 0;
 681}
 682EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
 683
 684int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
 685        size_t target, ssize_t copied, bool data_read,
 686        struct vsock_transport_recv_notify_data *data)
 687{
 688        return 0;
 689}
 690EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
 691
 692int virtio_transport_notify_send_init(struct vsock_sock *vsk,
 693        struct vsock_transport_send_notify_data *data)
 694{
 695        return 0;
 696}
 697EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
 698
 699int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
 700        struct vsock_transport_send_notify_data *data)
 701{
 702        return 0;
 703}
 704EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
 705
 706int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
 707        struct vsock_transport_send_notify_data *data)
 708{
 709        return 0;
 710}
 711EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
 712
 713int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
 714        ssize_t written, struct vsock_transport_send_notify_data *data)
 715{
 716        return 0;
 717}
 718EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
 719
 720u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
 721{
 722        return vsk->buffer_size;
 723}
 724EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
 725
 726bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
 727{
 728        return true;
 729}
 730EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
 731
 732bool virtio_transport_stream_allow(u32 cid, u32 port)
 733{
 734        return true;
 735}
 736EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
 737
 738int virtio_transport_dgram_bind(struct vsock_sock *vsk,
 739                                struct sockaddr_vm *addr)
 740{
 741        return -EOPNOTSUPP;
 742}
 743EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
 744
 745bool virtio_transport_dgram_allow(u32 cid, u32 port)
 746{
 747        return false;
 748}
 749EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
 750
 751int virtio_transport_connect(struct vsock_sock *vsk)
 752{
 753        struct virtio_vsock_pkt_info info = {
 754                .op = VIRTIO_VSOCK_OP_REQUEST,
 755                .vsk = vsk,
 756        };
 757
 758        return virtio_transport_send_pkt_info(vsk, &info);
 759}
 760EXPORT_SYMBOL_GPL(virtio_transport_connect);
 761
 762int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
 763{
 764        struct virtio_vsock_pkt_info info = {
 765                .op = VIRTIO_VSOCK_OP_SHUTDOWN,
 766                .flags = (mode & RCV_SHUTDOWN ?
 767                          VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
 768                         (mode & SEND_SHUTDOWN ?
 769                          VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
 770                .vsk = vsk,
 771        };
 772
 773        return virtio_transport_send_pkt_info(vsk, &info);
 774}
 775EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
 776
 777int
 778virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
 779                               struct sockaddr_vm *remote_addr,
 780                               struct msghdr *msg,
 781                               size_t dgram_len)
 782{
 783        return -EOPNOTSUPP;
 784}
 785EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
 786
 787ssize_t
 788virtio_transport_stream_enqueue(struct vsock_sock *vsk,
 789                                struct msghdr *msg,
 790                                size_t len)
 791{
 792        struct virtio_vsock_pkt_info info = {
 793                .op = VIRTIO_VSOCK_OP_RW,
 794                .msg = msg,
 795                .pkt_len = len,
 796                .vsk = vsk,
 797        };
 798
 799        return virtio_transport_send_pkt_info(vsk, &info);
 800}
 801EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
 802
 803void virtio_transport_destruct(struct vsock_sock *vsk)
 804{
 805        struct virtio_vsock_sock *vvs = vsk->trans;
 806
 807        kfree(vvs);
 808}
 809EXPORT_SYMBOL_GPL(virtio_transport_destruct);
 810
 811static int virtio_transport_reset(struct vsock_sock *vsk,
 812                                  struct virtio_vsock_pkt *pkt)
 813{
 814        struct virtio_vsock_pkt_info info = {
 815                .op = VIRTIO_VSOCK_OP_RST,
 816                .reply = !!pkt,
 817                .vsk = vsk,
 818        };
 819
 820        /* Send RST only if the original pkt is not a RST pkt */
 821        if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
 822                return 0;
 823
 824        return virtio_transport_send_pkt_info(vsk, &info);
 825}
 826
 827/* Normally packets are associated with a socket.  There may be no socket if an
 828 * attempt was made to connect to a socket that does not exist.
 829 */
 830static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
 831                                          struct virtio_vsock_pkt *pkt)
 832{
 833        struct virtio_vsock_pkt *reply;
 834        struct virtio_vsock_pkt_info info = {
 835                .op = VIRTIO_VSOCK_OP_RST,
 836                .type = le16_to_cpu(pkt->hdr.type),
 837                .reply = true,
 838        };
 839
 840        /* Send RST only if the original pkt is not a RST pkt */
 841        if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
 842                return 0;
 843
 844        reply = virtio_transport_alloc_pkt(&info, 0,
 845                                           le64_to_cpu(pkt->hdr.dst_cid),
 846                                           le32_to_cpu(pkt->hdr.dst_port),
 847                                           le64_to_cpu(pkt->hdr.src_cid),
 848                                           le32_to_cpu(pkt->hdr.src_port));
 849        if (!reply)
 850                return -ENOMEM;
 851
 852        if (!t) {
 853                virtio_transport_free_pkt(reply);
 854                return -ENOTCONN;
 855        }
 856
 857        return t->send_pkt(reply);
 858}
 859
 860/* This function should be called with sk_lock held and SOCK_DONE set */
 861static void virtio_transport_remove_sock(struct vsock_sock *vsk)
 862{
 863        struct virtio_vsock_sock *vvs = vsk->trans;
 864        struct virtio_vsock_pkt *pkt, *tmp;
 865
 866        /* We don't need to take rx_lock, as the socket is closing and we are
 867         * removing it.
 868         */
 869        list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
 870                list_del(&pkt->list);
 871                virtio_transport_free_pkt(pkt);
 872        }
 873
 874        vsock_remove_sock(vsk);
 875}
 876
 877static void virtio_transport_wait_close(struct sock *sk, long timeout)
 878{
 879        if (timeout) {
 880                DEFINE_WAIT_FUNC(wait, woken_wake_function);
 881
 882                add_wait_queue(sk_sleep(sk), &wait);
 883
 884                do {
 885                        if (sk_wait_event(sk, &timeout,
 886                                          sock_flag(sk, SOCK_DONE), &wait))
 887                                break;
 888                } while (!signal_pending(current) && timeout);
 889
 890                remove_wait_queue(sk_sleep(sk), &wait);
 891        }
 892}
 893
 894static void virtio_transport_do_close(struct vsock_sock *vsk,
 895                                      bool cancel_timeout)
 896{
 897        struct sock *sk = sk_vsock(vsk);
 898
 899        sock_set_flag(sk, SOCK_DONE);
 900        vsk->peer_shutdown = SHUTDOWN_MASK;
 901        if (vsock_stream_has_data(vsk) <= 0)
 902                sk->sk_state = TCP_CLOSING;
 903        sk->sk_state_change(sk);
 904
 905        if (vsk->close_work_scheduled &&
 906            (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
 907                vsk->close_work_scheduled = false;
 908
 909                virtio_transport_remove_sock(vsk);
 910
 911                /* Release refcnt obtained when we scheduled the timeout */
 912                sock_put(sk);
 913        }
 914}
 915
 916static void virtio_transport_close_timeout(struct work_struct *work)
 917{
 918        struct vsock_sock *vsk =
 919                container_of(work, struct vsock_sock, close_work.work);
 920        struct sock *sk = sk_vsock(vsk);
 921
 922        sock_hold(sk);
 923        lock_sock(sk);
 924
 925        if (!sock_flag(sk, SOCK_DONE)) {
 926                (void)virtio_transport_reset(vsk, NULL);
 927
 928                virtio_transport_do_close(vsk, false);
 929        }
 930
 931        vsk->close_work_scheduled = false;
 932
 933        release_sock(sk);
 934        sock_put(sk);
 935}
 936
 937/* User context, vsk->sk is locked */
 938static bool virtio_transport_close(struct vsock_sock *vsk)
 939{
 940        struct sock *sk = &vsk->sk;
 941
 942        if (!(sk->sk_state == TCP_ESTABLISHED ||
 943              sk->sk_state == TCP_CLOSING))
 944                return true;
 945
 946        /* Already received SHUTDOWN from peer, reply with RST */
 947        if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
 948                (void)virtio_transport_reset(vsk, NULL);
 949                return true;
 950        }
 951
 952        if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
 953                (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
 954
 955        if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
 956                virtio_transport_wait_close(sk, sk->sk_lingertime);
 957
 958        if (sock_flag(sk, SOCK_DONE)) {
 959                return true;
 960        }
 961
 962        sock_hold(sk);
 963        INIT_DELAYED_WORK(&vsk->close_work,
 964                          virtio_transport_close_timeout);
 965        vsk->close_work_scheduled = true;
 966        schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
 967        return false;
 968}
 969
 970void virtio_transport_release(struct vsock_sock *vsk)
 971{
 972        struct sock *sk = &vsk->sk;
 973        bool remove_sock = true;
 974
 975        if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
 976                remove_sock = virtio_transport_close(vsk);
 977
 978        if (remove_sock) {
 979                sock_set_flag(sk, SOCK_DONE);
 980                virtio_transport_remove_sock(vsk);
 981        }
 982}
 983EXPORT_SYMBOL_GPL(virtio_transport_release);
 984
 985static int
 986virtio_transport_recv_connecting(struct sock *sk,
 987                                 struct virtio_vsock_pkt *pkt)
 988{
 989        struct vsock_sock *vsk = vsock_sk(sk);
 990        int err;
 991        int skerr;
 992
 993        switch (le16_to_cpu(pkt->hdr.op)) {
 994        case VIRTIO_VSOCK_OP_RESPONSE:
 995                sk->sk_state = TCP_ESTABLISHED;
 996                sk->sk_socket->state = SS_CONNECTED;
 997                vsock_insert_connected(vsk);
 998                sk->sk_state_change(sk);
 999                break;
1000        case VIRTIO_VSOCK_OP_INVALID:
1001                break;
1002        case VIRTIO_VSOCK_OP_RST:
1003                skerr = ECONNRESET;
1004                err = 0;
1005                goto destroy;
1006        default:
1007                skerr = EPROTO;
1008                err = -EINVAL;
1009                goto destroy;
1010        }
1011        return 0;
1012
1013destroy:
1014        virtio_transport_reset(vsk, pkt);
1015        sk->sk_state = TCP_CLOSE;
1016        sk->sk_err = skerr;
1017        sk_error_report(sk);
1018        return err;
1019}
1020
1021static void
1022virtio_transport_recv_enqueue(struct vsock_sock *vsk,
1023                              struct virtio_vsock_pkt *pkt)
1024{
1025        struct virtio_vsock_sock *vvs = vsk->trans;
1026        bool can_enqueue, free_pkt = false;
1027
1028        pkt->len = le32_to_cpu(pkt->hdr.len);
1029        pkt->off = 0;
1030
1031        spin_lock_bh(&vvs->rx_lock);
1032
1033        can_enqueue = virtio_transport_inc_rx_pkt(vvs, pkt);
1034        if (!can_enqueue) {
1035                free_pkt = true;
1036                goto out;
1037        }
1038
1039        if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOM)
1040                vvs->msg_count++;
1041
1042        /* Try to copy small packets into the buffer of last packet queued,
1043         * to avoid wasting memory queueing the entire buffer with a small
1044         * payload.
1045         */
1046        if (pkt->len <= GOOD_COPY_LEN && !list_empty(&vvs->rx_queue)) {
1047                struct virtio_vsock_pkt *last_pkt;
1048
1049                last_pkt = list_last_entry(&vvs->rx_queue,
1050                                           struct virtio_vsock_pkt, list);
1051
1052                /* If there is space in the last packet queued, we copy the
1053                 * new packet in its buffer. We avoid this if the last packet
1054                 * queued has VIRTIO_VSOCK_SEQ_EOM set, because this is
1055                 * delimiter of SEQPACKET message, so 'pkt' is the first packet
1056                 * of a new message.
1057                 */
1058                if ((pkt->len <= last_pkt->buf_len - last_pkt->len) &&
1059                    !(le32_to_cpu(last_pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOM)) {
1060                        memcpy(last_pkt->buf + last_pkt->len, pkt->buf,
1061                               pkt->len);
1062                        last_pkt->len += pkt->len;
1063                        free_pkt = true;
1064                        last_pkt->hdr.flags |= pkt->hdr.flags;
1065                        goto out;
1066                }
1067        }
1068
1069        list_add_tail(&pkt->list, &vvs->rx_queue);
1070
1071out:
1072        spin_unlock_bh(&vvs->rx_lock);
1073        if (free_pkt)
1074                virtio_transport_free_pkt(pkt);
1075}
1076
1077static int
1078virtio_transport_recv_connected(struct sock *sk,
1079                                struct virtio_vsock_pkt *pkt)
1080{
1081        struct vsock_sock *vsk = vsock_sk(sk);
1082        int err = 0;
1083
1084        switch (le16_to_cpu(pkt->hdr.op)) {
1085        case VIRTIO_VSOCK_OP_RW:
1086                virtio_transport_recv_enqueue(vsk, pkt);
1087                sk->sk_data_ready(sk);
1088                return err;
1089        case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
1090                virtio_transport_send_credit_update(vsk);
1091                break;
1092        case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
1093                sk->sk_write_space(sk);
1094                break;
1095        case VIRTIO_VSOCK_OP_SHUTDOWN:
1096                if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
1097                        vsk->peer_shutdown |= RCV_SHUTDOWN;
1098                if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
1099                        vsk->peer_shutdown |= SEND_SHUTDOWN;
1100                if (vsk->peer_shutdown == SHUTDOWN_MASK &&
1101                    vsock_stream_has_data(vsk) <= 0 &&
1102                    !sock_flag(sk, SOCK_DONE)) {
1103                        (void)virtio_transport_reset(vsk, NULL);
1104
1105                        virtio_transport_do_close(vsk, true);
1106                }
1107                if (le32_to_cpu(pkt->hdr.flags))
1108                        sk->sk_state_change(sk);
1109                break;
1110        case VIRTIO_VSOCK_OP_RST:
1111                virtio_transport_do_close(vsk, true);
1112                break;
1113        default:
1114                err = -EINVAL;
1115                break;
1116        }
1117
1118        virtio_transport_free_pkt(pkt);
1119        return err;
1120}
1121
1122static void
1123virtio_transport_recv_disconnecting(struct sock *sk,
1124                                    struct virtio_vsock_pkt *pkt)
1125{
1126        struct vsock_sock *vsk = vsock_sk(sk);
1127
1128        if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
1129                virtio_transport_do_close(vsk, true);
1130}
1131
1132static int
1133virtio_transport_send_response(struct vsock_sock *vsk,
1134                               struct virtio_vsock_pkt *pkt)
1135{
1136        struct virtio_vsock_pkt_info info = {
1137                .op = VIRTIO_VSOCK_OP_RESPONSE,
1138                .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
1139                .remote_port = le32_to_cpu(pkt->hdr.src_port),
1140                .reply = true,
1141                .vsk = vsk,
1142        };
1143
1144        return virtio_transport_send_pkt_info(vsk, &info);
1145}
1146
1147static bool virtio_transport_space_update(struct sock *sk,
1148                                          struct virtio_vsock_pkt *pkt)
1149{
1150        struct vsock_sock *vsk = vsock_sk(sk);
1151        struct virtio_vsock_sock *vvs = vsk->trans;
1152        bool space_available;
1153
1154        /* Listener sockets are not associated with any transport, so we are
1155         * not able to take the state to see if there is space available in the
1156         * remote peer, but since they are only used to receive requests, we
1157         * can assume that there is always space available in the other peer.
1158         */
1159        if (!vvs)
1160                return true;
1161
1162        /* buf_alloc and fwd_cnt is always included in the hdr */
1163        spin_lock_bh(&vvs->tx_lock);
1164        vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
1165        vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
1166        space_available = virtio_transport_has_space(vsk);
1167        spin_unlock_bh(&vvs->tx_lock);
1168        return space_available;
1169}
1170
1171/* Handle server socket */
1172static int
1173virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
1174                             struct virtio_transport *t)
1175{
1176        struct vsock_sock *vsk = vsock_sk(sk);
1177        struct vsock_sock *vchild;
1178        struct sock *child;
1179        int ret;
1180
1181        if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
1182                virtio_transport_reset_no_sock(t, pkt);
1183                return -EINVAL;
1184        }
1185
1186        if (sk_acceptq_is_full(sk)) {
1187                virtio_transport_reset_no_sock(t, pkt);
1188                return -ENOMEM;
1189        }
1190
1191        child = vsock_create_connected(sk);
1192        if (!child) {
1193                virtio_transport_reset_no_sock(t, pkt);
1194                return -ENOMEM;
1195        }
1196
1197        sk_acceptq_added(sk);
1198
1199        lock_sock_nested(child, SINGLE_DEPTH_NESTING);
1200
1201        child->sk_state = TCP_ESTABLISHED;
1202
1203        vchild = vsock_sk(child);
1204        vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
1205                        le32_to_cpu(pkt->hdr.dst_port));
1206        vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
1207                        le32_to_cpu(pkt->hdr.src_port));
1208
1209        ret = vsock_assign_transport(vchild, vsk);
1210        /* Transport assigned (looking at remote_addr) must be the same
1211         * where we received the request.
1212         */
1213        if (ret || vchild->transport != &t->transport) {
1214                release_sock(child);
1215                virtio_transport_reset_no_sock(t, pkt);
1216                sock_put(child);
1217                return ret;
1218        }
1219
1220        if (virtio_transport_space_update(child, pkt))
1221                child->sk_write_space(child);
1222
1223        vsock_insert_connected(vchild);
1224        vsock_enqueue_accept(sk, child);
1225        virtio_transport_send_response(vchild, pkt);
1226
1227        release_sock(child);
1228
1229        sk->sk_data_ready(sk);
1230        return 0;
1231}
1232
1233static bool virtio_transport_valid_type(u16 type)
1234{
1235        return (type == VIRTIO_VSOCK_TYPE_STREAM) ||
1236               (type == VIRTIO_VSOCK_TYPE_SEQPACKET);
1237}
1238
1239/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
1240 * lock.
1241 */
1242void virtio_transport_recv_pkt(struct virtio_transport *t,
1243                               struct virtio_vsock_pkt *pkt)
1244{
1245        struct sockaddr_vm src, dst;
1246        struct vsock_sock *vsk;
1247        struct sock *sk;
1248        bool space_available;
1249
1250        vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
1251                        le32_to_cpu(pkt->hdr.src_port));
1252        vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
1253                        le32_to_cpu(pkt->hdr.dst_port));
1254
1255        trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
1256                                        dst.svm_cid, dst.svm_port,
1257                                        le32_to_cpu(pkt->hdr.len),
1258                                        le16_to_cpu(pkt->hdr.type),
1259                                        le16_to_cpu(pkt->hdr.op),
1260                                        le32_to_cpu(pkt->hdr.flags),
1261                                        le32_to_cpu(pkt->hdr.buf_alloc),
1262                                        le32_to_cpu(pkt->hdr.fwd_cnt));
1263
1264        if (!virtio_transport_valid_type(le16_to_cpu(pkt->hdr.type))) {
1265                (void)virtio_transport_reset_no_sock(t, pkt);
1266                goto free_pkt;
1267        }
1268
1269        /* The socket must be in connected or bound table
1270         * otherwise send reset back
1271         */
1272        sk = vsock_find_connected_socket(&src, &dst);
1273        if (!sk) {
1274                sk = vsock_find_bound_socket(&dst);
1275                if (!sk) {
1276                        (void)virtio_transport_reset_no_sock(t, pkt);
1277                        goto free_pkt;
1278                }
1279        }
1280
1281        if (virtio_transport_get_type(sk) != le16_to_cpu(pkt->hdr.type)) {
1282                (void)virtio_transport_reset_no_sock(t, pkt);
1283                sock_put(sk);
1284                goto free_pkt;
1285        }
1286
1287        vsk = vsock_sk(sk);
1288
1289        lock_sock(sk);
1290
1291        /* Check if sk has been closed before lock_sock */
1292        if (sock_flag(sk, SOCK_DONE)) {
1293                (void)virtio_transport_reset_no_sock(t, pkt);
1294                release_sock(sk);
1295                sock_put(sk);
1296                goto free_pkt;
1297        }
1298
1299        space_available = virtio_transport_space_update(sk, pkt);
1300
1301        /* Update CID in case it has changed after a transport reset event */
1302        if (vsk->local_addr.svm_cid != VMADDR_CID_ANY)
1303                vsk->local_addr.svm_cid = dst.svm_cid;
1304
1305        if (space_available)
1306                sk->sk_write_space(sk);
1307
1308        switch (sk->sk_state) {
1309        case TCP_LISTEN:
1310                virtio_transport_recv_listen(sk, pkt, t);
1311                virtio_transport_free_pkt(pkt);
1312                break;
1313        case TCP_SYN_SENT:
1314                virtio_transport_recv_connecting(sk, pkt);
1315                virtio_transport_free_pkt(pkt);
1316                break;
1317        case TCP_ESTABLISHED:
1318                virtio_transport_recv_connected(sk, pkt);
1319                break;
1320        case TCP_CLOSING:
1321                virtio_transport_recv_disconnecting(sk, pkt);
1322                virtio_transport_free_pkt(pkt);
1323                break;
1324        default:
1325                (void)virtio_transport_reset_no_sock(t, pkt);
1326                virtio_transport_free_pkt(pkt);
1327                break;
1328        }
1329
1330        release_sock(sk);
1331
1332        /* Release refcnt obtained when we fetched this socket out of the
1333         * bound or connected list.
1334         */
1335        sock_put(sk);
1336        return;
1337
1338free_pkt:
1339        virtio_transport_free_pkt(pkt);
1340}
1341EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1342
1343void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1344{
1345        kfree(pkt->buf);
1346        kfree(pkt);
1347}
1348EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1349
1350MODULE_LICENSE("GPL v2");
1351MODULE_AUTHOR("Asias He");
1352MODULE_DESCRIPTION("common code for virtio vsock");
1353