linux/net/vmw_vsock/virtio_transport_common.c
<<
>>
Prefs
   1/*
   2 * common code for virtio vsock
   3 *
   4 * Copyright (C) 2013-2015 Red Hat, Inc.
   5 * Author: Asias He <asias@redhat.com>
   6 *         Stefan Hajnoczi <stefanha@redhat.com>
   7 *
   8 * This work is licensed under the terms of the GNU GPL, version 2.
   9 */
  10#include <linux/spinlock.h>
  11#include <linux/module.h>
  12#include <linux/ctype.h>
  13#include <linux/list.h>
  14#include <linux/virtio.h>
  15#include <linux/virtio_ids.h>
  16#include <linux/virtio_config.h>
  17#include <linux/virtio_vsock.h>
  18
  19#include <net/sock.h>
  20#include <net/af_vsock.h>
  21
  22#define CREATE_TRACE_POINTS
  23#include <trace/events/vsock_virtio_transport_common.h>
  24
  25/* How long to wait for graceful shutdown of a connection */
  26#define VSOCK_CLOSE_TIMEOUT (8 * HZ)
  27
  28static const struct virtio_transport *virtio_transport_get_ops(void)
  29{
  30        const struct vsock_transport *t = vsock_core_get_transport();
  31
  32        return container_of(t, struct virtio_transport, transport);
  33}
  34
  35struct virtio_vsock_pkt *
  36virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
  37                           size_t len,
  38                           u32 src_cid,
  39                           u32 src_port,
  40                           u32 dst_cid,
  41                           u32 dst_port)
  42{
  43        struct virtio_vsock_pkt *pkt;
  44        int err;
  45
  46        pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
  47        if (!pkt)
  48                return NULL;
  49
  50        pkt->hdr.type           = cpu_to_le16(info->type);
  51        pkt->hdr.op             = cpu_to_le16(info->op);
  52        pkt->hdr.src_cid        = cpu_to_le64(src_cid);
  53        pkt->hdr.dst_cid        = cpu_to_le64(dst_cid);
  54        pkt->hdr.src_port       = cpu_to_le32(src_port);
  55        pkt->hdr.dst_port       = cpu_to_le32(dst_port);
  56        pkt->hdr.flags          = cpu_to_le32(info->flags);
  57        pkt->len                = len;
  58        pkt->hdr.len            = cpu_to_le32(len);
  59        pkt->reply              = info->reply;
  60
  61        if (info->msg && len > 0) {
  62                pkt->buf = kmalloc(len, GFP_KERNEL);
  63                if (!pkt->buf)
  64                        goto out_pkt;
  65                err = memcpy_from_msg(pkt->buf, info->msg, len);
  66                if (err)
  67                        goto out;
  68        }
  69
  70        trace_virtio_transport_alloc_pkt(src_cid, src_port,
  71                                         dst_cid, dst_port,
  72                                         len,
  73                                         info->type,
  74                                         info->op,
  75                                         info->flags);
  76
  77        return pkt;
  78
  79out:
  80        kfree(pkt->buf);
  81out_pkt:
  82        kfree(pkt);
  83        return NULL;
  84}
  85EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt);
  86
  87static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
  88                                          struct virtio_vsock_pkt_info *info)
  89{
  90        u32 src_cid, src_port, dst_cid, dst_port;
  91        struct virtio_vsock_sock *vvs;
  92        struct virtio_vsock_pkt *pkt;
  93        u32 pkt_len = info->pkt_len;
  94
  95        src_cid = vm_sockets_get_local_cid();
  96        src_port = vsk->local_addr.svm_port;
  97        if (!info->remote_cid) {
  98                dst_cid = vsk->remote_addr.svm_cid;
  99                dst_port = vsk->remote_addr.svm_port;
 100        } else {
 101                dst_cid = info->remote_cid;
 102                dst_port = info->remote_port;
 103        }
 104
 105        vvs = vsk->trans;
 106
 107        /* we can send less than pkt_len bytes */
 108        if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
 109                pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
 110
 111        /* virtio_transport_get_credit might return less than pkt_len credit */
 112        pkt_len = virtio_transport_get_credit(vvs, pkt_len);
 113
 114        /* Do not send zero length OP_RW pkt */
 115        if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
 116                return pkt_len;
 117
 118        pkt = virtio_transport_alloc_pkt(info, pkt_len,
 119                                         src_cid, src_port,
 120                                         dst_cid, dst_port);
 121        if (!pkt) {
 122                virtio_transport_put_credit(vvs, pkt_len);
 123                return -ENOMEM;
 124        }
 125
 126        virtio_transport_inc_tx_pkt(vvs, pkt);
 127
 128        return virtio_transport_get_ops()->send_pkt(pkt);
 129}
 130
 131static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
 132                                        struct virtio_vsock_pkt *pkt)
 133{
 134        vvs->rx_bytes += pkt->len;
 135}
 136
 137static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
 138                                        struct virtio_vsock_pkt *pkt)
 139{
 140        vvs->rx_bytes -= pkt->len;
 141        vvs->fwd_cnt += pkt->len;
 142}
 143
 144void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
 145{
 146        spin_lock_bh(&vvs->tx_lock);
 147        pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
 148        pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
 149        spin_unlock_bh(&vvs->tx_lock);
 150}
 151EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
 152
 153u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
 154{
 155        u32 ret;
 156
 157        spin_lock_bh(&vvs->tx_lock);
 158        ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
 159        if (ret > credit)
 160                ret = credit;
 161        vvs->tx_cnt += ret;
 162        spin_unlock_bh(&vvs->tx_lock);
 163
 164        return ret;
 165}
 166EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
 167
 168void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
 169{
 170        spin_lock_bh(&vvs->tx_lock);
 171        vvs->tx_cnt -= credit;
 172        spin_unlock_bh(&vvs->tx_lock);
 173}
 174EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
 175
 176static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
 177                                               int type,
 178                                               struct virtio_vsock_hdr *hdr)
 179{
 180        struct virtio_vsock_pkt_info info = {
 181                .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
 182                .type = type,
 183        };
 184
 185        return virtio_transport_send_pkt_info(vsk, &info);
 186}
 187
 188static ssize_t
 189virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
 190                                   struct msghdr *msg,
 191                                   size_t len)
 192{
 193        struct virtio_vsock_sock *vvs = vsk->trans;
 194        struct virtio_vsock_pkt *pkt;
 195        size_t bytes, total = 0;
 196        int err = -EFAULT;
 197
 198        spin_lock_bh(&vvs->rx_lock);
 199        while (total < len && !list_empty(&vvs->rx_queue)) {
 200                pkt = list_first_entry(&vvs->rx_queue,
 201                                       struct virtio_vsock_pkt, list);
 202
 203                bytes = len - total;
 204                if (bytes > pkt->len - pkt->off)
 205                        bytes = pkt->len - pkt->off;
 206
 207                /* sk_lock is held by caller so no one else can dequeue.
 208                 * Unlock rx_lock since memcpy_to_msg() may sleep.
 209                 */
 210                spin_unlock_bh(&vvs->rx_lock);
 211
 212                err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
 213                if (err)
 214                        goto out;
 215
 216                spin_lock_bh(&vvs->rx_lock);
 217
 218                total += bytes;
 219                pkt->off += bytes;
 220                if (pkt->off == pkt->len) {
 221                        virtio_transport_dec_rx_pkt(vvs, pkt);
 222                        list_del(&pkt->list);
 223                        virtio_transport_free_pkt(pkt);
 224                }
 225        }
 226        spin_unlock_bh(&vvs->rx_lock);
 227
 228        /* Send a credit pkt to peer */
 229        virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
 230                                            NULL);
 231
 232        return total;
 233
 234out:
 235        if (total)
 236                err = total;
 237        return err;
 238}
 239
 240ssize_t
 241virtio_transport_stream_dequeue(struct vsock_sock *vsk,
 242                                struct msghdr *msg,
 243                                size_t len, int flags)
 244{
 245        if (flags & MSG_PEEK)
 246                return -EOPNOTSUPP;
 247
 248        return virtio_transport_stream_do_dequeue(vsk, msg, len);
 249}
 250EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
 251
 252int
 253virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
 254                               struct msghdr *msg,
 255                               size_t len, int flags)
 256{
 257        return -EOPNOTSUPP;
 258}
 259EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
 260
 261s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
 262{
 263        struct virtio_vsock_sock *vvs = vsk->trans;
 264        s64 bytes;
 265
 266        spin_lock_bh(&vvs->rx_lock);
 267        bytes = vvs->rx_bytes;
 268        spin_unlock_bh(&vvs->rx_lock);
 269
 270        return bytes;
 271}
 272EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
 273
 274static s64 virtio_transport_has_space(struct vsock_sock *vsk)
 275{
 276        struct virtio_vsock_sock *vvs = vsk->trans;
 277        s64 bytes;
 278
 279        bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
 280        if (bytes < 0)
 281                bytes = 0;
 282
 283        return bytes;
 284}
 285
 286s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
 287{
 288        struct virtio_vsock_sock *vvs = vsk->trans;
 289        s64 bytes;
 290
 291        spin_lock_bh(&vvs->tx_lock);
 292        bytes = virtio_transport_has_space(vsk);
 293        spin_unlock_bh(&vvs->tx_lock);
 294
 295        return bytes;
 296}
 297EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
 298
 299int virtio_transport_do_socket_init(struct vsock_sock *vsk,
 300                                    struct vsock_sock *psk)
 301{
 302        struct virtio_vsock_sock *vvs;
 303
 304        vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
 305        if (!vvs)
 306                return -ENOMEM;
 307
 308        vsk->trans = vvs;
 309        vvs->vsk = vsk;
 310        if (psk) {
 311                struct virtio_vsock_sock *ptrans = psk->trans;
 312
 313                vvs->buf_size   = ptrans->buf_size;
 314                vvs->buf_size_min = ptrans->buf_size_min;
 315                vvs->buf_size_max = ptrans->buf_size_max;
 316                vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
 317        } else {
 318                vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
 319                vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
 320                vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
 321        }
 322
 323        vvs->buf_alloc = vvs->buf_size;
 324
 325        spin_lock_init(&vvs->rx_lock);
 326        spin_lock_init(&vvs->tx_lock);
 327        INIT_LIST_HEAD(&vvs->rx_queue);
 328
 329        return 0;
 330}
 331EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
 332
 333u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
 334{
 335        struct virtio_vsock_sock *vvs = vsk->trans;
 336
 337        return vvs->buf_size;
 338}
 339EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
 340
 341u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
 342{
 343        struct virtio_vsock_sock *vvs = vsk->trans;
 344
 345        return vvs->buf_size_min;
 346}
 347EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
 348
 349u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
 350{
 351        struct virtio_vsock_sock *vvs = vsk->trans;
 352
 353        return vvs->buf_size_max;
 354}
 355EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
 356
 357void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
 358{
 359        struct virtio_vsock_sock *vvs = vsk->trans;
 360
 361        if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
 362                val = VIRTIO_VSOCK_MAX_BUF_SIZE;
 363        if (val < vvs->buf_size_min)
 364                vvs->buf_size_min = val;
 365        if (val > vvs->buf_size_max)
 366                vvs->buf_size_max = val;
 367        vvs->buf_size = val;
 368        vvs->buf_alloc = val;
 369}
 370EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
 371
 372void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
 373{
 374        struct virtio_vsock_sock *vvs = vsk->trans;
 375
 376        if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
 377                val = VIRTIO_VSOCK_MAX_BUF_SIZE;
 378        if (val > vvs->buf_size)
 379                vvs->buf_size = val;
 380        vvs->buf_size_min = val;
 381}
 382EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
 383
 384void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
 385{
 386        struct virtio_vsock_sock *vvs = vsk->trans;
 387
 388        if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
 389                val = VIRTIO_VSOCK_MAX_BUF_SIZE;
 390        if (val < vvs->buf_size)
 391                vvs->buf_size = val;
 392        vvs->buf_size_max = val;
 393}
 394EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
 395
 396int
 397virtio_transport_notify_poll_in(struct vsock_sock *vsk,
 398                                size_t target,
 399                                bool *data_ready_now)
 400{
 401        if (vsock_stream_has_data(vsk))
 402                *data_ready_now = true;
 403        else
 404                *data_ready_now = false;
 405
 406        return 0;
 407}
 408EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
 409
 410int
 411virtio_transport_notify_poll_out(struct vsock_sock *vsk,
 412                                 size_t target,
 413                                 bool *space_avail_now)
 414{
 415        s64 free_space;
 416
 417        free_space = vsock_stream_has_space(vsk);
 418        if (free_space > 0)
 419                *space_avail_now = true;
 420        else if (free_space == 0)
 421                *space_avail_now = false;
 422
 423        return 0;
 424}
 425EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
 426
 427int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
 428        size_t target, struct vsock_transport_recv_notify_data *data)
 429{
 430        return 0;
 431}
 432EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
 433
 434int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
 435        size_t target, struct vsock_transport_recv_notify_data *data)
 436{
 437        return 0;
 438}
 439EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
 440
 441int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
 442        size_t target, struct vsock_transport_recv_notify_data *data)
 443{
 444        return 0;
 445}
 446EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
 447
 448int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
 449        size_t target, ssize_t copied, bool data_read,
 450        struct vsock_transport_recv_notify_data *data)
 451{
 452        return 0;
 453}
 454EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
 455
 456int virtio_transport_notify_send_init(struct vsock_sock *vsk,
 457        struct vsock_transport_send_notify_data *data)
 458{
 459        return 0;
 460}
 461EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
 462
 463int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
 464        struct vsock_transport_send_notify_data *data)
 465{
 466        return 0;
 467}
 468EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
 469
 470int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
 471        struct vsock_transport_send_notify_data *data)
 472{
 473        return 0;
 474}
 475EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
 476
 477int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
 478        ssize_t written, struct vsock_transport_send_notify_data *data)
 479{
 480        return 0;
 481}
 482EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
 483
 484u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
 485{
 486        struct virtio_vsock_sock *vvs = vsk->trans;
 487
 488        return vvs->buf_size;
 489}
 490EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
 491
 492bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
 493{
 494        return true;
 495}
 496EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
 497
 498bool virtio_transport_stream_allow(u32 cid, u32 port)
 499{
 500        return true;
 501}
 502EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
 503
 504int virtio_transport_dgram_bind(struct vsock_sock *vsk,
 505                                struct sockaddr_vm *addr)
 506{
 507        return -EOPNOTSUPP;
 508}
 509EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
 510
 511bool virtio_transport_dgram_allow(u32 cid, u32 port)
 512{
 513        return false;
 514}
 515EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
 516
 517int virtio_transport_connect(struct vsock_sock *vsk)
 518{
 519        struct virtio_vsock_pkt_info info = {
 520                .op = VIRTIO_VSOCK_OP_REQUEST,
 521                .type = VIRTIO_VSOCK_TYPE_STREAM,
 522        };
 523
 524        return virtio_transport_send_pkt_info(vsk, &info);
 525}
 526EXPORT_SYMBOL_GPL(virtio_transport_connect);
 527
 528int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
 529{
 530        struct virtio_vsock_pkt_info info = {
 531                .op = VIRTIO_VSOCK_OP_SHUTDOWN,
 532                .type = VIRTIO_VSOCK_TYPE_STREAM,
 533                .flags = (mode & RCV_SHUTDOWN ?
 534                          VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
 535                         (mode & SEND_SHUTDOWN ?
 536                          VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
 537        };
 538
 539        return virtio_transport_send_pkt_info(vsk, &info);
 540}
 541EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
 542
 543int
 544virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
 545                               struct sockaddr_vm *remote_addr,
 546                               struct msghdr *msg,
 547                               size_t dgram_len)
 548{
 549        return -EOPNOTSUPP;
 550}
 551EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
 552
 553ssize_t
 554virtio_transport_stream_enqueue(struct vsock_sock *vsk,
 555                                struct msghdr *msg,
 556                                size_t len)
 557{
 558        struct virtio_vsock_pkt_info info = {
 559                .op = VIRTIO_VSOCK_OP_RW,
 560                .type = VIRTIO_VSOCK_TYPE_STREAM,
 561                .msg = msg,
 562                .pkt_len = len,
 563        };
 564
 565        return virtio_transport_send_pkt_info(vsk, &info);
 566}
 567EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
 568
 569void virtio_transport_destruct(struct vsock_sock *vsk)
 570{
 571        struct virtio_vsock_sock *vvs = vsk->trans;
 572
 573        kfree(vvs);
 574}
 575EXPORT_SYMBOL_GPL(virtio_transport_destruct);
 576
 577static int virtio_transport_reset(struct vsock_sock *vsk,
 578                                  struct virtio_vsock_pkt *pkt)
 579{
 580        struct virtio_vsock_pkt_info info = {
 581                .op = VIRTIO_VSOCK_OP_RST,
 582                .type = VIRTIO_VSOCK_TYPE_STREAM,
 583                .reply = !!pkt,
 584        };
 585
 586        /* Send RST only if the original pkt is not a RST pkt */
 587        if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
 588                return 0;
 589
 590        return virtio_transport_send_pkt_info(vsk, &info);
 591}
 592
 593/* Normally packets are associated with a socket.  There may be no socket if an
 594 * attempt was made to connect to a socket that does not exist.
 595 */
 596static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
 597{
 598        struct virtio_vsock_pkt_info info = {
 599                .op = VIRTIO_VSOCK_OP_RST,
 600                .type = le16_to_cpu(pkt->hdr.type),
 601                .reply = true,
 602        };
 603
 604        /* Send RST only if the original pkt is not a RST pkt */
 605        if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
 606                return 0;
 607
 608        pkt = virtio_transport_alloc_pkt(&info, 0,
 609                                         le32_to_cpu(pkt->hdr.dst_cid),
 610                                         le32_to_cpu(pkt->hdr.dst_port),
 611                                         le32_to_cpu(pkt->hdr.src_cid),
 612                                         le32_to_cpu(pkt->hdr.src_port));
 613        if (!pkt)
 614                return -ENOMEM;
 615
 616        return virtio_transport_get_ops()->send_pkt(pkt);
 617}
 618
 619static void virtio_transport_wait_close(struct sock *sk, long timeout)
 620{
 621        if (timeout) {
 622                DEFINE_WAIT(wait);
 623
 624                do {
 625                        prepare_to_wait(sk_sleep(sk), &wait,
 626                                        TASK_INTERRUPTIBLE);
 627                        if (sk_wait_event(sk, &timeout,
 628                                          sock_flag(sk, SOCK_DONE)))
 629                                break;
 630                } while (!signal_pending(current) && timeout);
 631
 632                finish_wait(sk_sleep(sk), &wait);
 633        }
 634}
 635
 636static void virtio_transport_do_close(struct vsock_sock *vsk,
 637                                      bool cancel_timeout)
 638{
 639        struct sock *sk = sk_vsock(vsk);
 640
 641        sock_set_flag(sk, SOCK_DONE);
 642        vsk->peer_shutdown = SHUTDOWN_MASK;
 643        if (vsock_stream_has_data(vsk) <= 0)
 644                sk->sk_state = SS_DISCONNECTING;
 645        sk->sk_state_change(sk);
 646
 647        if (vsk->close_work_scheduled &&
 648            (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
 649                vsk->close_work_scheduled = false;
 650
 651                vsock_remove_sock(vsk);
 652
 653                /* Release refcnt obtained when we scheduled the timeout */
 654                sock_put(sk);
 655        }
 656}
 657
 658static void virtio_transport_close_timeout(struct work_struct *work)
 659{
 660        struct vsock_sock *vsk =
 661                container_of(work, struct vsock_sock, close_work.work);
 662        struct sock *sk = sk_vsock(vsk);
 663
 664        sock_hold(sk);
 665        lock_sock(sk);
 666
 667        if (!sock_flag(sk, SOCK_DONE)) {
 668                (void)virtio_transport_reset(vsk, NULL);
 669
 670                virtio_transport_do_close(vsk, false);
 671        }
 672
 673        vsk->close_work_scheduled = false;
 674
 675        release_sock(sk);
 676        sock_put(sk);
 677}
 678
 679/* User context, vsk->sk is locked */
 680static bool virtio_transport_close(struct vsock_sock *vsk)
 681{
 682        struct sock *sk = &vsk->sk;
 683
 684        if (!(sk->sk_state == SS_CONNECTED ||
 685              sk->sk_state == SS_DISCONNECTING))
 686                return true;
 687
 688        /* Already received SHUTDOWN from peer, reply with RST */
 689        if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
 690                (void)virtio_transport_reset(vsk, NULL);
 691                return true;
 692        }
 693
 694        if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
 695                (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
 696
 697        if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
 698                virtio_transport_wait_close(sk, sk->sk_lingertime);
 699
 700        if (sock_flag(sk, SOCK_DONE)) {
 701                return true;
 702        }
 703
 704        sock_hold(sk);
 705        INIT_DELAYED_WORK(&vsk->close_work,
 706                          virtio_transport_close_timeout);
 707        vsk->close_work_scheduled = true;
 708        schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
 709        return false;
 710}
 711
 712void virtio_transport_release(struct vsock_sock *vsk)
 713{
 714        struct sock *sk = &vsk->sk;
 715        bool remove_sock = true;
 716
 717        lock_sock(sk);
 718        if (sk->sk_type == SOCK_STREAM)
 719                remove_sock = virtio_transport_close(vsk);
 720        release_sock(sk);
 721
 722        if (remove_sock)
 723                vsock_remove_sock(vsk);
 724}
 725EXPORT_SYMBOL_GPL(virtio_transport_release);
 726
 727static int
 728virtio_transport_recv_connecting(struct sock *sk,
 729                                 struct virtio_vsock_pkt *pkt)
 730{
 731        struct vsock_sock *vsk = vsock_sk(sk);
 732        int err;
 733        int skerr;
 734
 735        switch (le16_to_cpu(pkt->hdr.op)) {
 736        case VIRTIO_VSOCK_OP_RESPONSE:
 737                sk->sk_state = SS_CONNECTED;
 738                sk->sk_socket->state = SS_CONNECTED;
 739                vsock_insert_connected(vsk);
 740                sk->sk_state_change(sk);
 741                break;
 742        case VIRTIO_VSOCK_OP_INVALID:
 743                break;
 744        case VIRTIO_VSOCK_OP_RST:
 745                skerr = ECONNRESET;
 746                err = 0;
 747                goto destroy;
 748        default:
 749                skerr = EPROTO;
 750                err = -EINVAL;
 751                goto destroy;
 752        }
 753        return 0;
 754
 755destroy:
 756        virtio_transport_reset(vsk, pkt);
 757        sk->sk_state = SS_UNCONNECTED;
 758        sk->sk_err = skerr;
 759        sk->sk_error_report(sk);
 760        return err;
 761}
 762
 763static int
 764virtio_transport_recv_connected(struct sock *sk,
 765                                struct virtio_vsock_pkt *pkt)
 766{
 767        struct vsock_sock *vsk = vsock_sk(sk);
 768        struct virtio_vsock_sock *vvs = vsk->trans;
 769        int err = 0;
 770
 771        switch (le16_to_cpu(pkt->hdr.op)) {
 772        case VIRTIO_VSOCK_OP_RW:
 773                pkt->len = le32_to_cpu(pkt->hdr.len);
 774                pkt->off = 0;
 775
 776                spin_lock_bh(&vvs->rx_lock);
 777                virtio_transport_inc_rx_pkt(vvs, pkt);
 778                list_add_tail(&pkt->list, &vvs->rx_queue);
 779                spin_unlock_bh(&vvs->rx_lock);
 780
 781                sk->sk_data_ready(sk);
 782                return err;
 783        case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
 784                sk->sk_write_space(sk);
 785                break;
 786        case VIRTIO_VSOCK_OP_SHUTDOWN:
 787                if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
 788                        vsk->peer_shutdown |= RCV_SHUTDOWN;
 789                if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
 790                        vsk->peer_shutdown |= SEND_SHUTDOWN;
 791                if (vsk->peer_shutdown == SHUTDOWN_MASK &&
 792                    vsock_stream_has_data(vsk) <= 0)
 793                        sk->sk_state = SS_DISCONNECTING;
 794                if (le32_to_cpu(pkt->hdr.flags))
 795                        sk->sk_state_change(sk);
 796                break;
 797        case VIRTIO_VSOCK_OP_RST:
 798                virtio_transport_do_close(vsk, true);
 799                break;
 800        default:
 801                err = -EINVAL;
 802                break;
 803        }
 804
 805        virtio_transport_free_pkt(pkt);
 806        return err;
 807}
 808
 809static void
 810virtio_transport_recv_disconnecting(struct sock *sk,
 811                                    struct virtio_vsock_pkt *pkt)
 812{
 813        struct vsock_sock *vsk = vsock_sk(sk);
 814
 815        if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
 816                virtio_transport_do_close(vsk, true);
 817}
 818
 819static int
 820virtio_transport_send_response(struct vsock_sock *vsk,
 821                               struct virtio_vsock_pkt *pkt)
 822{
 823        struct virtio_vsock_pkt_info info = {
 824                .op = VIRTIO_VSOCK_OP_RESPONSE,
 825                .type = VIRTIO_VSOCK_TYPE_STREAM,
 826                .remote_cid = le32_to_cpu(pkt->hdr.src_cid),
 827                .remote_port = le32_to_cpu(pkt->hdr.src_port),
 828                .reply = true,
 829        };
 830
 831        return virtio_transport_send_pkt_info(vsk, &info);
 832}
 833
 834/* Handle server socket */
 835static int
 836virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
 837{
 838        struct vsock_sock *vsk = vsock_sk(sk);
 839        struct vsock_sock *vchild;
 840        struct sock *child;
 841
 842        if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
 843                virtio_transport_reset(vsk, pkt);
 844                return -EINVAL;
 845        }
 846
 847        if (sk_acceptq_is_full(sk)) {
 848                virtio_transport_reset(vsk, pkt);
 849                return -ENOMEM;
 850        }
 851
 852        child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
 853                               sk->sk_type, 0);
 854        if (!child) {
 855                virtio_transport_reset(vsk, pkt);
 856                return -ENOMEM;
 857        }
 858
 859        sk->sk_ack_backlog++;
 860
 861        lock_sock_nested(child, SINGLE_DEPTH_NESTING);
 862
 863        child->sk_state = SS_CONNECTED;
 864
 865        vchild = vsock_sk(child);
 866        vsock_addr_init(&vchild->local_addr, le32_to_cpu(pkt->hdr.dst_cid),
 867                        le32_to_cpu(pkt->hdr.dst_port));
 868        vsock_addr_init(&vchild->remote_addr, le32_to_cpu(pkt->hdr.src_cid),
 869                        le32_to_cpu(pkt->hdr.src_port));
 870
 871        vsock_insert_connected(vchild);
 872        vsock_enqueue_accept(sk, child);
 873        virtio_transport_send_response(vchild, pkt);
 874
 875        release_sock(child);
 876
 877        sk->sk_data_ready(sk);
 878        return 0;
 879}
 880
 881static bool virtio_transport_space_update(struct sock *sk,
 882                                          struct virtio_vsock_pkt *pkt)
 883{
 884        struct vsock_sock *vsk = vsock_sk(sk);
 885        struct virtio_vsock_sock *vvs = vsk->trans;
 886        bool space_available;
 887
 888        /* buf_alloc and fwd_cnt is always included in the hdr */
 889        spin_lock_bh(&vvs->tx_lock);
 890        vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
 891        vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
 892        space_available = virtio_transport_has_space(vsk);
 893        spin_unlock_bh(&vvs->tx_lock);
 894        return space_available;
 895}
 896
 897/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
 898 * lock.
 899 */
 900void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
 901{
 902        struct sockaddr_vm src, dst;
 903        struct vsock_sock *vsk;
 904        struct sock *sk;
 905        bool space_available;
 906
 907        vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid),
 908                        le32_to_cpu(pkt->hdr.src_port));
 909        vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid),
 910                        le32_to_cpu(pkt->hdr.dst_port));
 911
 912        trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
 913                                        dst.svm_cid, dst.svm_port,
 914                                        le32_to_cpu(pkt->hdr.len),
 915                                        le16_to_cpu(pkt->hdr.type),
 916                                        le16_to_cpu(pkt->hdr.op),
 917                                        le32_to_cpu(pkt->hdr.flags),
 918                                        le32_to_cpu(pkt->hdr.buf_alloc),
 919                                        le32_to_cpu(pkt->hdr.fwd_cnt));
 920
 921        if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
 922                (void)virtio_transport_reset_no_sock(pkt);
 923                goto free_pkt;
 924        }
 925
 926        /* The socket must be in connected or bound table
 927         * otherwise send reset back
 928         */
 929        sk = vsock_find_connected_socket(&src, &dst);
 930        if (!sk) {
 931                sk = vsock_find_bound_socket(&dst);
 932                if (!sk) {
 933                        (void)virtio_transport_reset_no_sock(pkt);
 934                        goto free_pkt;
 935                }
 936        }
 937
 938        vsk = vsock_sk(sk);
 939
 940        space_available = virtio_transport_space_update(sk, pkt);
 941
 942        lock_sock(sk);
 943
 944        /* Update CID in case it has changed after a transport reset event */
 945        vsk->local_addr.svm_cid = dst.svm_cid;
 946
 947        if (space_available)
 948                sk->sk_write_space(sk);
 949
 950        switch (sk->sk_state) {
 951        case VSOCK_SS_LISTEN:
 952                virtio_transport_recv_listen(sk, pkt);
 953                virtio_transport_free_pkt(pkt);
 954                break;
 955        case SS_CONNECTING:
 956                virtio_transport_recv_connecting(sk, pkt);
 957                virtio_transport_free_pkt(pkt);
 958                break;
 959        case SS_CONNECTED:
 960                virtio_transport_recv_connected(sk, pkt);
 961                break;
 962        case SS_DISCONNECTING:
 963                virtio_transport_recv_disconnecting(sk, pkt);
 964                virtio_transport_free_pkt(pkt);
 965                break;
 966        default:
 967                virtio_transport_free_pkt(pkt);
 968                break;
 969        }
 970        release_sock(sk);
 971
 972        /* Release refcnt obtained when we fetched this socket out of the
 973         * bound or connected list.
 974         */
 975        sock_put(sk);
 976        return;
 977
 978free_pkt:
 979        virtio_transport_free_pkt(pkt);
 980}
 981EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
 982
 983void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
 984{
 985        kfree(pkt->buf);
 986        kfree(pkt);
 987}
 988EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
 989
 990MODULE_LICENSE("GPL v2");
 991MODULE_AUTHOR("Asias He");
 992MODULE_DESCRIPTION("common code for virtio vsock");
 993