linux/drivers/vhost/vsock.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * vhost transport for vsock
   4 *
   5 * Copyright (C) 2013-2015 Red Hat, Inc.
   6 * Author: Asias He <asias@redhat.com>
   7 *         Stefan Hajnoczi <stefanha@redhat.com>
   8 */
   9#include <linux/miscdevice.h>
  10#include <linux/atomic.h>
  11#include <linux/module.h>
  12#include <linux/mutex.h>
  13#include <linux/vmalloc.h>
  14#include <net/sock.h>
  15#include <linux/virtio_vsock.h>
  16#include <linux/vhost.h>
  17#include <linux/hashtable.h>
  18
  19#include <net/af_vsock.h>
  20#include "vhost.h"
  21
  22#define VHOST_VSOCK_DEFAULT_HOST_CID    2
  23/* Max number of bytes transferred before requeueing the job.
  24 * Using this limit prevents one virtqueue from starving others. */
  25#define VHOST_VSOCK_WEIGHT 0x80000
  26/* Max number of packets transferred before requeueing the job.
  27 * Using this limit prevents one virtqueue from starving others with
  28 * small pkts.
  29 */
  30#define VHOST_VSOCK_PKT_WEIGHT 256
  31
  32enum {
  33        VHOST_VSOCK_FEATURES = VHOST_FEATURES,
  34};
  35
  36/* Used to track all the vhost_vsock instances on the system. */
  37static DEFINE_MUTEX(vhost_vsock_mutex);
  38static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
  39
  40struct vhost_vsock {
  41        struct vhost_dev dev;
  42        struct vhost_virtqueue vqs[2];
  43
  44        /* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
  45        struct hlist_node hash;
  46
  47        struct vhost_work send_pkt_work;
  48        spinlock_t send_pkt_list_lock;
  49        struct list_head send_pkt_list; /* host->guest pending packets */
  50
  51        atomic_t queued_replies;
  52
  53        u32 guest_cid;
  54};
  55
  56static u32 vhost_transport_get_local_cid(void)
  57{
  58        return VHOST_VSOCK_DEFAULT_HOST_CID;
  59}
  60
  61/* Callers that dereference the return value must hold vhost_vsock_mutex or the
  62 * RCU read lock.
  63 */
  64static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
  65{
  66        struct vhost_vsock *vsock;
  67
  68        hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
  69                u32 other_cid = vsock->guest_cid;
  70
  71                /* Skip instances that have no CID yet */
  72                if (other_cid == 0)
  73                        continue;
  74
  75                if (other_cid == guest_cid)
  76                        return vsock;
  77
  78        }
  79
  80        return NULL;
  81}
  82
  83static void
  84vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
  85                            struct vhost_virtqueue *vq)
  86{
  87        struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
  88        int pkts = 0, total_len = 0;
  89        bool added = false;
  90        bool restart_tx = false;
  91
  92        mutex_lock(&vq->mutex);
  93
  94        if (!vq->private_data)
  95                goto out;
  96
  97        /* Avoid further vmexits, we're already processing the virtqueue */
  98        vhost_disable_notify(&vsock->dev, vq);
  99
 100        do {
 101                struct virtio_vsock_pkt *pkt;
 102                struct iov_iter iov_iter;
 103                unsigned out, in;
 104                size_t nbytes;
 105                size_t iov_len, payload_len;
 106                int head;
 107
 108                spin_lock_bh(&vsock->send_pkt_list_lock);
 109                if (list_empty(&vsock->send_pkt_list)) {
 110                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 111                        vhost_enable_notify(&vsock->dev, vq);
 112                        break;
 113                }
 114
 115                pkt = list_first_entry(&vsock->send_pkt_list,
 116                                       struct virtio_vsock_pkt, list);
 117                list_del_init(&pkt->list);
 118                spin_unlock_bh(&vsock->send_pkt_list_lock);
 119
 120                head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
 121                                         &out, &in, NULL, NULL);
 122                if (head < 0) {
 123                        spin_lock_bh(&vsock->send_pkt_list_lock);
 124                        list_add(&pkt->list, &vsock->send_pkt_list);
 125                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 126                        break;
 127                }
 128
 129                if (head == vq->num) {
 130                        spin_lock_bh(&vsock->send_pkt_list_lock);
 131                        list_add(&pkt->list, &vsock->send_pkt_list);
 132                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 133
 134                        /* We cannot finish yet if more buffers snuck in while
 135                         * re-enabling notify.
 136                         */
 137                        if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
 138                                vhost_disable_notify(&vsock->dev, vq);
 139                                continue;
 140                        }
 141                        break;
 142                }
 143
 144                if (out) {
 145                        virtio_transport_free_pkt(pkt);
 146                        vq_err(vq, "Expected 0 output buffers, got %u\n", out);
 147                        break;
 148                }
 149
 150                iov_len = iov_length(&vq->iov[out], in);
 151                if (iov_len < sizeof(pkt->hdr)) {
 152                        virtio_transport_free_pkt(pkt);
 153                        vq_err(vq, "Buffer len [%zu] too small\n", iov_len);
 154                        break;
 155                }
 156
 157                iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len);
 158                payload_len = pkt->len - pkt->off;
 159
 160                /* If the packet is greater than the space available in the
 161                 * buffer, we split it using multiple buffers.
 162                 */
 163                if (payload_len > iov_len - sizeof(pkt->hdr))
 164                        payload_len = iov_len - sizeof(pkt->hdr);
 165
 166                /* Set the correct length in the header */
 167                pkt->hdr.len = cpu_to_le32(payload_len);
 168
 169                nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
 170                if (nbytes != sizeof(pkt->hdr)) {
 171                        virtio_transport_free_pkt(pkt);
 172                        vq_err(vq, "Faulted on copying pkt hdr\n");
 173                        break;
 174                }
 175
 176                nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len,
 177                                      &iov_iter);
 178                if (nbytes != payload_len) {
 179                        virtio_transport_free_pkt(pkt);
 180                        vq_err(vq, "Faulted on copying pkt buf\n");
 181                        break;
 182                }
 183
 184                vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len);
 185                added = true;
 186
 187                /* Deliver to monitoring devices all correctly transmitted
 188                 * packets.
 189                 */
 190                virtio_transport_deliver_tap_pkt(pkt);
 191
 192                pkt->off += payload_len;
 193                total_len += payload_len;
 194
 195                /* If we didn't send all the payload we can requeue the packet
 196                 * to send it with the next available buffer.
 197                 */
 198                if (pkt->off < pkt->len) {
 199                        spin_lock_bh(&vsock->send_pkt_list_lock);
 200                        list_add(&pkt->list, &vsock->send_pkt_list);
 201                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 202                } else {
 203                        if (pkt->reply) {
 204                                int val;
 205
 206                                val = atomic_dec_return(&vsock->queued_replies);
 207
 208                                /* Do we have resources to resume tx
 209                                 * processing?
 210                                 */
 211                                if (val + 1 == tx_vq->num)
 212                                        restart_tx = true;
 213                        }
 214
 215                        virtio_transport_free_pkt(pkt);
 216                }
 217        } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
 218        if (added)
 219                vhost_signal(&vsock->dev, vq);
 220
 221out:
 222        mutex_unlock(&vq->mutex);
 223
 224        if (restart_tx)
 225                vhost_poll_queue(&tx_vq->poll);
 226}
 227
 228static void vhost_transport_send_pkt_work(struct vhost_work *work)
 229{
 230        struct vhost_virtqueue *vq;
 231        struct vhost_vsock *vsock;
 232
 233        vsock = container_of(work, struct vhost_vsock, send_pkt_work);
 234        vq = &vsock->vqs[VSOCK_VQ_RX];
 235
 236        vhost_transport_do_send_pkt(vsock, vq);
 237}
 238
 239static int
 240vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
 241{
 242        struct vhost_vsock *vsock;
 243        int len = pkt->len;
 244
 245        rcu_read_lock();
 246
 247        /* Find the vhost_vsock according to guest context id  */
 248        vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid));
 249        if (!vsock) {
 250                rcu_read_unlock();
 251                virtio_transport_free_pkt(pkt);
 252                return -ENODEV;
 253        }
 254
 255        if (pkt->reply)
 256                atomic_inc(&vsock->queued_replies);
 257
 258        spin_lock_bh(&vsock->send_pkt_list_lock);
 259        list_add_tail(&pkt->list, &vsock->send_pkt_list);
 260        spin_unlock_bh(&vsock->send_pkt_list_lock);
 261
 262        vhost_work_queue(&vsock->dev, &vsock->send_pkt_work);
 263
 264        rcu_read_unlock();
 265        return len;
 266}
 267
 268static int
 269vhost_transport_cancel_pkt(struct vsock_sock *vsk)
 270{
 271        struct vhost_vsock *vsock;
 272        struct virtio_vsock_pkt *pkt, *n;
 273        int cnt = 0;
 274        int ret = -ENODEV;
 275        LIST_HEAD(freeme);
 276
 277        rcu_read_lock();
 278
 279        /* Find the vhost_vsock according to guest context id  */
 280        vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
 281        if (!vsock)
 282                goto out;
 283
 284        spin_lock_bh(&vsock->send_pkt_list_lock);
 285        list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
 286                if (pkt->vsk != vsk)
 287                        continue;
 288                list_move(&pkt->list, &freeme);
 289        }
 290        spin_unlock_bh(&vsock->send_pkt_list_lock);
 291
 292        list_for_each_entry_safe(pkt, n, &freeme, list) {
 293                if (pkt->reply)
 294                        cnt++;
 295                list_del(&pkt->list);
 296                virtio_transport_free_pkt(pkt);
 297        }
 298
 299        if (cnt) {
 300                struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
 301                int new_cnt;
 302
 303                new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
 304                if (new_cnt + cnt >= tx_vq->num && new_cnt < tx_vq->num)
 305                        vhost_poll_queue(&tx_vq->poll);
 306        }
 307
 308        ret = 0;
 309out:
 310        rcu_read_unlock();
 311        return ret;
 312}
 313
 314static struct virtio_vsock_pkt *
 315vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq,
 316                      unsigned int out, unsigned int in)
 317{
 318        struct virtio_vsock_pkt *pkt;
 319        struct iov_iter iov_iter;
 320        size_t nbytes;
 321        size_t len;
 322
 323        if (in != 0) {
 324                vq_err(vq, "Expected 0 input buffers, got %u\n", in);
 325                return NULL;
 326        }
 327
 328        pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
 329        if (!pkt)
 330                return NULL;
 331
 332        len = iov_length(vq->iov, out);
 333        iov_iter_init(&iov_iter, WRITE, vq->iov, out, len);
 334
 335        nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
 336        if (nbytes != sizeof(pkt->hdr)) {
 337                vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n",
 338                       sizeof(pkt->hdr), nbytes);
 339                kfree(pkt);
 340                return NULL;
 341        }
 342
 343        if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM)
 344                pkt->len = le32_to_cpu(pkt->hdr.len);
 345
 346        /* No payload */
 347        if (!pkt->len)
 348                return pkt;
 349
 350        /* The pkt is too big */
 351        if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
 352                kfree(pkt);
 353                return NULL;
 354        }
 355
 356        pkt->buf = kmalloc(pkt->len, GFP_KERNEL);
 357        if (!pkt->buf) {
 358                kfree(pkt);
 359                return NULL;
 360        }
 361
 362        pkt->buf_len = pkt->len;
 363
 364        nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
 365        if (nbytes != pkt->len) {
 366                vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
 367                       pkt->len, nbytes);
 368                virtio_transport_free_pkt(pkt);
 369                return NULL;
 370        }
 371
 372        return pkt;
 373}
 374
 375/* Is there space left for replies to rx packets? */
 376static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
 377{
 378        struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX];
 379        int val;
 380
 381        smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
 382        val = atomic_read(&vsock->queued_replies);
 383
 384        return val < vq->num;
 385}
 386
 387static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
 388{
 389        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 390                                                  poll.work);
 391        struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
 392                                                 dev);
 393        struct virtio_vsock_pkt *pkt;
 394        int head, pkts = 0, total_len = 0;
 395        unsigned int out, in;
 396        bool added = false;
 397
 398        mutex_lock(&vq->mutex);
 399
 400        if (!vq->private_data)
 401                goto out;
 402
 403        vhost_disable_notify(&vsock->dev, vq);
 404        do {
 405                u32 len;
 406
 407                if (!vhost_vsock_more_replies(vsock)) {
 408                        /* Stop tx until the device processes already
 409                         * pending replies.  Leave tx virtqueue
 410                         * callbacks disabled.
 411                         */
 412                        goto no_more_replies;
 413                }
 414
 415                head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
 416                                         &out, &in, NULL, NULL);
 417                if (head < 0)
 418                        break;
 419
 420                if (head == vq->num) {
 421                        if (unlikely(vhost_enable_notify(&vsock->dev, vq))) {
 422                                vhost_disable_notify(&vsock->dev, vq);
 423                                continue;
 424                        }
 425                        break;
 426                }
 427
 428                pkt = vhost_vsock_alloc_pkt(vq, out, in);
 429                if (!pkt) {
 430                        vq_err(vq, "Faulted on pkt\n");
 431                        continue;
 432                }
 433
 434                len = pkt->len;
 435
 436                /* Deliver to monitoring devices all received packets */
 437                virtio_transport_deliver_tap_pkt(pkt);
 438
 439                /* Only accept correctly addressed packets */
 440                if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid)
 441                        virtio_transport_recv_pkt(pkt);
 442                else
 443                        virtio_transport_free_pkt(pkt);
 444
 445                len += sizeof(pkt->hdr);
 446                vhost_add_used(vq, head, len);
 447                total_len += len;
 448                added = true;
 449        } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
 450
 451no_more_replies:
 452        if (added)
 453                vhost_signal(&vsock->dev, vq);
 454
 455out:
 456        mutex_unlock(&vq->mutex);
 457}
 458
 459static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
 460{
 461        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 462                                                poll.work);
 463        struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
 464                                                 dev);
 465
 466        vhost_transport_do_send_pkt(vsock, vq);
 467}
 468
 469static int vhost_vsock_start(struct vhost_vsock *vsock)
 470{
 471        struct vhost_virtqueue *vq;
 472        size_t i;
 473        int ret;
 474
 475        mutex_lock(&vsock->dev.mutex);
 476
 477        ret = vhost_dev_check_owner(&vsock->dev);
 478        if (ret)
 479                goto err;
 480
 481        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 482                vq = &vsock->vqs[i];
 483
 484                mutex_lock(&vq->mutex);
 485
 486                if (!vhost_vq_access_ok(vq)) {
 487                        ret = -EFAULT;
 488                        goto err_vq;
 489                }
 490
 491                if (!vq->private_data) {
 492                        vq->private_data = vsock;
 493                        ret = vhost_vq_init_access(vq);
 494                        if (ret)
 495                                goto err_vq;
 496                }
 497
 498                mutex_unlock(&vq->mutex);
 499        }
 500
 501        mutex_unlock(&vsock->dev.mutex);
 502        return 0;
 503
 504err_vq:
 505        vq->private_data = NULL;
 506        mutex_unlock(&vq->mutex);
 507
 508        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 509                vq = &vsock->vqs[i];
 510
 511                mutex_lock(&vq->mutex);
 512                vq->private_data = NULL;
 513                mutex_unlock(&vq->mutex);
 514        }
 515err:
 516        mutex_unlock(&vsock->dev.mutex);
 517        return ret;
 518}
 519
 520static int vhost_vsock_stop(struct vhost_vsock *vsock)
 521{
 522        size_t i;
 523        int ret;
 524
 525        mutex_lock(&vsock->dev.mutex);
 526
 527        ret = vhost_dev_check_owner(&vsock->dev);
 528        if (ret)
 529                goto err;
 530
 531        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 532                struct vhost_virtqueue *vq = &vsock->vqs[i];
 533
 534                mutex_lock(&vq->mutex);
 535                vq->private_data = NULL;
 536                mutex_unlock(&vq->mutex);
 537        }
 538
 539err:
 540        mutex_unlock(&vsock->dev.mutex);
 541        return ret;
 542}
 543
 544static void vhost_vsock_free(struct vhost_vsock *vsock)
 545{
 546        kvfree(vsock);
 547}
 548
 549static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
 550{
 551        struct vhost_virtqueue **vqs;
 552        struct vhost_vsock *vsock;
 553        int ret;
 554
 555        /* This struct is large and allocation could fail, fall back to vmalloc
 556         * if there is no other way.
 557         */
 558        vsock = kvmalloc(sizeof(*vsock), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
 559        if (!vsock)
 560                return -ENOMEM;
 561
 562        vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL);
 563        if (!vqs) {
 564                ret = -ENOMEM;
 565                goto out;
 566        }
 567
 568        vsock->guest_cid = 0; /* no CID assigned yet */
 569
 570        atomic_set(&vsock->queued_replies, 0);
 571
 572        vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX];
 573        vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX];
 574        vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
 575        vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
 576
 577        vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs),
 578                       UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT,
 579                       VHOST_VSOCK_WEIGHT);
 580
 581        file->private_data = vsock;
 582        spin_lock_init(&vsock->send_pkt_list_lock);
 583        INIT_LIST_HEAD(&vsock->send_pkt_list);
 584        vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work);
 585        return 0;
 586
 587out:
 588        vhost_vsock_free(vsock);
 589        return ret;
 590}
 591
 592static void vhost_vsock_flush(struct vhost_vsock *vsock)
 593{
 594        int i;
 595
 596        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++)
 597                if (vsock->vqs[i].handle_kick)
 598                        vhost_poll_flush(&vsock->vqs[i].poll);
 599        vhost_work_flush(&vsock->dev, &vsock->send_pkt_work);
 600}
 601
 602static void vhost_vsock_reset_orphans(struct sock *sk)
 603{
 604        struct vsock_sock *vsk = vsock_sk(sk);
 605
 606        /* vmci_transport.c doesn't take sk_lock here either.  At least we're
 607         * under vsock_table_lock so the sock cannot disappear while we're
 608         * executing.
 609         */
 610
 611        /* If the peer is still valid, no need to reset connection */
 612        if (vhost_vsock_get(vsk->remote_addr.svm_cid))
 613                return;
 614
 615        /* If the close timeout is pending, let it expire.  This avoids races
 616         * with the timeout callback.
 617         */
 618        if (vsk->close_work_scheduled)
 619                return;
 620
 621        sock_set_flag(sk, SOCK_DONE);
 622        vsk->peer_shutdown = SHUTDOWN_MASK;
 623        sk->sk_state = SS_UNCONNECTED;
 624        sk->sk_err = ECONNRESET;
 625        sk->sk_error_report(sk);
 626}
 627
 628static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
 629{
 630        struct vhost_vsock *vsock = file->private_data;
 631
 632        mutex_lock(&vhost_vsock_mutex);
 633        if (vsock->guest_cid)
 634                hash_del_rcu(&vsock->hash);
 635        mutex_unlock(&vhost_vsock_mutex);
 636
 637        /* Wait for other CPUs to finish using vsock */
 638        synchronize_rcu();
 639
 640        /* Iterating over all connections for all CIDs to find orphans is
 641         * inefficient.  Room for improvement here. */
 642        vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
 643
 644        vhost_vsock_stop(vsock);
 645        vhost_vsock_flush(vsock);
 646        vhost_dev_stop(&vsock->dev);
 647
 648        spin_lock_bh(&vsock->send_pkt_list_lock);
 649        while (!list_empty(&vsock->send_pkt_list)) {
 650                struct virtio_vsock_pkt *pkt;
 651
 652                pkt = list_first_entry(&vsock->send_pkt_list,
 653                                struct virtio_vsock_pkt, list);
 654                list_del_init(&pkt->list);
 655                virtio_transport_free_pkt(pkt);
 656        }
 657        spin_unlock_bh(&vsock->send_pkt_list_lock);
 658
 659        vhost_dev_cleanup(&vsock->dev);
 660        kfree(vsock->dev.vqs);
 661        vhost_vsock_free(vsock);
 662        return 0;
 663}
 664
 665static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
 666{
 667        struct vhost_vsock *other;
 668
 669        /* Refuse reserved CIDs */
 670        if (guest_cid <= VMADDR_CID_HOST ||
 671            guest_cid == U32_MAX)
 672                return -EINVAL;
 673
 674        /* 64-bit CIDs are not yet supported */
 675        if (guest_cid > U32_MAX)
 676                return -EINVAL;
 677
 678        /* Refuse if CID is already in use */
 679        mutex_lock(&vhost_vsock_mutex);
 680        other = vhost_vsock_get(guest_cid);
 681        if (other && other != vsock) {
 682                mutex_unlock(&vhost_vsock_mutex);
 683                return -EADDRINUSE;
 684        }
 685
 686        if (vsock->guest_cid)
 687                hash_del_rcu(&vsock->hash);
 688
 689        vsock->guest_cid = guest_cid;
 690        hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
 691        mutex_unlock(&vhost_vsock_mutex);
 692
 693        return 0;
 694}
 695
 696static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features)
 697{
 698        struct vhost_virtqueue *vq;
 699        int i;
 700
 701        if (features & ~VHOST_VSOCK_FEATURES)
 702                return -EOPNOTSUPP;
 703
 704        mutex_lock(&vsock->dev.mutex);
 705        if ((features & (1 << VHOST_F_LOG_ALL)) &&
 706            !vhost_log_access_ok(&vsock->dev)) {
 707                mutex_unlock(&vsock->dev.mutex);
 708                return -EFAULT;
 709        }
 710
 711        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
 712                vq = &vsock->vqs[i];
 713                mutex_lock(&vq->mutex);
 714                vq->acked_features = features;
 715                mutex_unlock(&vq->mutex);
 716        }
 717        mutex_unlock(&vsock->dev.mutex);
 718        return 0;
 719}
 720
 721static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl,
 722                                  unsigned long arg)
 723{
 724        struct vhost_vsock *vsock = f->private_data;
 725        void __user *argp = (void __user *)arg;
 726        u64 guest_cid;
 727        u64 features;
 728        int start;
 729        int r;
 730
 731        switch (ioctl) {
 732        case VHOST_VSOCK_SET_GUEST_CID:
 733                if (copy_from_user(&guest_cid, argp, sizeof(guest_cid)))
 734                        return -EFAULT;
 735                return vhost_vsock_set_cid(vsock, guest_cid);
 736        case VHOST_VSOCK_SET_RUNNING:
 737                if (copy_from_user(&start, argp, sizeof(start)))
 738                        return -EFAULT;
 739                if (start)
 740                        return vhost_vsock_start(vsock);
 741                else
 742                        return vhost_vsock_stop(vsock);
 743        case VHOST_GET_FEATURES:
 744                features = VHOST_VSOCK_FEATURES;
 745                if (copy_to_user(argp, &features, sizeof(features)))
 746                        return -EFAULT;
 747                return 0;
 748        case VHOST_SET_FEATURES:
 749                if (copy_from_user(&features, argp, sizeof(features)))
 750                        return -EFAULT;
 751                return vhost_vsock_set_features(vsock, features);
 752        default:
 753                mutex_lock(&vsock->dev.mutex);
 754                r = vhost_dev_ioctl(&vsock->dev, ioctl, argp);
 755                if (r == -ENOIOCTLCMD)
 756                        r = vhost_vring_ioctl(&vsock->dev, ioctl, argp);
 757                else
 758                        vhost_vsock_flush(vsock);
 759                mutex_unlock(&vsock->dev.mutex);
 760                return r;
 761        }
 762}
 763
 764#ifdef CONFIG_COMPAT
 765static long vhost_vsock_dev_compat_ioctl(struct file *f, unsigned int ioctl,
 766                                         unsigned long arg)
 767{
 768        return vhost_vsock_dev_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
 769}
 770#endif
 771
 772static const struct file_operations vhost_vsock_fops = {
 773        .owner          = THIS_MODULE,
 774        .open           = vhost_vsock_dev_open,
 775        .release        = vhost_vsock_dev_release,
 776        .llseek         = noop_llseek,
 777        .unlocked_ioctl = vhost_vsock_dev_ioctl,
 778#ifdef CONFIG_COMPAT
 779        .compat_ioctl   = vhost_vsock_dev_compat_ioctl,
 780#endif
 781};
 782
 783static struct miscdevice vhost_vsock_misc = {
 784        .minor = VHOST_VSOCK_MINOR,
 785        .name = "vhost-vsock",
 786        .fops = &vhost_vsock_fops,
 787};
 788
 789static struct virtio_transport vhost_transport = {
 790        .transport = {
 791                .get_local_cid            = vhost_transport_get_local_cid,
 792
 793                .init                     = virtio_transport_do_socket_init,
 794                .destruct                 = virtio_transport_destruct,
 795                .release                  = virtio_transport_release,
 796                .connect                  = virtio_transport_connect,
 797                .shutdown                 = virtio_transport_shutdown,
 798                .cancel_pkt               = vhost_transport_cancel_pkt,
 799
 800                .dgram_enqueue            = virtio_transport_dgram_enqueue,
 801                .dgram_dequeue            = virtio_transport_dgram_dequeue,
 802                .dgram_bind               = virtio_transport_dgram_bind,
 803                .dgram_allow              = virtio_transport_dgram_allow,
 804
 805                .stream_enqueue           = virtio_transport_stream_enqueue,
 806                .stream_dequeue           = virtio_transport_stream_dequeue,
 807                .stream_has_data          = virtio_transport_stream_has_data,
 808                .stream_has_space         = virtio_transport_stream_has_space,
 809                .stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
 810                .stream_is_active         = virtio_transport_stream_is_active,
 811                .stream_allow             = virtio_transport_stream_allow,
 812
 813                .notify_poll_in           = virtio_transport_notify_poll_in,
 814                .notify_poll_out          = virtio_transport_notify_poll_out,
 815                .notify_recv_init         = virtio_transport_notify_recv_init,
 816                .notify_recv_pre_block    = virtio_transport_notify_recv_pre_block,
 817                .notify_recv_pre_dequeue  = virtio_transport_notify_recv_pre_dequeue,
 818                .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
 819                .notify_send_init         = virtio_transport_notify_send_init,
 820                .notify_send_pre_block    = virtio_transport_notify_send_pre_block,
 821                .notify_send_pre_enqueue  = virtio_transport_notify_send_pre_enqueue,
 822                .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
 823
 824                .set_buffer_size          = virtio_transport_set_buffer_size,
 825                .set_min_buffer_size      = virtio_transport_set_min_buffer_size,
 826                .set_max_buffer_size      = virtio_transport_set_max_buffer_size,
 827                .get_buffer_size          = virtio_transport_get_buffer_size,
 828                .get_min_buffer_size      = virtio_transport_get_min_buffer_size,
 829                .get_max_buffer_size      = virtio_transport_get_max_buffer_size,
 830        },
 831
 832        .send_pkt = vhost_transport_send_pkt,
 833};
 834
 835static int __init vhost_vsock_init(void)
 836{
 837        int ret;
 838
 839        ret = vsock_core_init(&vhost_transport.transport);
 840        if (ret < 0)
 841                return ret;
 842        return misc_register(&vhost_vsock_misc);
 843};
 844
 845static void __exit vhost_vsock_exit(void)
 846{
 847        misc_deregister(&vhost_vsock_misc);
 848        vsock_core_exit();
 849};
 850
 851module_init(vhost_vsock_init);
 852module_exit(vhost_vsock_exit);
 853MODULE_LICENSE("GPL v2");
 854MODULE_AUTHOR("Asias He");
 855MODULE_DESCRIPTION("vhost transport for vsock ");
 856MODULE_ALIAS_MISCDEV(VHOST_VSOCK_MINOR);
 857MODULE_ALIAS("devname:vhost-vsock");
 858