linux/drivers/vhost/vsock.c
<<
>>
Prefs
   1/*
   2 * vhost transport for vsock
   3 *
   4 * Copyright (C) 2013-2015 Red Hat, Inc.
   5 * Author: Asias He <asias@redhat.com>
   6 *         Stefan Hajnoczi <stefanha@redhat.com>
   7 *
   8 * This work is licensed under the terms of the GNU GPL, version 2.
   9 */
  10#include <linux/miscdevice.h>
  11#include <linux/atomic.h>
  12#include <linux/module.h>
  13#include <linux/mutex.h>
  14#include <linux/vmalloc.h>
  15#include <net/sock.h>
  16#include <linux/virtio_vsock.h>
  17#include <linux/vhost.h>
  18#include <linux/hashtable.h>
  19
  20#include <net/af_vsock.h>
  21#include "vhost.h"
  22
  23#define VHOST_VSOCK_DEFAULT_HOST_CID    2
  24
  25enum {
  26        VHOST_VSOCK_FEATURES = VHOST_FEATURES,
  27};
  28
  29/* Used to track all the vhost_vsock instances on the system. */
  30static DEFINE_MUTEX(vhost_vsock_mutex);
  31static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
  32
  33struct vhost_vsock {
  34        struct vhost_dev dev;
  35        struct vhost_virtqueue vqs[2];
  36
  37        /* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
  38        struct hlist_node hash;
  39
  40        struct vhost_work send_pkt_work;
  41        spinlock_t send_pkt_list_lock;
  42        struct list_head send_pkt_list; /* host->guest pending packets */
  43
  44        atomic_t queued_replies;
  45
  46        u32 guest_cid;
  47};
  48
  49static u32 vhost_transport_get_local_cid(void)
  50{
  51        return VHOST_VSOCK_DEFAULT_HOST_CID;
  52}
  53
  54/* Callers that dereference the return value must hold vhost_vsock_mutex or the
  55 * RCU read lock.
  56 */
  57static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
  58{
  59        struct vhost_vsock *vsock;
  60
  61        hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
  62                u32 other_cid = vsock->guest_cid;
  63
  64                /* Skip instances that have no CID yet */
  65                if (other_cid == 0)
  66                        continue;
  67
  68                if (other_cid == guest_cid)
  69                        return vsock;
  70
  71        }
  72
  73        return NULL;
  74}
  75
  76static void
  77vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
  78                            struct vhost_virtqueue *vq)
  79{
  80        struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
  81        bool added = false;
  82        bool restart_tx = false;
  83
  84        mutex_lock(&vq->mutex);
  85
  86        if (!vq->private_data)
  87                goto out;
  88
  89        /* Avoid further vmexits, we're already processing the virtqueue */
  90        vhost_disable_notify(&vsock->dev, vq);
  91
  92        for (;;) {
  93                struct virtio_vsock_pkt *pkt;
  94                struct iov_iter iov_iter;
  95                unsigned out, in;
  96                size_t nbytes;
  97                size_t len;
  98                int head;
  99
 100                spin_lock_bh(&vsock->send_pkt_list_lock);
 101                if (list_empty(&vsock->send_pkt_list)) {
 102                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 103                        vhost_enable_notify(&vsock->dev, vq);
 104                        break;
 105                }
 106
 107                pkt = list_first_entry(&vsock->send_pkt_list,
 108                                       struct virtio_vsock_pkt, list);
 109                list_del_init(&pkt->list);
 110                spin_unlock_bh(&vsock->send_pkt_list_lock);
 111
 112                head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
 113                                         &out, &in, NULL, NULL);
 114                if (head < 0) {
 115                        spin_lock_bh(&vsock->send_pkt_list_lock);
 116                        list_add(&pkt->list, &vsock->send_pkt_list);
 117                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 118                        break;
 119                }
 120
 121                if (head == vq->num) {
 122                        spin_lock_bh(&vsock->send_pkt_list_lock);
 123                        list_add(&pkt->list, &vsock->send_pkt_list);
 124                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 125
 126                        /* We cannot finish yet if more buffers snuck in while
 127                         * re-enabling notify.
 128                         */
 129                        if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
 130                                vhost_disable_notify(&vsock->dev, vq);
 131                                continue;
 132                        }
 133                        break;
 134                }
 135
 136                if (out) {
 137                        virtio_transport_free_pkt(pkt);
 138                        vq_err(vq, "Expected 0 output buffers, got %u\n", out);
 139                        break;
 140                }
 141
 142                len = iov_length(&vq->iov[out], in);
 143                iov_iter_init(&iov_iter, READ, &vq->iov[out], in, len);
 144
 145                nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
 146                if (nbytes != sizeof(pkt->hdr)) {
 147                        virtio_transport_free_pkt(pkt);
 148                        vq_err(vq, "Faulted on copying pkt hdr\n");
 149                        break;
 150                }
 151
 152                nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter);
 153                if (nbytes != pkt->len) {
 154                        virtio_transport_free_pkt(pkt);
 155                        vq_err(vq, "Faulted on copying pkt buf\n");
 156                        break;
 157                }
 158
 159                vhost_add_used(vq, head, sizeof(pkt->hdr) + pkt->len);
 160                added = true;
 161
 162                if (pkt->reply) {
 163                        int val;
 164
 165                        val = atomic_dec_return(&vsock->queued_replies);
 166
 167                        /* Do we have resources to resume tx processing? */
 168                        if (val + 1 == tx_vq->num)
 169                                restart_tx = true;
 170                }
 171
 172                /* Deliver to monitoring devices all correctly transmitted
 173                 * packets.
 174                 */
 175                virtio_transport_deliver_tap_pkt(pkt);
 176
 177                virtio_transport_free_pkt(pkt);
 178        }
 179        if (added)
 180                vhost_signal(&vsock->dev, vq);
 181
 182out:
 183        mutex_unlock(&vq->mutex);
 184
 185        if (restart_tx)
 186                vhost_poll_queue(&tx_vq->poll);
 187}
 188
 189static void vhost_transport_send_pkt_work(struct vhost_work *work)
 190{
 191        struct vhost_virtqueue *vq;
 192        struct vhost_vsock *vsock;
 193
 194        vsock = container_of(work, struct vhost_vsock, send_pkt_work);
 195        vq = &vsock->vqs[VSOCK_VQ_RX];
 196
 197        vhost_transport_do_send_pkt(vsock, vq);
 198}
 199
 200static int
 201vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
 202{
 203        struct vhost_vsock *vsock;
 204        int len = pkt->len;
 205
 206        rcu_read_lock();
 207
 208        /* Find the vhost_vsock according to guest context id  */
 209        vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
 210        if (!vsock) {
 211                rcu_read_unlock();
 212                virtio_transport_free_pkt(pkt);
 213                return -ENODEV;
 214        }
 215
 216        if (pkt->reply)
 217                atomic_inc(&vsock->queued_replies);
 218
 219        spin_lock_bh(&vsock->send_pkt_list_lock);
 220        list_add_tail(&pkt->list, &vsock->send_pkt_list);
 221        spin_unlock_bh(&vsock->send_pkt_list_lock);
 222
 223        vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
 224
 225        rcu_read_unlock();
 226        return len;
 227}
 228
 229static int
 230vhost_transport_cancel_pkt(struct vsock_sock *vsk)
 231{
 232        struct vhost_vsock *vsock;
 233        struct virtio_vsock_pkt *pkt, *n;
 234        int cnt = 0;
 235        int ret = -ENODEV;
 236        LIST_HEAD(freeme);
 237
 238        rcu_read_lock();
 239
 240        /* Find the vhost_vsock according to guest context id  */
 241        vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
 242        if (!vsock)
 243                goto out;
 244
 245        spin_lock_bh(&vsock->send_pkt_list_lock);
 246        list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
 247                if (pkt->vsk != vsk)
 248                        continue;
 249                list_move(&pkt->list, &freeme);
 250        }
 251        spin_unlock_bh(&vsock->send_pkt_list_lock);
 252
 253        list_for_each_entry_safe(pkt, n, &freeme, list) {
 254                if (pkt->reply)
 255                        cnt++;
 256                list_del(&pkt->list);
 257                virtio_transport_free_pkt(pkt);
 258        }
 259
 260        if (cnt) {
 261                struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
 262                int new_cnt;
 263
 264                new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
 265                if (new_cnt + cnt >= tx_vq->num && new_cnt < tx_vq->num)
 266                        vhost_poll_queue(&tx_vq->poll);
 267        }
 268
 269        ret = 0;
 270out:
 271        rcu_read_unlock();
 272        return ret;
 273}
 274
 275static struct virtio_vsock_pkt *
 276vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
 277                      unsigned int out, unsigned int in)
 278{
 279        struct virtio_vsock_pkt *pkt;
 280        struct iov_iter iov_iter;
 281        size_t nbytes;
 282        size_t len;
 283
 284        if (in != 0) {
 285                vq_err(vq, "Expected 0 input buffers, got %u\n", in);
 286                return NULL;
 287        }
 288
 289        pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
 290        if (!pkt)
 291                return NULL;
 292
 293        len = iov_length(vq->iov, out);
 294        iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
 295
 296        nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
 297        if (nbytes != sizeof(pkt->hdr)) {
 298                vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
 299                       sizeof(pkt->hdr), nbytes);
 300                kfree(pkt);
 301                return NULL;
 302        }
 303
 304        if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
 305                pkt->len = le32_to_cpu(pkt->hdr.len);
 306
 307        /* No payload */
 308        if (!pkt->len)
 309                return pkt;
 310
 311        /* The pkt is too big */
 312        if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
 313                kfree(pkt);
 314                return NULL;
 315        }
 316
 317        pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
 318        if (!pkt->buf) {
 319                kfree(pkt);
 320                return NULL;
 321        }
 322
 323        nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
 324        if (nbytes != pkt->len) {
 325                vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
 326                       pkt->len, nbytes);
 327                virtio_transport_free_pkt(pkt);
 328                return NULL;
 329        }
 330
 331        return pkt;
 332}
 333
 334/* Is there space left for replies to rx packets? */
 335static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
 336{
 337        struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX];
 338        int val;
 339
 340        smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
 341        val = atomic_read(&vsock->queued_replies);
 342
 343        return val < vq->num;
 344}
 345
 346static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
 347{
 348        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 349                                                  poll.work);
 350        struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
 351                                                 dev);
 352        struct virtio_vsock_pkt *pkt;
 353        int head;
 354        unsigned int out, in;
 355        bool added = false;
 356
 357        mutex_lock(&vq->mutex);
 358
 359        if (!vq->private_data)
 360                goto out;
 361
 362        vhost_disable_notify(&vsock->dev, vq);
 363        for (;;) {
 364                u32 len;
 365
 366                if (!vhost_vsock_more_replies(vsock)) {
 367                        /* Stop tx until the device processes already
 368                         * pending replies.  Leave tx virtqueue
 369                         * callbacks disabled.
 370                         */
 371                        goto no_more_replies;
 372                }
 373
 374                head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
 375                                         &out, &in, NULL, NULL);
 376                if (head < 0)
 377                        break;
 378
 379                if (head == vq->num) {
 380                        if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
 381                                vhost_disable_notify(&vsock->dev, vq);
 382                                continue;
 383                        }
 384                        break;
 385                }
 386
 387                pkt = vhost_vsock_alloc_pkt(vq, out, in);
 388                if (!pkt) {
 389                        vq_err(vq, "Faulted on pkt\n");
 390                        continue;
 391                }
 392
 393                len = pkt->len;
 394
 395                /* Deliver to monitoring devices all received packets */
 396                virtio_transport_deliver_tap_pkt(pkt);
 397
 398                /* Only accept correctly addressed packets */
 399                if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid)
 400                        virtio_transport_recv_pkt(pkt);
 401                else
 402                        virtio_transport_free_pkt(pkt);
 403
 404                vhost_add_used(vq, head, sizeof(pkt->hdr) + len);
 405                added = true;
 406        }
 407
 408no_more_replies:
 409        if (added)
 410                vhost_signal(&vsock->dev, vq);
 411
 412out:
 413        mutex_unlock(&vq->mutex);
 414}
 415
 416static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
 417{
 418        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 419                                                poll.work);
 420        struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
 421                                                 dev);
 422
 423        vhost_transport_do_send_pkt(vsock, vq);
 424}
 425
 426static int vhost_vsock_start(struct vhost_vsock *vsock)
 427{
 428        struct vhost_virtqueue *vq;
 429        size_t i;
 430        int ret;
 431
 432        mutex_lock(&vsock->dev.mutex);
 433
 434        ret = vhost_dev_check_owner(&vsock->dev);
 435        if (ret)
 436                goto err;
 437
 438        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 439                vq = &vsock->vqs[i];
 440
 441                mutex_lock(&vq->mutex);
 442
 443                if (!vhost_vq_access_ok(vq)) {
 444                        ret = -EFAULT;
 445                        goto err_vq;
 446                }
 447
 448                if (!vq->private_data) {
 449                        vq->private_data = vsock;
 450                        ret = vhost_vq_init_access(vq);
 451                        if (ret)
 452                                goto err_vq;
 453                }
 454
 455                mutex_unlock(&vq->mutex);
 456        }
 457
 458        mutex_unlock(&vsock->dev.mutex);
 459        return 0;
 460
 461err_vq:
 462        vq->private_data = NULL;
 463        mutex_unlock(&vq->mutex);
 464
 465        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 466                vq = &vsock->vqs[i];
 467
 468                mutex_lock(&vq->mutex);
 469                vq->private_data = NULL;
 470                mutex_unlock(&vq->mutex);
 471        }
 472err:
 473        mutex_unlock(&vsock->dev.mutex);
 474        return ret;
 475}
 476
 477static int vhost_vsock_stop(struct vhost_vsock *vsock)
 478{
 479        size_t i;
 480        int ret;
 481
 482        mutex_lock(&vsock->dev.mutex);
 483
 484        ret = vhost_dev_check_owner(&vsock->dev);
 485        if (ret)
 486                goto err;
 487
 488        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 489                struct vhost_virtqueue *vq = &vsock->vqs[i];
 490
 491                mutex_lock(&vq->mutex);
 492                vq->private_data = NULL;
 493                mutex_unlock(&vq->mutex);
 494        }
 495
 496err:
 497        mutex_unlock(&vsock->dev.mutex);
 498        return ret;
 499}
 500
 501static void vhost_vsock_free(struct vhost_vsock *vsock)
 502{
 503        kvfree(vsock);
 504}
 505
 506static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
 507{
 508        struct vhost_virtqueue **vqs;
 509        struct vhost_vsock *vsock;
 510        int ret;
 511
 512        /* This struct is large and allocation could fail, fall back to vmalloc
 513         * if there is no other way.
 514         */
 515        vsock = kvmalloc(sizeof(*vsock), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
 516        if (!vsock)
 517                return -ENOMEM;
 518
 519        vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL);
 520        if (!vqs) {
 521                ret = -ENOMEM;
 522                goto out;
 523        }
 524
 525        vsock->guest_cid = 0; /* no CID assigned yet */
 526
 527        atomic_set(&vsock->queued_replies, 0);
 528
 529        vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX];
 530        vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX];
 531        vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
 532        vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
 533
 534        vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs), UIO_MAXIOV);
 535
 536        file->private_data = vsock;
 537        spin_lock_init(&vsock->send_pkt_list_lock);
 538        INIT_LIST_HEAD(&vsock->send_pkt_list);
 539        vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
 540        return 0;
 541
 542out:
 543        vhost_vsock_free(vsock);
 544        return ret;
 545}
 546
 547static void vhost_vsock_flush(struct vhost_vsock *vsock)
 548{
 549        int i;
 550
 551        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++)
 552                if (vsock->vqs[i].handle_kick)
 553                        vhost_poll_flush(&vsock->vqs[i].poll);
 554        vhost_work_flush(&vsock->dev, &vsock->send_pkt_work);
 555}
 556
 557static void vhost_vsock_reset_orphans(struct sock *sk)
 558{
 559        struct vsock_sock *vsk = vsock_sk(sk);
 560
 561        /* vmci_transport.c doesn't take sk_lock here either.  At least we're
 562         * under vsock_table_lock so the sock cannot disappear while we're
 563         * executing.
 564         */
 565
 566        /* If the peer is still valid, no need to reset connection */
 567        if (vhost_vsock_get(vsk->remote_addr.svm_cid))
 568                return;
 569
 570        /* If the close timeout is pending, let it expire.  This avoids races
 571         * with the timeout callback.
 572         */
 573        if (vsk->close_work_scheduled)
 574                return;
 575
 576        sock_set_flag(sk, SOCK_DONE);
 577        vsk->peer_shutdown = SHUTDOWN_MASK;
 578        sk->sk_state = SS_UNCONNECTED;
 579        sk->sk_err = ECONNRESET;
 580        sk->sk_error_report(sk);
 581}
 582
 583static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
 584{
 585        struct vhost_vsock *vsock = file->private_data;
 586
 587        mutex_lock(&vhost_vsock_mutex);
 588        if (vsock->guest_cid)
 589                hash_del_rcu(&vsock->hash);
 590        mutex_unlock(&vhost_vsock_mutex);
 591
 592        /* Wait for other CPUs to finish using vsock */
 593        synchronize_rcu();
 594
 595        /* Iterating over all connections for all CIDs to find orphans is
 596         * inefficient.  Room for improvement here. */
 597        vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
 598
 599        vhost_vsock_stop(vsock);
 600        vhost_vsock_flush(vsock);
 601        vhost_dev_stop(&vsock->dev);
 602
 603        spin_lock_bh(&vsock->send_pkt_list_lock);
 604        while (!list_empty(&vsock->send_pkt_list)) {
 605                struct virtio_vsock_pkt *pkt;
 606
 607                pkt = list_first_entry(&vsock->send_pkt_list,
 608                                struct virtio_vsock_pkt, list);
 609                list_del_init(&pkt->list);
 610                virtio_transport_free_pkt(pkt);
 611        }
 612        spin_unlock_bh(&vsock->send_pkt_list_lock);
 613
 614        vhost_dev_cleanup(&vsock->dev);
 615        kfree(vsock->dev.vqs);
 616        vhost_vsock_free(vsock);
 617        return 0;
 618}
 619
 620static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
 621{
 622        struct vhost_vsock *other;
 623
 624        /* Refuse reserved CIDs */
 625        if (guest_cid <= VMADDR_CID_HOST ||
 626            guest_cid == U32_MAX)
 627                return -EINVAL;
 628
 629        /* 64-bit CIDs are not yet supported */
 630        if (guest_cid > U32_MAX)
 631                return -EINVAL;
 632
 633        /* Refuse if CID is already in use */
 634        mutex_lock(&vhost_vsock_mutex);
 635        other = vhost_vsock_get(guest_cid);
 636        if (other && other != vsock) {
 637                mutex_unlock(&vhost_vsock_mutex);
 638                return -EADDRINUSE;
 639        }
 640
 641        if (vsock->guest_cid)
 642                hash_del_rcu(&vsock->hash);
 643
 644        vsock->guest_cid = guest_cid;
 645        hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
 646        mutex_unlock(&vhost_vsock_mutex);
 647
 648        return 0;
 649}
 650
 651static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
 652{
 653        struct vhost_virtqueue *vq;
 654        int i;
 655
 656        if (features & ~VHOST_VSOCK_FEATURES)
 657                return -EOPNOTSUPP;
 658
 659        mutex_lock(&vsock->dev.mutex);
 660        if ((features & (1 << VHOST_F_LOG_ALL)) &&
 661            !vhost_log_access_ok(&vsock->dev)) {
 662                mutex_unlock(&vsock->dev.mutex);
 663                return -EFAULT;
 664        }
 665
 666        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 667                vq = &vsock->vqs[i];
 668                mutex_lock(&vq->mutex);
 669                vq->acked_features = features;
 670                mutex_unlock(&vq->mutex);
 671        }
 672        mutex_unlock(&vsock->dev.mutex);
 673        return 0;
 674}
 675
 676static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
 677                                  unsigned long arg)
 678{
 679        struct vhost_vsock *vsock = f->private_data;
 680        void __user *argp = (void __user *)arg;
 681        u64 guest_cid;
 682        u64 features;
 683        int start;
 684        int r;
 685
 686        switch (ioctl) {
 687        case VHOST_VSOCK_SET_GUEST_CID:
 688                if (copy_from_user(&guest_cid, argp, sizeof(guest_cid)))
 689                        return -EFAULT;
 690                return vhost_vsock_set_cid(vsock, guest_cid);
 691        case VHOST_VSOCK_SET_RUNNING:
 692                if (copy_from_user(&start, argp, sizeof(start)))
 693                        return -EFAULT;
 694                if (start)
 695                        return vhost_vsock_start(vsock);
 696                else
 697                        return vhost_vsock_stop(vsock);
 698        case VHOST_GET_FEATURES:
 699                features = VHOST_VSOCK_FEATURES;
 700                if (copy_to_user(argp, &features, sizeof(features)))
 701                        return -EFAULT;
 702                return 0;
 703        case VHOST_SET_FEATURES:
 704                if (copy_from_user(&features, argp, sizeof(features)))
 705                        return -EFAULT;
 706                return vhost_vsock_set_features(vsock, features);
 707        default:
 708                mutex_lock(&vsock->dev.mutex);
 709                r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
 710                if (r == -ENOIOCTLCMD)
 711                        r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
 712                else
 713                        vhost_vsock_flush(vsock);
 714                mutex_unlock(&vsock->dev.mutex);
 715                return r;
 716        }
 717}
 718
 719#ifdef CONFIG_COMPAT
 720static long vhost_vsock_dev_compat_ioctl(struct file *f, unsigned int ioctl,
 721                                         unsigned long arg)
 722{
 723        return vhost_vsock_dev_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
 724}
 725#endif
 726
 727static const struct file_operations vhost_vsock_fops = {
 728        .owner          = THIS_MODULE,
 729        .open           = vhost_vsock_dev_open,
 730        .release        = vhost_vsock_dev_release,
 731        .llseek         = noop_llseek,
 732        .unlocked_ioctl = vhost_vsock_dev_ioctl,
 733#ifdef CONFIG_COMPAT
 734        .compat_ioctl   = vhost_vsock_dev_compat_ioctl,
 735#endif
 736};
 737
 738static struct miscdevice vhost_vsock_misc = {
 739        .minor = VHOST_VSOCK_MINOR,
 740        .name = "vhost-vsock",
 741        .fops = &vhost_vsock_fops,
 742};
 743
 744static struct virtio_transport vhost_transport = {
 745        .transport = {
 746                .get_local_cid            = vhost_transport_get_local_cid,
 747
 748                .init                     = virtio_transport_do_socket_init,
 749                .destruct                 = virtio_transport_destruct,
 750                .release                  = virtio_transport_release,
 751                .connect                  = virtio_transport_connect,
 752                .shutdown                 = virtio_transport_shutdown,
 753                .cancel_pkt               = vhost_transport_cancel_pkt,
 754
 755                .dgram_enqueue            = virtio_transport_dgram_enqueue,
 756                .dgram_dequeue            = virtio_transport_dgram_dequeue,
 757                .dgram_bind               = virtio_transport_dgram_bind,
 758                .dgram_allow              = virtio_transport_dgram_allow,
 759
 760                .stream_enqueue           = virtio_transport_stream_enqueue,
 761                .stream_dequeue           = virtio_transport_stream_dequeue,
 762                .stream_has_data          = virtio_transport_stream_has_data,
 763                .stream_has_space         = virtio_transport_stream_has_space,
 764                .stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
 765                .stream_is_active         = virtio_transport_stream_is_active,
 766                .stream_allow             = virtio_transport_stream_allow,
 767
 768                .notify_poll_in           = virtio_transport_notify_poll_in,
 769                .notify_poll_out          = virtio_transport_notify_poll_out,
 770                .notify_recv_init         = virtio_transport_notify_recv_init,
 771                .notify_recv_pre_block    = virtio_transport_notify_recv_pre_block,
 772                .notify_recv_pre_dequeue  = virtio_transport_notify_recv_pre_dequeue,
 773                .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
 774                .notify_send_init         = virtio_transport_notify_send_init,
 775                .notify_send_pre_block    = virtio_transport_notify_send_pre_block,
 776                .notify_send_pre_enqueue  = virtio_transport_notify_send_pre_enqueue,
 777                .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
 778
 779                .set_buffer_size          = virtio_transport_set_buffer_size,
 780                .set_min_buffer_size      = virtio_transport_set_min_buffer_size,
 781                .set_max_buffer_size      = virtio_transport_set_max_buffer_size,
 782                .get_buffer_size          = virtio_transport_get_buffer_size,
 783                .get_min_buffer_size      = virtio_transport_get_min_buffer_size,
 784                .get_max_buffer_size      = virtio_transport_get_max_buffer_size,
 785        },
 786
 787        .send_pkt = vhost_transport_send_pkt,
 788};
 789
 790static int __init vhost_vsock_init(void)
 791{
 792        int ret;
 793
 794        ret = vsock_core_init(&vhost_transport.transport);
 795        if (ret < 0)
 796                return ret;
 797        return misc_register(&vhost_vsock_misc);
 798};
 799
 800static void __exit vhost_vsock_exit(void)
 801{
 802        misc_deregister(&vhost_vsock_misc);
 803        vsock_core_exit();
 804};
 805
 806module_init(vhost_vsock_init);
 807module_exit(vhost_vsock_exit);
 808MODULE_LICENSE("GPL v2");
 809MODULE_AUTHOR("Asias He");
 810MODULE_DESCRIPTION("vhost transport for vsock ");
 811MODULE_ALIAS_MISCDEV(VHOST_VSOCK_MINOR);
 812MODULE_ALIAS("devname:vhost-vsock");
 813