linux/net/vmw_vsock/virtio_transport.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * virtio transport for vsock
   4 *
   5 * Copyright (C) 2013-2015 Red Hat, Inc.
   6 * Author: Asias He <asias@redhat.com>
   7 *         Stefan Hajnoczi <stefanha@redhat.com>
   8 *
   9 * Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s
  10 * early virtio-vsock proof-of-concept bits.
  11 */
  12#include <linux/spinlock.h>
  13#include <linux/module.h>
  14#include <linux/list.h>
  15#include <linux/atomic.h>
  16#include <linux/virtio.h>
  17#include <linux/virtio_ids.h>
  18#include <linux/virtio_config.h>
  19#include <linux/virtio_vsock.h>
  20#include <net/sock.h>
  21#include <linux/mutex.h>
  22#include <net/af_vsock.h>
  23
  24static struct workqueue_struct *virtio_vsock_workqueue;
  25static struct virtio_vsock *the_virtio_vsock;
  26static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */
  27
  28struct virtio_vsock {
  29        struct virtio_device *vdev;
  30        struct virtqueue *vqs[VSOCK_VQ_MAX];
  31
  32        /* Virtqueue processing is deferred to a workqueue */
  33        struct work_struct tx_work;
  34        struct work_struct rx_work;
  35        struct work_struct event_work;
  36
  37        /* The following fields are protected by tx_lock.  vqs[VSOCK_VQ_TX]
  38         * must be accessed with tx_lock held.
  39         */
  40        struct mutex tx_lock;
  41        bool tx_run;
  42
  43        struct work_struct send_pkt_work;
  44        spinlock_t send_pkt_list_lock;
  45        struct list_head send_pkt_list;
  46
  47        struct work_struct loopback_work;
  48        spinlock_t loopback_list_lock; /* protects loopback_list */
  49        struct list_head loopback_list;
  50
  51        atomic_t queued_replies;
  52
  53        /* The following fields are protected by rx_lock.  vqs[VSOCK_VQ_RX]
  54         * must be accessed with rx_lock held.
  55         */
  56        struct mutex rx_lock;
  57        bool rx_run;
  58        int rx_buf_nr;
  59        int rx_buf_max_nr;
  60
  61        /* The following fields are protected by event_lock.
  62         * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held.
  63         */
  64        struct mutex event_lock;
  65        bool event_run;
  66        struct virtio_vsock_event event_list[8];
  67
  68        u32 guest_cid;
  69};
  70
  71static u32 virtio_transport_get_local_cid(void)
  72{
  73        struct virtio_vsock *vsock;
  74        u32 ret;
  75
  76        rcu_read_lock();
  77        vsock = rcu_dereference(the_virtio_vsock);
  78        if (!vsock) {
  79                ret = VMADDR_CID_ANY;
  80                goto out_rcu;
  81        }
  82
  83        ret = vsock->guest_cid;
  84out_rcu:
  85        rcu_read_unlock();
  86        return ret;
  87}
  88
  89static void virtio_transport_loopback_work(struct work_struct *work)
  90{
  91        struct virtio_vsock *vsock =
  92                container_of(work, struct virtio_vsock, loopback_work);
  93        LIST_HEAD(pkts);
  94
  95        spin_lock_bh(&vsock->loopback_list_lock);
  96        list_splice_init(&vsock->loopback_list, &pkts);
  97        spin_unlock_bh(&vsock->loopback_list_lock);
  98
  99        mutex_lock(&vsock->rx_lock);
 100
 101        if (!vsock->rx_run)
 102                goto out;
 103
 104        while (!list_empty(&pkts)) {
 105                struct virtio_vsock_pkt *pkt;
 106
 107                pkt = list_first_entry(&pkts, struct virtio_vsock_pkt, list);
 108                list_del_init(&pkt->list);
 109
 110                virtio_transport_recv_pkt(pkt);
 111        }
 112out:
 113        mutex_unlock(&vsock->rx_lock);
 114}
 115
 116static int virtio_transport_send_pkt_loopback(struct virtio_vsock *vsock,
 117                                              struct virtio_vsock_pkt *pkt)
 118{
 119        int len = pkt->len;
 120
 121        spin_lock_bh(&vsock->loopback_list_lock);
 122        list_add_tail(&pkt->list, &vsock->loopback_list);
 123        spin_unlock_bh(&vsock->loopback_list_lock);
 124
 125        queue_work(virtio_vsock_workqueue, &vsock->loopback_work);
 126
 127        return len;
 128}
 129
 130static void
 131virtio_transport_send_pkt_work(struct work_struct *work)
 132{
 133        struct virtio_vsock *vsock =
 134                container_of(work, struct virtio_vsock, send_pkt_work);
 135        struct virtqueue *vq;
 136        bool added = false;
 137        bool restart_rx = false;
 138
 139        mutex_lock(&vsock->tx_lock);
 140
 141        if (!vsock->tx_run)
 142                goto out;
 143
 144        vq = vsock->vqs[VSOCK_VQ_TX];
 145
 146        for (;;) {
 147                struct virtio_vsock_pkt *pkt;
 148                struct scatterlist hdr, buf, *sgs[2];
 149                int ret, in_sg = 0, out_sg = 0;
 150                bool reply;
 151
 152                spin_lock_bh(&vsock->send_pkt_list_lock);
 153                if (list_empty(&vsock->send_pkt_list)) {
 154                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 155                        break;
 156                }
 157
 158                pkt = list_first_entry(&vsock->send_pkt_list,
 159                                       struct virtio_vsock_pkt, list);
 160                list_del_init(&pkt->list);
 161                spin_unlock_bh(&vsock->send_pkt_list_lock);
 162
 163                virtio_transport_deliver_tap_pkt(pkt);
 164
 165                reply = pkt->reply;
 166
 167                sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
 168                sgs[out_sg++] = &hdr;
 169                if (pkt->buf) {
 170                        sg_init_one(&buf, pkt->buf, pkt->len);
 171                        sgs[out_sg++] = &buf;
 172                }
 173
 174                ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, pkt, GFP_KERNEL);
 175                /* Usually this means that there is no more space available in
 176                 * the vq
 177                 */
 178                if (ret < 0) {
 179                        spin_lock_bh(&vsock->send_pkt_list_lock);
 180                        list_add(&pkt->list, &vsock->send_pkt_list);
 181                        spin_unlock_bh(&vsock->send_pkt_list_lock);
 182                        break;
 183                }
 184
 185                if (reply) {
 186                        struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
 187                        int val;
 188
 189                        val = atomic_dec_return(&vsock->queued_replies);
 190
 191                        /* Do we now have resources to resume rx processing? */
 192                        if (val + 1 == virtqueue_get_vring_size(rx_vq))
 193                                restart_rx = true;
 194                }
 195
 196                added = true;
 197        }
 198
 199        if (added)
 200                virtqueue_kick(vq);
 201
 202out:
 203        mutex_unlock(&vsock->tx_lock);
 204
 205        if (restart_rx)
 206                queue_work(virtio_vsock_workqueue, &vsock->rx_work);
 207}
 208
 209static int
 210virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
 211{
 212        struct virtio_vsock *vsock;
 213        int len = pkt->len;
 214
 215        rcu_read_lock();
 216        vsock = rcu_dereference(the_virtio_vsock);
 217        if (!vsock) {
 218                virtio_transport_free_pkt(pkt);
 219                len = -ENODEV;
 220                goto out_rcu;
 221        }
 222
 223        if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) {
 224                len = virtio_transport_send_pkt_loopback(vsock, pkt);
 225                goto out_rcu;
 226        }
 227
 228        if (pkt->reply)
 229                atomic_inc(&vsock->queued_replies);
 230
 231        spin_lock_bh(&vsock->send_pkt_list_lock);
 232        list_add_tail(&pkt->list, &vsock->send_pkt_list);
 233        spin_unlock_bh(&vsock->send_pkt_list_lock);
 234
 235        queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
 236
 237out_rcu:
 238        rcu_read_unlock();
 239        return len;
 240}
 241
 242static int
 243virtio_transport_cancel_pkt(struct vsock_sock *vsk)
 244{
 245        struct virtio_vsock *vsock;
 246        struct virtio_vsock_pkt *pkt, *n;
 247        int cnt = 0, ret;
 248        LIST_HEAD(freeme);
 249
 250        rcu_read_lock();
 251        vsock = rcu_dereference(the_virtio_vsock);
 252        if (!vsock) {
 253                ret = -ENODEV;
 254                goto out_rcu;
 255        }
 256
 257        spin_lock_bh(&vsock->send_pkt_list_lock);
 258        list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
 259                if (pkt->vsk != vsk)
 260                        continue;
 261                list_move(&pkt->list, &freeme);
 262        }
 263        spin_unlock_bh(&vsock->send_pkt_list_lock);
 264
 265        list_for_each_entry_safe(pkt, n, &freeme, list) {
 266                if (pkt->reply)
 267                        cnt++;
 268                list_del(&pkt->list);
 269                virtio_transport_free_pkt(pkt);
 270        }
 271
 272        if (cnt) {
 273                struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
 274                int new_cnt;
 275
 276                new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
 277                if (new_cnt + cnt >= virtqueue_get_vring_size(rx_vq) &&
 278                    new_cnt < virtqueue_get_vring_size(rx_vq))
 279                        queue_work(virtio_vsock_workqueue, &vsock->rx_work);
 280        }
 281
 282        ret = 0;
 283
 284out_rcu:
 285        rcu_read_unlock();
 286        return ret;
 287}
 288
 289static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
 290{
 291        int buf_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
 292        struct virtio_vsock_pkt *pkt;
 293        struct scatterlist hdr, buf, *sgs[2];
 294        struct virtqueue *vq;
 295        int ret;
 296
 297        vq = vsock->vqs[VSOCK_VQ_RX];
 298
 299        do {
 300                pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
 301                if (!pkt)
 302                        break;
 303
 304                pkt->buf = kmalloc(buf_len, GFP_KERNEL);
 305                if (!pkt->buf) {
 306                        virtio_transport_free_pkt(pkt);
 307                        break;
 308                }
 309
 310                pkt->buf_len = buf_len;
 311                pkt->len = buf_len;
 312
 313                sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
 314                sgs[0] = &hdr;
 315
 316                sg_init_one(&buf, pkt->buf, buf_len);
 317                sgs[1] = &buf;
 318                ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL);
 319                if (ret) {
 320                        virtio_transport_free_pkt(pkt);
 321                        break;
 322                }
 323                vsock->rx_buf_nr++;
 324        } while (vq->num_free);
 325        if (vsock->rx_buf_nr > vsock->rx_buf_max_nr)
 326                vsock->rx_buf_max_nr = vsock->rx_buf_nr;
 327        virtqueue_kick(vq);
 328}
 329
 330static void virtio_transport_tx_work(struct work_struct *work)
 331{
 332        struct virtio_vsock *vsock =
 333                container_of(work, struct virtio_vsock, tx_work);
 334        struct virtqueue *vq;
 335        bool added = false;
 336
 337        vq = vsock->vqs[VSOCK_VQ_TX];
 338        mutex_lock(&vsock->tx_lock);
 339
 340        if (!vsock->tx_run)
 341                goto out;
 342
 343        do {
 344                struct virtio_vsock_pkt *pkt;
 345                unsigned int len;
 346
 347                virtqueue_disable_cb(vq);
 348                while ((pkt = virtqueue_get_buf(vq, &len)) != NULL) {
 349                        virtio_transport_free_pkt(pkt);
 350                        added = true;
 351                }
 352        } while (!virtqueue_enable_cb(vq));
 353
 354out:
 355        mutex_unlock(&vsock->tx_lock);
 356
 357        if (added)
 358                queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
 359}
 360
 361/* Is there space left for replies to rx packets? */
 362static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
 363{
 364        struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX];
 365        int val;
 366
 367        smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
 368        val = atomic_read(&vsock->queued_replies);
 369
 370        return val < virtqueue_get_vring_size(vq);
 371}
 372
 373static void virtio_transport_rx_work(struct work_struct *work)
 374{
 375        struct virtio_vsock *vsock =
 376                container_of(work, struct virtio_vsock, rx_work);
 377        struct virtqueue *vq;
 378
 379        vq = vsock->vqs[VSOCK_VQ_RX];
 380
 381        mutex_lock(&vsock->rx_lock);
 382
 383        if (!vsock->rx_run)
 384                goto out;
 385
 386        do {
 387                virtqueue_disable_cb(vq);
 388                for (;;) {
 389                        struct virtio_vsock_pkt *pkt;
 390                        unsigned int len;
 391
 392                        if (!virtio_transport_more_replies(vsock)) {
 393                                /* Stop rx until the device processes already
 394                                 * pending replies.  Leave rx virtqueue
 395                                 * callbacks disabled.
 396                                 */
 397                                goto out;
 398                        }
 399
 400                        pkt = virtqueue_get_buf(vq, &len);
 401                        if (!pkt) {
 402                                break;
 403                        }
 404
 405                        vsock->rx_buf_nr--;
 406
 407                        /* Drop short/long packets */
 408                        if (unlikely(len < sizeof(pkt->hdr) ||
 409                                     len > sizeof(pkt->hdr) + pkt->len)) {
 410                                virtio_transport_free_pkt(pkt);
 411                                continue;
 412                        }
 413
 414                        pkt->len = len - sizeof(pkt->hdr);
 415                        virtio_transport_deliver_tap_pkt(pkt);
 416                        virtio_transport_recv_pkt(pkt);
 417                }
 418        } while (!virtqueue_enable_cb(vq));
 419
 420out:
 421        if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
 422                virtio_vsock_rx_fill(vsock);
 423        mutex_unlock(&vsock->rx_lock);
 424}
 425
 426/* event_lock must be held */
 427static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock,
 428                                       struct virtio_vsock_event *event)
 429{
 430        struct scatterlist sg;
 431        struct virtqueue *vq;
 432
 433        vq = vsock->vqs[VSOCK_VQ_EVENT];
 434
 435        sg_init_one(&sg, event, sizeof(*event));
 436
 437        return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL);
 438}
 439
 440/* event_lock must be held */
 441static void virtio_vsock_event_fill(struct virtio_vsock *vsock)
 442{
 443        size_t i;
 444
 445        for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) {
 446                struct virtio_vsock_event *event = &vsock->event_list[i];
 447
 448                virtio_vsock_event_fill_one(vsock, event);
 449        }
 450
 451        virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
 452}
 453
 454static void virtio_vsock_reset_sock(struct sock *sk)
 455{
 456        lock_sock(sk);
 457        sk->sk_state = TCP_CLOSE;
 458        sk->sk_err = ECONNRESET;
 459        sk->sk_error_report(sk);
 460        release_sock(sk);
 461}
 462
 463static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock)
 464{
 465        struct virtio_device *vdev = vsock->vdev;
 466        __le64 guest_cid;
 467
 468        vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid),
 469                          &guest_cid, sizeof(guest_cid));
 470        vsock->guest_cid = le64_to_cpu(guest_cid);
 471}
 472
 473/* event_lock must be held */
 474static void virtio_vsock_event_handle(struct virtio_vsock *vsock,
 475                                      struct virtio_vsock_event *event)
 476{
 477        switch (le32_to_cpu(event->id)) {
 478        case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET:
 479                virtio_vsock_update_guest_cid(vsock);
 480                vsock_for_each_connected_socket(virtio_vsock_reset_sock);
 481                break;
 482        }
 483}
 484
 485static void virtio_transport_event_work(struct work_struct *work)
 486{
 487        struct virtio_vsock *vsock =
 488                container_of(work, struct virtio_vsock, event_work);
 489        struct virtqueue *vq;
 490
 491        vq = vsock->vqs[VSOCK_VQ_EVENT];
 492
 493        mutex_lock(&vsock->event_lock);
 494
 495        if (!vsock->event_run)
 496                goto out;
 497
 498        do {
 499                struct virtio_vsock_event *event;
 500                unsigned int len;
 501
 502                virtqueue_disable_cb(vq);
 503                while ((event = virtqueue_get_buf(vq, &len)) != NULL) {
 504                        if (len == sizeof(*event))
 505                                virtio_vsock_event_handle(vsock, event);
 506
 507                        virtio_vsock_event_fill_one(vsock, event);
 508                }
 509        } while (!virtqueue_enable_cb(vq));
 510
 511        virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
 512out:
 513        mutex_unlock(&vsock->event_lock);
 514}
 515
 516static void virtio_vsock_event_done(struct virtqueue *vq)
 517{
 518        struct virtio_vsock *vsock = vq->vdev->priv;
 519
 520        if (!vsock)
 521                return;
 522        queue_work(virtio_vsock_workqueue, &vsock->event_work);
 523}
 524
 525static void virtio_vsock_tx_done(struct virtqueue *vq)
 526{
 527        struct virtio_vsock *vsock = vq->vdev->priv;
 528
 529        if (!vsock)
 530                return;
 531        queue_work(virtio_vsock_workqueue, &vsock->tx_work);
 532}
 533
 534static void virtio_vsock_rx_done(struct virtqueue *vq)
 535{
 536        struct virtio_vsock *vsock = vq->vdev->priv;
 537
 538        if (!vsock)
 539                return;
 540        queue_work(virtio_vsock_workqueue, &vsock->rx_work);
 541}
 542
 543static struct virtio_transport virtio_transport = {
 544        .transport = {
 545                .get_local_cid            = virtio_transport_get_local_cid,
 546
 547                .init                     = virtio_transport_do_socket_init,
 548                .destruct                 = virtio_transport_destruct,
 549                .release                  = virtio_transport_release,
 550                .connect                  = virtio_transport_connect,
 551                .shutdown                 = virtio_transport_shutdown,
 552                .cancel_pkt               = virtio_transport_cancel_pkt,
 553
 554                .dgram_bind               = virtio_transport_dgram_bind,
 555                .dgram_dequeue            = virtio_transport_dgram_dequeue,
 556                .dgram_enqueue            = virtio_transport_dgram_enqueue,
 557                .dgram_allow              = virtio_transport_dgram_allow,
 558
 559                .stream_dequeue           = virtio_transport_stream_dequeue,
 560                .stream_enqueue           = virtio_transport_stream_enqueue,
 561                .stream_has_data          = virtio_transport_stream_has_data,
 562                .stream_has_space         = virtio_transport_stream_has_space,
 563                .stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
 564                .stream_is_active         = virtio_transport_stream_is_active,
 565                .stream_allow             = virtio_transport_stream_allow,
 566
 567                .notify_poll_in           = virtio_transport_notify_poll_in,
 568                .notify_poll_out          = virtio_transport_notify_poll_out,
 569                .notify_recv_init         = virtio_transport_notify_recv_init,
 570                .notify_recv_pre_block    = virtio_transport_notify_recv_pre_block,
 571                .notify_recv_pre_dequeue  = virtio_transport_notify_recv_pre_dequeue,
 572                .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
 573                .notify_send_init         = virtio_transport_notify_send_init,
 574                .notify_send_pre_block    = virtio_transport_notify_send_pre_block,
 575                .notify_send_pre_enqueue  = virtio_transport_notify_send_pre_enqueue,
 576                .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
 577
 578                .set_buffer_size          = virtio_transport_set_buffer_size,
 579                .set_min_buffer_size      = virtio_transport_set_min_buffer_size,
 580                .set_max_buffer_size      = virtio_transport_set_max_buffer_size,
 581                .get_buffer_size          = virtio_transport_get_buffer_size,
 582                .get_min_buffer_size      = virtio_transport_get_min_buffer_size,
 583                .get_max_buffer_size      = virtio_transport_get_max_buffer_size,
 584        },
 585
 586        .send_pkt = virtio_transport_send_pkt,
 587};
 588
 589static int virtio_vsock_probe(struct virtio_device *vdev)
 590{
 591        vq_callback_t *callbacks[] = {
 592                virtio_vsock_rx_done,
 593                virtio_vsock_tx_done,
 594                virtio_vsock_event_done,
 595        };
 596        static const char * const names[] = {
 597                "rx",
 598                "tx",
 599                "event",
 600        };
 601        struct virtio_vsock *vsock = NULL;
 602        int ret;
 603
 604        ret = mutex_lock_interruptible(&the_virtio_vsock_mutex);
 605        if (ret)
 606                return ret;
 607
 608        /* Only one virtio-vsock device per guest is supported */
 609        if (rcu_dereference_protected(the_virtio_vsock,
 610                                lockdep_is_held(&the_virtio_vsock_mutex))) {
 611                ret = -EBUSY;
 612                goto out;
 613        }
 614
 615        vsock = kzalloc(sizeof(*vsock), GFP_KERNEL);
 616        if (!vsock) {
 617                ret = -ENOMEM;
 618                goto out;
 619        }
 620
 621        vsock->vdev = vdev;
 622
 623        ret = virtio_find_vqs(vsock->vdev, VSOCK_VQ_MAX,
 624                              vsock->vqs, callbacks, names,
 625                              NULL);
 626        if (ret < 0)
 627                goto out;
 628
 629        virtio_vsock_update_guest_cid(vsock);
 630
 631        vsock->rx_buf_nr = 0;
 632        vsock->rx_buf_max_nr = 0;
 633        atomic_set(&vsock->queued_replies, 0);
 634
 635        mutex_init(&vsock->tx_lock);
 636        mutex_init(&vsock->rx_lock);
 637        mutex_init(&vsock->event_lock);
 638        spin_lock_init(&vsock->send_pkt_list_lock);
 639        INIT_LIST_HEAD(&vsock->send_pkt_list);
 640        spin_lock_init(&vsock->loopback_list_lock);
 641        INIT_LIST_HEAD(&vsock->loopback_list);
 642        INIT_WORK(&vsock->rx_work, virtio_transport_rx_work);
 643        INIT_WORK(&vsock->tx_work, virtio_transport_tx_work);
 644        INIT_WORK(&vsock->event_work, virtio_transport_event_work);
 645        INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work);
 646        INIT_WORK(&vsock->loopback_work, virtio_transport_loopback_work);
 647
 648        mutex_lock(&vsock->tx_lock);
 649        vsock->tx_run = true;
 650        mutex_unlock(&vsock->tx_lock);
 651
 652        mutex_lock(&vsock->rx_lock);
 653        virtio_vsock_rx_fill(vsock);
 654        vsock->rx_run = true;
 655        mutex_unlock(&vsock->rx_lock);
 656
 657        mutex_lock(&vsock->event_lock);
 658        virtio_vsock_event_fill(vsock);
 659        vsock->event_run = true;
 660        mutex_unlock(&vsock->event_lock);
 661
 662        vdev->priv = vsock;
 663        rcu_assign_pointer(the_virtio_vsock, vsock);
 664
 665        mutex_unlock(&the_virtio_vsock_mutex);
 666        return 0;
 667
 668out:
 669        kfree(vsock);
 670        mutex_unlock(&the_virtio_vsock_mutex);
 671        return ret;
 672}
 673
 674static void virtio_vsock_remove(struct virtio_device *vdev)
 675{
 676        struct virtio_vsock *vsock = vdev->priv;
 677        struct virtio_vsock_pkt *pkt;
 678
 679        mutex_lock(&the_virtio_vsock_mutex);
 680
 681        vdev->priv = NULL;
 682        rcu_assign_pointer(the_virtio_vsock, NULL);
 683        synchronize_rcu();
 684
 685        /* Reset all connected sockets when the device disappear */
 686        vsock_for_each_connected_socket(virtio_vsock_reset_sock);
 687
 688        /* Stop all work handlers to make sure no one is accessing the device,
 689         * so we can safely call vdev->config->reset().
 690         */
 691        mutex_lock(&vsock->rx_lock);
 692        vsock->rx_run = false;
 693        mutex_unlock(&vsock->rx_lock);
 694
 695        mutex_lock(&vsock->tx_lock);
 696        vsock->tx_run = false;
 697        mutex_unlock(&vsock->tx_lock);
 698
 699        mutex_lock(&vsock->event_lock);
 700        vsock->event_run = false;
 701        mutex_unlock(&vsock->event_lock);
 702
 703        /* Flush all device writes and interrupts, device will not use any
 704         * more buffers.
 705         */
 706        vdev->config->reset(vdev);
 707
 708        mutex_lock(&vsock->rx_lock);
 709        while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX])))
 710                virtio_transport_free_pkt(pkt);
 711        mutex_unlock(&vsock->rx_lock);
 712
 713        mutex_lock(&vsock->tx_lock);
 714        while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX])))
 715                virtio_transport_free_pkt(pkt);
 716        mutex_unlock(&vsock->tx_lock);
 717
 718        spin_lock_bh(&vsock->send_pkt_list_lock);
 719        while (!list_empty(&vsock->send_pkt_list)) {
 720                pkt = list_first_entry(&vsock->send_pkt_list,
 721                                       struct virtio_vsock_pkt, list);
 722                list_del(&pkt->list);
 723                virtio_transport_free_pkt(pkt);
 724        }
 725        spin_unlock_bh(&vsock->send_pkt_list_lock);
 726
 727        spin_lock_bh(&vsock->loopback_list_lock);
 728        while (!list_empty(&vsock->loopback_list)) {
 729                pkt = list_first_entry(&vsock->loopback_list,
 730                                       struct virtio_vsock_pkt, list);
 731                list_del(&pkt->list);
 732                virtio_transport_free_pkt(pkt);
 733        }
 734        spin_unlock_bh(&vsock->loopback_list_lock);
 735
 736        /* Delete virtqueues and flush outstanding callbacks if any */
 737        vdev->config->del_vqs(vdev);
 738
 739        /* Other works can be queued before 'config->del_vqs()', so we flush
 740         * all works before to free the vsock object to avoid use after free.
 741         */
 742        flush_work(&vsock->loopback_work);
 743        flush_work(&vsock->rx_work);
 744        flush_work(&vsock->tx_work);
 745        flush_work(&vsock->event_work);
 746        flush_work(&vsock->send_pkt_work);
 747
 748        mutex_unlock(&the_virtio_vsock_mutex);
 749
 750        kfree(vsock);
 751}
 752
 753static struct virtio_device_id id_table[] = {
 754        { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID },
 755        { 0 },
 756};
 757
 758static unsigned int features[] = {
 759};
 760
 761static struct virtio_driver virtio_vsock_driver = {
 762        .feature_table = features,
 763        .feature_table_size = ARRAY_SIZE(features),
 764        .driver.name = KBUILD_MODNAME,
 765        .driver.owner = THIS_MODULE,
 766        .id_table = id_table,
 767        .probe = virtio_vsock_probe,
 768        .remove = virtio_vsock_remove,
 769};
 770
 771static int __init virtio_vsock_init(void)
 772{
 773        int ret;
 774
 775        virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0);
 776        if (!virtio_vsock_workqueue)
 777                return -ENOMEM;
 778
 779        ret = vsock_core_init(&virtio_transport.transport);
 780        if (ret)
 781                goto out_wq;
 782
 783        ret = register_virtio_driver(&virtio_vsock_driver);
 784        if (ret)
 785                goto out_vci;
 786
 787        return 0;
 788
 789out_vci:
 790        vsock_core_exit();
 791out_wq:
 792        destroy_workqueue(virtio_vsock_workqueue);
 793        return ret;
 794}
 795
 796static void __exit virtio_vsock_exit(void)
 797{
 798        unregister_virtio_driver(&virtio_vsock_driver);
 799        vsock_core_exit();
 800        destroy_workqueue(virtio_vsock_workqueue);
 801}
 802
 803module_init(virtio_vsock_init);
 804module_exit(virtio_vsock_exit);
 805MODULE_LICENSE("GPL v2");
 806MODULE_AUTHOR("Asias He");
 807MODULE_DESCRIPTION("virtio transport for vsock");
 808MODULE_DEVICE_TABLE(virtio, id_table);
 809