linux/net/vmw_vsock/virtio_transport_common.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * common code for virtio vsock
   4 *
   5 * Copyright (C) 2013-2015 Red Hat, Inc.
   6 * Author: Asias He <asias@redhat.com>
   7 *         Stefan Hajnoczi <stefanha@redhat.com>
   8 */
   9#include <linux/spinlock.h>
  10#include <linux/module.h>
  11#include <linux/sched/signal.h>
  12#include <linux/ctype.h>
  13#include <linux/list.h>
  14#include <linux/virtio.h>
  15#include <linux/virtio_ids.h>
  16#include <linux/virtio_config.h>
  17#include <linux/virtio_vsock.h>
  18#include <uapi/linux/vsockmon.h>
  19
  20#include <net/sock.h>
  21#include <net/af_vsock.h>
  22
  23#define CREATE_TRACE_POINTS
  24#include <trace/events/vsock_virtio_transport_common.h>
  25
  26/* How long to wait for graceful shutdown of a connection */
  27#define VSOCK_CLOSE_TIMEOUT (8 * HZ)
  28
  29static const struct virtio_transport *virtio_transport_get_ops(void)
  30{
  31        const struct vsock_transport *t = vsock_core_get_transport();
  32
  33        return container_of(t, struct virtio_transport, transport);
  34}
  35
  36static struct virtio_vsock_pkt *
  37virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
  38                           size_t len,
  39                           u32 src_cid,
  40                           u32 src_port,
  41                           u32 dst_cid,
  42                           u32 dst_port)
  43{
  44        struct virtio_vsock_pkt *pkt;
  45        int err;
  46
  47        pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
  48        if (!pkt)
  49                return NULL;
  50
  51        pkt->hdr.type           = cpu_to_le16(info->type);
  52        pkt->hdr.op             = cpu_to_le16(info->op);
  53        pkt->hdr.src_cid        = cpu_to_le64(src_cid);
  54        pkt->hdr.dst_cid        = cpu_to_le64(dst_cid);
  55        pkt->hdr.src_port       = cpu_to_le32(src_port);
  56        pkt->hdr.dst_port       = cpu_to_le32(dst_port);
  57        pkt->hdr.flags          = cpu_to_le32(info->flags);
  58        pkt->len                = len;
  59        pkt->hdr.len            = cpu_to_le32(len);
  60        pkt->reply              = info->reply;
  61        pkt->vsk                = info->vsk;
  62
  63        if (info->msg && len > 0) {
  64                pkt->buf = kmalloc(len, GFP_KERNEL);
  65                if (!pkt->buf)
  66                        goto out_pkt;
  67                err = memcpy_from_msg(pkt->buf, info->msg, len);
  68                if (err)
  69                        goto out;
  70        }
  71
  72        trace_virtio_transport_alloc_pkt(src_cid, src_port,
  73                                         dst_cid, dst_port,
  74                                         len,
  75                                         info->type,
  76                                         info->op,
  77                                         info->flags);
  78
  79        return pkt;
  80
  81out:
  82        kfree(pkt->buf);
  83out_pkt:
  84        kfree(pkt);
  85        return NULL;
  86}
  87
  88/* Packet capture */
  89static struct sk_buff *virtio_transport_build_skb(void *opaque)
  90{
  91        struct virtio_vsock_pkt *pkt = opaque;
  92        struct af_vsockmon_hdr *hdr;
  93        struct sk_buff *skb;
  94
  95        skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len,
  96                        GFP_ATOMIC);
  97        if (!skb)
  98                return NULL;
  99
 100        hdr = skb_put(skb, sizeof(*hdr));
 101
 102        /* pkt->hdr is little-endian so no need to byteswap here */
 103        hdr->src_cid = pkt->hdr.src_cid;
 104        hdr->src_port = pkt->hdr.src_port;
 105        hdr->dst_cid = pkt->hdr.dst_cid;
 106        hdr->dst_port = pkt->hdr.dst_port;
 107
 108        hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
 109        hdr->len = cpu_to_le16(sizeof(pkt->hdr));
 110        memset(hdr->reserved, 0, sizeof(hdr->reserved));
 111
 112        switch (le16_to_cpu(pkt->hdr.op)) {
 113        case VIRTIO_VSOCK_OP_REQUEST:
 114        case VIRTIO_VSOCK_OP_RESPONSE:
 115                hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
 116                break;
 117        case VIRTIO_VSOCK_OP_RST:
 118        case VIRTIO_VSOCK_OP_SHUTDOWN:
 119                hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
 120                break;
 121        case VIRTIO_VSOCK_OP_RW:
 122                hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
 123                break;
 124        case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
 125        case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
 126                hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
 127                break;
 128        default:
 129                hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
 130                break;
 131        }
 132
 133        skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
 134
 135        if (pkt->len) {
 136                skb_put_data(skb, pkt->buf, pkt->len);
 137        }
 138
 139        return skb;
 140}
 141
 142void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
 143{
 144        vsock_deliver_tap(virtio_transport_build_skb, pkt);
 145}
 146EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
 147
 148static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
 149                                          struct virtio_vsock_pkt_info *info)
 150{
 151        u32 src_cid, src_port, dst_cid, dst_port;
 152        struct virtio_vsock_sock *vvs;
 153        struct virtio_vsock_pkt *pkt;
 154        u32 pkt_len = info->pkt_len;
 155
 156        src_cid = vm_sockets_get_local_cid();
 157        src_port = vsk->local_addr.svm_port;
 158        if (!info->remote_cid) {
 159                dst_cid = vsk->remote_addr.svm_cid;
 160                dst_port = vsk->remote_addr.svm_port;
 161        } else {
 162                dst_cid = info->remote_cid;
 163                dst_port = info->remote_port;
 164        }
 165
 166        vvs = vsk->trans;
 167
 168        /* we can send less than pkt_len bytes */
 169        if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
 170                pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
 171
 172        /* virtio_transport_get_credit might return less than pkt_len credit */
 173        pkt_len = virtio_transport_get_credit(vvs, pkt_len);
 174
 175        /* Do not send zero length OP_RW pkt */
 176        if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
 177                return pkt_len;
 178
 179        pkt = virtio_transport_alloc_pkt(info, pkt_len,
 180                                         src_cid, src_port,
 181                                         dst_cid, dst_port);
 182        if (!pkt) {
 183                virtio_transport_put_credit(vvs, pkt_len);
 184                return -ENOMEM;
 185        }
 186
 187        virtio_transport_inc_tx_pkt(vvs, pkt);
 188
 189        return virtio_transport_get_ops()->send_pkt(pkt);
 190}
 191
 192static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
 193                                        struct virtio_vsock_pkt *pkt)
 194{
 195        vvs->rx_bytes += pkt->len;
 196}
 197
 198static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
 199                                        struct virtio_vsock_pkt *pkt)
 200{
 201        vvs->rx_bytes -= pkt->len;
 202        vvs->fwd_cnt += pkt->len;
 203}
 204
 205void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
 206{
 207        spin_lock_bh(&vvs->tx_lock);
 208        pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
 209        pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
 210        spin_unlock_bh(&vvs->tx_lock);
 211}
 212EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
 213
 214u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
 215{
 216        u32 ret;
 217
 218        spin_lock_bh(&vvs->tx_lock);
 219        ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
 220        if (ret > credit)
 221                ret = credit;
 222        vvs->tx_cnt += ret;
 223        spin_unlock_bh(&vvs->tx_lock);
 224
 225        return ret;
 226}
 227EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
 228
 229void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
 230{
 231        spin_lock_bh(&vvs->tx_lock);
 232        vvs->tx_cnt -= credit;
 233        spin_unlock_bh(&vvs->tx_lock);
 234}
 235EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
 236
 237static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
 238                                               int type,
 239                                               struct virtio_vsock_hdr *hdr)
 240{
 241        struct virtio_vsock_pkt_info info = {
 242                .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
 243                .type = type,
 244                .vsk = vsk,
 245        };
 246
 247        return virtio_transport_send_pkt_info(vsk, &info);
 248}
 249
 250static ssize_t
 251virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
 252                                   struct msghdr *msg,
 253                                   size_t len)
 254{
 255        struct virtio_vsock_sock *vvs = vsk->trans;
 256        struct virtio_vsock_pkt *pkt;
 257        size_t bytes, total = 0;
 258        int err = -EFAULT;
 259
 260        spin_lock_bh(&vvs->rx_lock);
 261        while (total < len && !list_empty(&vvs->rx_queue)) {
 262                pkt = list_first_entry(&vvs->rx_queue,
 263                                       struct virtio_vsock_pkt, list);
 264
 265                bytes = len - total;
 266                if (bytes > pkt->len - pkt->off)
 267                        bytes = pkt->len - pkt->off;
 268
 269                /* sk_lock is held by caller so no one else can dequeue.
 270                 * Unlock rx_lock since memcpy_to_msg() may sleep.
 271                 */
 272                spin_unlock_bh(&vvs->rx_lock);
 273
 274                err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
 275                if (err)
 276                        goto out;
 277
 278                spin_lock_bh(&vvs->rx_lock);
 279
 280                total += bytes;
 281                pkt->off += bytes;
 282                if (pkt->off == pkt->len) {
 283                        virtio_transport_dec_rx_pkt(vvs, pkt);
 284                        list_del(&pkt->list);
 285                        virtio_transport_free_pkt(pkt);
 286                }
 287        }
 288        spin_unlock_bh(&vvs->rx_lock);
 289
 290        /* Send a credit pkt to peer */
 291        virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
 292                                            NULL);
 293
 294        return total;
 295
 296out:
 297        if (total)
 298                err = total;
 299        return err;
 300}
 301
 302ssize_t
 303virtio_transport_stream_dequeue(struct vsock_sock *vsk,
 304                                struct msghdr *msg,
 305                                size_t len, int flags)
 306{
 307        if (flags & MSG_PEEK)
 308                return -EOPNOTSUPP;
 309
 310        return virtio_transport_stream_do_dequeue(vsk, msg, len);
 311}
 312EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
 313
 314int
 315virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
 316                               struct msghdr *msg,
 317                               size_t len, int flags)
 318{
 319        return -EOPNOTSUPP;
 320}
 321EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
 322
 323s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
 324{
 325        struct virtio_vsock_sock *vvs = vsk->trans;
 326        s64 bytes;
 327
 328        spin_lock_bh(&vvs->rx_lock);
 329        bytes = vvs->rx_bytes;
 330        spin_unlock_bh(&vvs->rx_lock);
 331
 332        return bytes;
 333}
 334EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
 335
 336static s64 virtio_transport_has_space(struct vsock_sock *vsk)
 337{
 338        struct virtio_vsock_sock *vvs = vsk->trans;
 339        s64 bytes;
 340
 341        bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
 342        if (bytes < 0)
 343                bytes = 0;
 344
 345        return bytes;
 346}
 347
 348s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
 349{
 350        struct virtio_vsock_sock *vvs = vsk->trans;
 351        s64 bytes;
 352
 353        spin_lock_bh(&vvs->tx_lock);
 354        bytes = virtio_transport_has_space(vsk);
 355        spin_unlock_bh(&vvs->tx_lock);
 356
 357        return bytes;
 358}
 359EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
 360
 361int virtio_transport_do_socket_init(struct vsock_sock *vsk,
 362                                    struct vsock_sock *psk)
 363{
 364        struct virtio_vsock_sock *vvs;
 365
 366        vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
 367        if (!vvs)
 368                return -ENOMEM;
 369
 370        vsk->trans = vvs;
 371        vvs->vsk = vsk;
 372        if (psk) {
 373                struct virtio_vsock_sock *ptrans = psk->trans;
 374
 375                vvs->buf_size   = ptrans->buf_size;
 376                vvs->buf_size_min = ptrans->buf_size_min;
 377                vvs->buf_size_max = ptrans->buf_size_max;
 378                vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
 379        } else {
 380                vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
 381                vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
 382                vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
 383        }
 384
 385        vvs->buf_alloc = vvs->buf_size;
 386
 387        spin_lock_init(&vvs->rx_lock);
 388        spin_lock_init(&vvs->tx_lock);
 389        INIT_LIST_HEAD(&vvs->rx_queue);
 390
 391        return 0;
 392}
 393EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
 394
 395u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
 396{
 397        struct virtio_vsock_sock *vvs = vsk->trans;
 398
 399        return vvs->buf_size;
 400}
 401EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
 402
 403u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
 404{
 405        struct virtio_vsock_sock *vvs = vsk->trans;
 406
 407        return vvs->buf_size_min;
 408}
 409EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
 410
 411u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
 412{
 413        struct virtio_vsock_sock *vvs = vsk->trans;
 414
 415        return vvs->buf_size_max;
 416}
 417EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
 418
 419void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
 420{
 421        struct virtio_vsock_sock *vvs = vsk->trans;
 422
 423        if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
 424                val = VIRTIO_VSOCK_MAX_BUF_SIZE;
 425        if (val < vvs->buf_size_min)
 426                vvs->buf_size_min = val;
 427        if (val > vvs->buf_size_max)
 428                vvs->buf_size_max = val;
 429        vvs->buf_size = val;
 430        vvs->buf_alloc = val;
 431}
 432EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
 433
 434void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
 435{
 436        struct virtio_vsock_sock *vvs = vsk->trans;
 437
 438        if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
 439                val = VIRTIO_VSOCK_MAX_BUF_SIZE;
 440        if (val > vvs->buf_size)
 441                vvs->buf_size = val;
 442        vvs->buf_size_min = val;
 443}
 444EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
 445
 446void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
 447{
 448        struct virtio_vsock_sock *vvs = vsk->trans;
 449
 450        if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
 451                val = VIRTIO_VSOCK_MAX_BUF_SIZE;
 452        if (val < vvs->buf_size)
 453                vvs->buf_size = val;
 454        vvs->buf_size_max = val;
 455}
 456EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
 457
 458int
 459virtio_transport_notify_poll_in(struct vsock_sock *vsk,
 460                                size_t target,
 461                                bool *data_ready_now)
 462{
 463        if (vsock_stream_has_data(vsk))
 464                *data_ready_now = true;
 465        else
 466                *data_ready_now = false;
 467
 468        return 0;
 469}
 470EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
 471
 472int
 473virtio_transport_notify_poll_out(struct vsock_sock *vsk,
 474                                 size_t target,
 475                                 bool *space_avail_now)
 476{
 477        s64 free_space;
 478
 479        free_space = vsock_stream_has_space(vsk);
 480        if (free_space > 0)
 481                *space_avail_now = true;
 482        else if (free_space == 0)
 483                *space_avail_now = false;
 484
 485        return 0;
 486}
 487EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
 488
 489int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
 490        size_t target, struct vsock_transport_recv_notify_data *data)
 491{
 492        return 0;
 493}
 494EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
 495
 496int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
 497        size_t target, struct vsock_transport_recv_notify_data *data)
 498{
 499        return 0;
 500}
 501EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
 502
 503int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
 504        size_t target, struct vsock_transport_recv_notify_data *data)
 505{
 506        return 0;
 507}
 508EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
 509
 510int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
 511        size_t target, ssize_t copied, bool data_read,
 512        struct vsock_transport_recv_notify_data *data)
 513{
 514        return 0;
 515}
 516EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
 517
 518int virtio_transport_notify_send_init(struct vsock_sock *vsk,
 519        struct vsock_transport_send_notify_data *data)
 520{
 521        return 0;
 522}
 523EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
 524
 525int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
 526        struct vsock_transport_send_notify_data *data)
 527{
 528        return 0;
 529}
 530EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
 531
 532int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
 533        struct vsock_transport_send_notify_data *data)
 534{
 535        return 0;
 536}
 537EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
 538
 539int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
 540        ssize_t written, struct vsock_transport_send_notify_data *data)
 541{
 542        return 0;
 543}
 544EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
 545
 546u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
 547{
 548        struct virtio_vsock_sock *vvs = vsk->trans;
 549
 550        return vvs->buf_size;
 551}
 552EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
 553
 554bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
 555{
 556        return true;
 557}
 558EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
 559
 560bool virtio_transport_stream_allow(u32 cid, u32 port)
 561{
 562        return true;
 563}
 564EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
 565
 566int virtio_transport_dgram_bind(struct vsock_sock *vsk,
 567                                struct sockaddr_vm *addr)
 568{
 569        return -EOPNOTSUPP;
 570}
 571EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
 572
 573bool virtio_transport_dgram_allow(u32 cid, u32 port)
 574{
 575        return false;
 576}
 577EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
 578
 579int virtio_transport_connect(struct vsock_sock *vsk)
 580{
 581        struct virtio_vsock_pkt_info info = {
 582                .op = VIRTIO_VSOCK_OP_REQUEST,
 583                .type = VIRTIO_VSOCK_TYPE_STREAM,
 584                .vsk = vsk,
 585        };
 586
 587        return virtio_transport_send_pkt_info(vsk, &info);
 588}
 589EXPORT_SYMBOL_GPL(virtio_transport_connect);
 590
 591int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
 592{
 593        struct virtio_vsock_pkt_info info = {
 594                .op = VIRTIO_VSOCK_OP_SHUTDOWN,
 595                .type = VIRTIO_VSOCK_TYPE_STREAM,
 596                .flags = (mode & RCV_SHUTDOWN ?
 597                          VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
 598                         (mode & SEND_SHUTDOWN ?
 599                          VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
 600                .vsk = vsk,
 601        };
 602
 603        return virtio_transport_send_pkt_info(vsk, &info);
 604}
 605EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
 606
 607int
 608virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
 609                               struct sockaddr_vm *remote_addr,
 610                               struct msghdr *msg,
 611                               size_t dgram_len)
 612{
 613        return -EOPNOTSUPP;
 614}
 615EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
 616
 617ssize_t
 618virtio_transport_stream_enqueue(struct vsock_sock *vsk,
 619                                struct msghdr *msg,
 620                                size_t len)
 621{
 622        struct virtio_vsock_pkt_info info = {
 623                .op = VIRTIO_VSOCK_OP_RW,
 624                .type = VIRTIO_VSOCK_TYPE_STREAM,
 625                .msg = msg,
 626                .pkt_len = len,
 627                .vsk = vsk,
 628        };
 629
 630        return virtio_transport_send_pkt_info(vsk, &info);
 631}
 632EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
 633
 634void virtio_transport_destruct(struct vsock_sock *vsk)
 635{
 636        struct virtio_vsock_sock *vvs = vsk->trans;
 637
 638        kfree(vvs);
 639}
 640EXPORT_SYMBOL_GPL(virtio_transport_destruct);
 641
 642static int virtio_transport_reset(struct vsock_sock *vsk,
 643                                  struct virtio_vsock_pkt *pkt)
 644{
 645        struct virtio_vsock_pkt_info info = {
 646                .op = VIRTIO_VSOCK_OP_RST,
 647                .type = VIRTIO_VSOCK_TYPE_STREAM,
 648                .reply = !!pkt,
 649                .vsk = vsk,
 650        };
 651
 652        /* Send RST only if the original pkt is not a RST pkt */
 653        if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
 654                return 0;
 655
 656        return virtio_transport_send_pkt_info(vsk, &info);
 657}
 658
 659/* Normally packets are associated with a socket.  There may be no socket if an
 660 * attempt was made to connect to a socket that does not exist.
 661 */
 662static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
 663{
 664        const struct virtio_transport *t;
 665        struct virtio_vsock_pkt *reply;
 666        struct virtio_vsock_pkt_info info = {
 667                .op = VIRTIO_VSOCK_OP_RST,
 668                .type = le16_to_cpu(pkt->hdr.type),
 669                .reply = true,
 670        };
 671
 672        /* Send RST only if the original pkt is not a RST pkt */
 673        if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
 674                return 0;
 675
 676        reply = virtio_transport_alloc_pkt(&info, 0,
 677                                           le64_to_cpu(pkt->hdr.dst_cid),
 678                                           le32_to_cpu(pkt->hdr.dst_port),
 679                                           le64_to_cpu(pkt->hdr.src_cid),
 680                                           le32_to_cpu(pkt->hdr.src_port));
 681        if (!reply)
 682                return -ENOMEM;
 683
 684        t = virtio_transport_get_ops();
 685        if (!t) {
 686                virtio_transport_free_pkt(reply);
 687                return -ENOTCONN;
 688        }
 689
 690        return t->send_pkt(reply);
 691}
 692
 693static void virtio_transport_wait_close(struct sock *sk, long timeout)
 694{
 695        if (timeout) {
 696                DEFINE_WAIT_FUNC(wait, woken_wake_function);
 697
 698                add_wait_queue(sk_sleep(sk), &wait);
 699
 700                do {
 701                        if (sk_wait_event(sk, &timeout,
 702                                          sock_flag(sk, SOCK_DONE), &wait))
 703                                break;
 704                } while (!signal_pending(current) && timeout);
 705
 706                remove_wait_queue(sk_sleep(sk), &wait);
 707        }
 708}
 709
 710static void virtio_transport_do_close(struct vsock_sock *vsk,
 711                                      bool cancel_timeout)
 712{
 713        struct sock *sk = sk_vsock(vsk);
 714
 715        sock_set_flag(sk, SOCK_DONE);
 716        vsk->peer_shutdown = SHUTDOWN_MASK;
 717        if (vsock_stream_has_data(vsk) <= 0)
 718                sk->sk_state = TCP_CLOSING;
 719        sk->sk_state_change(sk);
 720
 721        if (vsk->close_work_scheduled &&
 722            (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
 723                vsk->close_work_scheduled = false;
 724
 725                vsock_remove_sock(vsk);
 726
 727                /* Release refcnt obtained when we scheduled the timeout */
 728                sock_put(sk);
 729        }
 730}
 731
 732static void virtio_transport_close_timeout(struct work_struct *work)
 733{
 734        struct vsock_sock *vsk =
 735                container_of(work, struct vsock_sock, close_work.work);
 736        struct sock *sk = sk_vsock(vsk);
 737
 738        sock_hold(sk);
 739        lock_sock(sk);
 740
 741        if (!sock_flag(sk, SOCK_DONE)) {
 742                (void)virtio_transport_reset(vsk, NULL);
 743
 744                virtio_transport_do_close(vsk, false);
 745        }
 746
 747        vsk->close_work_scheduled = false;
 748
 749        release_sock(sk);
 750        sock_put(sk);
 751}
 752
 753/* User context, vsk->sk is locked */
 754static bool virtio_transport_close(struct vsock_sock *vsk)
 755{
 756        struct sock *sk = &vsk->sk;
 757
 758        if (!(sk->sk_state == TCP_ESTABLISHED ||
 759              sk->sk_state == TCP_CLOSING))
 760                return true;
 761
 762        /* Already received SHUTDOWN from peer, reply with RST */
 763        if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
 764                (void)virtio_transport_reset(vsk, NULL);
 765                return true;
 766        }
 767
 768        if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
 769                (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
 770
 771        if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
 772                virtio_transport_wait_close(sk, sk->sk_lingertime);
 773
 774        if (sock_flag(sk, SOCK_DONE)) {
 775                return true;
 776        }
 777
 778        sock_hold(sk);
 779        INIT_DELAYED_WORK(&vsk->close_work,
 780                          virtio_transport_close_timeout);
 781        vsk->close_work_scheduled = true;
 782        schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
 783        return false;
 784}
 785
 786void virtio_transport_release(struct vsock_sock *vsk)
 787{
 788        struct virtio_vsock_sock *vvs = vsk->trans;
 789        struct virtio_vsock_pkt *pkt, *tmp;
 790        struct sock *sk = &vsk->sk;
 791        bool remove_sock = true;
 792
 793        lock_sock(sk);
 794        if (sk->sk_type == SOCK_STREAM)
 795                remove_sock = virtio_transport_close(vsk);
 796
 797        list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
 798                list_del(&pkt->list);
 799                virtio_transport_free_pkt(pkt);
 800        }
 801        release_sock(sk);
 802
 803        if (remove_sock)
 804                vsock_remove_sock(vsk);
 805}
 806EXPORT_SYMBOL_GPL(virtio_transport_release);
 807
 808static int
 809virtio_transport_recv_connecting(struct sock *sk,
 810                                 struct virtio_vsock_pkt *pkt)
 811{
 812        struct vsock_sock *vsk = vsock_sk(sk);
 813        int err;
 814        int skerr;
 815
 816        switch (le16_to_cpu(pkt->hdr.op)) {
 817        case VIRTIO_VSOCK_OP_RESPONSE:
 818                sk->sk_state = TCP_ESTABLISHED;
 819                sk->sk_socket->state = SS_CONNECTED;
 820                vsock_insert_connected(vsk);
 821                sk->sk_state_change(sk);
 822                break;
 823        case VIRTIO_VSOCK_OP_INVALID:
 824                break;
 825        case VIRTIO_VSOCK_OP_RST:
 826                skerr = ECONNRESET;
 827                err = 0;
 828                goto destroy;
 829        default:
 830                skerr = EPROTO;
 831                err = -EINVAL;
 832                goto destroy;
 833        }
 834        return 0;
 835
 836destroy:
 837        virtio_transport_reset(vsk, pkt);
 838        sk->sk_state = TCP_CLOSE;
 839        sk->sk_err = skerr;
 840        sk->sk_error_report(sk);
 841        return err;
 842}
 843
 844static int
 845virtio_transport_recv_connected(struct sock *sk,
 846                                struct virtio_vsock_pkt *pkt)
 847{
 848        struct vsock_sock *vsk = vsock_sk(sk);
 849        struct virtio_vsock_sock *vvs = vsk->trans;
 850        int err = 0;
 851
 852        switch (le16_to_cpu(pkt->hdr.op)) {
 853        case VIRTIO_VSOCK_OP_RW:
 854                pkt->len = le32_to_cpu(pkt->hdr.len);
 855                pkt->off = 0;
 856
 857                spin_lock_bh(&vvs->rx_lock);
 858                virtio_transport_inc_rx_pkt(vvs, pkt);
 859                list_add_tail(&pkt->list, &vvs->rx_queue);
 860                spin_unlock_bh(&vvs->rx_lock);
 861
 862                sk->sk_data_ready(sk);
 863                return err;
 864        case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
 865                sk->sk_write_space(sk);
 866                break;
 867        case VIRTIO_VSOCK_OP_SHUTDOWN:
 868                if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
 869                        vsk->peer_shutdown |= RCV_SHUTDOWN;
 870                if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
 871                        vsk->peer_shutdown |= SEND_SHUTDOWN;
 872                if (vsk->peer_shutdown == SHUTDOWN_MASK &&
 873                    vsock_stream_has_data(vsk) <= 0) {
 874                        sock_set_flag(sk, SOCK_DONE);
 875                        sk->sk_state = TCP_CLOSING;
 876                }
 877                if (le32_to_cpu(pkt->hdr.flags))
 878                        sk->sk_state_change(sk);
 879                break;
 880        case VIRTIO_VSOCK_OP_RST:
 881                virtio_transport_do_close(vsk, true);
 882                break;
 883        default:
 884                err = -EINVAL;
 885                break;
 886        }
 887
 888        virtio_transport_free_pkt(pkt);
 889        return err;
 890}
 891
 892static void
 893virtio_transport_recv_disconnecting(struct sock *sk,
 894                                    struct virtio_vsock_pkt *pkt)
 895{
 896        struct vsock_sock *vsk = vsock_sk(sk);
 897
 898        if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
 899                virtio_transport_do_close(vsk, true);
 900}
 901
 902static int
 903virtio_transport_send_response(struct vsock_sock *vsk,
 904                               struct virtio_vsock_pkt *pkt)
 905{
 906        struct virtio_vsock_pkt_info info = {
 907                .op = VIRTIO_VSOCK_OP_RESPONSE,
 908                .type = VIRTIO_VSOCK_TYPE_STREAM,
 909                .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
 910                .remote_port = le32_to_cpu(pkt->hdr.src_port),
 911                .reply = true,
 912                .vsk = vsk,
 913        };
 914
 915        return virtio_transport_send_pkt_info(vsk, &info);
 916}
 917
 918/* Handle server socket */
 919static int
 920virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
 921{
 922        struct vsock_sock *vsk = vsock_sk(sk);
 923        struct vsock_sock *vchild;
 924        struct sock *child;
 925
 926        if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
 927                virtio_transport_reset(vsk, pkt);
 928                return -EINVAL;
 929        }
 930
 931        if (sk_acceptq_is_full(sk)) {
 932                virtio_transport_reset(vsk, pkt);
 933                return -ENOMEM;
 934        }
 935
 936        child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
 937                               sk->sk_type, 0);
 938        if (!child) {
 939                virtio_transport_reset(vsk, pkt);
 940                return -ENOMEM;
 941        }
 942
 943        sk->sk_ack_backlog++;
 944
 945        lock_sock_nested(child, SINGLE_DEPTH_NESTING);
 946
 947        child->sk_state = TCP_ESTABLISHED;
 948
 949        vchild = vsock_sk(child);
 950        vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
 951                        le32_to_cpu(pkt->hdr.dst_port));
 952        vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
 953                        le32_to_cpu(pkt->hdr.src_port));
 954
 955        vsock_insert_connected(vchild);
 956        vsock_enqueue_accept(sk, child);
 957        virtio_transport_send_response(vchild, pkt);
 958
 959        release_sock(child);
 960
 961        sk->sk_data_ready(sk);
 962        return 0;
 963}
 964
 965static bool virtio_transport_space_update(struct sock *sk,
 966                                          struct virtio_vsock_pkt *pkt)
 967{
 968        struct vsock_sock *vsk = vsock_sk(sk);
 969        struct virtio_vsock_sock *vvs = vsk->trans;
 970        bool space_available;
 971
 972        /* buf_alloc and fwd_cnt is always included in the hdr */
 973        spin_lock_bh(&vvs->tx_lock);
 974        vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
 975        vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
 976        space_available = virtio_transport_has_space(vsk);
 977        spin_unlock_bh(&vvs->tx_lock);
 978        return space_available;
 979}
 980
 981/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
 982 * lock.
 983 */
 984void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
 985{
 986        struct sockaddr_vm src, dst;
 987        struct vsock_sock *vsk;
 988        struct sock *sk;
 989        bool space_available;
 990
 991        vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
 992                        le32_to_cpu(pkt->hdr.src_port));
 993        vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
 994                        le32_to_cpu(pkt->hdr.dst_port));
 995
 996        trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
 997                                        dst.svm_cid, dst.svm_port,
 998                                        le32_to_cpu(pkt->hdr.len),
 999                                        le16_to_cpu(pkt->hdr.type),
