linux/net/vmw_vsock/virtio_transport.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * virtio transport for vsock
   4 *
   5 * Copyright (C) 2013-2015 Red Hat, Inc.
   6 * Author: Asias He <asias@redhat.com>
   7 *         Stefan Hajnoczi <stefanha@redhat.com>
   8 *
   9 * Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s
  10 * early virtio-vsock proof-of-concept bits.
  11 */
  12#include <linux/spinlock.h>
  13#include <linux/module.h>
  14#include <linux/list.h>
  15#include <linux/atomic.h>
  16#include <linux/virtio.h>
  17#include <linux/virtio_ids.h>
  18#include <linux/virtio_config.h>
  19#include <linux/virtio_vsock.h>
  20#include <net/sock.h>
  21#include <linux/mutex.h>
  22#include <net/af_vsock.h>
  23
  24static struct workqueue_struct *virtio_vsock_workqueue;
  25static struct virtio_vsock __rcu *the_virtio_vsock;
  26static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */
  27
  28struct virtio_vsock {
  29        struct virtio_device *vdev;
  30        struct virtqueue *vqs[VSOCK_VQ_MAX];
  31
  32        /* Virtqueue processing is deferred to a workqueue */
  33        struct work_struct tx_work;
  34        struct work_struct rx_work;
  35        struct work_struct event_work;
  36
  37        /* The following fields are protected by tx_lock.  vqs[VSOCK_VQ_TX]
  38         * must be accessed with tx_lock held.
  39         */
  40        struct mutex tx_lock;
  41        bool tx_run;
  42
  43        struct work_struct send_pkt_work;
  44        spinlock_t send_pkt_list_lock;
  45        struct list_head send_pkt_list;
  46
  47        atomic_t queued_replies;
  48
  49        /* The following fields are protected by rx_lock.  vqs[VSOCK_VQ_RX]
  50         * must be accessed with rx_lock held.
  51         */
  52        struct mutex rx_lock;
  53        bool rx_run;
  54        int rx_buf_nr;
  55        int rx_buf_max_nr;
  56
  57        /* The following fields are protected by event_lock.
  58         * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held.
  59         */
  60        struct mutex event_lock;
  61        bool event_run;
  62        struct virtio_vsock_event event_list[8];
  63
  64        u32 guest_cid;
  65        bool seqpacket_allow;
  66};
  67
  68static u32 virtio_transport_get_local_cid(void)
  69{
  70        struct virtio_vsock *vsock;
  71        u32 ret;
  72
  73        rcu_read_lock();
  74        vsock = rcu_dereference(the_virtio_vsock);
  75        if (!vsock) {
  76                ret = VMADDR_CID_ANY;
  77                goto out_rcu;
  78        }
  79
  80        ret = vsock->guest_cid;
  81out_rcu:
  82        rcu_read_unlock();
  83        return ret;
  84}
  85
  86static void
  87virtio_transport_send_pkt_work(struct work_struct *work)
  88{
  89        struct virtio_vsock *vsock =
  90                container_of(work, struct virtio_vsock, send_pkt_work);
  91        struct virtqueue *vq;
  92        bool added = false;
  93        bool restart_rx = false;
  94
  95        mutex_lock(&vsock->tx_lock);
  96
  97        if (!vsock->tx_run)
  98                goto out;
  99
 100        vq = vsock->vqs[VSOCK_VQ_TX];
 101
 102        for (;;) {
 103                struct virtio_vsock_pkt *pkt;
 104                struct scatterlist hdr, buf, *sgs[2];
 105                int ret, in_sg = 0, out_sg = 0;
 106                bool reply;
 107
 108                spin_lock_bh(&vsock->send_pkt_list_lock);
 109                if (list_empty(&vsock->send_pkt_list)) {
 110                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 111                        break;
 112                }
 113
 114                pkt = list_first_entry(&vsock->send_pkt_list,
 115                                       struct virtio_vsock_pkt, list);
 116                list_del_init(&pkt->list);
 117                spin_unlock_bh(&vsock->send_pkt_list_lock);
 118
 119                virtio_transport_deliver_tap_pkt(pkt);
 120
 121                reply = pkt->reply;
 122
 123                sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
 124                sgs[out_sg++] = &hdr;
 125                if (pkt->buf) {
 126                        sg_init_one(&buf, pkt->buf, pkt->len);
 127                        sgs[out_sg++] = &buf;
 128                }
 129
 130                ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, pkt, GFP_KERNEL);
 131                /* Usually this means that there is no more space available in
 132                 * the vq
 133                 */
 134                if (ret < 0) {
 135                        spin_lock_bh(&vsock->send_pkt_list_lock);
 136                        list_add(&pkt->list, &vsock->send_pkt_list);
 137                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 138                        break;
 139                }
 140
 141                if (reply) {
 142                        struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
 143                        int val;
 144
 145                        val = atomic_dec_return(&vsock->queued_replies);
 146
 147                        /* Do we now have resources to resume rx processing? */
 148                        if (val + 1 == virtqueue_get_vring_size(rx_vq))
 149                                restart_rx = true;
 150                }
 151
 152                added = true;
 153        }
 154
 155        if (added)
 156                virtqueue_kick(vq);
 157
 158out:
 159        mutex_unlock(&vsock->tx_lock);
 160
 161        if (restart_rx)
 162                queue_work(virtio_vsock_workqueue, &vsock->rx_work);
 163}
 164
 165static int
 166virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
 167{
 168        struct virtio_vsock *vsock;
 169        int len = pkt->len;
 170
 171        rcu_read_lock();
 172        vsock = rcu_dereference(the_virtio_vsock);
 173        if (!vsock) {
 174                virtio_transport_free_pkt(pkt);
 175                len = -ENODEV;
 176                goto out_rcu;
 177        }
 178
 179        if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) {
 180                virtio_transport_free_pkt(pkt);
 181                len = -ENODEV;
 182                goto out_rcu;
 183        }
 184
 185        if (pkt->reply)
 186                atomic_inc(&vsock->queued_replies);
 187
 188        spin_lock_bh(&vsock->send_pkt_list_lock);
 189        list_add_tail(&pkt->list, &vsock->send_pkt_list);
 190        spin_unlock_bh(&vsock->send_pkt_list_lock);
 191
 192        queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
 193
 194out_rcu:
 195        rcu_read_unlock();
 196        return len;
 197}
 198
 199static int
 200virtio_transport_cancel_pkt(struct vsock_sock *vsk)
 201{
 202        struct virtio_vsock *vsock;
 203        struct virtio_vsock_pkt *pkt, *n;
 204        int cnt = 0, ret;
 205        LIST_HEAD(freeme);
 206
 207        rcu_read_lock();
 208        vsock = rcu_dereference(the_virtio_vsock);
 209        if (!vsock) {
 210                ret = -ENODEV;
 211                goto out_rcu;
 212        }
 213
 214        spin_lock_bh(&vsock->send_pkt_list_lock);
 215        list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
 216                if (pkt->vsk != vsk)
 217                        continue;
 218                list_move(&pkt->list, &freeme);
 219        }
 220        spin_unlock_bh(&vsock->send_pkt_list_lock);
 221
 222        list_for_each_entry_safe(pkt, n, &freeme, list) {
 223                if (pkt->reply)
 224                        cnt++;
 225                list_del(&pkt->list);
 226                virtio_transport_free_pkt(pkt);
 227        }
 228
 229        if (cnt) {
 230                struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
 231                int new_cnt;
 232
 233                new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
 234                if (new_cnt + cnt >= virtqueue_get_vring_size(rx_vq) &&
 235                    new_cnt < virtqueue_get_vring_size(rx_vq))
 236                        queue_work(virtio_vsock_workqueue, &vsock->rx_work);
 237        }
 238
 239        ret = 0;
 240
 241out_rcu:
 242        rcu_read_unlock();
 243        return ret;
 244}
 245
 246static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
 247{
 248        int buf_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
 249        struct virtio_vsock_pkt *pkt;
 250        struct scatterlist hdr, buf, *sgs[2];
 251        struct virtqueue *vq;
 252        int ret;
 253
 254        vq = vsock->vqs[VSOCK_VQ_RX];
 255
 256        do {
 257                pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
 258                if (!pkt)
 259                        break;
 260
 261                pkt->buf = kmalloc(buf_len, GFP_KERNEL);
 262                if (!pkt->buf) {
 263                        virtio_transport_free_pkt(pkt);
 264                        break;
 265                }
 266
 267                pkt->buf_len = buf_len;
 268                pkt->len = buf_len;
 269
 270                sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
 271                sgs[0] = &hdr;
 272
 273                sg_init_one(&buf, pkt->buf, buf_len);
 274                sgs[1] = &buf;
 275                ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL);
 276                if (ret) {
 277                        virtio_transport_free_pkt(pkt);
 278                        break;
 279                }
 280                vsock->rx_buf_nr++;
 281        } while (vq->num_free);
 282        if (vsock->rx_buf_nr > vsock->rx_buf_max_nr)
 283                vsock->rx_buf_max_nr = vsock->rx_buf_nr;
 284        virtqueue_kick(vq);
 285}
 286
 287static void virtio_transport_tx_work(struct work_struct *work)
 288{
 289        struct virtio_vsock *vsock =
 290                container_of(work, struct virtio_vsock, tx_work);
 291        struct virtqueue *vq;
 292        bool added = false;
 293
 294        vq = vsock->vqs[VSOCK_VQ_TX];
 295        mutex_lock(&vsock->tx_lock);
 296
 297        if (!vsock->tx_run)
 298                goto out;
 299
 300        do {
 301                struct virtio_vsock_pkt *pkt;
 302                unsigned int len;
 303
 304                virtqueue_disable_cb(vq);
 305                while ((pkt = virtqueue_get_buf(vq, &len)) != NULL) {
 306                        virtio_transport_free_pkt(pkt);
 307                        added = true;
 308                }
 309        } while (!virtqueue_enable_cb(vq));
 310
 311out:
 312        mutex_unlock(&vsock->tx_lock);
 313
 314        if (added)
 315                queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
 316}
 317
 318/* Is there space left for replies to rx packets? */
 319static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
 320{
 321        struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX];
 322        int val;
 323
 324        smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
 325        val = atomic_read(&vsock->queued_replies);
 326
 327        return val < virtqueue_get_vring_size(vq);
 328}
 329
 330/* event_lock must be held */
 331static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock,
 332                                       struct virtio_vsock_event *event)
 333{
 334        struct scatterlist sg;
 335        struct virtqueue *vq;
 336
 337        vq = vsock->vqs[VSOCK_VQ_EVENT];
 338
 339        sg_init_one(&sg, event, sizeof(*event));
 340
 341        return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL);
 342}
 343
 344/* event_lock must be held */
 345static void virtio_vsock_event_fill(struct virtio_vsock *vsock)
 346{
 347        size_t i;
 348
 349        for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) {
 350                struct virtio_vsock_event *event = &vsock->event_list[i];
 351
 352                virtio_vsock_event_fill_one(vsock, event);
 353        }
 354
 355        virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
 356}
 357
 358static void virtio_vsock_reset_sock(struct sock *sk)
 359{
 360        /* vmci_transport.c doesn't take sk_lock here either.  At least we're
 361         * under vsock_table_lock so the sock cannot disappear while we're
 362         * executing.
 363         */
 364
 365        sk->sk_state = TCP_CLOSE;
 366        sk->sk_err = ECONNRESET;
 367        sk_error_report(sk);
 368}
 369
 370static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock)
 371{
 372        struct virtio_device *vdev = vsock->vdev;
 373        __le64 guest_cid;
 374
 375        vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid),
 376                          &guest_cid, sizeof(guest_cid));
 377        vsock->guest_cid = le64_to_cpu(guest_cid);
 378}
 379
 380/* event_lock must be held */
 381static void virtio_vsock_event_handle(struct virtio_vsock *vsock,
 382                                      struct virtio_vsock_event *event)
 383{
 384        switch (le32_to_cpu(event->id)) {
 385        case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET:
 386                virtio_vsock_update_guest_cid(vsock);
 387                vsock_for_each_connected_socket(virtio_vsock_reset_sock);
 388                break;
 389        }
 390}
 391
 392static void virtio_transport_event_work(struct work_struct *work)
 393{
 394        struct virtio_vsock *vsock =
 395                container_of(work, struct virtio_vsock, event_work);
 396        struct virtqueue *vq;
 397
 398        vq = vsock->vqs[VSOCK_VQ_EVENT];
 399
 400        mutex_lock(&vsock->event_lock);
 401
 402        if (!vsock->event_run)
 403                goto out;
 404
 405        do {
 406                struct virtio_vsock_event *event;
 407                unsigned int len;
 408
 409                virtqueue_disable_cb(vq);
 410                while ((event = virtqueue_get_buf(vq, &len)) != NULL) {
 411                        if (len == sizeof(*event))
 412                                virtio_vsock_event_handle(vsock, event);
 413
 414                        virtio_vsock_event_fill_one(vsock, event);
 415                }
 416        } while (!virtqueue_enable_cb(vq));
 417
 418        virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
 419out:
 420        mutex_unlock(&vsock->event_lock);
 421}
 422
 423static void virtio_vsock_event_done(struct virtqueue *vq)
 424{
 425        struct virtio_vsock *vsock = vq->vdev->priv;
 426
 427        if (!vsock)
 428                return;
 429        queue_work(virtio_vsock_workqueue, &vsock->event_work);
 430}
 431
 432static void virtio_vsock_tx_done(struct virtqueue *vq)
 433{
 434        struct virtio_vsock *vsock = vq->vdev->priv;
 435
 436        if (!vsock)
 437                return;
 438        queue_work(virtio_vsock_workqueue, &vsock->tx_work);
 439}
 440
 441static void virtio_vsock_rx_done(struct virtqueue *vq)
 442{
 443        struct virtio_vsock *vsock = vq->vdev->priv;
 444
 445        if (!vsock)
 446                return;
 447        queue_work(virtio_vsock_workqueue, &vsock->rx_work);
 448}
 449
 450static bool virtio_transport_seqpacket_allow(u32 remote_cid);
 451
 452static struct virtio_transport virtio_transport = {
 453        .transport = {
 454                .module                   = THIS_MODULE,
 455
 456                .get_local_cid            = virtio_transport_get_local_cid,
 457
 458                .init                     = virtio_transport_do_socket_init,
 459                .destruct                 = virtio_transport_destruct,
 460                .release                  = virtio_transport_release,
 461                .connect                  = virtio_transport_connect,
 462                .shutdown                 = virtio_transport_shutdown,
 463                .cancel_pkt               = virtio_transport_cancel_pkt,
 464
 465                .dgram_bind               = virtio_transport_dgram_bind,
 466                .dgram_dequeue            = virtio_transport_dgram_dequeue,
 467                .dgram_enqueue            = virtio_transport_dgram_enqueue,
 468                .dgram_allow              = virtio_transport_dgram_allow,
 469
 470                .stream_dequeue           = virtio_transport_stream_dequeue,
 471                .stream_enqueue           = virtio_transport_stream_enqueue,
 472                .stream_has_data          = virtio_transport_stream_has_data,
 473                .stream_has_space         = virtio_transport_stream_has_space,
 474                .stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
 475                .stream_is_active         = virtio_transport_stream_is_active,
 476                .stream_allow             = virtio_transport_stream_allow,
 477
 478                .seqpacket_dequeue        = virtio_transport_seqpacket_dequeue,
 479                .seqpacket_enqueue        = virtio_transport_seqpacket_enqueue,
 480                .seqpacket_allow          = virtio_transport_seqpacket_allow,
 481                .seqpacket_has_data       = virtio_transport_seqpacket_has_data,
 482
 483                .notify_poll_in           = virtio_transport_notify_poll_in,
 484                .notify_poll_out          = virtio_transport_notify_poll_out,
 485                .notify_recv_init         = virtio_transport_notify_recv_init,
 486                .notify_recv_pre_block    = virtio_transport_notify_recv_pre_block,
 487                .notify_recv_pre_dequeue  = virtio_transport_notify_recv_pre_dequeue,
 488                .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
 489                .notify_send_init         = virtio_transport_notify_send_init,
 490                .notify_send_pre_block    = virtio_transport_notify_send_pre_block,
 491                .notify_send_pre_enqueue  = virtio_transport_notify_send_pre_enqueue,
 492                .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
 493                .notify_buffer_size       = virtio_transport_notify_buffer_size,
 494        },
 495
 496        .send_pkt = virtio_transport_send_pkt,
 497};
 498
 499static bool virtio_transport_seqpacket_allow(u32 remote_cid)
 500{
 501        struct virtio_vsock *vsock;
 502        bool seqpacket_allow;
 503
 504        seqpacket_allow = false;
 505        rcu_read_lock();
 506        vsock = rcu_dereference(the_virtio_vsock);
 507        if (vsock)
 508                seqpacket_allow = vsock->seqpacket_allow;
 509        rcu_read_unlock();
 510
 511        return seqpacket_allow;
 512}
 513
 514static void virtio_transport_rx_work(struct work_struct *work)
 515{
 516        struct virtio_vsock *vsock =
 517                container_of(work, struct virtio_vsock, rx_work);
 518        struct virtqueue *vq;
 519
 520        vq = vsock->vqs[VSOCK_VQ_RX];
 521
 522        mutex_lock(&vsock->rx_lock);
 523
 524        if (!vsock->rx_run)
 525                goto out;
 526
 527        do {
 528                virtqueue_disable_cb(vq);
 529                for (;;) {
 530                        struct virtio_vsock_pkt *pkt;
 531                        unsigned int len;
 532
 533                        if (!virtio_transport_more_replies(vsock)) {
 534                                /* Stop rx until the device processes already
 535                                 * pending replies.  Leave rx virtqueue
 536                                 * callbacks disabled.
 537                                 */
 538                                goto out;
 539                        }
 540
 541                        pkt = virtqueue_get_buf(vq, &len);
 542                        if (!pkt) {
 543                                break;
 544                        }
 545
 546                        vsock->rx_buf_nr--;
 547
 548                        /* Drop short/long packets */
 549                        if (unlikely(len < sizeof(pkt->hdr) ||
 550                                     len > sizeof(pkt->hdr) + pkt->len)) {
 551                                virtio_transport_free_pkt(pkt);
 552                                continue;
 553                        }
 554
 555                        pkt->len = len - sizeof(pkt->hdr);
 556                        virtio_transport_deliver_tap_pkt(pkt);
 557                        virtio_transport_recv_pkt(&virtio_transport, pkt);
 558                }
 559        } while (!virtqueue_enable_cb(vq));
 560
 561out:
 562        if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
 563                virtio_vsock_rx_fill(vsock);
 564        mutex_unlock(&vsock->rx_lock);
 565}
 566
 567static int virtio_vsock_probe(struct virtio_device *vdev)
 568{
 569        vq_callback_t *callbacks[] = {
 570                virtio_vsock_rx_done,
 571                virtio_vsock_tx_done,
 572                virtio_vsock_event_done,
 573        };
 574        static const char * const names[] = {
 575                "rx",
 576                "tx",
 577                "event",
 578        };
 579        struct virtio_vsock *vsock = NULL;
 580        int ret;
 581
 582        ret = mutex_lock_interruptible(&the_virtio_vsock_mutex);
 583        if (ret)
 584                return ret;
 585
 586        /* Only one virtio-vsock device per guest is supported */
 587        if (rcu_dereference_protected(the_virtio_vsock,
 588                                lockdep_is_held(&the_virtio_vsock_mutex))) {
 589                ret = -EBUSY;
 590                goto out;
 591        }
 592
 593        vsock = kzalloc(sizeof(*vsock), GFP_KERNEL);
 594        if (!vsock) {
 595                ret = -ENOMEM;
 596                goto out;
 597        }
 598
 599        vsock->vdev = vdev;
 600
 601        ret = virtio_find_vqs(vsock->vdev, VSOCK_VQ_MAX,
 602                              vsock->vqs, callbacks, names,
 603                              NULL);
 604        if (ret < 0)
 605                goto out;
 606
 607        virtio_vsock_update_guest_cid(vsock);
 608
 609        vsock->rx_buf_nr = 0;
 610        vsock->rx_buf_max_nr = 0;
 611        atomic_set(&vsock->queued_replies, 0);
 612
 613        mutex_init(&vsock->tx_lock);
 614        mutex_init(&vsock->rx_lock);
 615        mutex_init(&vsock->event_lock);
 616        spin_lock_init(&vsock->send_pkt_list_lock);
 617        INIT_LIST_HEAD(&vsock->send_pkt_list);
 618        INIT_WORK(&vsock->rx_work, virtio_transport_rx_work);
 619        INIT_WORK(&vsock->tx_work, virtio_transport_tx_work);
 620        INIT_WORK(&vsock->event_work, virtio_transport_event_work);
 621        INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work);
 622
 623        mutex_lock(&vsock->tx_lock);
 624        vsock->tx_run = true;
 625        mutex_unlock(&vsock->tx_lock);
 626
 627        mutex_lock(&vsock->rx_lock);
 628        virtio_vsock_rx_fill(vsock);
 629        vsock->rx_run = true;
 630        mutex_unlock(&vsock->rx_lock);
 631
 632        mutex_lock(&vsock->event_lock);
 633        virtio_vsock_event_fill(vsock);
 634        vsock->event_run = true;
 635        mutex_unlock(&vsock->event_lock);
 636
 637        if (virtio_has_feature(vdev, VIRTIO_VSOCK_F_SEQPACKET))
 638                vsock->seqpacket_allow = true;
 639
 640        vdev->priv = vsock;
 641        rcu_assign_pointer(the_virtio_vsock, vsock);
 642
 643        mutex_unlock(&the_virtio_vsock_mutex);
 644
 645        return 0;
 646
 647out:
 648        kfree(vsock);
 649        mutex_unlock(&the_virtio_vsock_mutex);
 650        return ret;
 651}
 652
 653static void virtio_vsock_remove(struct virtio_device *vdev)
 654{
 655        struct virtio_vsock *vsock = vdev->priv;
 656        struct virtio_vsock_pkt *pkt;
 657
 658        mutex_lock(&the_virtio_vsock_mutex);
 659
 660        vdev->priv = NULL;
 661        rcu_assign_pointer(the_virtio_vsock, NULL);
 662        synchronize_rcu();
 663
 664        /* Reset all connected sockets when the device disappear */
 665        vsock_for_each_connected_socket(virtio_vsock_reset_sock);
 666
 667        /* Stop all work handlers to make sure no one is accessing the device,
 668         * so we can safely call vdev->config->reset().
 669         */
 670        mutex_lock(&vsock->rx_lock);
 671        vsock->rx_run = false;
 672        mutex_unlock(&vsock->rx_lock);
 673
 674        mutex_lock(&vsock->tx_lock);
 675        vsock->tx_run = false;
 676        mutex_unlock(&vsock->tx_lock);
 677
 678        mutex_lock(&vsock->event_lock);
 679        vsock->event_run = false;
 680        mutex_unlock(&vsock->event_lock);
 681
 682        /* Flush all device writes and interrupts, device will not use any
 683         * more buffers.
 684         */
 685        vdev->config->reset(vdev);
 686
 687        mutex_lock(&vsock->rx_lock);
 688        while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX])))
 689                virtio_transport_free_pkt(pkt);
 690        mutex_unlock(&vsock->rx_lock);
 691
 692        mutex_lock(&vsock->tx_lock);
 693        while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX])))
 694                virtio_transport_free_pkt(pkt);
 695        mutex_unlock(&vsock->tx_lock);
 696
 697        spin_lock_bh(&vsock->send_pkt_list_lock);
 698        while (!list_empty(&vsock->send_pkt_list)) {
 699                pkt = list_first_entry(&vsock->send_pkt_list,
 700                                       struct virtio_vsock_pkt, list);
 701                list_del(&pkt->list);
 702                virtio_transport_free_pkt(pkt);
 703        }
 704        spin_unlock_bh(&vsock->send_pkt_list_lock);
 705
 706        /* Delete virtqueues and flush outstanding callbacks if any */
 707        vdev->config->del_vqs(vdev);
 708
 709        /* Other works can be queued before 'config->del_vqs()', so we flush
 710         * all works before to free the vsock object to avoid use after free.
 711         */
 712        flush_work(&vsock->rx_work);
 713        flush_work(&vsock->tx_work);
 714        flush_work(&vsock->event_work);
 715        flush_work(&vsock->send_pkt_work);
 716
 717        mutex_unlock(&the_virtio_vsock_mutex);
 718
 719        kfree(vsock);
 720}
 721
 722static struct virtio_device_id id_table[] = {
 723        { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID },
 724        { 0 },
 725};
 726
 727static unsigned int features[] = {
 728        VIRTIO_VSOCK_F_SEQPACKET
 729};
 730
 731static struct virtio_driver virtio_vsock_driver = {
 732        .feature_table = features,
 733        .feature_table_size = ARRAY_SIZE(features),
 734        .driver.name = KBUILD_MODNAME,
 735        .driver.owner = THIS_MODULE,
 736        .id_table = id_table,
 737        .probe = virtio_vsock_probe,
 738        .remove = virtio_vsock_remove,
 739};
 740
 741static int __init virtio_vsock_init(void)
 742{
 743        int ret;
 744
 745        virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0);
 746        if (!virtio_vsock_workqueue)
 747                return -ENOMEM;
 748
 749        ret = vsock_core_register(&virtio_transport.transport,
 750                                  VSOCK_TRANSPORT_F_G2H);
 751        if (ret)
 752                goto out_wq;
 753
 754        ret = register_virtio_driver(&virtio_vsock_driver);
 755        if (ret)
 756                goto out_vci;
 757
 758        return 0;
 759
 760out_vci:
 761        vsock_core_unregister(&virtio_transport.transport);
 762out_wq:
 763        destroy_workqueue(virtio_vsock_workqueue);
 764        return ret;
 765}
 766
 767static void __exit virtio_vsock_exit(void)
 768{
 769        unregister_virtio_driver(&virtio_vsock_driver);
 770        vsock_core_unregister(&virtio_transport.transport);
 771        destroy_workqueue(virtio_vsock_workqueue);
 772}
 773
 774module_init(virtio_vsock_init);
 775module_exit(virtio_vsock_exit);
 776MODULE_LICENSE("GPL v2");
 777MODULE_AUTHOR("Asias He");
 778MODULE_DESCRIPTION("virtio transport for vsock");
 779MODULE_DEVICE_TABLE(virtio, id_table);
 780