linux/drivers/vhost/vhost.c
<<
>>
Prefs
   1/* Copyright (C) 2009 Red Hat, Inc.
   2 * Copyright (C) 2006 Rusty Russell IBM Corporation
   3 *
   4 * Author: Michael S. Tsirkin <mst@redhat.com>
   5 *
   6 * Inspiration, some code, and most witty comments come from
   7 * Documentation/virtual/lguest/lguest.c, by Rusty Russell
   8 *
   9 * This work is licensed under the terms of the GNU GPL, version 2.
  10 *
  11 * Generic code for virtio server in host kernel.
  12 */
  13
  14#include <linux/eventfd.h>
  15#include <linux/vhost.h>
  16#include <linux/socket.h> /* memcpy_fromiovec */
  17#include <linux/mm.h>
  18#include <linux/mmu_context.h>
  19#include <linux/miscdevice.h>
  20#include <linux/mutex.h>
  21#include <linux/poll.h>
  22#include <linux/file.h>
  23#include <linux/highmem.h>
  24#include <linux/slab.h>
  25#include <linux/vmalloc.h>
  26#include <linux/kthread.h>
  27#include <linux/cgroup.h>
  28#include <linux/module.h>
  29#include <linux/sort.h>
  30#include <linux/interval_tree_generic.h>
  31#include <linux/nospec.h>
  32
  33#include "vhost.h"
  34
  35static ushort max_mem_regions = 64;
  36module_param(max_mem_regions, ushort, 0444);
  37MODULE_PARM_DESC(max_mem_regions,
  38        "Maximum number of memory regions in memory map. (default: 64)");
  39static int max_iotlb_entries = 2048;
  40module_param(max_iotlb_entries, int, 0444);
  41MODULE_PARM_DESC(max_iotlb_entries,
  42        "Maximum number of iotlb entries. (default: 2048)");
  43
  44enum {
  45        VHOST_MEMORY_F_LOG = 0x1,
  46};
  47
  48#define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
  49#define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
  50
  51INTERVAL_TREE_DEFINE(struct vhost_umem_node,
  52                     rb, __u64, __subtree_last,
  53                     START, LAST, static inline, vhost_umem_interval_tree);
  54
  55#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
  56static void vhost_vq_reset_user_be(struct vhost_virtqueue *vq)
  57{
  58        vq->user_be = !virtio_legacy_is_little_endian();
  59}
  60
  61static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
  62{
  63        struct vhost_vring_state s;
  64
  65        if (vq->private_data)
  66                return -EBUSY;
  67
  68        if (copy_from_user(&s, argp, sizeof(s)))
  69                return -EFAULT;
  70
  71        if (s.num != VHOST_VRING_LITTLE_ENDIAN &&
  72            s.num != VHOST_VRING_BIG_ENDIAN)
  73                return -EINVAL;
  74
  75        vq->user_be = s.num;
  76
  77        return 0;
  78}
  79
  80static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
  81                                   int __user *argp)
  82{
  83        struct vhost_vring_state s = {
  84                .index = idx,
  85                .num = vq->user_be
  86        };
  87
  88        if (copy_to_user(argp, &s, sizeof(s)))
  89                return -EFAULT;
  90
  91        return 0;
  92}
  93
  94static void vhost_init_is_le(struct vhost_virtqueue *vq)
  95{
  96        /* Note for legacy virtio: user_be is initialized at reset time
  97         * according to the host endianness. If userspace does not set an
  98         * explicit endianness, the default behavior is native endian, as
  99         * expected by legacy virtio.
 100         */
 101        vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) || !vq->user_be;
 102}
 103#else
 104static void vhost_vq_reset_user_be(struct vhost_virtqueue *vq)
 105{
 106}
 107
 108static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
 109{
 110        return -ENOIOCTLCMD;
 111}
 112
 113static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
 114                                   int __user *argp)
 115{
 116        return -ENOIOCTLCMD;
 117}
 118
 119static void vhost_init_is_le(struct vhost_virtqueue *vq)
 120{
 121        if (vhost_has_feature(vq, VIRTIO_F_VERSION_1))
 122                vq->is_le = true;
 123}
 124#endif /* CONFIG_VHOST_CROSS_ENDIAN_LEGACY */
 125
 126
 127struct vhost_flush_struct {
 128        struct vhost_work work;
 129        struct completion wait_event;
 130};
 131
 132static void vhost_flush_work(struct vhost_work *work)
 133{
 134        struct vhost_flush_struct *s;
 135
 136        s = container_of(work, struct vhost_flush_struct, work);
 137        complete(&s->wait_event);
 138}
 139
 140static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
 141                            poll_table *pt)
 142{
 143        struct vhost_poll *poll;
 144
 145        poll = container_of(pt, struct vhost_poll, table);
 146        poll->wqh = wqh;
 147        add_wait_queue(wqh, &poll->wait);
 148}
 149
 150static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync,
 151                             void *key)
 152{
 153        struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait);
 154
 155        if (!((unsigned long)key & poll->mask))
 156                return 0;
 157
 158        vhost_poll_queue(poll);
 159        return 0;
 160}
 161
 162void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
 163{
 164        clear_bit(VHOST_WORK_QUEUED, &work->flags);
 165        work->fn = fn;
 166        init_waitqueue_head(&work->done);
 167}
 168EXPORT_SYMBOL_GPL(vhost_work_init);
 169
 170/* Init poll structure */
 171void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
 172                     unsigned long mask, struct vhost_dev *dev)
 173{
 174        init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
 175        init_poll_funcptr(&poll->table, vhost_poll_func);
 176        poll->mask = mask;
 177        poll->dev = dev;
 178        poll->wqh = NULL;
 179
 180        vhost_work_init(&poll->work, fn);
 181}
 182EXPORT_SYMBOL_GPL(vhost_poll_init);
 183
 184/* Start polling a file. We add ourselves to file's wait queue. The caller must
 185 * keep a reference to a file until after vhost_poll_stop is called. */
 186int vhost_poll_start(struct vhost_poll *poll, struct file *file)
 187{
 188        unsigned long mask;
 189        int ret = 0;
 190
 191        if (poll->wqh)
 192                return 0;
 193
 194        mask = file->f_op->poll(file, &poll->table);
 195        if (mask)
 196                vhost_poll_wakeup(&poll->wait, 0, 0, (void *)mask);
 197        if (mask & POLLERR) {
 198                vhost_poll_stop(poll);
 199                ret = -EINVAL;
 200        }
 201
 202        return ret;
 203}
 204EXPORT_SYMBOL_GPL(vhost_poll_start);
 205
 206/* Stop polling a file. After this function returns, it becomes safe to drop the
 207 * file reference. You must also flush afterwards. */
 208void vhost_poll_stop(struct vhost_poll *poll)
 209{
 210        if (poll->wqh) {
 211                remove_wait_queue(poll->wqh, &poll->wait);
 212                poll->wqh = NULL;
 213        }
 214}
 215EXPORT_SYMBOL_GPL(vhost_poll_stop);
 216
 217void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
 218{
 219        struct vhost_flush_struct flush;
 220
 221        if (dev->worker) {
 222                init_completion(&flush.wait_event);
 223                vhost_work_init(&flush.work, vhost_flush_work);
 224
 225                vhost_work_queue(dev, &flush.work);
 226                wait_for_completion(&flush.wait_event);
 227        }
 228}
 229EXPORT_SYMBOL_GPL(vhost_work_flush);
 230
 231/* Flush any work that has been scheduled. When calling this, don't hold any
 232 * locks that are also used by the callback. */
 233void vhost_poll_flush(struct vhost_poll *poll)
 234{
 235        vhost_work_flush(poll->dev, &poll->work);
 236}
 237EXPORT_SYMBOL_GPL(vhost_poll_flush);
 238
 239void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
 240{
 241        if (!dev->worker)
 242                return;
 243
 244        if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
 245                /* We can only add the work to the list after we're
 246                 * sure it was not in the list.
 247                 */
 248                smp_mb();
 249                llist_add(&work->node, &dev->work_list);
 250                wake_up_process(dev->worker);
 251        }
 252}
 253EXPORT_SYMBOL_GPL(vhost_work_queue);
 254
 255/* A lockless hint for busy polling code to exit the loop */
 256bool vhost_has_work(struct vhost_dev *dev)
 257{
 258        return !llist_empty(&dev->work_list);
 259}
 260EXPORT_SYMBOL_GPL(vhost_has_work);
 261
 262void vhost_poll_queue(struct vhost_poll *poll)
 263{
 264        vhost_work_queue(poll->dev, &poll->work);
 265}
 266EXPORT_SYMBOL_GPL(vhost_poll_queue);
 267
 268static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq)
 269{
 270        int j;
 271
 272        for (j = 0; j < VHOST_NUM_ADDRS; j++)
 273                vq->meta_iotlb[j] = NULL;
 274}
 275
 276static void vhost_vq_meta_reset(struct vhost_dev *d)
 277{
 278        int i;
 279
 280        for (i = 0; i < d->nvqs; ++i)
 281                __vhost_vq_meta_reset(d->vqs[i]);
 282}
 283
 284static void vhost_vq_reset(struct vhost_dev *dev,
 285                           struct vhost_virtqueue *vq)
 286{
 287        vq->num = 1;
 288        vq->desc = NULL;
 289        vq->avail = NULL;
 290        vq->used = NULL;
 291        vq->last_avail_idx = 0;
 292        vq->avail_idx = 0;
 293        vq->last_used_idx = 0;
 294        vq->signalled_used = 0;
 295        vq->signalled_used_valid = false;
 296        vq->used_flags = 0;
 297        vq->log_used = false;
 298        vq->log_addr = -1ull;
 299        vq->private_data = NULL;
 300        vq->acked_features = 0;
 301        vq->log_base = NULL;
 302        vq->error_ctx = NULL;
 303        vq->error = NULL;
 304        vq->kick = NULL;
 305        vq->call_ctx = NULL;
 306        vq->call = NULL;
 307        vq->log_ctx = NULL;
 308        vq->is_le = virtio_legacy_is_little_endian();
 309        vhost_vq_reset_user_be(vq);
 310        vq->busyloop_timeout = 0;
 311        vq->umem = NULL;
 312        vq->iotlb = NULL;
 313        __vhost_vq_meta_reset(vq);
 314}
 315
 316static int vhost_worker(void *data)
 317{
 318        struct vhost_dev *dev = data;
 319        struct vhost_work *work, *work_next;
 320        struct llist_node *node;
 321        mm_segment_t oldfs = get_fs();
 322
 323        set_fs(USER_DS);
 324        use_mm(dev->mm);
 325
 326        for (;;) {
 327                /* mb paired w/ kthread_stop */
 328                set_current_state(TASK_INTERRUPTIBLE);
 329
 330                if (kthread_should_stop()) {
 331                        __set_current_state(TASK_RUNNING);
 332                        break;
 333                }
 334
 335                node = llist_del_all(&dev->work_list);
 336                if (!node)
 337                        schedule();
 338
 339                node = llist_reverse_order(node);
 340                /* make sure flag is seen after deletion */
 341                smp_wmb();
 342                llist_for_each_entry_safe(work, work_next, node, node) {
 343                        clear_bit(VHOST_WORK_QUEUED, &work->flags);
 344                        __set_current_state(TASK_RUNNING);
 345                        work->fn(work);
 346                        if (need_resched())
 347                                schedule();
 348                }
 349        }
 350        unuse_mm(dev->mm);
 351        set_fs(oldfs);
 352        return 0;
 353}
 354
 355static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
 356{
 357        kfree(vq->indirect);
 358        vq->indirect = NULL;
 359        kfree(vq->log);
 360        vq->log = NULL;
 361        kfree(vq->heads);
 362        vq->heads = NULL;
 363}
 364
 365/* Helper to allocate iovec buffers for all vqs. */
 366static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
 367{
 368        int i;
 369
 370        for (i = 0; i < dev->nvqs; ++i) {
 371                dev->vqs[i]->indirect = kmalloc(sizeof *dev->vqs[i]->indirect *
 372                                               dev->iov_limit, GFP_KERNEL);
 373                dev->vqs[i]->log = kmalloc(sizeof *dev->vqs[i]->log *
 374                                           dev->iov_limit, GFP_KERNEL);
 375                dev->vqs[i]->heads = kmalloc(sizeof *dev->vqs[i]->heads *
 376                                            dev->iov_limit, GFP_KERNEL);
 377                if (!dev->vqs[i]->indirect || !dev->vqs[i]->log ||
 378                        !dev->vqs[i]->heads)
 379                        goto err_nomem;
 380        }
 381        return 0;
 382
 383err_nomem:
 384        for (; i >= 0; --i)
 385                vhost_vq_free_iovecs(dev->vqs[i]);
 386        return -ENOMEM;
 387}
 388
 389static void vhost_dev_free_iovecs(struct vhost_dev *dev)
 390{
 391        int i;
 392
 393        for (i = 0; i < dev->nvqs; ++i)
 394                vhost_vq_free_iovecs(dev->vqs[i]);
 395}
 396
 397bool vhost_exceeds_weight(struct vhost_virtqueue *vq,
 398                          int pkts, int total_len)
 399{
 400        struct vhost_dev *dev = vq->dev;
 401
 402        if ((dev->byte_weight && total_len >= dev->byte_weight) ||
 403            pkts >= dev->weight) {
 404                vhost_poll_queue(&vq->poll);
 405                return true;
 406        }
 407
 408        return false;
 409}
 410EXPORT_SYMBOL_GPL(vhost_exceeds_weight);
 411
 412long vhost_dev_init(struct vhost_dev *dev,
 413                    struct vhost_virtqueue **vqs, int nvqs,
 414                    int iov_limit, int weight, int byte_weight)
 415{
 416        int i;
 417
 418        dev->vqs = vqs;
 419        dev->nvqs = nvqs;
 420        mutex_init(&dev->mutex);
 421        dev->log_ctx = NULL;
 422        dev->log_file = NULL;
 423        dev->umem = NULL;
 424        dev->iotlb = NULL;
 425        dev->mm = NULL;
 426        dev->worker = NULL;
 427        dev->iov_limit = iov_limit;
 428        dev->weight = weight;
 429        dev->byte_weight = byte_weight;
 430        init_llist_head(&dev->work_list);
 431        init_waitqueue_head(&dev->wait);
 432        INIT_LIST_HEAD(&dev->read_list);
 433        INIT_LIST_HEAD(&dev->pending_list);
 434        spin_lock_init(&dev->iotlb_lock);
 435
 436        for (i = 0; i < dev->nvqs; ++i) {
 437                dev->vqs[i]->log = NULL;
 438                dev->vqs[i]->indirect = NULL;
 439                dev->vqs[i]->heads = NULL;
 440                dev->vqs[i]->dev = dev;
 441                mutex_init(&dev->vqs[i]->mutex);
 442                vhost_vq_reset(dev, dev->vqs[i]);
 443                if (dev->vqs[i]->handle_kick)
 444                        vhost_poll_init(&dev->vqs[i]->poll,
 445                                        dev->vqs[i]->handle_kick, POLLIN, dev);
 446        }
 447
 448        return 0;
 449}
 450EXPORT_SYMBOL_GPL(vhost_dev_init);
 451
 452/* Caller should have device mutex */
 453long vhost_dev_check_owner(struct vhost_dev *dev)
 454{
 455        /* Are you the owner? If not, I don't think you mean to do that */
 456        return dev->mm == current->mm ? 0 : -EPERM;
 457}
 458EXPORT_SYMBOL_GPL(vhost_dev_check_owner);
 459
 460struct vhost_attach_cgroups_struct {
 461        struct vhost_work work;
 462        struct task_struct *owner;
 463        int ret;
 464};
 465
 466static void vhost_attach_cgroups_work(struct vhost_work *work)
 467{
 468        struct vhost_attach_cgroups_struct *s;
 469
 470        s = container_of(work, struct vhost_attach_cgroups_struct, work);
 471        s->ret = cgroup_attach_task_all(s->owner, current);
 472}
 473
 474static int vhost_attach_cgroups(struct vhost_dev *dev)
 475{
 476        struct vhost_attach_cgroups_struct attach;
 477
 478        attach.owner = current;
 479        vhost_work_init(&attach.work, vhost_attach_cgroups_work);
 480        vhost_work_queue(dev, &attach.work);
 481        vhost_work_flush(dev, &attach.work);
 482        return attach.ret;
 483}
 484
 485/* Caller should have device mutex */
 486bool vhost_dev_has_owner(struct vhost_dev *dev)
 487{
 488        return dev->mm;
 489}
 490EXPORT_SYMBOL_GPL(vhost_dev_has_owner);
 491
 492/* Caller should have device mutex */
 493long vhost_dev_set_owner(struct vhost_dev *dev)
 494{
 495        struct task_struct *worker;
 496        int err;
 497
 498        /* Is there an owner already? */
 499        if (vhost_dev_has_owner(dev)) {
 500                err = -EBUSY;
 501                goto err_mm;
 502        }
 503
 504        /* No owner, become one */
 505        dev->mm = get_task_mm(current);
 506        worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid);
 507        if (IS_ERR(worker)) {
 508                err = PTR_ERR(worker);
 509                goto err_worker;
 510        }
 511
 512        dev->worker = worker;
 513        wake_up_process(worker);        /* avoid contributing to loadavg */
 514
 515        err = vhost_attach_cgroups(dev);
 516        if (err)
 517                goto err_cgroup;
 518
 519        err = vhost_dev_alloc_iovecs(dev);
 520        if (err)
 521                goto err_cgroup;
 522
 523        return 0;
 524err_cgroup:
 525        kthread_stop(worker);
 526        dev->worker = NULL;
 527err_worker:
 528        if (dev->mm)
 529                mmput(dev->mm);
 530        dev->mm = NULL;
 531err_mm:
 532        return err;
 533}
 534EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
 535
 536struct vhost_umem *vhost_dev_reset_owner_prepare(void)
 537{
 538        return kvzalloc(sizeof(struct vhost_umem), GFP_KERNEL);
 539}
 540EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
 541
 542/* Caller should have device mutex */
 543void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem)
 544{
 545        int i;
 546
 547        vhost_dev_cleanup(dev, true);
 548
 549        /* Restore memory to default empty mapping. */
 550        INIT_LIST_HEAD(&umem->umem_list);
 551        dev->umem = umem;
 552        /* We don't need VQ locks below since vhost_dev_cleanup makes sure
 553         * VQs aren't running.
 554         */
 555        for (i = 0; i < dev->nvqs; ++i)
 556                dev->vqs[i]->umem = umem;
 557}
 558EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
 559
 560void vhost_dev_stop(struct vhost_dev *dev)
 561{
 562        int i;
 563
 564        for (i = 0; i < dev->nvqs; ++i) {
 565                if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) {
 566                        vhost_poll_stop(&dev->vqs[i]->poll);
 567                        vhost_poll_flush(&dev->vqs[i]->poll);
 568                }
 569        }
 570}
 571EXPORT_SYMBOL_GPL(vhost_dev_stop);
 572
 573static void vhost_umem_free(struct vhost_umem *umem,
 574                            struct vhost_umem_node *node)
 575{
 576        vhost_umem_interval_tree_remove(node, &umem->umem_tree);
 577        list_del(&node->link);
 578        kfree(node);
 579        umem->numem--;
 580}
 581
 582static void vhost_umem_clean(struct vhost_umem *umem)
 583{
 584        struct vhost_umem_node *node, *tmp;
 585
 586        if (!umem)
 587                return;
 588
 589        list_for_each_entry_safe(node, tmp, &umem->umem_list, link)
 590                vhost_umem_free(umem, node);
 591
 592        kvfree(umem);
 593}
 594
 595static void vhost_clear_msg(struct vhost_dev *dev)
 596{
 597        struct vhost_msg_node *node, *n;
 598
 599        spin_lock(&dev->iotlb_lock);
 600
 601        list_for_each_entry_safe(node, n, &dev->read_list, node) {
 602                list_del(&node->node);
 603                kfree(node);
 604        }
 605
 606        list_for_each_entry_safe(node, n, &dev->pending_list, node) {
 607                list_del(&node->node);
 608                kfree(node);
 609        }
 610
 611        spin_unlock(&dev->iotlb_lock);
 612}
 613
 614/* Caller should have device mutex if and only if locked is set */
 615void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
 616{
 617        int i;
 618
 619        for (i = 0; i < dev->nvqs; ++i) {
 620                if (dev->vqs[i]->error_ctx)
 621                        eventfd_ctx_put(dev->vqs[i]->error_ctx);
 622                if (dev->vqs[i]->error)
 623                        fput(dev->vqs[i]->error);
 624                if (dev->vqs[i]->kick)
 625                        fput(dev->vqs[i]->kick);
 626                if (dev->vqs[i]->call_ctx)
 627                        eventfd_ctx_put(dev->vqs[i]->call_ctx);
 628                if (dev->vqs[i]->call)
 629                        fput(dev->vqs[i]->call);
 630                vhost_vq_reset(dev, dev->vqs[i]);
 631        }
 632        vhost_dev_free_iovecs(dev);
 633        if (dev->log_ctx)
 634                eventfd_ctx_put(dev->log_ctx);
 635        dev->log_ctx = NULL;
 636        if (dev->log_file)
 637                fput(dev->log_file);
 638        dev->log_file = NULL;
 639        /* No one will access memory at this point */
 640        vhost_umem_clean(dev->umem);
 641        dev->umem = NULL;
 642        vhost_umem_clean(dev->iotlb);
 643        dev->iotlb = NULL;
 644        vhost_clear_msg(dev);
 645        wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
 646        WARN_ON(!llist_empty(&dev->work_list));
 647        if (dev->worker) {
 648                kthread_stop(dev->worker);
 649                dev->worker = NULL;
 650        }
 651        if (dev->mm)
 652                mmput(dev->mm);
 653        dev->mm = NULL;
 654}
 655EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
 656
 657static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
 658{
 659        u64 a = addr / VHOST_PAGE_SIZE / 8;
 660
 661        /* Make sure 64 bit math will not overflow. */
 662        if (a > ULONG_MAX - (unsigned long)log_base ||
 663            a + (unsigned long)log_base > ULONG_MAX)
 664                return false;
 665
 666        return access_ok(VERIFY_WRITE, log_base + a,
 667                         (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
 668}
 669
 670static bool vhost_overflow(u64 uaddr, u64 size)
 671{
 672        /* Make sure 64 bit math will not overflow. */
 673        return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size;
 674}
 675
 676/* Caller should have vq mutex and device mutex. */
 677static bool vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
 678                                int log_all)
 679{
 680        struct vhost_umem_node *node;
 681
 682        if (!umem)
 683                return false;
 684
 685        list_for_each_entry(node, &umem->umem_list, link) {
 686                unsigned long a = node->userspace_addr;
 687
 688                if (vhost_overflow(node->userspace_addr, node->size))
 689                        return false;
 690
 691
 692                if (!access_ok(VERIFY_WRITE, (void __user *)a,
 693                                    node->size))
 694                        return false;
 695                else if (log_all && !log_access_ok(log_base,
 696                                                   node->start,
 697                                                   node->size))
 698                        return false;
 699        }
 700        return true;
 701}
 702
 703static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq,
 704                                               u64 addr, unsigned int size,
 705                                               int type)
 706{
 707        const struct vhost_umem_node *node = vq->meta_iotlb[type];
 708
 709        if (!node)
 710                return NULL;
 711
 712        return (void *)(uintptr_t)(node->userspace_addr + addr - node->start);
 713}
 714
 715/* Can we switch to this memory table? */
 716/* Caller should have device mutex but not vq mutex */
 717static bool memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
 718                             int log_all)
 719{
 720        int i;
 721
 722        for (i = 0; i < d->nvqs; ++i) {
 723                bool ok;
 724                bool log;
 725
 726                mutex_lock(&d->vqs[i]->mutex);
 727                log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
 728                /* If ring is inactive, will check when it's enabled. */
 729                if (d->vqs[i]->private_data)
 730                        ok = vq_memory_access_ok(d->vqs[i]->log_base,
 731                                                 umem, log);
 732                else
 733                        ok = true;
 734                mutex_unlock(&d->vqs[i]->mutex);
 735                if (!ok)
 736                        return false;
 737        }
 738        return true;
 739}
 740
 741static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
 742                          struct iovec iov[], int iov_size, int access);
 743
 744static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to,
 745                              const void *from, unsigned size)
 746{
 747        int ret;
 748
 749        if (!vq->iotlb)
 750                return __copy_to_user(to, from, size);
 751        else {
 752                /* This function should be called after iotlb
 753                 * prefetch, which means we're sure that all vq
 754                 * could be access through iotlb. So -EAGAIN should
 755                 * not happen in this case.
 756                 */
 757                struct iov_iter t;
 758                void __user *uaddr = vhost_vq_meta_fetch(vq,
 759                                     (u64)(uintptr_t)to, size,
 760                                     VHOST_ADDR_USED);
 761
 762                if (uaddr)
 763                        return __copy_to_user(uaddr, from, size);
 764
 765                ret = translate_desc(vq, (u64)(uintptr_t)to, size,
 766                                     vq->iotlb_iov, ARRAY_SIZE(vq->iotlb_iov),
 767                                     VHOST_ACCESS_WO);
 768                if (ret < 0)
 769                        goto out;
 770                iov_iter_init(&t, (const struct iovec *)vq->iotlb_iov,
 771                              ret, size, 0);
 772                ret = memcpy_toiovecend(t.iov, (unsigned char *)from, 0, size);
 773        }
 774out:
 775        return ret;
 776}
 777
 778static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
 779                                void __user *from, unsigned size)
 780{
 781        int ret;
 782
 783        if (!vq->iotlb)
 784                return __copy_from_user(to, from, size);
 785        else {
 786                /* This function should be called after iotlb
 787                 * prefetch, which means we're sure that vq
 788                 * could be access through iotlb. So -EAGAIN should
 789                 * not happen in this case.
 790                 */
 791                void __user *uaddr = vhost_vq_meta_fetch(vq,
 792                                     (u64)(uintptr_t)from, size,
 793                                     VHOST_ADDR_DESC);
 794                struct iov_iter f;
 795
 796                if (uaddr)
 797                        return __copy_from_user(to, uaddr, size);
 798
 799                ret = translate_desc(vq, (u64)(uintptr_t)from, size,
 800                                     vq->iotlb_iov, ARRAY_SIZE(vq->iotlb_iov),
 801                                     VHOST_ACCESS_RO);
 802                if (ret < 0) {
 803                        vq_err(vq, "IOTLB translation failure: uaddr "
 804                               "%p size 0x%llx\n", from,
 805                               (unsigned long long) size);
 806                        goto out;
 807                }
 808                iov_iter_init(&f, (const struct iovec *)vq->iotlb_iov,
 809                              ret, size, 0);
 810                ret = memcpy_fromiovecend((unsigned char *)to, f.iov, 0, size);
 811        }
 812
 813out:
 814        return ret;
 815}
 816
 817static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq,
 818                                          void __user *addr, unsigned int size,
 819                                          int type)
 820{
 821        int ret;
 822
 823        ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
 824                             ARRAY_SIZE(vq->iotlb_iov),
 825                             VHOST_ACCESS_RO);
 826        if (ret < 0) {
 827                vq_err(vq, "IOTLB translation failure: uaddr "
 828                        "%p size 0x%llx\n", addr,
 829                        (unsigned long long) size);
 830                return NULL;
 831        }
 832
 833        if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
 834                vq_err(vq, "Non atomic userspace memory access: uaddr "
 835                        "%p size 0x%llx\n", addr,
 836                        (unsigned long long) size);
 837                return NULL;
 838        }
 839
 840        return vq->iotlb_iov[0].iov_base;
 841}
 842
 843/* This function should be called after iotlb
 844 * prefetch, which means we're sure that vq
 845 * could be access through iotlb. So -EAGAIN should
 846 * not happen in this case.
 847 */
 848static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
 849                                            void *addr, unsigned int size,
 850                                            int type)
 851{
 852        void __user *uaddr = vhost_vq_meta_fetch(vq,
 853                             (u64)(uintptr_t)addr, size, type);
 854        if (uaddr)
 855                return uaddr;
 856
 857        return __vhost_get_user_slow(vq, addr, size, type);
 858}
 859
 860#define vhost_put_user(vq, x, ptr)              \
 861({ \
 862        int ret = -EFAULT; \
 863        if (!vq->iotlb) { \
 864                ret = __put_user(x, ptr); \
 865        } else { \
 866                __typeof__(ptr) to = \
 867                        (__typeof__(ptr)) __vhost_get_user(vq, ptr,     \
 868                                          sizeof(*ptr), VHOST_ADDR_USED); \
 869                if (to != NULL) \
 870                        ret = __put_user(x, to); \
 871                else \
 872                        ret = -EFAULT;  \
 873        } \
 874        ret; \
 875})
 876
 877#define vhost_get_user(vq, x, ptr, type)                \
 878({ \
 879        int ret; \
 880        if (!vq->iotlb) { \
 881                ret = __get_user(x, ptr); \
 882        } else { \
 883                __typeof__(ptr) from = \
 884                        (__typeof__(ptr)) __vhost_get_user(vq, ptr, \
 885                                                           sizeof(*ptr), \
 886                                                           type); \
 887                if (from != NULL) \
 888                        ret = __get_user(x, from); \
 889                else \
 890                        ret = -EFAULT; \
 891        } \
 892        ret; \
 893})
 894
 895#define vhost_get_avail(vq, x, ptr) \
 896        vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL)
 897
 898#define vhost_get_used(vq, x, ptr) \
 899        vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
 900
 901static void vhost_dev_lock_vqs(struct vhost_dev *d)
 902{
 903        int i = 0;
 904        for (i = 0; i < d->nvqs; ++i)
 905                mutex_lock_nested(&d->vqs[i]->mutex, i);
 906}
 907
 908static void vhost_dev_unlock_vqs(struct vhost_dev *d)
 909{
 910        int i = 0;
 911        for (i = 0; i < d->nvqs; ++i)
 912                mutex_unlock(&d->vqs[i]->mutex);
 913}
 914
 915static int vhost_new_umem_range(struct vhost_umem *umem,
 916                                u64 start, u64 size, u64 end,
 917                                u64 userspace_addr, int perm)
 918{
 919        struct vhost_umem_node *tmp, *node;
 920
 921        if (!size)
 922                return -EFAULT;
 923
 924        node = kmalloc(sizeof(*node), GFP_ATOMIC);
 925        if (!node)
 926                return -ENOMEM;
 927
 928        if (umem->numem == max_iotlb_entries) {
 929                tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link);
 930                vhost_umem_free(umem, tmp);
 931        }
 932
 933        node->start = start;
 934        node->size = size;
 935        node->last = end;
 936        node->userspace_addr = userspace_addr;
 937        node->perm = perm;
 938        INIT_LIST_HEAD(&node->link);
 939        list_add_tail(&node->link, &umem->umem_list);
 940        vhost_umem_interval_tree_insert(node, &umem->umem_tree);
 941        umem->numem++;
 942
 943        return 0;
 944}
 945
 946static void vhost_del_umem_range(struct vhost_umem *umem,
 947                                 u64 start, u64 end)
 948{
 949        struct vhost_umem_node *node;
 950
 951        while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
 952                                                           start, end)))
 953                vhost_umem_free(umem, node);
 954}
 955
 956static void vhost_iotlb_notify_vq(struct vhost_dev *d,
 957                                  struct vhost_iotlb_msg *msg)
 958{
 959        struct vhost_msg_node *node, *n;
 960
 961        spin_lock(&d->iotlb_lock);
 962
 963        list_for_each_entry_safe(node, n, &d->pending_list, node) {
 964                struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
 965                if (msg->iova <= vq_msg->iova &&
 966                    msg->iova + msg->size - 1 > vq_msg->iova &&
 967                    vq_msg->type == VHOST_IOTLB_MISS) {
 968                        vhost_poll_queue(&node->vq->poll);
 969                        list_del(&node->node);
 970                        kfree(node);
 971                }
 972        }
 973
 974        spin_unlock(&d->iotlb_lock);
 975}
 976
 977static bool umem_access_ok(u64 uaddr, u64 size, int access)
 978{
 979        unsigned long a = uaddr;
 980
 981        /* Make sure 64 bit math will not overflow. */
 982        if (vhost_overflow(uaddr, size))
 983                return false;
 984
 985        if ((access & VHOST_ACCESS_RO) &&
 986            !access_ok(VERIFY_READ, (void __user *)a, size))
 987                return false;
 988        if ((access & VHOST_ACCESS_WO) &&
 989            !access_ok(VERIFY_WRITE, (void __user *)a, size))
 990                return false;
 991        return true;
 992}
 993
 994static int vhost_process_iotlb_msg(struct vhost_dev *dev,
 995                                   struct vhost_iotlb_msg *msg)
 996{
 997        int ret = 0;
 998
 999        vhost_dev_lock_vqs(dev);
1000        switch (msg->type) {
1001        case VHOST_IOTLB_UPDATE:
1002                if (!dev->iotlb) {
1003                        ret = -EFAULT;
1004                        break;
1005                }
1006                if (!umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
1007                        ret = -EFAULT;
1008                        break;
1009                }
1010                vhost_vq_meta_reset(dev);
1011                if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size,
1012                                         msg->iova + msg->size - 1,
1013                                         msg->uaddr, msg->perm)) {
1014                        ret = -ENOMEM;
1015                        break;
1016                }
1017                vhost_iotlb_notify_vq(dev, msg);
1018                break;
1019        case VHOST_IOTLB_INVALIDATE:
1020                if (!dev->iotlb) {
1021                        ret = -EFAULT;
1022                        break;
1023                }
1024                vhost_vq_meta_reset(dev);
1025                vhost_del_umem_range(dev->iotlb, msg->iova,
1026                                     msg->iova + msg->size - 1);
1027                break;
1028        default:
1029                ret = -EINVAL;
1030                break;
1031        }
1032
1033        vhost_dev_unlock_vqs(dev);
1034        return ret;
1035}
1036
1037ssize_t vhost_chr_write_iter(struct vhost_dev *dev, const struct iovec *from)
1038{
1039        struct vhost_msg_node node;
1040        unsigned size = sizeof(struct vhost_msg);
1041        size_t ret;
1042        int err;
1043
1044        ret = memcpy_fromiovecend((unsigned char *)&node.msg, from, 0, size);
1045        if (ret)
1046                goto done;
1047
1048        switch (node.msg.type) {
1049        case VHOST_IOTLB_MSG:
1050                err = vhost_process_iotlb_msg(dev, &node.msg.iotlb);
1051                if (err)
1052                        ret = err;
1053                break;
1054        default:
1055                ret = -EINVAL;
1056                break;
1057        }
1058
1059done:
1060        return ret ? ret : size;
1061}
1062EXPORT_SYMBOL(vhost_chr_write_iter);
1063
1064unsigned int vhost_chr_poll(struct file *file, struct vhost_dev *dev,
1065                            poll_table *wait)
1066{
1067        unsigned int mask = 0;
1068
1069        poll_wait(file, &dev->wait, wait);
1070
1071        if (!list_empty(&dev->read_list))
1072                mask |= POLLIN | POLLRDNORM;
1073
1074        return mask;
1075}
1076EXPORT_SYMBOL(vhost_chr_poll);
1077
1078ssize_t vhost_chr_read_iter(struct vhost_dev *dev, const struct iovec *to,
1079                            int noblock)
1080{
1081        DEFINE_WAIT(wait);
1082        struct vhost_msg_node *node;
1083        ssize_t ret = 0;
1084        unsigned size = sizeof(struct vhost_msg);
1085
1086        while (1) {
1087                if (!noblock)
1088                        prepare_to_wait(&dev->wait, &wait,
1089                                        TASK_INTERRUPTIBLE);
1090
1091                node = vhost_dequeue_msg(dev, &dev->read_list);
1092                if (node)
1093                        break;
1094                if (noblock) {
1095                        ret = -EAGAIN;
1096                        break;
1097                }
1098                if (signal_pending(current)) {
1099                        ret = -ERESTARTSYS;
1100                        break;
1101                }
1102                if (!dev->iotlb) {
1103                        ret = -EBADFD;
1104                        break;
1105                }
1106
1107                schedule();
1108        }
1109
1110        if (!noblock)
1111                finish_wait(&dev->wait, &wait);
1112
1113        if (node) {
1114                ret = memcpy_toiovecend(to, (unsigned char *)&node->msg,
1115                                        0, size);
1116
1117                if (ret || node->msg.type != VHOST_IOTLB_MISS) {
1118                        kfree(node);
1119                        return ret;
1120                }
1121
1122                vhost_enqueue_msg(dev, &dev->pending_list, node);
1123        }
1124
1125        return ret ? ret : size;
1126}
1127EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
1128
1129static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
1130{
1131        struct vhost_dev *dev = vq->dev;
1132        struct vhost_msg_node *node;
1133        struct vhost_iotlb_msg *msg;
1134
1135        node = vhost_new_msg(vq, VHOST_IOTLB_MISS);
1136        if (!node)
1137                return -ENOMEM;
1138
1139        msg = &node->msg.iotlb;
1140        msg->type = VHOST_IOTLB_MISS;
1141        msg->iova = iova;
1142        msg->perm = access;
1143
1144        vhost_enqueue_msg(dev, &dev->read_list, node);
1145
1146        return 0;
1147}
1148
1149static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
1150                         struct vring_desc __user *desc,
1151                         struct vring_avail __user *avail,
1152                         struct vring_used __user *used)
1153
1154{
1155        size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1156
1157        return access_ok(VERIFY_READ, desc, num * sizeof(*desc)) &&
1158               access_ok(VERIFY_READ, avail,
1159                         sizeof(*avail) + num * sizeof(*avail->ring) + s) &&
1160               access_ok(VERIFY_WRITE, used,
1161                        sizeof *used + num * sizeof *used->ring + s);
1162}
1163
1164static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
1165                                 const struct vhost_umem_node *node,
1166                                 int type)
1167{
1168        int access = (type == VHOST_ADDR_USED) ?
1169                     VHOST_ACCESS_WO : VHOST_ACCESS_RO;
1170
1171        if (likely(node->perm & access))
1172                vq->meta_iotlb[type] = node;
1173}
1174
1175static bool iotlb_access_ok(struct vhost_virtqueue *vq,
1176                            int access, u64 addr, u64 len, int type)
1177{
1178        const struct vhost_umem_node *node;
1179        struct vhost_umem *umem = vq->iotlb;
1180        u64 s = 0, size, orig_addr = addr;
1181
1182        if (vhost_vq_meta_fetch(vq, addr, len, type))
1183                return true;
1184
1185        while (len > s) {
1186                node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
1187                                                           addr,
1188                                                           addr + len - 1);
1189                if (node == NULL || node->start > addr) {
1190                        vhost_iotlb_miss(vq, addr, access);
1191                        return false;
1192                } else if (!(node->perm & access)) {
1193                        /* Report the possible access violation by
1194                         * request another translation from userspace.
1195                         */
1196                        return false;
1197                }
1198
1199                size = node->size - addr + node->start;
1200
1201                if (orig_addr == addr && size >= len)
1202                        vhost_vq_meta_update(vq, node, type);
1203
1204                s += size;
1205                addr += size;
1206        }
1207
1208        return true;
1209}
1210
1211int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
1212{
1213        size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1214        unsigned int num = vq->num;
1215
1216        if (!vq->iotlb)
1217                return 1;
1218
1219        return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
1220                               num * sizeof(*vq->desc), VHOST_ADDR_DESC) &&
1221               iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
1222                               sizeof(*vq->avail) +
1223                               num * sizeof(*vq->avail->ring) + s,
1224                               VHOST_ADDR_AVAIL) &&
1225               iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
1226                               sizeof(*vq->used) +
1227                               num * sizeof(*vq->used->ring) + s,
1228                               VHOST_ADDR_USED);
1229}
1230EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
1231
1232/* Can we log writes? */
1233/* Caller should have device mutex but not vq mutex */
1234bool vhost_log_access_ok(struct vhost_dev *dev)
1235{
1236        return memory_access_ok(dev, dev->umem, 1);
1237}
1238EXPORT_SYMBOL_GPL(vhost_log_access_ok);
1239
1240/* Verify access for write logging. */
1241/* Caller should have vq mutex and device mutex */
1242static bool vq_log_access_ok(struct vhost_virtqueue *vq,
1243                             void __user *log_base)
1244{
1245        size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1246
1247        return vq_memory_access_ok(log_base, vq->umem,
1248                                   vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
1249                (!vq->log_used || log_access_ok(log_base, vq->log_addr,
1250                                        sizeof *vq->used +
1251                                        vq->num * sizeof *vq->used->ring + s));
1252}
1253
1254/* Can we start vq? */
1255/* Caller should have vq mutex and device mutex */
1256bool vhost_vq_access_ok(struct vhost_virtqueue *vq)
1257{
1258        if (!vq_log_access_ok(vq, vq->log_base))
1259                return false;
1260
1261        /* Access validation occurs at prefetch time with IOTLB */
1262        if (vq->iotlb)
1263                return true;
1264
1265        return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used);
1266}
1267EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
1268
1269static struct vhost_umem *vhost_umem_alloc(void)
1270{
1271        struct vhost_umem *umem = kvzalloc(sizeof(*umem), GFP_KERNEL);
1272
1273        if (!umem)
1274                return NULL;
1275
1276        umem->umem_tree = RB_ROOT;
1277        umem->numem = 0;
1278        INIT_LIST_HEAD(&umem->umem_list);
1279
1280        return umem;
1281}
1282
1283static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
1284{
1285        struct vhost_memory mem, *newmem;
1286        struct vhost_memory_region *region;
1287        struct vhost_umem *newumem, *oldumem;
1288        unsigned long size = offsetof(struct vhost_memory, regions);
1289        int i;
1290
1291        if (copy_from_user(&mem, m, size))
1292                return -EFAULT;
1293        if (mem.padding)
1294                return -EOPNOTSUPP;
1295        if (mem.nregions > max_mem_regions)
1296                return -E2BIG;
1297        newmem = kvzalloc(size + mem.nregions * sizeof(*m->regions), GFP_KERNEL);
1298        if (!newmem)
1299                return -ENOMEM;
1300
1301        memcpy(newmem, &mem, size);
1302        if (copy_from_user(newmem->regions, m->regions,
1303                           mem.nregions * sizeof *m->regions)) {
1304                kvfree(newmem);
1305                return -EFAULT;
1306        }
1307
1308        newumem = vhost_umem_alloc();
1309        if (!newumem) {
1310                kvfree(newmem);
1311                return -ENOMEM;
1312        }
1313
1314        for (region = newmem->regions;
1315             region < newmem->regions + mem.nregions;
1316             region++) {
1317                if (vhost_new_umem_range(newumem,
1318                                         region->guest_phys_addr,
1319                                         region->memory_size,
1320                                         region->guest_phys_addr +
1321                                         region->memory_size - 1,
1322                                         region->userspace_addr,
1323                                         VHOST_ACCESS_RW))
1324                        goto err;
1325        }
1326
1327        if (!memory_access_ok(d, newumem, 0))
1328                goto err;
1329
1330        oldumem = d->umem;
1331        d->umem = newumem;
1332
1333        /* All memory accesses are done under some VQ mutex. */
1334        for (i = 0; i < d->nvqs; ++i) {
1335                mutex_lock(&d->vqs[i]->mutex);
1336                d->vqs[i]->umem = newumem;
1337                mutex_unlock(&d->vqs[i]->mutex);
1338        }
1339
1340        kvfree(newmem);
1341        vhost_umem_clean(oldumem);
1342        return 0;
1343
1344err:
1345        vhost_umem_clean(newumem);
1346        kvfree(newmem);
1347        return -EFAULT;
1348}
1349
1350long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1351{
1352        struct file *eventfp, *filep = NULL;
1353        bool pollstart = false, pollstop = false;
1354        struct eventfd_ctx *ctx = NULL;
1355        u32 __user *idxp = argp;
1356        struct vhost_virtqueue *vq;
1357        struct vhost_vring_state s;
1358        struct vhost_vring_file f;
1359        struct vhost_vring_addr a;
1360        u32 idx;
1361        long r;
1362
1363        r = get_user(idx, idxp);
1364        if (r < 0)
1365                return r;
1366        if (idx >= d->nvqs)
1367                return -ENOBUFS;
1368
1369        idx = array_index_nospec(idx, d->nvqs);
1370        vq = d->vqs[idx];
1371
1372        mutex_lock(&vq->mutex);
1373
1374        switch (ioctl) {
1375        case VHOST_SET_VRING_NUM:
1376                /* Resizing ring with an active backend?
1377                 * You don't want to do that. */
1378                if (vq->private_data) {
1379                        r = -EBUSY;
1380                        break;
1381                }
1382                if (copy_from_user(&s, argp, sizeof s)) {
1383                        r = -EFAULT;
1384                        break;
1385                }
1386                if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) {
1387                        r = -EINVAL;
1388                        break;
1389                }
1390                vq->num = s.num;
1391                break;
1392        case VHOST_SET_VRING_BASE:
1393                /* Moving base with an active backend?
1394                 * You don't want to do that. */
1395                if (vq->private_data) {
1396                        r = -EBUSY;
1397                        break;
1398                }
1399                if (copy_from_user(&s, argp, sizeof s)) {
1400                        r = -EFAULT;
1401                        break;
1402                }
1403                if (s.num > 0xffff) {
1404                        r = -EINVAL;
1405                        break;
1406                }
1407                vq->last_avail_idx = s.num;
1408                /* Forget the cached index value. */
1409                vq->avail_idx = vq->last_avail_idx;
1410                break;
1411        case VHOST_GET_VRING_BASE:
1412                s.index = idx;
1413                s.num = vq->last_avail_idx;
1414                if (copy_to_user(argp, &s, sizeof s))
1415                        r = -EFAULT;
1416                break;
1417        case VHOST_SET_VRING_ADDR:
1418                if (copy_from_user(&a, argp, sizeof a)) {
1419                        r = -EFAULT;
1420                        break;
1421                }
1422                if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) {
1423                        r = -EOPNOTSUPP;
1424                        break;
1425                }
1426                /* For 32bit, verify that the top 32bits of the user
1427                   data are set to zero. */
1428                if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
1429                    (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
1430                    (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr) {
1431                        r = -EFAULT;
1432                        break;
1433                }
1434
1435                /* Make sure it's safe to cast pointers to vring types. */
1436                BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
1437                BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
1438                if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
1439                    (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
1440                    (a.log_guest_addr & (sizeof(u64) - 1))) {
1441                        r = -EINVAL;
1442                        break;
1443                }
1444
1445                /* We only verify access here if backend is configured.
1446                 * If it is not, we don't as size might not have been setup.
1447                 * We will verify when backend is configured. */
1448                if (vq->private_data) {
1449                        if (!vq_access_ok(vq, vq->num,
1450                                (void __user *)(unsigned long)a.desc_user_addr,
1451                                (void __user *)(unsigned long)a.avail_user_addr,
1452                                (void __user *)(unsigned long)a.used_user_addr)) {
1453                                r = -EINVAL;
1454                                break;
1455                        }
1456
1457                        /* Also validate log access for used ring if enabled. */
1458                        if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) &&
1459                            !log_access_ok(vq->log_base, a.log_guest_addr,
1460                                           sizeof *vq->used +
1461                                           vq->num * sizeof *vq->used->ring)) {
1462                                r = -EINVAL;
1463                                break;
1464                        }
1465                }
1466
1467                vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
1468                vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
1469                vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
1470                vq->log_addr = a.log_guest_addr;
1471                vq->used = (void __user *)(unsigned long)a.used_user_addr;
1472                break;
1473        case VHOST_SET_VRING_KICK:
1474                if (copy_from_user(&f, argp, sizeof f)) {
1475                        r = -EFAULT;
1476                        break;
1477                }
1478                eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
1479                if (IS_ERR(eventfp)) {
1480                        r = PTR_ERR(eventfp);
1481                        break;
1482                }
1483                if (eventfp != vq->kick) {
1484                        pollstop = (filep = vq->kick) != NULL;
1485                        pollstart = (vq->kick = eventfp) != NULL;
1486                } else
1487                        filep = eventfp;
1488                break;
1489        case VHOST_SET_VRING_CALL:
1490                if (copy_from_user(&f, argp, sizeof f)) {
1491                        r = -EFAULT;
1492                        break;
1493                }
1494                eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
1495                if (IS_ERR(eventfp)) {
1496                        r = PTR_ERR(eventfp);
1497                        break;
1498                }
1499                if (eventfp != vq->call) {
1500                        filep = vq->call;
1501                        ctx = vq->call_ctx;
1502                        vq->call = eventfp;
1503                        vq->call_ctx = eventfp ?
1504                                eventfd_ctx_fileget(eventfp) : NULL;
1505                } else
1506                        filep = eventfp;
1507                break;
1508        case VHOST_SET_VRING_ERR:
1509                if (copy_from_user(&f, argp, sizeof f)) {
1510                        r = -EFAULT;
1511                        break;
1512                }
1513                eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
1514                if (IS_ERR(eventfp)) {
1515                        r = PTR_ERR(eventfp);
1516                        break;
1517                }
1518                if (eventfp != vq->error) {
1519                        filep = vq->error;
1520                        vq->error = eventfp;
1521                        ctx = vq->error_ctx;
1522                        vq->error_ctx = eventfp ?
1523                                eventfd_ctx_fileget(eventfp) : NULL;
1524                } else
1525                        filep = eventfp;
1526                break;
1527        case VHOST_SET_VRING_ENDIAN:
1528                r = vhost_set_vring_endian(vq, argp);
1529                break;
1530        case VHOST_GET_VRING_ENDIAN:
1531                r = vhost_get_vring_endian(vq, idx, argp);
1532                break;
1533        case VHOST_SET_VRING_BUSYLOOP_TIMEOUT:
1534                if (copy_from_user(&s, argp, sizeof(s))) {
1535                        r = -EFAULT;
1536                        break;
1537                }
1538                vq->busyloop_timeout = s.num;
1539                break;
1540        case VHOST_GET_VRING_BUSYLOOP_TIMEOUT:
1541                s.index = idx;
1542                s.num = vq->busyloop_timeout;
1543                if (copy_to_user(argp, &s, sizeof(s)))
1544                        r = -EFAULT;
1545                break;
1546        default:
1547                r = -ENOIOCTLCMD;
1548        }
1549
1550        if (pollstop && vq->handle_kick)
1551                vhost_poll_stop(&vq->poll);
1552
1553        if (ctx)
1554                eventfd_ctx_put(ctx);
1555        if (filep)
1556                fput(filep);
1557
1558        if (pollstart && vq->handle_kick)
1559                r = vhost_poll_start(&vq->poll, vq->kick);
1560
1561        mutex_unlock(&vq->mutex);
1562
1563        if (pollstop && vq->handle_kick)
1564                vhost_poll_flush(&vq->poll);
1565        return r;
1566}
1567EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
1568
1569int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
1570{
1571        struct vhost_umem *niotlb, *oiotlb;
1572        int i;
1573
1574        niotlb = vhost_umem_alloc();
1575        if (!niotlb)
1576                return -ENOMEM;
1577
1578        oiotlb = d->iotlb;
1579        d->iotlb = niotlb;
1580
1581        for (i = 0; i < d->nvqs; ++i) {
1582                mutex_lock(&d->vqs[i]->mutex);
1583                d->vqs[i]->iotlb = niotlb;
1584                mutex_unlock(&d->vqs[i]->mutex);
1585        }
1586
1587        vhost_umem_clean(oiotlb);
1588
1589        return 0;
1590}
1591EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
1592
1593/* Caller must have device mutex */
1594long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1595{
1596        struct file *eventfp, *filep = NULL;
1597        struct eventfd_ctx *ctx = NULL;
1598        u64 p;
1599        long r;
1600        int i, fd;
1601
1602        /* If you are not the owner, you can become one */
1603        if (ioctl == VHOST_SET_OWNER) {
1604                r = vhost_dev_set_owner(d);
1605                goto done;
1606        }
1607
1608        /* You must be the owner to do anything else */
1609        r = vhost_dev_check_owner(d);
1610        if (r)
1611                goto done;
1612
1613        switch (ioctl) {
1614        case VHOST_SET_MEM_TABLE:
1615                r = vhost_set_memory(d, argp);
1616                break;
1617        case VHOST_SET_LOG_BASE:
1618                if (copy_from_user(&p, argp, sizeof p)) {
1619                        r = -EFAULT;
1620                        break;
1621                }
1622                if ((u64)(unsigned long)p != p) {
1623                        r = -EFAULT;
1624                        break;
1625                }
1626                for (i = 0; i < d->nvqs; ++i) {
1627                        struct vhost_virtqueue *vq;
1628                        void __user *base = (void __user *)(unsigned long)p;
1629                        vq = d->vqs[i];
1630                        mutex_lock(&vq->mutex);
1631                        /* If ring is inactive, will check when it's enabled. */
1632                        if (vq->private_data && !vq_log_access_ok(vq, base))
1633                                r = -EFAULT;
1634                        else
1635                                vq->log_base = base;
1636                        mutex_unlock(&vq->mutex);
1637                }
1638                break;
1639        case VHOST_SET_LOG_FD:
1640                r = get_user(fd, (int __user *)argp);
1641                if (r < 0)
1642                        break;
1643                eventfp = fd == -1 ? NULL : eventfd_fget(fd);
1644                if (IS_ERR(eventfp)) {
1645                        r = PTR_ERR(eventfp);
1646                        break;
1647                }
1648                if (eventfp != d->log_file) {
1649                        filep = d->log_file;
1650                        ctx = d->log_ctx;
1651                        d->log_ctx = eventfp ?
1652                                eventfd_ctx_fileget(eventfp) : NULL;
1653                } else
1654                        filep = eventfp;
1655                for (i = 0; i < d->nvqs; ++i) {
1656                        mutex_lock(&d->vqs[i]->mutex);
1657                        d->vqs[i]->log_ctx = d->log_ctx;
1658                        mutex_unlock(&d->vqs[i]->mutex);
1659                }
1660                if (ctx)
1661                        eventfd_ctx_put(ctx);
1662                if (filep)
1663                        fput(filep);
1664                break;
1665        default:
1666                r = -ENOIOCTLCMD;
1667                break;
1668        }
1669done:
1670        return r;
1671}
1672EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
1673
1674/* TODO: This is really inefficient.  We need something like get_user()
1675 * (instruction directly accesses the data, with an exception table entry
1676 * returning -EFAULT). See Documentation/x86/exception-tables.txt.
1677 */
1678static int set_bit_to_user(int nr, void __user *addr)
1679{
1680        unsigned long log = (unsigned long)addr;
1681        struct page *page;
1682        void *base;
1683        int bit = nr + (log % PAGE_SIZE) * 8;
1684        int r;
1685
1686        r = get_user_pages_fast(log, 1, 1, &page);
1687        if (r < 0)
1688                return r;
1689        BUG_ON(r != 1);
1690        base = kmap_atomic(page);
1691        set_bit(bit, base);
1692        kunmap_atomic(base);
1693        set_page_dirty_lock(page);
1694        put_page(page);
1695        return 0;
1696}
1697
1698static int log_write(void __user *log_base,
1699                     u64 write_address, u64 write_length)
1700{
1701        u64 write_page = write_address / VHOST_PAGE_SIZE;
1702        int r;
1703
1704        if (!write_length)
1705                return 0;
1706        write_length += write_address % VHOST_PAGE_SIZE;
1707        for (;;) {
1708                u64 base = (u64)(unsigned long)log_base;
1709                u64 log = base + write_page / 8;
1710                int bit = write_page % 8;
1711                if ((u64)(unsigned long)log != log)
1712                        return -EFAULT;
1713                r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
1714                if (r < 0)
1715                        return r;
1716                if (write_length <= VHOST_PAGE_SIZE)
1717                        break;
1718                write_length -= VHOST_PAGE_SIZE;
1719                write_page += 1;
1720        }
1721        return r;
1722}
1723
1724static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
1725{
1726        struct vhost_umem *umem = vq->umem;
1727        struct vhost_umem_node *u;
1728        u64 start, end, l, min;
1729        int r;
1730        bool hit = false;
1731
1732        while (len) {
1733                min = len;
1734                /* More than one GPAs can be mapped into a single HVA. So
1735                 * iterate all possible umems here to be safe.
1736                 */
1737                list_for_each_entry(u, &umem->umem_list, link) {
1738                        if (u->userspace_addr > hva - 1 + len ||
1739                            u->userspace_addr - 1 + u->size < hva)
1740                                continue;
1741                        start = max(u->userspace_addr, hva);
1742                        end = min(u->userspace_addr - 1 + u->size,
1743                                  hva - 1 + len);
1744                        l = end - start + 1;
1745                        r = log_write(vq->log_base,
1746                                      u->start + start - u->userspace_addr,
1747                                      l);
1748                        if (r < 0)
1749                                return r;
1750                        hit = true;
1751                        min = min(l, min);
1752                }
1753
1754                if (!hit)
1755                        return -EFAULT;
1756
1757                len -= min;
1758                hva += min;
1759        }
1760
1761        return 0;
1762}
1763
1764static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
1765{
1766        struct iovec iov[64];
1767        int i, ret;
1768
1769        if (!vq->iotlb)
1770                return log_write(vq->log_base, vq->log_addr + used_offset, len);
1771
1772        ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
1773                             len, iov, 64, VHOST_ACCESS_WO);
1774        if (ret < 0)
1775                return ret;
1776
1777        for (i = 0; i < ret; i++) {
1778                ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
1779                                    iov[i].iov_len);
1780                if (ret)
1781                        return ret;
1782        }
1783
1784        return 0;
1785}
1786
1787int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
1788                    unsigned int log_num, u64 len, struct iovec *iov, int count)
1789{
1790        int i, r;
1791
1792        /* Make sure data written is seen before log. */
1793        smp_wmb();
1794
1795        if (vq->iotlb) {
1796                for (i = 0; i < count; i++) {
1797                        r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
1798                                          iov[i].iov_len);
1799                        if (r < 0)
1800                                return r;
1801                }
1802                return 0;
1803        }
1804
1805        for (i = 0; i < log_num; ++i) {
1806                u64 l = min(log[i].len, len);
1807                r = log_write(vq->log_base, log[i].addr, l);
1808                if (r < 0)
1809                        return r;
1810                len -= l;
1811                if (!len) {
1812                        if (vq->log_ctx)
1813                                eventfd_signal(vq->log_ctx, 1);
1814                        return 0;
1815                }
1816        }
1817        /* Length written exceeds what we have stored. This is a bug. */
1818        BUG();
1819        return 0;
1820}
1821EXPORT_SYMBOL_GPL(vhost_log_write);
1822
1823static int vhost_update_used_flags(struct vhost_virtqueue *vq)
1824{
1825        void __user *used;
1826        if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
1827                           &vq->used->flags) < 0)
1828                return -EFAULT;
1829        if (unlikely(vq->log_used)) {
1830                /* Make sure the flag is seen before log. */
1831                smp_wmb();
1832                /* Log used flag write. */
1833                used = &vq->used->flags;
1834                log_used(vq, (used - (void __user *)vq->used),
1835                         sizeof vq->used->flags);
1836                if (vq->log_ctx)
1837                        eventfd_signal(vq->log_ctx, 1);
1838        }
1839        return 0;
1840}
1841
1842static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
1843{
1844        if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
1845                           vhost_avail_event(vq)))
1846                return -EFAULT;
1847        if (unlikely(vq->log_used)) {
1848                void __user *used;
1849                /* Make sure the event is seen before log. */
1850                smp_wmb();
1851                /* Log avail event write */
1852                used = vhost_avail_event(vq);
1853                log_used(vq, (used - (void __user *)vq->used),
1854                         sizeof *vhost_avail_event(vq));
1855                if (vq->log_ctx)
1856                        eventfd_signal(vq->log_ctx, 1);
1857        }
1858        return 0;
1859}
1860
1861int vhost_init_used(struct vhost_virtqueue *vq)
1862{
1863        __virtio16 last_used_idx;
1864        int r;
1865        if (!vq->private_data) {
1866                vq->is_le = virtio_legacy_is_little_endian();
1867                return 0;
1868        }
1869
1870        vhost_init_is_le(vq);
1871
1872        r = vhost_update_used_flags(vq);
1873        if (r)
1874                return r;
1875        vq->signalled_used_valid = false;
1876        if (!vq->iotlb &&
1877            !access_ok(VERIFY_READ, &vq->used->idx, sizeof(vq->used->idx)))
1878                return -EFAULT;
1879        r = vhost_get_used(vq, last_used_idx, &vq->used->idx);
1880        if (r) {
1881                vq_err(vq, "Can't access used idx at %p\n",
1882                       &vq->used->idx);
1883                return r;
1884        }
1885        vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
1886        return 0;
1887}
1888EXPORT_SYMBOL_GPL(vhost_init_used);
1889
1890static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
1891                          struct iovec iov[], int iov_size, int access)
1892{
1893        const struct vhost_umem_node *node;
1894        struct vhost_dev *dev = vq->dev;
1895        struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem;
1896        struct iovec *_iov;
1897        u64 s = 0;
1898        int ret = 0;
1899
1900        while ((u64)len > s) {
1901                u64 size;
1902                if (unlikely(ret >= iov_size)) {
1903                        ret = -ENOBUFS;
1904                        break;
1905                }
1906
1907                node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
1908                                                        addr, addr + len - 1);
1909                if (node == NULL || node->start > addr) {
1910                        if (umem != dev->iotlb) {
1911                                ret = -EFAULT;
1912                                break;
1913                        }
1914                        ret = -EAGAIN;
1915                        break;
1916                } else if (!(node->perm & access)) {
1917                        ret = -EPERM;
1918                        break;
1919                }
1920
1921                _iov = iov + ret;
1922                size = node->size - addr + node->start;
1923                _iov->iov_len = min((u64)len - s, size);
1924                _iov->iov_base = (void __user *)(unsigned long)
1925                        (node->userspace_addr + addr - node->start);
1926                s += size;
1927                addr += size;
1928                ++ret;
1929        }
1930
1931        if (ret == -EAGAIN)
1932                vhost_iotlb_miss(vq, addr, access);
1933        return ret;
1934}
1935
1936/* Each buffer in the virtqueues is actually a chain of descriptors.  This
1937 * function returns the next descriptor in the chain,
1938 * or -1U if we're at the end. */
1939static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc)
1940{
1941        unsigned int next;
1942
1943        /* If this descriptor says it doesn't chain, we're done. */
1944        if (!(desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT)))
1945                return -1U;
1946
1947        /* Check they're not leading us off end of descriptors. */
1948        next = vhost16_to_cpu(vq, desc->next);
1949        /* Make sure compiler knows to grab that: we don't want it changing! */
1950        /* We will use the result as an index in an array, so most
1951         * architectures only need a compiler barrier here. */
1952        read_barrier_depends();
1953
1954        return next;
1955}
1956
1957static int get_indirect(struct vhost_virtqueue *vq,
1958                        struct iovec iov[], unsigned int iov_size,
1959                        unsigned int *out_num, unsigned int *in_num,
1960                        struct vhost_log *log, unsigned int *log_num,
1961                        struct vring_desc *indirect)
1962{
1963        struct vring_desc desc;
1964        unsigned int i = 0, count, found = 0;
1965        u32 len = vhost32_to_cpu(vq, indirect->len);
1966        struct iov_iter from;
1967        int ret, access;
1968
1969        /* Sanity check */
1970        if (unlikely(len % sizeof desc)) {
1971                vq_err(vq, "Invalid length in indirect descriptor: "
1972                       "len 0x%llx not multiple of 0x%zx\n",
1973                       (unsigned long long)len,
1974                       sizeof desc);
1975                return -EINVAL;
1976        }
1977
1978        ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
1979                             UIO_MAXIOV, VHOST_ACCESS_RO);
1980        if (unlikely(ret < 0)) {
1981                if (ret != -EAGAIN)
1982                        vq_err(vq, "Translation failure %d in indirect.\n", ret);
1983                return ret;
1984        }
1985        iov_iter_init(&from, vq->indirect, ret, len, 0);
1986
1987        /* We will use the result as an address to read from, so most
1988         * architectures only need a compiler barrier here. */
1989        read_barrier_depends();
1990
1991        count = len / sizeof desc;
1992        /* Buffers are chained via a 16 bit next field, so
1993         * we can have at most 2^16 of these. */
1994        if (unlikely(count > USHRT_MAX + 1)) {
1995                vq_err(vq, "Indirect buffer length too big: %d\n",
1996                       indirect->len);
1997                return -E2BIG;
1998        }
1999
2000        do {
2001                unsigned iov_count = *in_num + *out_num;
2002                if (unlikely(++found > count)) {
2003                        vq_err(vq, "Loop detected: last one at %u "
2004                               "indirect size %u\n",
2005                               i, count);
2006                        return -EINVAL;
2007                }
2008                if (unlikely(memcpy_fromiovec((unsigned char *)&desc,
2009                                              vq->indirect, sizeof desc))) {
2010                        vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n",
2011                               i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2012                        return -EINVAL;
2013                }
2014                if (unlikely(desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) {
2015                        vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n",
2016                               i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2017                        return -EINVAL;
2018                }
2019
2020                if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2021                        access = VHOST_ACCESS_WO;
2022                else
2023                        access = VHOST_ACCESS_RO;
2024
2025                ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2026                                     vhost32_to_cpu(vq, desc.len), iov + iov_count,
2027                                     iov_size - iov_count, access);
2028                if (unlikely(ret < 0)) {
2029                        if (ret != -EAGAIN)
2030                                vq_err(vq, "Translation failure %d indirect idx %d\n",
2031                                        ret, i);
2032                        return ret;
2033                }
2034                /* If this is an input descriptor, increment that count. */
2035                if (access == VHOST_ACCESS_WO) {
2036                        *in_num += ret;
2037                        if (unlikely(log)) {
2038                                log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2039                                log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2040                                ++*log_num;
2041                        }
2042                } else {
2043                        /* If it's an output descriptor, they're all supposed
2044                         * to come before any input descriptors. */
2045                        if (unlikely(*in_num)) {
2046                                vq_err(vq, "Indirect descriptor "
2047                                       "has out after in: idx %d\n", i);
2048                                return -EINVAL;
2049                        }
2050                        *out_num += ret;
2051                }
2052        } while ((i = next_desc(vq, &desc)) != -1);
2053        return 0;
2054}
2055
2056/* This looks in the virtqueue and for the first available buffer, and converts
2057 * it to an iovec for convenient access.  Since descriptors consist of some
2058 * number of output then some number of input descriptors, it's actually two
2059 * iovecs, but we pack them into one and note how many of each there were.
2060 *
2061 * This function returns the descriptor number found, or vq->num (which is
2062 * never a valid descriptor number) if none was found.  A negative code is
2063 * returned on error. */
2064int vhost_get_vq_desc(struct vhost_virtqueue *vq,
2065                      struct iovec iov[], unsigned int iov_size,
2066                      unsigned int *out_num, unsigned int *in_num,
2067                      struct vhost_log *log, unsigned int *log_num)
2068{
2069        struct vring_desc desc = {0};
2070        unsigned int i, head, found = 0;
2071        u16 last_avail_idx;
2072        __virtio16 avail_idx;
2073        __virtio16 ring_head;
2074        int ret, access;
2075
2076        /* Check it isn't doing very strange things with descriptor numbers. */
2077        last_avail_idx = vq->last_avail_idx;
2078        if (unlikely(vhost_get_avail(vq, avail_idx, &vq->avail->idx))) {
2079                vq_err(vq, "Failed to access avail idx at %p\n",
2080                       &vq->avail->idx);
2081                return -EFAULT;
2082        }
2083        vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2084
2085        if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) {
2086                vq_err(vq, "Guest moved used index from %u to %u",
2087                       last_avail_idx, vq->avail_idx);
2088                return -EFAULT;
2089        }
2090
2091        /* If there's nothing new since last we looked, return invalid. */
2092        if (vq->avail_idx == last_avail_idx)
2093                return vq->num;
2094
2095        /* Only get avail ring entries after they have been exposed by guest. */
2096        smp_rmb();
2097
2098        /* Grab the next descriptor number they're advertising, and increment
2099         * the index we've seen. */
2100        if (unlikely(vhost_get_avail(vq, ring_head,
2101                     &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
2102                vq_err(vq, "Failed to read head: idx %d address %p\n",
2103                       last_avail_idx,
2104                       &vq->avail->ring[last_avail_idx % vq->num]);
2105                return -EFAULT;
2106        }
2107
2108        head = vhost16_to_cpu(vq, ring_head);
2109
2110        /* If their number is silly, that's an error. */
2111        if (unlikely(head >= vq->num)) {
2112                vq_err(vq, "Guest says index %u > %u is available",
2113                       head, vq->num);
2114                return -EINVAL;
2115        }
2116
2117        /* When we start there are none of either input nor output. */
2118        *out_num = *in_num = 0;
2119        if (unlikely(log))
2120                *log_num = 0;
2121
2122        i = head;
2123        do {
2124                unsigned iov_count = *in_num + *out_num;
2125                if (unlikely(i >= vq->num)) {
2126                        vq_err(vq, "Desc index is %u > %u, head = %u",
2127                               i, vq->num, head);
2128                        return -EINVAL;
2129                }
2130                if (unlikely(++found > vq->num)) {
2131                        vq_err(vq, "Loop detected: last one at %u "
2132                               "vq size %u head %u\n",
2133                               i, vq->num, head);
2134                        return -EINVAL;
2135                }
2136                ret = vhost_copy_from_user(vq, &desc, vq->desc + i,
2137                                           sizeof desc);
2138                if (unlikely(ret)) {
2139                        vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
2140                               i, vq->desc + i);
2141                        return -EFAULT;
2142                }
2143                if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) {
2144                        ret = get_indirect(vq, iov, iov_size,
2145                                           out_num, in_num,
2146                                           log, log_num, &desc);
2147                        if (unlikely(ret < 0)) {
2148                                if (ret != -EAGAIN)
2149                                        vq_err(vq, "Failure detected "
2150                                               "in indirect descriptor at idx %d\n", i);
2151                                return ret;
2152                        }
2153                        continue;
2154                }
2155
2156                if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2157                        access = VHOST_ACCESS_WO;
2158                else
2159                        access = VHOST_ACCESS_RO;
2160                ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2161                                     vhost32_to_cpu(vq, desc.len), iov + iov_count,
2162                                     iov_size - iov_count, access);
2163                if (unlikely(ret < 0)) {
2164                        if (ret != -EAGAIN)
2165                                vq_err(vq, "Translation failure %d descriptor idx %d\n",
2166                                        ret, i);
2167                        return ret;
2168                }
2169                if (access == VHOST_ACCESS_WO) {
2170                        /* If this is an input descriptor,
2171                         * increment that count. */
2172                        *in_num += ret;
2173                        if (unlikely(log)) {
2174                                log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2175                                log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2176                                ++*log_num;
2177                        }
2178                } else {
2179                        /* If it's an output descriptor, they're all supposed
2180                         * to come before any input descriptors. */
2181                        if (unlikely(*in_num)) {
2182                                vq_err(vq, "Descriptor has out after in: "
2183                                       "idx %d\n", i);
2184                                return -EINVAL;
2185                        }
2186                        *out_num += ret;
2187                }
2188        } while ((i = next_desc(vq, &desc)) != -1);
2189
2190        /* On success, increment avail index. */
2191        vq->last_avail_idx++;
2192
2193        /* Assume notifications from guest are disabled at this point,
2194         * if they aren't we would need to update avail_event index. */
2195        BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
2196        return head;
2197}
2198EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
2199
2200/* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
2201void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
2202{
2203        vq->last_avail_idx -= n;
2204}
2205EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);
2206
2207/* After we've used one of their buffers, we tell them about it.  We'll then
2208 * want to notify the guest, using eventfd. */
2209int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
2210{
2211        struct vring_used_elem heads = {
2212                cpu_to_vhost32(vq, head),
2213                cpu_to_vhost32(vq, len)
2214        };
2215
2216        return vhost_add_used_n(vq, &heads, 1);
2217}
2218EXPORT_SYMBOL_GPL(vhost_add_used);
2219
2220static int __vhost_add_used_n(struct vhost_virtqueue *vq,
2221                            struct vring_used_elem *heads,
2222                            unsigned count)
2223{
2224        struct vring_used_elem __user *used;
2225        u16 old, new;
2226        int start;
2227
2228        start = vq->last_used_idx & (vq->num - 1);
2229        used = vq->used->ring + start;
2230        if (count == 1) {
2231                if (vhost_put_user(vq, heads[0].id, &used->id)) {
2232                        vq_err(vq, "Failed to write used id");
2233                        return -EFAULT;
2234                }
2235                if (vhost_put_user(vq, heads[0].len, &used->len)) {
2236                        vq_err(vq, "Failed to write used len");
2237                        return -EFAULT;
2238                }
2239        } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) {
2240                vq_err(vq, "Failed to write used");
2241                return -EFAULT;
2242        }
2243        if (unlikely(vq->log_used)) {
2244                /* Make sure data is seen before log. */
2245                smp_wmb();
2246                /* Log used ring entry write. */
2247                log_used(vq, ((void __user *)used - (void __user *)vq->used),
2248                         count * sizeof *used);
2249        }
2250        old = vq->last_used_idx;
2251        new = (vq->last_used_idx += count);
2252        /* If the driver never bothers to signal in a very long while,
2253         * used index might wrap around. If that happens, invalidate
2254         * signalled_used index we stored. TODO: make sure driver
2255         * signals at least once in 2^16 and remove this. */
2256        if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old)))
2257                vq->signalled_used_valid = false;
2258        return 0;
2259}
2260
2261/* After we've used one of their buffers, we tell them about it.  We'll then
2262 * want to notify the guest, using eventfd. */
2263int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
2264                     unsigned count)
2265{
2266        int start, n, r;
2267
2268        start = vq->last_used_idx & (vq->num - 1);
2269        n = vq->num - start;
2270        if (n < count) {
2271                r = __vhost_add_used_n(vq, heads, n);
2272                if (r < 0)
2273                        return r;
2274                heads += n;
2275                count -= n;
2276        }
2277        r = __vhost_add_used_n(vq, heads, count);
2278
2279        /* Make sure buffer is written before we update index. */
2280        smp_wmb();
2281        if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
2282                           &vq->used->idx)) {
2283                vq_err(vq, "Failed to increment used idx");
2284                return -EFAULT;
2285        }
2286        if (unlikely(vq->log_used)) {
2287                /* Log used index update. */
2288                log_used(vq, offsetof(struct vring_used, idx),
2289                         sizeof vq->used->idx);
2290                if (vq->log_ctx)
2291                        eventfd_signal(vq->log_ctx, 1);
2292        }
2293        return r;
2294}
2295EXPORT_SYMBOL_GPL(vhost_add_used_n);
2296
2297static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2298{
2299        __u16 old, new;
2300        __virtio16 event;
2301        bool v;
2302        /* Flush out used index updates. This is paired
2303         * with the barrier that the Guest executes when enabling
2304         * interrupts. */
2305        smp_mb();
2306
2307        if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2308            unlikely(vq->avail_idx == vq->last_avail_idx))
2309                return true;
2310
2311        if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2312                __virtio16 flags;
2313                if (vhost_get_avail(vq, flags, &vq->avail->flags)) {
2314                        vq_err(vq, "Failed to get flags");
2315                        return true;
2316                }
2317                return !(flags & cpu_to_vhost16(vq, VRING_AVAIL_F_NO_INTERRUPT));
2318        }
2319        old = vq->signalled_used;
2320        v = vq->signalled_used_valid;
2321        new = vq->signalled_used = vq->last_used_idx;
2322        vq->signalled_used_valid = true;
2323
2324        if (unlikely(!v))
2325                return true;
2326
2327        if (vhost_get_avail(vq, event, vhost_used_event(vq))) {
2328                vq_err(vq, "Failed to get used event idx");
2329                return true;
2330        }
2331        return vring_need_event(vhost16_to_cpu(vq, event), new, old);
2332}
2333
2334/* This actually signals the guest, using eventfd. */
2335void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2336{
2337        /* Signal the Guest tell them we used something up. */
2338        if (vq->call_ctx && vhost_notify(dev, vq))
2339                eventfd_signal(vq->call_ctx, 1);
2340}
2341EXPORT_SYMBOL_GPL(vhost_signal);
2342
2343/* And here's the combo meal deal.  Supersize me! */
2344void vhost_add_used_and_signal(struct vhost_dev *dev,
2345                               struct vhost_virtqueue *vq,
2346                               unsigned int head, int len)
2347{
2348        vhost_add_used(vq, head, len);
2349        vhost_signal(dev, vq);
2350}
2351EXPORT_SYMBOL_GPL(vhost_add_used_and_signal);
2352
2353/* multi-buffer version of vhost_add_used_and_signal */
2354void vhost_add_used_and_signal_n(struct vhost_dev *dev,
2355                                 struct vhost_virtqueue *vq,
2356                                 struct vring_used_elem *heads, unsigned count)
2357{
2358        vhost_add_used_n(vq, heads, count);
2359        vhost_signal(dev, vq);
2360}
2361EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n);
2362
2363/* return true if we're sure that avaiable ring is empty */
2364bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2365{
2366        __virtio16 avail_idx;
2367        int r;
2368
2369        if (vq->avail_idx != vq->last_avail_idx)
2370                return false;
2371
2372        r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
2373        if (unlikely(r))
2374                return false;
2375        vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2376
2377        return vq->avail_idx == vq->last_avail_idx;
2378}
2379EXPORT_SYMBOL_GPL(vhost_vq_avail_empty);
2380
2381/* OK, now we need to know about added descriptors. */
2382bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2383{
2384        __virtio16 avail_idx;
2385        int r;
2386
2387        if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
2388                return false;
2389        vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
2390        if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2391                r = vhost_update_used_flags(vq);
2392                if (r) {
2393                        vq_err(vq, "Failed to enable notification at %p: %d\n",
2394                               &vq->used->flags, r);
2395                        return false;
2396                }
2397        } else {
2398                r = vhost_update_avail_event(vq, vq->avail_idx);
2399                if (r) {
2400                        vq_err(vq, "Failed to update avail event index at %p: %d\n",
2401                               vhost_avail_event(vq), r);
2402                        return false;
2403                }
2404        }
2405        /* They could have slipped one in as we were doing that: make
2406         * sure it's written, then check again. */
2407        smp_mb();
2408        r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
2409        if (r) {
2410                vq_err(vq, "Failed to check avail idx at %p: %d\n",
2411                       &vq->avail->idx, r);
2412                return false;
2413        }
2414
2415        return vhost16_to_cpu(vq, avail_idx) != vq->avail_idx;
2416}
2417EXPORT_SYMBOL_GPL(vhost_enable_notify);
2418
2419/* We don't need to be notified again. */
2420void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2421{
2422        int r;
2423
2424        if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
2425                return;
2426        vq->used_flags |= VRING_USED_F_NO_NOTIFY;
2427        if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2428                r = vhost_update_used_flags(vq);
2429                if (r)
2430                        vq_err(vq, "Failed to enable notification at %p: %d\n",
2431                               &vq->used->flags, r);
2432        }
2433}
2434EXPORT_SYMBOL_GPL(vhost_disable_notify);
2435
2436/* Create a new message. */
2437struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
2438{
2439        struct vhost_msg_node *node = kmalloc(sizeof(*node), GFP_KERNEL);
2440        if (!node)
2441                return NULL;
2442
2443        /* Make sure all padding within the structure is initialized. */
2444        memset(&node->msg, 0, sizeof node->msg);
2445        node->vq = vq;
2446        node->msg.type = type;
2447        return node;
2448}
2449EXPORT_SYMBOL_GPL(vhost_new_msg);
2450
2451void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
2452                       struct vhost_msg_node *node)
2453{
2454        spin_lock(&dev->iotlb_lock);
2455        list_add_tail(&node->node, head);
2456        spin_unlock(&dev->iotlb_lock);
2457
2458        wake_up_interruptible_poll(&dev->wait, POLLIN | POLLRDNORM);
2459}
2460EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
2461
2462struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
2463                                         struct list_head *head)
2464{
2465        struct vhost_msg_node *node = NULL;
2466
2467        spin_lock(&dev->iotlb_lock);
2468        if (!list_empty(head)) {
2469                node = list_first_entry(head, struct vhost_msg_node,
2470                                        node);
2471                list_del(&node->node);
2472        }
2473        spin_unlock(&dev->iotlb_lock);
2474
2475        return node;
2476}
2477EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
2478
2479
2480static int __init vhost_init(void)
2481{
2482        return 0;
2483}
2484
2485static void __exit vhost_exit(void)
2486{
2487}
2488
2489module_init(vhost_init);
2490module_exit(vhost_exit);
2491
2492MODULE_VERSION("0.0.1");
2493MODULE_LICENSE("GPL v2");
2494MODULE_AUTHOR("Michael S. Tsirkin");
2495MODULE_DESCRIPTION("Host kernel accelerator for virtio");
2496