1000                                        le16_to_cpu(pkt->hdr.op),
1001                                        le32_to_cpu(pkt->hdr.flags),
1002                                        le32_to_cpu(pkt->hdr.buf_alloc),
1003                                        le32_to_cpu(pkt->hdr.fwd_cnt));
1004
1005        if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
1006                (void)virtio_transport_reset_no_sock(pkt);
1007                goto free_pkt;
1008        }
1009
1010        /* The socket must be in connected or bound table
1011         * otherwise send reset back
1012         */
1013        sk = vsock_find_connected_socket(&src, &dst);
1014        if (!sk) {
1015                sk = vsock_find_bound_socket(&dst);
1016                if (!sk) {
1017                        (void)virtio_transport_reset_no_sock(pkt);
1018                        goto free_pkt;
1019                }
1020        }
1021
1022        vsk = vsock_sk(sk);
1023
1024        space_available = virtio_transport_space_update(sk, pkt);
1025
1026        lock_sock(sk);
1027
1028        /* Update CID in case it has changed after a transport reset event */
1029        vsk->local_addr.svm_cid = dst.svm_cid;
1030
1031        if (space_available)
1032                sk->sk_write_space(sk);
1033
1034        switch (sk->sk_state) {
1035        case TCP_LISTEN:
1036                virtio_transport_recv_listen(sk, pkt);
1037                virtio_transport_free_pkt(pkt);
1038                break;
1039        case TCP_SYN_SENT:
1040                virtio_transport_recv_connecting(sk, pkt);
1041                virtio_transport_free_pkt(pkt);
1042                break;
1043        case TCP_ESTABLISHED:
1044                virtio_transport_recv_connected(sk, pkt);
1045                break;
1046        case TCP_CLOSING:
1047                virtio_transport_recv_disconnecting(sk, pkt);
1048                virtio_transport_free_pkt(pkt);
1049                break;
1050        default:
1051                virtio_transport_free_pkt(pkt);
1052                break;
1053        }
1054        release_sock(sk);
1055
1056        /* Release refcnt obtained when we fetched this socket out of the
1057         * bound or connected list.
1058         */
1059        sock_put(sk);
1060        return;
1061
1062free_pkt:
1063        virtio_transport_free_pkt(pkt);
1064}
1065EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1066
1067void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1068{
1069        kfree(pkt->buf);
1070        kfree(pkt);
1071}
1072EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1073
1074MODULE_LICENSE("GPL v2");
1075MODULE_AUTHOR("Asias He");
1076MODULE_DESCRIPTION("common code for virtio vsock");
1077