linux/net/vmw_vsock/virtio_transport.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * virtio transport for vsock
   4 *
   5 * Copyright (C) 2013-2015 Red Hat, Inc.
   6 * Author: Asias He <asias@redhat.com>
   7 *         Stefan Hajnoczi <stefanha@redhat.com>
   8 *
   9 * Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s
  10 * early virtio-vsock proof-of-concept bits.
  11 */
  12#include <linux/spinlock.h>
  13#include <linux/module.h>
  14#include <linux/list.h>
  15#include <linux/atomic.h>
  16#include <linux/virtio.h>
  17#include <linux/virtio_ids.h>
  18#include <linux/virtio_config.h>
  19#include <linux/virtio_vsock.h>
  20#include <net/sock.h>
  21#include <linux/mutex.h>
  22#include <net/af_vsock.h>
  23
  24static struct workqueue_struct *virtio_vsock_workqueue;
  25static struct virtio_vsock *the_virtio_vsock;
  26static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */
  27
  28struct virtio_vsock {
  29        struct virtio_device *vdev;
  30        struct virtqueue *vqs[VSOCK_VQ_MAX];
  31
  32        /* Virtqueue processing is deferred to a workqueue */
  33        struct work_struct tx_work;
  34        struct work_struct rx_work;
  35        struct work_struct event_work;
  36
  37        /* The following fields are protected by tx_lock.  vqs[VSOCK_VQ_TX]
  38         * must be accessed with tx_lock held.
  39         */
  40        struct mutex tx_lock;
  41        bool tx_run;
  42
  43        struct work_struct send_pkt_work;
  44        spinlock_t send_pkt_list_lock;
  45        struct list_head send_pkt_list;
  46
  47        struct work_struct loopback_work;
  48        spinlock_t loopback_list_lock; /* protects loopback_list */
  49        struct list_head loopback_list;
  50
  51        atomic_t queued_replies;
  52
  53        /* The following fields are protected by rx_lock.  vqs[VSOCK_VQ_RX]
  54         * must be accessed with rx_lock held.
  55         */
  56        struct mutex rx_lock;
  57        bool rx_run;
  58        int rx_buf_nr;
  59        int rx_buf_max_nr;
  60
  61        /* The following fields are protected by event_lock.
  62         * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held.
  63         */
  64        struct mutex event_lock;
  65        bool event_run;
  66        struct virtio_vsock_event event_list[8];
  67
  68        u32 guest_cid;
  69};
  70
  71static u32 virtio_transport_get_local_cid(void)
  72{
  73        struct virtio_vsock *vsock;
  74        u32 ret;
  75
  76        rcu_read_lock();
  77        vsock = rcu_dereference(the_virtio_vsock);
  78        if (!vsock) {
  79                ret = VMADDR_CID_ANY;
  80                goto out_rcu;
  81        }
  82
  83        ret = vsock->guest_cid;
  84out_rcu:
  85        rcu_read_unlock();
  86        return ret;
  87}
  88
  89static void virtio_transport_loopback_work(struct work_struct *work)
  90{
  91        struct virtio_vsock *vsock =
  92                container_of(work, struct virtio_vsock, loopback_work);
  93        LIST_HEAD(pkts);
  94
  95        spin_lock_bh(&vsock->loopback_list_lock);
  96        list_splice_init(&vsock->loopback_list, &pkts);
  97        spin_unlock_bh(&vsock->loopback_list_lock);
  98
  99        mutex_lock(&vsock->rx_lock);
 100
 101        if (!vsock->rx_run)
 102                goto out;
 103
 104        while (!list_empty(&pkts)) {
 105                struct virtio_vsock_pkt *pkt;
 106
 107                pkt = list_first_entry(&pkts, struct virtio_vsock_pkt, list);
 108                list_del_init(&pkt->list);
 109
 110                virtio_transport_recv_pkt(pkt);
 111        }
 112out:
 113        mutex_unlock(&vsock->rx_lock);
 114}
 115
 116static int virtio_transport_send_pkt_loopback(struct virtio_vsock *vsock,
 117                                              struct virtio_vsock_pkt *pkt)
 118{
 119        int len = pkt->len;
 120
 121        spin_lock_bh(&vsock->loopback_list_lock);
 122        list_add_tail(&pkt->list, &vsock->loopback_list);
 123        spin_unlock_bh(&vsock->loopback_list_lock);
 124
 125        queue_work(virtio_vsock_workqueue, &vsock->loopback_work);
 126
 127        return len;
 128}
 129
 130static void
 131virtio_transport_send_pkt_work(struct work_struct *work)
 132{
 133        struct virtio_vsock *vsock =
 134                container_of(work, struct virtio_vsock, send_pkt_work);
 135        struct virtqueue *vq;
 136        bool added = false;
 137        bool restart_rx = false;
 138
 139        mutex_lock(&vsock->tx_lock);
 140
 141        if (!vsock->tx_run)
 142                goto out;
 143
 144        vq = vsock->vqs[VSOCK_VQ_TX];
 145
 146        for (;;) {
 147                struct virtio_vsock_pkt *pkt;
 148                struct scatterlist hdr, buf, *sgs[2];
 149                int ret, in_sg = 0, out_sg = 0;
 150                bool reply;
 151
 152                spin_lock_bh(&vsock->send_pkt_list_lock);
 153                if (list_empty(&vsock->send_pkt_list)) {
 154                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 155                        break;
 156                }
 157
 158                pkt = list_first_entry(&vsock->send_pkt_list,
 159                                       struct virtio_vsock_pkt, list);
 160                list_del_init(&pkt->list);
 161                spin_unlock_bh(&vsock->send_pkt_list_lock);
 162
 163                virtio_transport_deliver_tap_pkt(pkt);
 164
 165                reply = pkt->reply;
 166
 167                sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
 168                sgs[out_sg++] = &hdr;
 169                if (pkt->buf) {
 170                        sg_init_one(&buf, pkt->buf, pkt->len);
 171                        sgs[out_sg++] = &buf;
 172                }
 173
 174                ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, pkt, GFP_KERNEL);
 175                /* Usually this means that there is no more space available in
 176                 * the vq
 177                 */
 178                if (ret < 0) {
 179                        spin_lock_bh(&vsock->send_pkt_list_lock);
 180                        list_add(&pkt->list, &vsock->send_pkt_list);
 181                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 182                        break;
 183                }
 184
 185                if (reply) {
 186                        struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
 187                        int val;
 188
 189                        val = atomic_dec_return(&vsock->queued_replies);
 190
 191                        /* Do we now have resources to resume rx processing? */
 192                        if (val + 1 == virtqueue_get_vring_size(rx_vq))
 193                                restart_rx = true;
 194                }
 195
 196                added = true;
 197        }
 198
 199        if (added)
 200                virtqueue_kick(vq);
 201
 202out:
 203        mutex_unlock(&vsock->tx_lock);
 204
 205        if (restart_rx)
 206                queue_work(virtio_vsock_workqueue, &vsock->rx_work);
 207}
 208
 209static int
 210virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
 211{
 212        struct virtio_vsock *vsock;
 213        int len = pkt->len;
 214
 215        rcu_read_lock();
 216        vsock = rcu_dereference(the_virtio_vsock);
 217        if (!vsock) {
 218                virtio_transport_free_pkt(pkt);
 219                len = -ENODEV;
 220                goto out_rcu;
 221        }
 222
 223        if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) {
 224                len = virtio_transport_send_pkt_loopback(vsock, pkt);
 225                goto out_rcu;
 226        }
 227
 228        if (pkt->reply)
 229                atomic_inc(&vsock->queued_replies);
 230
 231        spin_lock_bh(&vsock->send_pkt_list_lock);
 232        list_add_tail(&pkt->list, &vsock->send_pkt_list);
 233        spin_unlock_bh(&vsock->send_pkt_list_lock);
 234
 235        queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
 236
 237out_rcu:
 238        rcu_read_unlock();
 239        return len;
 240}
 241
 242static int
 243virtio_transport_cancel_pkt(struct vsock_sock *vsk)
 244{
 245        struct virtio_vsock *vsock;
 246        struct virtio_vsock_pkt *pkt, *n;
 247        int cnt = 0, ret;
 248        LIST_HEAD(freeme);
 249
 250        rcu_read_lock();
 251        vsock = rcu_dereference(the_virtio_vsock);
 252        if (!vsock) {
 253                ret = -ENODEV;
 254                goto out_rcu;
 255        }
 256
 257        spin_lock_bh(&vsock->send_pkt_list_lock);
 258        list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
 259                if (pkt->vsk != vsk)
 260                        continue;
 261                list_move(&pkt->list, &freeme);
 262        }
 263        spin_unlock_bh(&vsock->send_pkt_list_lock);
 264
 265        list_for_each_entry_safe(pkt, n, &freeme, list) {
 266                if (pkt->reply)
 267                        cnt++;
 268                list_del(&pkt->list);
 269                virtio_transport_free_pkt(pkt);
 270        }
 271
 272        if (cnt) {
 273                struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
 274                int new_cnt;
 275
 276                new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
 277                if (new_cnt + cnt >= virtqueue_get_vring_size(rx_vq) &&
 278                    new_cnt < virtqueue_get_vring_size(rx_vq))
 279                        queue_work(virtio_vsock_workqueue, &vsock->rx_work);
 280        }
 281
 282        ret = 0;
 283
 284out_rcu:
 285        rcu_read_unlock();
 286        return ret;
 287}
 288
 289static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
 290{
 291        int buf_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
 292        struct virtio_vsock_pkt *pkt;
 293        struct scatterlist hdr, buf, *sgs[2];
 294        struct virtqueue *vq;
 295        int ret;
 296
 297        vq = vsock->vqs[VSOCK_VQ_RX];
 298
 299        do {
 300                pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
 301                if (!pkt)
 302                        break;
 303
 304                pkt->buf = kmalloc(buf_len, GFP_KERNEL);
 305                if (!pkt->buf) {
 306                        virtio_transport_free_pkt(pkt);
 307                        break;
 308                }
 309
 310                pkt->len = buf_len;
 311
 312                sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
 313                sgs[0] = &hdr;
 314
 315                sg_init_one(&buf, pkt->buf, buf_len);
 316                sgs[1] = &buf;
 317                ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL);
 318                if (ret) {
 319                        virtio_transport_free_pkt(pkt);
 320                        break;
 321                }
 322                vsock->rx_buf_nr++;
 323        } while (vq->num_free);
 324        if (vsock->rx_buf_nr > vsock->rx_buf_max_nr)
 325                vsock->rx_buf_max_nr = vsock->rx_buf_nr;
 326        virtqueue_kick(vq);
 327}
 328
 329static void virtio_transport_tx_work(struct work_struct *work)
 330{
 331        struct virtio_vsock *vsock =
 332                container_of(work, struct virtio_vsock, tx_work);
 333        struct virtqueue *vq;
 334        bool added = false;
 335
 336        vq = vsock->vqs[VSOCK_VQ_TX];
 337        mutex_lock(&vsock->tx_lock);
 338
 339        if (!vsock->tx_run)
 340                goto out;
 341
 342        do {
 343                struct virtio_vsock_pkt *pkt;
 344                unsigned int len;
 345
 346                virtqueue_disable_cb(vq);
 347                while ((pkt = virtqueue_get_buf(vq, &len)) != NULL) {
 348                        virtio_transport_free_pkt(pkt);
 349                        added = true;
 350                }
 351        } while (!virtqueue_enable_cb(vq));
 352
 353out:
 354        mutex_unlock(&vsock->tx_lock);
 355
 356        if (added)
 357                queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
 358}
 359
 360/* Is there space left for replies to rx packets? */
 361static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
 362{
 363        struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX];
 364        int val;
 365
 366        smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
 367        val = atomic_read(&vsock->queued_replies);
 368
 369        return val < virtqueue_get_vring_size(vq);
 370}
 371
 372static void virtio_transport_rx_work(struct work_struct *work)
 373{
 374        struct virtio_vsock *vsock =
 375                container_of(work, struct virtio_vsock, rx_work);
 376        struct virtqueue *vq;
 377
 378        vq = vsock->vqs[VSOCK_VQ_RX];
 379
 380        mutex_lock(&vsock->rx_lock);
 381
 382        if (!vsock->rx_run)
 383                goto out;
 384
 385        do {
 386                virtqueue_disable_cb(vq);
 387                for (;;) {
 388                        struct virtio_vsock_pkt *pkt;
 389                        unsigned int len;
 390
 391                        if (!virtio_transport_more_replies(vsock)) {
 392                                /* Stop rx until the device processes already
 393                                 * pending replies.  Leave rx virtqueue
 394                                 * callbacks disabled.
 395                                 */
 396                                goto out;
 397                        }
 398
 399                        pkt = virtqueue_get_buf(vq, &len);
 400                        if (!pkt) {
 401                                break;
 402                        }
 403
 404                        vsock->rx_buf_nr--;
 405
 406                        /* Drop short/long packets */
 407                        if (unlikely(len < sizeof(pkt->hdr) ||
 408                                     len > sizeof(pkt->hdr) + pkt->len)) {
 409                                virtio_transport_free_pkt(pkt);
 410                                continue;
 411                        }
 412
 413                        pkt->len = len - sizeof(pkt->hdr);
 414                        virtio_transport_deliver_tap_pkt(pkt);
 415                        virtio_transport_recv_pkt(pkt);
 416                }
 417        } while (!virtqueue_enable_cb(vq));
 418
 419out:
 420        if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
 421                virtio_vsock_rx_fill(vsock);
 422        mutex_unlock(&vsock->rx_lock);
 423}
 424
 425/* event_lock must be held */
 426static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock,
 427                                       struct virtio_vsock_event *event)
 428{
 429        struct scatterlist sg;
 430        struct virtqueue *vq;
 431
 432        vq = vsock->vqs[VSOCK_VQ_EVENT];
 433
 434        sg_init_one(&sg, event, sizeof(*event));
 435
 436        return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL);
 437}
 438
 439/* event_lock must be held */
 440static void virtio_vsock_event_fill(struct virtio_vsock *vsock)
 441{
 442        size_t i;
 443
 444        for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) {
 445                struct virtio_vsock_event *event = &vsock->event_list[i];
 446
 447                virtio_vsock_event_fill_one(vsock, event);
 448        }
 449
 450        virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
 451}
 452
 453static void virtio_vsock_reset_sock(struct sock *sk)
 454{
 455        lock_sock(sk);
 456        sk->sk_state = TCP_CLOSE;
 457        sk->sk_err = ECONNRESET;
 458        sk->sk_error_report(sk);
 459        release_sock(sk);
 460}
 461
 462static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock)
 463{
 464        struct virtio_device *vdev = vsock->vdev;
 465        __le64 guest_cid;
 466
 467        vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid),
 468                          &guest_cid, sizeof(guest_cid));
 469        vsock->guest_cid = le64_to_cpu(guest_cid);
 470}
 471
 472/* event_lock must be held */
 473static void virtio_vsock_event_handle(struct virtio_vsock *vsock,
 474                                      struct virtio_vsock_event *event)
 475{
 476        switch (le32_to_cpu(event->id)) {
 477        case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET:
 478                virtio_vsock_update_guest_cid(vsock);
 479                vsock_for_each_connected_socket(virtio_vsock_reset_sock);
 480                break;
 481        }
 482}
 483
 484static void virtio_transport_event_work(struct work_struct *work)
 485{
 486        struct virtio_vsock *vsock =
 487                container_of(work, struct virtio_vsock, event_work);
 488        struct virtqueue *vq;
 489
 490        vq = vsock->vqs[VSOCK_VQ_EVENT];
 491
 492        mutex_lock(&vsock->event_lock);
 493
 494        if (!vsock->event_run)
 495                goto out;
 496
 497        do {
 498                struct virtio_vsock_event *event;
 499                unsigned int len;
 500
 501                virtqueue_disable_cb(vq);
 502                while ((event = virtqueue_get_buf(vq, &len)) != NULL) {
 503                        if (len == sizeof(*event))
 504                                virtio_vsock_event_handle(vsock, event);
 505
 506                        virtio_vsock_event_fill_one(vsock, event);
 507                }
 508        } while (!virtqueue_enable_cb(vq));
 509
 510        virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
 511out:
 512        mutex_unlock(&vsock->event_lock);
 513}
 514
 515static void virtio_vsock_event_done(struct virtqueue *vq)
 516{
 517        struct virtio_vsock *vsock = vq->vdev->priv;
 518
 519        if (!vsock)
 520                return;
 521        queue_work(virtio_vsock_workqueue, &vsock->event_work);
 522}
 523
 524static void virtio_vsock_tx_done(struct virtqueue *vq)
 525{
 526        struct virtio_vsock *vsock = vq->vdev->priv;
 527
 528        if (!vsock)
 529                return;
 530        queue_work(virtio_vsock_workqueue, &vsock->tx_work);
 531}
 532
 533static void virtio_vsock_rx_done(struct virtqueue *vq)
 534{
 535        struct virtio_vsock *vsock = vq->vdev->priv;
 536
 537        if (!vsock)
 538                return;
 539        queue_work(virtio_vsock_workqueue, &vsock->rx_work);
 540}
 541
 542static struct virtio_transport virtio_transport = {
 543        .transport = {
 544                .get_local_cid            = virtio_transport_get_local_cid,
 545
 546                .init                     = virtio_transport_do_socket_init,
 547                .destruct                 = virtio_transport_destruct,
 548                .release                  = virtio_transport_release,
 549                .connect                  = virtio_transport_connect,
 550                .shutdown                 = virtio_transport_shutdown,
 551                .cancel_pkt               = virtio_transport_cancel_pkt,
 552
 553                .dgram_bind               = virtio_transport_dgram_bind,
 554                .dgram_dequeue            = virtio_transport_dgram_dequeue,
 555                .dgram_enqueue            = virtio_transport_dgram_enqueue,
 556                .dgram_allow              = virtio_transport_dgram_allow,
 557
 558                .stream_dequeue           = virtio_transport_stream_dequeue,
 559                .stream_enqueue           = virtio_transport_stream_enqueue,
 560                .stream_has_data          = virtio_transport_stream_has_data,
 561                .stream_has_space         = virtio_transport_stream_has_space,
 562                .stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
 563                .stream_is_active         = virtio_transport_stream_is_active,
 564                .stream_allow             = virtio_transport_stream_allow,
 565
 566                .notify_poll_in           = virtio_transport_notify_poll_in,
 567                .notify_poll_out          = virtio_transport_notify_poll_out,
 568                .notify_recv_init         = virtio_transport_notify_recv_init,
 569                .notify_recv_pre_block    = virtio_transport_notify_recv_pre_block,
 570                .notify_recv_pre_dequeue  = virtio_transport_notify_recv_pre_dequeue,
 571                .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
 572                .notify_send_init         = virtio_transport_notify_send_init,
 573                .notify_send_pre_block    = virtio_transport_notify_send_pre_block,
 574                .notify_send_pre_enqueue  = virtio_transport_notify_send_pre_enqueue,
 575                .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
 576
 577                .set_buffer_size          = virtio_transport_set_buffer_size,
 578                .set_min_buffer_size      = virtio_transport_set_min_buffer_size,
 579                .set_max_buffer_size      = virtio_transport_set_max_buffer_size,
 580                .get_buffer_size          = virtio_transport_get_buffer_size,
 581                .get_min_buffer_size      = virtio_transport_get_min_buffer_size,
 582                .get_max_buffer_size      = virtio_transport_get_max_buffer_size,
 583        },
 584
 585        .send_pkt = virtio_transport_send_pkt,
 586};
 587
 588static int virtio_vsock_probe(struct virtio_device *vdev)
 589{
 590        vq_callback_t *callbacks[] = {
 591                virtio_vsock_rx_done,
 592                virtio_vsock_tx_done,
 593                virtio_vsock_event_done,
 594        };
 595        static const char * const names[] = {
 596                "rx",
 597                "tx",
 598                "event",
 599        };
 600        struct virtio_vsock *vsock = NULL;
 601        int ret;
 602
 603        ret = mutex_lock_interruptible(&the_virtio_vsock_mutex);
 604        if (ret)
 605                return ret;
 606
 607        /* Only one virtio-vsock device per guest is supported */
 608        if (rcu_dereference_protected(the_virtio_vsock,
 609                                lockdep_is_held(&the_virtio_vsock_mutex))) {
 610                ret = -EBUSY;
 611                goto out;
 612        }
 613
 614        vsock = kzalloc(sizeof(*vsock), GFP_KERNEL);
 615        if (!vsock) {
 616                ret = -ENOMEM;
 617                goto out;
 618        }
 619
 620        vsock->vdev = vdev;
 621
 622        ret = virtio_find_vqs(vsock->vdev, VSOCK_VQ_MAX,
 623                              vsock->vqs, callbacks, names,
 624                              NULL);
 625        if (ret < 0)
 626                goto out;
 627
 628        virtio_vsock_update_guest_cid(vsock);
 629
 630        vsock->rx_buf_nr = 0;
 631        vsock->rx_buf_max_nr = 0;
 632        atomic_set(&vsock->queued_replies, 0);
 633
 634        mutex_init(&vsock->tx_lock);
 635        mutex_init(&vsock->rx_lock);
 636        mutex_init(&vsock->event_lock);
 637        spin_lock_init(&vsock->send_pkt_list_lock);
 638        INIT_LIST_HEAD(&vsock->send_pkt_list);
 639        spin_lock_init(&vsock->loopback_list_lock);
 640        INIT_LIST_HEAD(&vsock->loopback_list);
 641        INIT_WORK(&vsock->rx_work, virtio_transport_rx_work);
 642        INIT_WORK(&vsock->tx_work, virtio_transport_tx_work);
 643        INIT_WORK(&vsock->event_work, virtio_transport_event_work);
 644        INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work);
 645        INIT_WORK(&vsock->loopback_work, virtio_transport_loopback_work);
 646
 647        mutex_lock(&vsock->tx_lock);
 648        vsock->tx_run = true;
 649        mutex_unlock(&vsock->tx_lock);
 650
 651        mutex_lock(&vsock->rx_lock);
 652        virtio_vsock_rx_fill(vsock);
 653        vsock->rx_run = true;
 654        mutex_unlock(&vsock->rx_lock);
 655
 656        mutex_lock(&vsock->event_lock);
 657        virtio_vsock_event_fill(vsock);
 658        vsock->event_run = true;
 659        mutex_unlock(&vsock->event_lock);
 660
 661        vdev->priv = vsock;
 662        rcu_assign_pointer(the_virtio_vsock, vsock);
 663
 664        mutex_unlock(&the_virtio_vsock_mutex);
 665        return 0;
 666
 667out:
 668        kfree(vsock);
 669        mutex_unlock(&the_virtio_vsock_mutex);
 670        return ret;
 671}
 672
 673static void virtio_vsock_remove(struct virtio_device *vdev)
 674{
 675        struct virtio_vsock *vsock = vdev->priv;
 676        struct virtio_vsock_pkt *pkt;
 677
 678        mutex_lock(&the_virtio_vsock_mutex);
 679
 680        vdev->priv = NULL;
 681        rcu_assign_pointer(the_virtio_vsock, NULL);
 682        synchronize_rcu();
 683
 684        /* Reset all connected sockets when the device disappear */
 685        vsock_for_each_connected_socket(virtio_vsock_reset_sock);
 686
 687        /* Stop all work handlers to make sure no one is accessing the device,
 688         * so we can safely call vdev->config->reset().
 689         */
 690        mutex_lock(&vsock->rx_lock);
 691        vsock->rx_run = false;
 692        mutex_unlock(&vsock->rx_lock);
 693
 694        mutex_lock(&vsock->tx_lock);
 695        vsock->tx_run = false;
 696        mutex_unlock(&vsock->tx_lock);
 697
 698        mutex_lock(&vsock->event_lock);
 699        vsock->event_run = false;
 700        mutex_unlock(&vsock->event_lock);
 701
 702        /* Flush all device writes and interrupts, device will not use any
 703         * more buffers.
 704         */
 705        vdev->config->reset(vdev);
 706
 707        mutex_lock(&vsock->rx_lock);
 708        while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX])))
 709                virtio_transport_free_pkt(pkt);
 710        mutex_unlock(&vsock->rx_lock);
 711
 712        mutex_lock(&vsock->tx_lock);
 713        while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX])))
 714                virtio_transport_free_pkt(pkt);
 715        mutex_unlock(&vsock->tx_lock);
 716
 717        spin_lock_bh(&vsock->send_pkt_list_lock);
 718        while (!list_empty(&vsock->send_pkt_list)) {
 719                pkt = list_first_entry(&vsock->send_pkt_list,
 720                                       struct virtio_vsock_pkt, list);
 721                list_del(&pkt->list);
 722                virtio_transport_free_pkt(pkt);
 723        }
 724        spin_unlock_bh(&vsock->send_pkt_list_lock);
 725
 726        spin_lock_bh(&vsock->loopback_list_lock);
 727        while (!list_empty(&vsock->loopback_list)) {
 728                pkt = list_first_entry(&vsock->loopback_list,
 729                                       struct virtio_vsock_pkt, list);
 730                list_del(&pkt->list);
 731                virtio_transport_free_pkt(pkt);
 732        }
 733        spin_unlock_bh(&vsock->loopback_list_lock);
 734
 735        /* Delete virtqueues and flush outstanding callbacks if any */
 736        vdev->config->del_vqs(vdev);
 737
 738        /* Other works can be queued before 'config->del_vqs()', so we flush
 739         * all works before to free the vsock object to avoid use after free.
 740         */
 741        flush_work(&vsock->loopback_work);
 742        flush_work(&vsock->rx_work);
 743        flush_work(&vsock->tx_work);
 744        flush_work(&vsock->event_work);
 745        flush_work(&vsock->send_pkt_work);
 746
 747        mutex_unlock(&the_virtio_vsock_mutex);
 748
 749        kfree(vsock);
 750}
 751
 752static struct virtio_device_id id_table[] = {
 753        { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID },
 754        { 0 },
 755};
 756
 757static unsigned int features[] = {
 758};
 759
 760static struct virtio_driver virtio_vsock_driver = {
 761        .feature_table = features,
 762        .feature_table_size = ARRAY_SIZE(features),
 763        .driver.name = KBUILD_MODNAME,
 764        .driver.owner = THIS_MODULE,
 765        .id_table = id_table,
 766        .probe = virtio_vsock_probe,
 767        .remove = virtio_vsock_remove,
 768};
 769
 770static int __init virtio_vsock_init(void)
 771{
 772        int ret;
 773
 774        virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0);
 775        if (!virtio_vsock_workqueue)
 776                return -ENOMEM;
 777
 778        ret = vsock_core_init(&virtio_transport.transport);
 779        if (ret)
 780                goto out_wq;
 781
 782        ret = register_virtio_driver(&virtio_vsock_driver);
 783        if (ret)
 784                goto out_vci;
 785
 786        return 0;
 787
 788out_vci:
 789        vsock_core_exit();
 790out_wq:
 791        destroy_workqueue(virtio_vsock_workqueue);
 792        return ret;
 793}
 794
 795static void __exit virtio_vsock_exit(void)
 796{
 797        unregister_virtio_driver(&virtio_vsock_driver);
 798        vsock_core_exit();
 799        destroy_workqueue(virtio_vsock_workqueue);
 800}
 801
 802module_init(virtio_vsock_init);
 803module_exit(virtio_vsock_exit);
 804MODULE_LICENSE("GPL v2");
 805MODULE_AUTHOR("Asias He");
 806MODULE_DESCRIPTION("virtio transport for vsock");
 807MODULE_DEVICE_TABLE(virtio, id_table);
 808