linux/drivers/vhost/vhost.c
<<
>>
Prefs
   1/* Copyright (C) 2009 Red Hat, Inc.
   2 * Copyright (C) 2006 Rusty Russell IBM Corporation
   3 *
   4 * Author: Michael S. Tsirkin <mst@redhat.com>
   5 *
   6 * Inspiration, some code, and most witty comments come from
   7 * Documentation/virtual/lguest/lguest.c, by Rusty Russell
   8 *
   9 * This work is licensed under the terms of the GNU GPL, version 2.
  10 *
  11 * Generic code for virtio server in host kernel.
  12 */
  13
  14#include <linux/eventfd.h>
  15#include <linux/vhost.h>
  16#include <linux/uio.h>
  17#include <linux/mm.h>
  18#include <linux/mmu_context.h>
  19#include <linux/miscdevice.h>
  20#include <linux/mutex.h>
  21#include <linux/poll.h>
  22#include <linux/file.h>
  23#include <linux/highmem.h>
  24#include <linux/slab.h>
  25#include <linux/vmalloc.h>
  26#include <linux/kthread.h>
  27#include <linux/cgroup.h>
  28#include <linux/module.h>
  29#include <linux/sort.h>
  30#include <linux/sched/mm.h>
  31#include <linux/sched/signal.h>
  32#include <linux/interval_tree_generic.h>
  33
  34#include "vhost.h"
  35
  36static ushort max_mem_regions = 64;
  37module_param(max_mem_regions, ushort, 0444);
  38MODULE_PARM_DESC(max_mem_regions,
  39        "Maximum number of memory regions in memory map. (default: 64)");
  40static int max_iotlb_entries = 2048;
  41module_param(max_iotlb_entries, int, 0444);
  42MODULE_PARM_DESC(max_iotlb_entries,
  43        "Maximum number of iotlb entries. (default: 2048)");
  44
  45enum {
  46        VHOST_MEMORY_F_LOG = 0x1,
  47};
  48
  49#define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
  50#define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
  51
  52INTERVAL_TREE_DEFINE(struct vhost_umem_node,
  53                     rb, __u64, __subtree_last,
  54                     START, LAST, static inline, vhost_umem_interval_tree);
  55
  56#ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
  57static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
  58{
  59        vq->user_be = !virtio_legacy_is_little_endian();
  60}
  61
  62static void vhost_enable_cross_endian_big(struct vhost_virtqueue *vq)
  63{
  64        vq->user_be = true;
  65}
  66
  67static void vhost_enable_cross_endian_little(struct vhost_virtqueue *vq)
  68{
  69        vq->user_be = false;
  70}
  71
  72static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
  73{
  74        struct vhost_vring_state s;
  75
  76        if (vq->private_data)
  77                return -EBUSY;
  78
  79        if (copy_from_user(&s, argp, sizeof(s)))
  80                return -EFAULT;
  81
  82        if (s.num != VHOST_VRING_LITTLE_ENDIAN &&
  83            s.num != VHOST_VRING_BIG_ENDIAN)
  84                return -EINVAL;
  85
  86        if (s.num == VHOST_VRING_BIG_ENDIAN)
  87                vhost_enable_cross_endian_big(vq);
  88        else
  89                vhost_enable_cross_endian_little(vq);
  90
  91        return 0;
  92}
  93
  94static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
  95                                   int __user *argp)
  96{
  97        struct vhost_vring_state s = {
  98                .index = idx,
  99                .num = vq->user_be
 100        };
 101
 102        if (copy_to_user(argp, &s, sizeof(s)))
 103                return -EFAULT;
 104
 105        return 0;
 106}
 107
 108static void vhost_init_is_le(struct vhost_virtqueue *vq)
 109{
 110        /* Note for legacy virtio: user_be is initialized at reset time
 111         * according to the host endianness. If userspace does not set an
 112         * explicit endianness, the default behavior is native endian, as
 113         * expected by legacy virtio.
 114         */
 115        vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) || !vq->user_be;
 116}
 117#else
 118static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
 119{
 120}
 121
 122static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
 123{
 124        return -ENOIOCTLCMD;
 125}
 126
 127static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
 128                                   int __user *argp)
 129{
 130        return -ENOIOCTLCMD;
 131}
 132
 133static void vhost_init_is_le(struct vhost_virtqueue *vq)
 134{
 135        vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1)
 136                || virtio_legacy_is_little_endian();
 137}
 138#endif /* CONFIG_VHOST_CROSS_ENDIAN_LEGACY */
 139
 140static void vhost_reset_is_le(struct vhost_virtqueue *vq)
 141{
 142        vhost_init_is_le(vq);
 143}
 144
 145struct vhost_flush_struct {
 146        struct vhost_work work;
 147        struct completion wait_event;
 148};
 149
 150static void vhost_flush_work(struct vhost_work *work)
 151{
 152        struct vhost_flush_struct *s;
 153
 154        s = container_of(work, struct vhost_flush_struct, work);
 155        complete(&s->wait_event);
 156}
 157
 158static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
 159                            poll_table *pt)
 160{
 161        struct vhost_poll *poll;
 162
 163        poll = container_of(pt, struct vhost_poll, table);
 164        poll->wqh = wqh;
 165        add_wait_queue(wqh, &poll->wait);
 166}
 167
 168static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync,
 169                             void *key)
 170{
 171        struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait);
 172
 173        if (!(key_to_poll(key) & poll->mask))
 174                return 0;
 175
 176        vhost_poll_queue(poll);
 177        return 0;
 178}
 179
 180void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
 181{
 182        clear_bit(VHOST_WORK_QUEUED, &work->flags);
 183        work->fn = fn;
 184}
 185EXPORT_SYMBOL_GPL(vhost_work_init);
 186
 187/* Init poll structure */
 188void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
 189                     __poll_t mask, struct vhost_dev *dev)
 190{
 191        init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
 192        init_poll_funcptr(&poll->table, vhost_poll_func);
 193        poll->mask = mask;
 194        poll->dev = dev;
 195        poll->wqh = NULL;
 196
 197        vhost_work_init(&poll->work, fn);
 198}
 199EXPORT_SYMBOL_GPL(vhost_poll_init);
 200
 201/* Start polling a file. We add ourselves to file's wait queue. The caller must
 202 * keep a reference to a file until after vhost_poll_stop is called. */
 203int vhost_poll_start(struct vhost_poll *poll, struct file *file)
 204{
 205        __poll_t mask;
 206        int ret = 0;
 207
 208        if (poll->wqh)
 209                return 0;
 210
 211        mask = file->f_op->poll(file, &poll->table);
 212        if (mask)
 213                vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask));
 214        if (mask & EPOLLERR) {
 215                vhost_poll_stop(poll);
 216                ret = -EINVAL;
 217        }
 218
 219        return ret;
 220}
 221EXPORT_SYMBOL_GPL(vhost_poll_start);
 222
 223/* Stop polling a file. After this function returns, it becomes safe to drop the
 224 * file reference. You must also flush afterwards. */
 225void vhost_poll_stop(struct vhost_poll *poll)
 226{
 227        if (poll->wqh) {
 228                remove_wait_queue(poll->wqh, &poll->wait);
 229                poll->wqh = NULL;
 230        }
 231}
 232EXPORT_SYMBOL_GPL(vhost_poll_stop);
 233
 234void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work)
 235{
 236        struct vhost_flush_struct flush;
 237
 238        if (dev->worker) {
 239                init_completion(&flush.wait_event);
 240                vhost_work_init(&flush.work, vhost_flush_work);
 241
 242                vhost_work_queue(dev, &flush.work);
 243                wait_for_completion(&flush.wait_event);
 244        }
 245}
 246EXPORT_SYMBOL_GPL(vhost_work_flush);
 247
 248/* Flush any work that has been scheduled. When calling this, don't hold any
 249 * locks that are also used by the callback. */
 250void vhost_poll_flush(struct vhost_poll *poll)
 251{
 252        vhost_work_flush(poll->dev, &poll->work);
 253}
 254EXPORT_SYMBOL_GPL(vhost_poll_flush);
 255
 256void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
 257{
 258        if (!dev->worker)
 259                return;
 260
 261        if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
 262                /* We can only add the work to the list after we're
 263                 * sure it was not in the list.
 264                 * test_and_set_bit() implies a memory barrier.
 265                 */
 266                llist_add(&work->node, &dev->work_list);
 267                wake_up_process(dev->worker);
 268        }
 269}
 270EXPORT_SYMBOL_GPL(vhost_work_queue);
 271
 272/* A lockless hint for busy polling code to exit the loop */
 273bool vhost_has_work(struct vhost_dev *dev)
 274{
 275        return !llist_empty(&dev->work_list);
 276}
 277EXPORT_SYMBOL_GPL(vhost_has_work);
 278
 279void vhost_poll_queue(struct vhost_poll *poll)
 280{
 281        vhost_work_queue(poll->dev, &poll->work);
 282}
 283EXPORT_SYMBOL_GPL(vhost_poll_queue);
 284
 285static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq)
 286{
 287        int j;
 288
 289        for (j = 0; j < VHOST_NUM_ADDRS; j++)
 290                vq->meta_iotlb[j] = NULL;
 291}
 292
 293static void vhost_vq_meta_reset(struct vhost_dev *d)
 294{
 295        int i;
 296
 297        for (i = 0; i < d->nvqs; ++i)
 298                __vhost_vq_meta_reset(d->vqs[i]);
 299}
 300
 301static void vhost_vq_reset(struct vhost_dev *dev,
 302                           struct vhost_virtqueue *vq)
 303{
 304        vq->num = 1;
 305        vq->desc = NULL;
 306        vq->avail = NULL;
 307        vq->used = NULL;
 308        vq->last_avail_idx = 0;
 309        vq->avail_idx = 0;
 310        vq->last_used_idx = 0;
 311        vq->signalled_used = 0;
 312        vq->signalled_used_valid = false;
 313        vq->used_flags = 0;
 314        vq->log_used = false;
 315        vq->log_addr = -1ull;
 316        vq->private_data = NULL;
 317        vq->acked_features = 0;
 318        vq->log_base = NULL;
 319        vq->error_ctx = NULL;
 320        vq->kick = NULL;
 321        vq->call_ctx = NULL;
 322        vq->log_ctx = NULL;
 323        vhost_reset_is_le(vq);
 324        vhost_disable_cross_endian(vq);
 325        vq->busyloop_timeout = 0;
 326        vq->umem = NULL;
 327        vq->iotlb = NULL;
 328        __vhost_vq_meta_reset(vq);
 329}
 330
 331static int vhost_worker(void *data)
 332{
 333        struct vhost_dev *dev = data;
 334        struct vhost_work *work, *work_next;
 335        struct llist_node *node;
 336        mm_segment_t oldfs = get_fs();
 337
 338        set_fs(USER_DS);
 339        use_mm(dev->mm);
 340
 341        for (;;) {
 342                /* mb paired w/ kthread_stop */
 343                set_current_state(TASK_INTERRUPTIBLE);
 344
 345                if (kthread_should_stop()) {
 346                        __set_current_state(TASK_RUNNING);
 347                        break;
 348                }
 349
 350                node = llist_del_all(&dev->work_list);
 351                if (!node)
 352                        schedule();
 353
 354                node = llist_reverse_order(node);
 355                /* make sure flag is seen after deletion */
 356                smp_wmb();
 357                llist_for_each_entry_safe(work, work_next, node, node) {
 358                        clear_bit(VHOST_WORK_QUEUED, &work->flags);
 359                        __set_current_state(TASK_RUNNING);
 360                        work->fn(work);
 361                        if (need_resched())
 362                                schedule();
 363                }
 364        }
 365        unuse_mm(dev->mm);
 366        set_fs(oldfs);
 367        return 0;
 368}
 369
 370static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
 371{
 372        kfree(vq->indirect);
 373        vq->indirect = NULL;
 374        kfree(vq->log);
 375        vq->log = NULL;
 376        kfree(vq->heads);
 377        vq->heads = NULL;
 378}
 379
 380/* Helper to allocate iovec buffers for all vqs. */
 381static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
 382{
 383        struct vhost_virtqueue *vq;
 384        int i;
 385
 386        for (i = 0; i < dev->nvqs; ++i) {
 387                vq = dev->vqs[i];
 388                vq->indirect = kmalloc(sizeof *vq->indirect * UIO_MAXIOV,
 389                                       GFP_KERNEL);
 390                vq->log = kmalloc(sizeof *vq->log * UIO_MAXIOV, GFP_KERNEL);
 391                vq->heads = kmalloc(sizeof *vq->heads * UIO_MAXIOV, GFP_KERNEL);
 392                if (!vq->indirect || !vq->log || !vq->heads)
 393                        goto err_nomem;
 394        }
 395        return 0;
 396
 397err_nomem:
 398        for (; i >= 0; --i)
 399                vhost_vq_free_iovecs(dev->vqs[i]);
 400        return -ENOMEM;
 401}
 402
 403static void vhost_dev_free_iovecs(struct vhost_dev *dev)
 404{
 405        int i;
 406
 407        for (i = 0; i < dev->nvqs; ++i)
 408                vhost_vq_free_iovecs(dev->vqs[i]);
 409}
 410
 411void vhost_dev_init(struct vhost_dev *dev,
 412                    struct vhost_virtqueue **vqs, int nvqs)
 413{
 414        struct vhost_virtqueue *vq;
 415        int i;
 416
 417        dev->vqs = vqs;
 418        dev->nvqs = nvqs;
 419        mutex_init(&dev->mutex);
 420        dev->log_ctx = NULL;
 421        dev->umem = NULL;
 422        dev->iotlb = NULL;
 423        dev->mm = NULL;
 424        dev->worker = NULL;
 425        init_llist_head(&dev->work_list);
 426        init_waitqueue_head(&dev->wait);
 427        INIT_LIST_HEAD(&dev->read_list);
 428        INIT_LIST_HEAD(&dev->pending_list);
 429        spin_lock_init(&dev->iotlb_lock);
 430
 431
 432        for (i = 0; i < dev->nvqs; ++i) {
 433                vq = dev->vqs[i];
 434                vq->log = NULL;
 435                vq->indirect = NULL;
 436                vq->heads = NULL;
 437                vq->dev = dev;
 438                mutex_init(&vq->mutex);
 439                vhost_vq_reset(dev, vq);
 440                if (vq->handle_kick)
 441                        vhost_poll_init(&vq->poll, vq->handle_kick,
 442                                        EPOLLIN, dev);
 443        }
 444}
 445EXPORT_SYMBOL_GPL(vhost_dev_init);
 446
 447/* Caller should have device mutex */
 448long vhost_dev_check_owner(struct vhost_dev *dev)
 449{
 450        /* Are you the owner? If not, I don't think you mean to do that */
 451        return dev->mm == current->mm ? 0 : -EPERM;
 452}
 453EXPORT_SYMBOL_GPL(vhost_dev_check_owner);
 454
 455struct vhost_attach_cgroups_struct {
 456        struct vhost_work work;
 457        struct task_struct *owner;
 458        int ret;
 459};
 460
 461static void vhost_attach_cgroups_work(struct vhost_work *work)
 462{
 463        struct vhost_attach_cgroups_struct *s;
 464
 465        s = container_of(work, struct vhost_attach_cgroups_struct, work);
 466        s->ret = cgroup_attach_task_all(s->owner, current);
 467}
 468
 469static int vhost_attach_cgroups(struct vhost_dev *dev)
 470{
 471        struct vhost_attach_cgroups_struct attach;
 472
 473        attach.owner = current;
 474        vhost_work_init(&attach.work, vhost_attach_cgroups_work);
 475        vhost_work_queue(dev, &attach.work);
 476        vhost_work_flush(dev, &attach.work);
 477        return attach.ret;
 478}
 479
 480/* Caller should have device mutex */
 481bool vhost_dev_has_owner(struct vhost_dev *dev)
 482{
 483        return dev->mm;
 484}
 485EXPORT_SYMBOL_GPL(vhost_dev_has_owner);
 486
 487/* Caller should have device mutex */
 488long vhost_dev_set_owner(struct vhost_dev *dev)
 489{
 490        struct task_struct *worker;
 491        int err;
 492
 493        /* Is there an owner already? */
 494        if (vhost_dev_has_owner(dev)) {
 495                err = -EBUSY;
 496                goto err_mm;
 497        }
 498
 499        /* No owner, become one */
 500        dev->mm = get_task_mm(current);
 501        worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid);
 502        if (IS_ERR(worker)) {
 503                err = PTR_ERR(worker);
 504                goto err_worker;
 505        }
 506
 507        dev->worker = worker;
 508        wake_up_process(worker);        /* avoid contributing to loadavg */
 509
 510        err = vhost_attach_cgroups(dev);
 511        if (err)
 512                goto err_cgroup;
 513
 514        err = vhost_dev_alloc_iovecs(dev);
 515        if (err)
 516                goto err_cgroup;
 517
 518        return 0;
 519err_cgroup:
 520        kthread_stop(worker);
 521        dev->worker = NULL;
 522err_worker:
 523        if (dev->mm)
 524                mmput(dev->mm);
 525        dev->mm = NULL;
 526err_mm:
 527        return err;
 528}
 529EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
 530
 531struct vhost_umem *vhost_dev_reset_owner_prepare(void)
 532{
 533        return kvzalloc(sizeof(struct vhost_umem), GFP_KERNEL);
 534}
 535EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
 536
 537/* Caller should have device mutex */
 538void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem)
 539{
 540        int i;
 541
 542        vhost_dev_cleanup(dev);
 543
 544        /* Restore memory to default empty mapping. */
 545        INIT_LIST_HEAD(&umem->umem_list);
 546        dev->umem = umem;
 547        /* We don't need VQ locks below since vhost_dev_cleanup makes sure
 548         * VQs aren't running.
 549         */
 550        for (i = 0; i < dev->nvqs; ++i)
 551                dev->vqs[i]->umem = umem;
 552}
 553EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
 554
 555void vhost_dev_stop(struct vhost_dev *dev)
 556{
 557        int i;
 558
 559        for (i = 0; i < dev->nvqs; ++i) {
 560                if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) {
 561                        vhost_poll_stop(&dev->vqs[i]->poll);
 562                        vhost_poll_flush(&dev->vqs[i]->poll);
 563                }
 564        }
 565}
 566EXPORT_SYMBOL_GPL(vhost_dev_stop);
 567
 568static void vhost_umem_free(struct vhost_umem *umem,
 569                            struct vhost_umem_node *node)
 570{
 571        vhost_umem_interval_tree_remove(node, &umem->umem_tree);
 572        list_del(&node->link);
 573        kfree(node);
 574        umem->numem--;
 575}
 576
 577static void vhost_umem_clean(struct vhost_umem *umem)
 578{
 579        struct vhost_umem_node *node, *tmp;
 580
 581        if (!umem)
 582                return;
 583
 584        list_for_each_entry_safe(node, tmp, &umem->umem_list, link)
 585                vhost_umem_free(umem, node);
 586
 587        kvfree(umem);
 588}
 589
 590static void vhost_clear_msg(struct vhost_dev *dev)
 591{
 592        struct vhost_msg_node *node, *n;
 593
 594        spin_lock(&dev->iotlb_lock);
 595
 596        list_for_each_entry_safe(node, n, &dev->read_list, node) {
 597                list_del(&node->node);
 598                kfree(node);
 599        }
 600
 601        list_for_each_entry_safe(node, n, &dev->pending_list, node) {
 602                list_del(&node->node);
 603                kfree(node);
 604        }
 605
 606        spin_unlock(&dev->iotlb_lock);
 607}
 608
 609void vhost_dev_cleanup(struct vhost_dev *dev)
 610{
 611        int i;
 612
 613        for (i = 0; i < dev->nvqs; ++i) {
 614                if (dev->vqs[i]->error_ctx)
 615                        eventfd_ctx_put(dev->vqs[i]->error_ctx);
 616                if (dev->vqs[i]->kick)
 617                        fput(dev->vqs[i]->kick);
 618                if (dev->vqs[i]->call_ctx)
 619                        eventfd_ctx_put(dev->vqs[i]->call_ctx);
 620                vhost_vq_reset(dev, dev->vqs[i]);
 621        }
 622        vhost_dev_free_iovecs(dev);
 623        if (dev->log_ctx)
 624                eventfd_ctx_put(dev->log_ctx);
 625        dev->log_ctx = NULL;
 626        /* No one will access memory at this point */
 627        vhost_umem_clean(dev->umem);
 628        dev->umem = NULL;
 629        vhost_umem_clean(dev->iotlb);
 630        dev->iotlb = NULL;
 631        vhost_clear_msg(dev);
 632        wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
 633        WARN_ON(!llist_empty(&dev->work_list));
 634        if (dev->worker) {
 635                kthread_stop(dev->worker);
 636                dev->worker = NULL;
 637        }
 638        if (dev->mm)
 639                mmput(dev->mm);
 640        dev->mm = NULL;
 641}
 642EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
 643
 644static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
 645{
 646        u64 a = addr / VHOST_PAGE_SIZE / 8;
 647
 648        /* Make sure 64 bit math will not overflow. */
 649        if (a > ULONG_MAX - (unsigned long)log_base ||
 650            a + (unsigned long)log_base > ULONG_MAX)
 651                return false;
 652
 653        return access_ok(VERIFY_WRITE, log_base + a,
 654                         (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
 655}
 656
 657static bool vhost_overflow(u64 uaddr, u64 size)
 658{
 659        /* Make sure 64 bit math will not overflow. */
 660        return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size;
 661}
 662
 663/* Caller should have vq mutex and device mutex. */
 664static bool vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
 665                                int log_all)
 666{
 667        struct vhost_umem_node *node;
 668
 669        if (!umem)
 670                return false;
 671
 672        list_for_each_entry(node, &umem->umem_list, link) {
 673                unsigned long a = node->userspace_addr;
 674
 675                if (vhost_overflow(node->userspace_addr, node->size))
 676                        return false;
 677
 678
 679                if (!access_ok(VERIFY_WRITE, (void __user *)a,
 680                                    node->size))
 681                        return false;
 682                else if (log_all && !log_access_ok(log_base,
 683                                                   node->start,
 684                                                   node->size))
 685                        return false;
 686        }
 687        return true;
 688}
 689
 690static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq,
 691                                               u64 addr, unsigned int size,
 692                                               int type)
 693{
 694        const struct vhost_umem_node *node = vq->meta_iotlb[type];
 695
 696        if (!node)
 697                return NULL;
 698
 699        return (void *)(uintptr_t)(node->userspace_addr + addr - node->start);
 700}
 701
 702/* Can we switch to this memory table? */
 703/* Caller should have device mutex but not vq mutex */
 704static bool memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
 705                             int log_all)
 706{
 707        int i;
 708
 709        for (i = 0; i < d->nvqs; ++i) {
 710                bool ok;
 711                bool log;
 712
 713                mutex_lock(&d->vqs[i]->mutex);
 714                log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
 715                /* If ring is inactive, will check when it's enabled. */
 716                if (d->vqs[i]->private_data)
 717                        ok = vq_memory_access_ok(d->vqs[i]->log_base,
 718                                                 umem, log);
 719                else
 720                        ok = true;
 721                mutex_unlock(&d->vqs[i]->mutex);
 722                if (!ok)
 723                        return false;
 724        }
 725        return true;
 726}
 727
 728static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
 729                          struct iovec iov[], int iov_size, int access);
 730
 731static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to,
 732                              const void *from, unsigned size)
 733{
 734        int ret;
 735
 736        if (!vq->iotlb)
 737                return __copy_to_user(to, from, size);
 738        else {
 739                /* This function should be called after iotlb
 740                 * prefetch, which means we're sure that all vq
 741                 * could be access through iotlb. So -EAGAIN should
 742                 * not happen in this case.
 743                 */
 744                struct iov_iter t;
 745                void __user *uaddr = vhost_vq_meta_fetch(vq,
 746                                     (u64)(uintptr_t)to, size,
 747                                     VHOST_ADDR_USED);
 748
 749                if (uaddr)
 750                        return __copy_to_user(uaddr, from, size);
 751
 752                ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
 753                                     ARRAY_SIZE(vq->iotlb_iov),
 754                                     VHOST_ACCESS_WO);
 755                if (ret < 0)
 756                        goto out;
 757                iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size);
 758                ret = copy_to_iter(from, size, &t);
 759                if (ret == size)
 760                        ret = 0;
 761        }
 762out:
 763        return ret;
 764}
 765
 766static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
 767                                void __user *from, unsigned size)
 768{
 769        int ret;
 770
 771        if (!vq->iotlb)
 772                return __copy_from_user(to, from, size);
 773        else {
 774                /* This function should be called after iotlb
 775                 * prefetch, which means we're sure that vq
 776                 * could be access through iotlb. So -EAGAIN should
 777                 * not happen in this case.
 778                 */
 779                void __user *uaddr = vhost_vq_meta_fetch(vq,
 780                                     (u64)(uintptr_t)from, size,
 781                                     VHOST_ADDR_DESC);
 782                struct iov_iter f;
 783
 784                if (uaddr)
 785                        return __copy_from_user(to, uaddr, size);
 786
 787                ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
 788                                     ARRAY_SIZE(vq->iotlb_iov),
 789                                     VHOST_ACCESS_RO);
 790                if (ret < 0) {
 791                        vq_err(vq, "IOTLB translation failure: uaddr "
 792                               "%p size 0x%llx\n", from,
 793                               (unsigned long long) size);
 794                        goto out;
 795                }
 796                iov_iter_init(&f, READ, vq->iotlb_iov, ret, size);
 797                ret = copy_from_iter(to, size, &f);
 798                if (ret == size)
 799                        ret = 0;
 800        }
 801
 802out:
 803        return ret;
 804}
 805
 806static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq,
 807                                          void __user *addr, unsigned int size,
 808                                          int type)
 809{
 810        int ret;
 811
 812        ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
 813                             ARRAY_SIZE(vq->iotlb_iov),
 814                             VHOST_ACCESS_RO);
 815        if (ret < 0) {
 816                vq_err(vq, "IOTLB translation failure: uaddr "
 817                        "%p size 0x%llx\n", addr,
 818                        (unsigned long long) size);
 819                return NULL;
 820        }
 821
 822        if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
 823                vq_err(vq, "Non atomic userspace memory access: uaddr "
 824                        "%p size 0x%llx\n", addr,
 825                        (unsigned long long) size);
 826                return NULL;
 827        }
 828
 829        return vq->iotlb_iov[0].iov_base;
 830}
 831
 832/* This function should be called after iotlb
 833 * prefetch, which means we're sure that vq
 834 * could be access through iotlb. So -EAGAIN should
 835 * not happen in this case.
 836 */
 837static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
 838                                            void *addr, unsigned int size,
 839                                            int type)
 840{
 841        void __user *uaddr = vhost_vq_meta_fetch(vq,
 842                             (u64)(uintptr_t)addr, size, type);
 843        if (uaddr)
 844                return uaddr;
 845
 846        return __vhost_get_user_slow(vq, addr, size, type);
 847}
 848
 849#define vhost_put_user(vq, x, ptr)              \
 850({ \
 851        int ret = -EFAULT; \
 852        if (!vq->iotlb) { \
 853                ret = __put_user(x, ptr); \
 854        } else { \
 855                __typeof__(ptr) to = \
 856                        (__typeof__(ptr)) __vhost_get_user(vq, ptr,     \
 857                                          sizeof(*ptr), VHOST_ADDR_USED); \
 858                if (to != NULL) \
 859                        ret = __put_user(x, to); \
 860                else \
 861                        ret = -EFAULT;  \
 862        } \
 863        ret; \
 864})
 865
 866#define vhost_get_user(vq, x, ptr, type)                \
 867({ \
 868        int ret; \
 869        if (!vq->iotlb) { \
 870                ret = __get_user(x, ptr); \
 871        } else { \
 872                __typeof__(ptr) from = \
 873                        (__typeof__(ptr)) __vhost_get_user(vq, ptr, \
 874                                                           sizeof(*ptr), \
 875                                                           type); \
 876                if (from != NULL) \
 877                        ret = __get_user(x, from); \
 878                else \
 879                        ret = -EFAULT; \
 880        } \
 881        ret; \
 882})
 883
 884#define vhost_get_avail(vq, x, ptr) \
 885        vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL)
 886
 887#define vhost_get_used(vq, x, ptr) \
 888        vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
 889
 890static void vhost_dev_lock_vqs(struct vhost_dev *d)
 891{
 892        int i = 0;
 893        for (i = 0; i < d->nvqs; ++i)
 894                mutex_lock_nested(&d->vqs[i]->mutex, i);
 895}
 896
 897static void vhost_dev_unlock_vqs(struct vhost_dev *d)
 898{
 899        int i = 0;
 900        for (i = 0; i < d->nvqs; ++i)
 901                mutex_unlock(&d->vqs[i]->mutex);
 902}
 903
 904static int vhost_new_umem_range(struct vhost_umem *umem,
 905                                u64 start, u64 size, u64 end,
 906                                u64 userspace_addr, int perm)
 907{
 908        struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC);
 909
 910        if (!node)
 911                return -ENOMEM;
 912
 913        if (umem->numem == max_iotlb_entries) {
 914                tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link);
 915                vhost_umem_free(umem, tmp);
 916        }
 917
 918        node->start = start;
 919        node->size = size;
 920        node->last = end;
 921        node->userspace_addr = userspace_addr;
 922        node->perm = perm;
 923        INIT_LIST_HEAD(&node->link);
 924        list_add_tail(&node->link, &umem->umem_list);
 925        vhost_umem_interval_tree_insert(node, &umem->umem_tree);
 926        umem->numem++;
 927
 928        return 0;
 929}
 930
 931static void vhost_del_umem_range(struct vhost_umem *umem,
 932                                 u64 start, u64 end)
 933{
 934        struct vhost_umem_node *node;
 935
 936        while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
 937                                                           start, end)))
 938                vhost_umem_free(umem, node);
 939}
 940
 941static void vhost_iotlb_notify_vq(struct vhost_dev *d,
 942                                  struct vhost_iotlb_msg *msg)
 943{
 944        struct vhost_msg_node *node, *n;
 945
 946        spin_lock(&d->iotlb_lock);
 947
 948        list_for_each_entry_safe(node, n, &d->pending_list, node) {
 949                struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
 950                if (msg->iova <= vq_msg->iova &&
 951                    msg->iova + msg->size - 1 > vq_msg->iova &&
 952                    vq_msg->type == VHOST_IOTLB_MISS) {
 953                        vhost_poll_queue(&node->vq->poll);
 954                        list_del(&node->node);
 955                        kfree(node);
 956                }
 957        }
 958
 959        spin_unlock(&d->iotlb_lock);
 960}
 961
 962static bool umem_access_ok(u64 uaddr, u64 size, int access)
 963{
 964        unsigned long a = uaddr;
 965
 966        /* Make sure 64 bit math will not overflow. */
 967        if (vhost_overflow(uaddr, size))
 968                return false;
 969
 970        if ((access & VHOST_ACCESS_RO) &&
 971            !access_ok(VERIFY_READ, (void __user *)a, size))
 972                return false;
 973        if ((access & VHOST_ACCESS_WO) &&
 974            !access_ok(VERIFY_WRITE, (void __user *)a, size))
 975                return false;
 976        return true;
 977}
 978
 979static int vhost_process_iotlb_msg(struct vhost_dev *dev,
 980                                   struct vhost_iotlb_msg *msg)
 981{
 982        int ret = 0;
 983
 984        mutex_lock(&dev->mutex);
 985        vhost_dev_lock_vqs(dev);
 986        switch (msg->type) {
 987        case VHOST_IOTLB_UPDATE:
 988                if (!dev->iotlb) {
 989                        ret = -EFAULT;
 990                        break;
 991                }
 992                if (!umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
 993                        ret = -EFAULT;
 994                        break;
 995                }
 996                vhost_vq_meta_reset(dev);
 997                if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size,
 998                                         msg->iova + msg->size - 1,
 999                                         msg->uaddr, msg->perm)) {
1000                        ret = -ENOMEM;
1001                        break;
1002                }
1003                vhost_iotlb_notify_vq(dev, msg);
1004                break;
1005        case VHOST_IOTLB_INVALIDATE:
1006                if (!dev->iotlb) {
1007                        ret = -EFAULT;
1008                        break;
1009                }
1010                vhost_vq_meta_reset(dev);
1011                vhost_del_umem_range(dev->iotlb, msg->iova,
1012                                     msg->iova + msg->size - 1);
1013                break;
1014        default:
1015                ret = -EINVAL;
1016                break;
1017        }
1018
1019        vhost_dev_unlock_vqs(dev);
1020        mutex_unlock(&dev->mutex);
1021
1022        return ret;
1023}
1024ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
1025                             struct iov_iter *from)
1026{
1027        struct vhost_msg_node node;
1028        unsigned size = sizeof(struct vhost_msg);
1029        size_t ret;
1030        int err;
1031
1032        if (iov_iter_count(from) < size)
1033                return 0;
1034        ret = copy_from_iter(&node.msg, size, from);
1035        if (ret != size)
1036                goto done;
1037
1038        switch (node.msg.type) {
1039        case VHOST_IOTLB_MSG:
1040                err = vhost_process_iotlb_msg(dev, &node.msg.iotlb);
1041                if (err)
1042                        ret = err;
1043                break;
1044        default:
1045                ret = -EINVAL;
1046                break;
1047        }
1048
1049done:
1050        return ret;
1051}
1052EXPORT_SYMBOL(vhost_chr_write_iter);
1053
1054__poll_t vhost_chr_poll(struct file *file, struct vhost_dev *dev,
1055                            poll_table *wait)
1056{
1057        __poll_t mask = 0;
1058
1059        poll_wait(file, &dev->wait, wait);
1060
1061        if (!list_empty(&dev->read_list))
1062                mask |= EPOLLIN | EPOLLRDNORM;
1063
1064        return mask;
1065}
1066EXPORT_SYMBOL(vhost_chr_poll);
1067
1068ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
1069                            int noblock)
1070{
1071        DEFINE_WAIT(wait);
1072        struct vhost_msg_node *node;
1073        ssize_t ret = 0;
1074        unsigned size = sizeof(struct vhost_msg);
1075
1076        if (iov_iter_count(to) < size)
1077                return 0;
1078
1079        while (1) {
1080                if (!noblock)
1081                        prepare_to_wait(&dev->wait, &wait,
1082                                        TASK_INTERRUPTIBLE);
1083
1084                node = vhost_dequeue_msg(dev, &dev->read_list);
1085                if (node)
1086                        break;
1087                if (noblock) {
1088                        ret = -EAGAIN;
1089                        break;
1090                }
1091                if (signal_pending(current)) {
1092                        ret = -ERESTARTSYS;
1093                        break;
1094                }
1095                if (!dev->iotlb) {
1096                        ret = -EBADFD;
1097                        break;
1098                }
1099
1100                schedule();
1101        }
1102
1103        if (!noblock)
1104                finish_wait(&dev->wait, &wait);
1105
1106        if (node) {
1107                ret = copy_to_iter(&node->msg, size, to);
1108
1109                if (ret != size || node->msg.type != VHOST_IOTLB_MISS) {
1110                        kfree(node);
1111                        return ret;
1112                }
1113
1114                vhost_enqueue_msg(dev, &dev->pending_list, node);
1115        }
1116
1117        return ret;
1118}
1119EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
1120
1121static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
1122{
1123        struct vhost_dev *dev = vq->dev;
1124        struct vhost_msg_node *node;
1125        struct vhost_iotlb_msg *msg;
1126
1127        node = vhost_new_msg(vq, VHOST_IOTLB_MISS);
1128        if (!node)
1129                return -ENOMEM;
1130
1131        msg = &node->msg.iotlb;
1132        msg->type = VHOST_IOTLB_MISS;
1133        msg->iova = iova;
1134        msg->perm = access;
1135
1136        vhost_enqueue_msg(dev, &dev->read_list, node);
1137
1138        return 0;
1139}
1140
1141static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
1142                         struct vring_desc __user *desc,
1143                         struct vring_avail __user *avail,
1144                         struct vring_used __user *used)
1145
1146{
1147        size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1148
1149        return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
1150               access_ok(VERIFY_READ, avail,
1151                         sizeof *avail + num * sizeof *avail->ring + s) &&
1152               access_ok(VERIFY_WRITE, used,
1153                        sizeof *used + num * sizeof *used->ring + s);
1154}
1155
1156static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
1157                                 const struct vhost_umem_node *node,
1158                                 int type)
1159{
1160        int access = (type == VHOST_ADDR_USED) ?
1161                     VHOST_ACCESS_WO : VHOST_ACCESS_RO;
1162
1163        if (likely(node->perm & access))
1164                vq->meta_iotlb[type] = node;
1165}
1166
1167static bool iotlb_access_ok(struct vhost_virtqueue *vq,
1168                            int access, u64 addr, u64 len, int type)
1169{
1170        const struct vhost_umem_node *node;
1171        struct vhost_umem *umem = vq->iotlb;
1172        u64 s = 0, size, orig_addr = addr, last = addr + len - 1;
1173
1174        if (vhost_vq_meta_fetch(vq, addr, len, type))
1175                return true;
1176
1177        while (len > s) {
1178                node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
1179                                                           addr,
1180                                                           last);
1181                if (node == NULL || node->start > addr) {
1182                        vhost_iotlb_miss(vq, addr, access);
1183                        return false;
1184                } else if (!(node->perm & access)) {
1185                        /* Report the possible access violation by
1186                         * request another translation from userspace.
1187                         */
1188                        return false;
1189                }
1190
1191                size = node->size - addr + node->start;
1192
1193                if (orig_addr == addr && size >= len)
1194                        vhost_vq_meta_update(vq, node, type);
1195
1196                s += size;
1197                addr += size;
1198        }
1199
1200        return true;
1201}
1202
1203int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
1204{
1205        size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1206        unsigned int num = vq->num;
1207
1208        if (!vq->iotlb)
1209                return 1;
1210
1211        return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
1212                               num * sizeof(*vq->desc), VHOST_ADDR_DESC) &&
1213               iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
1214                               sizeof *vq->avail +
1215                               num * sizeof(*vq->avail->ring) + s,
1216                               VHOST_ADDR_AVAIL) &&
1217               iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
1218                               sizeof *vq->used +
1219                               num * sizeof(*vq->used->ring) + s,
1220                               VHOST_ADDR_USED);
1221}
1222EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
1223
1224/* Can we log writes? */
1225/* Caller should have device mutex but not vq mutex */
1226bool vhost_log_access_ok(struct vhost_dev *dev)
1227{
1228        return memory_access_ok(dev, dev->umem, 1);
1229}
1230EXPORT_SYMBOL_GPL(vhost_log_access_ok);
1231
1232/* Verify access for write logging. */
1233/* Caller should have vq mutex and device mutex */
1234static bool vq_log_access_ok(struct vhost_virtqueue *vq,
1235                             void __user *log_base)
1236{
1237        size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
1238
1239        return vq_memory_access_ok(log_base, vq->umem,
1240                                   vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
1241                (!vq->log_used || log_access_ok(log_base, vq->log_addr,
1242                                        sizeof *vq->used +
1243                                        vq->num * sizeof *vq->used->ring + s));
1244}
1245
1246/* Can we start vq? */
1247/* Caller should have vq mutex and device mutex */
1248bool vhost_vq_access_ok(struct vhost_virtqueue *vq)
1249{
1250        if (!vq_log_access_ok(vq, vq->log_base))
1251                return false;
1252
1253        /* Access validation occurs at prefetch time with IOTLB */
1254        if (vq->iotlb)
1255                return true;
1256
1257        return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used);
1258}
1259EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
1260
1261static struct vhost_umem *vhost_umem_alloc(void)
1262{
1263        struct vhost_umem *umem = kvzalloc(sizeof(*umem), GFP_KERNEL);
1264
1265        if (!umem)
1266                return NULL;
1267
1268        umem->umem_tree = RB_ROOT_CACHED;
1269        umem->numem = 0;
1270        INIT_LIST_HEAD(&umem->umem_list);
1271
1272        return umem;
1273}
1274
1275static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
1276{
1277        struct vhost_memory mem, *newmem;
1278        struct vhost_memory_region *region;
1279        struct vhost_umem *newumem, *oldumem;
1280        unsigned long size = offsetof(struct vhost_memory, regions);
1281        int i;
1282
1283        if (copy_from_user(&mem, m, size))
1284                return -EFAULT;
1285        if (mem.padding)
1286                return -EOPNOTSUPP;
1287        if (mem.nregions > max_mem_regions)
1288                return -E2BIG;
1289        newmem = kvzalloc(size + mem.nregions * sizeof(*m->regions), GFP_KERNEL);
1290        if (!newmem)
1291                return -ENOMEM;
1292
1293        memcpy(newmem, &mem, size);
1294        if (copy_from_user(newmem->regions, m->regions,
1295                           mem.nregions * sizeof *m->regions)) {
1296                kvfree(newmem);
1297                return -EFAULT;
1298        }
1299
1300        newumem = vhost_umem_alloc();
1301        if (!newumem) {
1302                kvfree(newmem);
1303                return -ENOMEM;
1304        }
1305
1306        for (region = newmem->regions;
1307             region < newmem->regions + mem.nregions;
1308             region++) {
1309                if (vhost_new_umem_range(newumem,
1310                                         region->guest_phys_addr,
1311                                         region->memory_size,
1312                                         region->guest_phys_addr +
1313                                         region->memory_size - 1,
1314                                         region->userspace_addr,
1315                                         VHOST_ACCESS_RW))
1316                        goto err;
1317        }
1318
1319        if (!memory_access_ok(d, newumem, 0))
1320                goto err;
1321
1322        oldumem = d->umem;
1323        d->umem = newumem;
1324
1325        /* All memory accesses are done under some VQ mutex. */
1326        for (i = 0; i < d->nvqs; ++i) {
1327                mutex_lock(&d->vqs[i]->mutex);
1328                d->vqs[i]->umem = newumem;
1329                mutex_unlock(&d->vqs[i]->mutex);
1330        }
1331
1332        kvfree(newmem);
1333        vhost_umem_clean(oldumem);
1334        return 0;
1335
1336err:
1337        vhost_umem_clean(newumem);
1338        kvfree(newmem);
1339        return -EFAULT;
1340}
1341
1342long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1343{
1344        struct file *eventfp, *filep = NULL;
1345        bool pollstart = false, pollstop = false;
1346        struct eventfd_ctx *ctx = NULL;
1347        u32 __user *idxp = argp;
1348        struct vhost_virtqueue *vq;
1349        struct vhost_vring_state s;
1350        struct vhost_vring_file f;
1351        struct vhost_vring_addr a;
1352        u32 idx;
1353        long r;
1354
1355        r = get_user(idx, idxp);
1356        if (r < 0)
1357                return r;
1358        if (idx >= d->nvqs)
1359                return -ENOBUFS;
1360
1361        vq = d->vqs[idx];
1362
1363        mutex_lock(&vq->mutex);
1364
1365        switch (ioctl) {
1366        case VHOST_SET_VRING_NUM:
1367                /* Resizing ring with an active backend?
1368                 * You don't want to do that. */
1369                if (vq->private_data) {
1370                        r = -EBUSY;
1371                        break;
1372                }
1373                if (copy_from_user(&s, argp, sizeof s)) {
1374                        r = -EFAULT;
1375                        break;
1376                }
1377                if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) {
1378                        r = -EINVAL;
1379                        break;
1380                }
1381                vq->num = s.num;
1382                break;
1383        case VHOST_SET_VRING_BASE:
1384                /* Moving base with an active backend?
1385                 * You don't want to do that. */
1386                if (vq->private_data) {
1387                        r = -EBUSY;
1388                        break;
1389                }
1390                if (copy_from_user(&s, argp, sizeof s)) {
1391                        r = -EFAULT;
1392                        break;
1393                }
1394                if (s.num > 0xffff) {
1395                        r = -EINVAL;
1396                        break;
1397                }
1398                vq->last_avail_idx = s.num;
1399                /* Forget the cached index value. */
1400                vq->avail_idx = vq->last_avail_idx;
1401                break;
1402        case VHOST_GET_VRING_BASE:
1403                s.index = idx;
1404                s.num = vq->last_avail_idx;
1405                if (copy_to_user(argp, &s, sizeof s))
1406                        r = -EFAULT;
1407                break;
1408        case VHOST_SET_VRING_ADDR:
1409                if (copy_from_user(&a, argp, sizeof a)) {
1410                        r = -EFAULT;
1411                        break;
1412                }
1413                if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) {
1414                        r = -EOPNOTSUPP;
1415                        break;
1416                }
1417                /* For 32bit, verify that the top 32bits of the user
1418                   data are set to zero. */
1419                if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
1420                    (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
1421                    (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr) {
1422                        r = -EFAULT;
1423                        break;
1424                }
1425
1426                /* Make sure it's safe to cast pointers to vring types. */
1427                BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
1428                BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
1429                if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
1430                    (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
1431                    (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1))) {
1432                        r = -EINVAL;
1433                        break;
1434                }
1435
1436                /* We only verify access here if backend is configured.
1437                 * If it is not, we don't as size might not have been setup.
1438                 * We will verify when backend is configured. */
1439                if (vq->private_data) {
1440                        if (!vq_access_ok(vq, vq->num,
1441                                (void __user *)(unsigned long)a.desc_user_addr,
1442                                (void __user *)(unsigned long)a.avail_user_addr,
1443                                (void __user *)(unsigned long)a.used_user_addr)) {
1444                                r = -EINVAL;
1445                                break;
1446                        }
1447
1448                        /* Also validate log access for used ring if enabled. */
1449                        if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) &&
1450                            !log_access_ok(vq->log_base, a.log_guest_addr,
1451                                           sizeof *vq->used +
1452                                           vq->num * sizeof *vq->used->ring)) {
1453                                r = -EINVAL;
1454                                break;
1455                        }
1456                }
1457
1458                vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
1459                vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
1460                vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
1461                vq->log_addr = a.log_guest_addr;
1462                vq->used = (void __user *)(unsigned long)a.used_user_addr;
1463                break;
1464        case VHOST_SET_VRING_KICK:
1465                if (copy_from_user(&f, argp, sizeof f)) {
1466                        r = -EFAULT;
1467                        break;
1468                }
1469                eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
1470                if (IS_ERR(eventfp)) {
1471                        r = PTR_ERR(eventfp);
1472                        break;
1473                }
1474                if (eventfp != vq->kick) {
1475                        pollstop = (filep = vq->kick) != NULL;
1476                        pollstart = (vq->kick = eventfp) != NULL;
1477                } else
1478                        filep = eventfp;
1479                break;
1480        case VHOST_SET_VRING_CALL:
1481                if (copy_from_user(&f, argp, sizeof f)) {
1482                        r = -EFAULT;
1483                        break;
1484                }
1485                ctx = f.fd == -1 ? NULL : eventfd_ctx_fdget(f.fd);
1486                if (IS_ERR(ctx)) {
1487                        r = PTR_ERR(ctx);
1488                        break;
1489                }
1490                swap(ctx, vq->call_ctx);
1491                break;
1492        case VHOST_SET_VRING_ERR:
1493                if (copy_from_user(&f, argp, sizeof f)) {
1494                        r = -EFAULT;
1495                        break;
1496                }
1497                ctx = f.fd == -1 ? NULL : eventfd_ctx_fdget(f.fd);
1498                if (IS_ERR(ctx)) {
1499                        r = PTR_ERR(ctx);
1500                        break;
1501                }
1502                swap(ctx, vq->error_ctx);
1503                break;
1504        case VHOST_SET_VRING_ENDIAN:
1505                r = vhost_set_vring_endian(vq, argp);
1506                break;
1507        case VHOST_GET_VRING_ENDIAN:
1508                r = vhost_get_vring_endian(vq, idx, argp);
1509                break;
1510        case VHOST_SET_VRING_BUSYLOOP_TIMEOUT:
1511                if (copy_from_user(&s, argp, sizeof(s))) {
1512                        r = -EFAULT;
1513                        break;
1514                }
1515                vq->busyloop_timeout = s.num;
1516                break;
1517        case VHOST_GET_VRING_BUSYLOOP_TIMEOUT:
1518                s.index = idx;
1519                s.num = vq->busyloop_timeout;
1520                if (copy_to_user(argp, &s, sizeof(s)))
1521                        r = -EFAULT;
1522                break;
1523        default:
1524                r = -ENOIOCTLCMD;
1525        }
1526
1527        if (pollstop && vq->handle_kick)
1528                vhost_poll_stop(&vq->poll);
1529
1530        if (!IS_ERR_OR_NULL(ctx))
1531                eventfd_ctx_put(ctx);
1532        if (filep)
1533                fput(filep);
1534
1535        if (pollstart && vq->handle_kick)
1536                r = vhost_poll_start(&vq->poll, vq->kick);
1537
1538        mutex_unlock(&vq->mutex);
1539
1540        if (pollstop && vq->handle_kick)
1541                vhost_poll_flush(&vq->poll);
1542        return r;
1543}
1544EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
1545
1546int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
1547{
1548        struct vhost_umem *niotlb, *oiotlb;
1549        int i;
1550
1551        niotlb = vhost_umem_alloc();
1552        if (!niotlb)
1553                return -ENOMEM;
1554
1555        oiotlb = d->iotlb;
1556        d->iotlb = niotlb;
1557
1558        for (i = 0; i < d->nvqs; ++i) {
1559                mutex_lock(&d->vqs[i]->mutex);
1560                d->vqs[i]->iotlb = niotlb;
1561                mutex_unlock(&d->vqs[i]->mutex);
1562        }
1563
1564        vhost_umem_clean(oiotlb);
1565
1566        return 0;
1567}
1568EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
1569
1570/* Caller must have device mutex */
1571long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1572{
1573        struct eventfd_ctx *ctx;
1574        u64 p;
1575        long r;
1576        int i, fd;
1577
1578        /* If you are not the owner, you can become one */
1579        if (ioctl == VHOST_SET_OWNER) {
1580                r = vhost_dev_set_owner(d);
1581                goto done;
1582        }
1583
1584        /* You must be the owner to do anything else */
1585        r = vhost_dev_check_owner(d);
1586        if (r)
1587                goto done;
1588
1589        switch (ioctl) {
1590        case VHOST_SET_MEM_TABLE:
1591                r = vhost_set_memory(d, argp);
1592                break;
1593        case VHOST_SET_LOG_BASE:
1594                if (copy_from_user(&p, argp, sizeof p)) {
1595                        r = -EFAULT;
1596                        break;
1597                }
1598                if ((u64)(unsigned long)p != p) {
1599                        r = -EFAULT;
1600                        break;
1601                }
1602                for (i = 0; i < d->nvqs; ++i) {
1603                        struct vhost_virtqueue *vq;
1604                        void __user *base = (void __user *)(unsigned long)p;
1605                        vq = d->vqs[i];
1606                        mutex_lock(&vq->mutex);
1607                        /* If ring is inactive, will check when it's enabled. */
1608                        if (vq->private_data && !vq_log_access_ok(vq, base))
1609                                r = -EFAULT;
1610                        else
1611                                vq->log_base = base;
1612                        mutex_unlock(&vq->mutex);
1613                }
1614                break;
1615        case VHOST_SET_LOG_FD:
1616                r = get_user(fd, (int __user *)argp);
1617                if (r < 0)
1618                        break;
1619                ctx = fd == -1 ? NULL : eventfd_ctx_fdget(fd);
1620                if (IS_ERR(ctx)) {
1621                        r = PTR_ERR(ctx);
1622                        break;
1623                }
1624                swap(ctx, d->log_ctx);
1625                for (i = 0; i < d->nvqs; ++i) {
1626                        mutex_lock(&d->vqs[i]->mutex);
1627                        d->vqs[i]->log_ctx = d->log_ctx;
1628                        mutex_unlock(&d->vqs[i]->mutex);
1629                }
1630                if (ctx)
1631                        eventfd_ctx_put(ctx);
1632                break;
1633        default:
1634                r = -ENOIOCTLCMD;
1635                break;
1636        }
1637done:
1638        return r;
1639}
1640EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
1641
1642/* TODO: This is really inefficient.  We need something like get_user()
1643 * (instruction directly accesses the data, with an exception table entry
1644 * returning -EFAULT). See Documentation/x86/exception-tables.txt.
1645 */
1646static int set_bit_to_user(int nr, void __user *addr)
1647{
1648        unsigned long log = (unsigned long)addr;
1649        struct page *page;
1650        void *base;
1651        int bit = nr + (log % PAGE_SIZE) * 8;
1652        int r;
1653
1654        r = get_user_pages_fast(log, 1, 1, &page);
1655        if (r < 0)
1656                return r;
1657        BUG_ON(r != 1);
1658        base = kmap_atomic(page);
1659        set_bit(bit, base);
1660        kunmap_atomic(base);
1661        set_page_dirty_lock(page);
1662        put_page(page);
1663        return 0;
1664}
1665
1666static int log_write(void __user *log_base,
1667                     u64 write_address, u64 write_length)
1668{
1669        u64 write_page = write_address / VHOST_PAGE_SIZE;
1670        int r;
1671
1672        if (!write_length)
1673                return 0;
1674        write_length += write_address % VHOST_PAGE_SIZE;
1675        for (;;) {
1676                u64 base = (u64)(unsigned long)log_base;
1677                u64 log = base + write_page / 8;
1678                int bit = write_page % 8;
1679                if ((u64)(unsigned long)log != log)
1680                        return -EFAULT;
1681                r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
1682                if (r < 0)
1683                        return r;
1684                if (write_length <= VHOST_PAGE_SIZE)
1685                        break;
1686                write_length -= VHOST_PAGE_SIZE;
1687                write_page += 1;
1688        }
1689        return r;
1690}
1691
1692int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
1693                    unsigned int log_num, u64 len)
1694{
1695        int i, r;
1696
1697        /* Make sure data written is seen before log. */
1698        smp_wmb();
1699        for (i = 0; i < log_num; ++i) {
1700                u64 l = min(log[i].len, len);
1701                r = log_write(vq->log_base, log[i].addr, l);
1702                if (r < 0)
1703                        return r;
1704                len -= l;
1705                if (!len) {
1706                        if (vq->log_ctx)
1707                                eventfd_signal(vq->log_ctx, 1);
1708                        return 0;
1709                }
1710        }
1711        /* Length written exceeds what we have stored. This is a bug. */
1712        BUG();
1713        return 0;
1714}
1715EXPORT_SYMBOL_GPL(vhost_log_write);
1716
1717static int vhost_update_used_flags(struct vhost_virtqueue *vq)
1718{
1719        void __user *used;
1720        if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
1721                           &vq->used->flags) < 0)
1722                return -EFAULT;
1723        if (unlikely(vq->log_used)) {
1724                /* Make sure the flag is seen before log. */
1725                smp_wmb();
1726                /* Log used flag write. */
1727                used = &vq->used->flags;
1728                log_write(vq->log_base, vq->log_addr +
1729                          (used - (void __user *)vq->used),
1730                          sizeof vq->used->flags);
1731                if (vq->log_ctx)
1732                        eventfd_signal(vq->log_ctx, 1);
1733        }
1734        return 0;
1735}
1736
1737static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
1738{
1739        if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
1740                           vhost_avail_event(vq)))
1741                return -EFAULT;
1742        if (unlikely(vq->log_used)) {
1743                void __user *used;
1744                /* Make sure the event is seen before log. */
1745                smp_wmb();
1746                /* Log avail event write */
1747                used = vhost_avail_event(vq);
1748                log_write(vq->log_base, vq->log_addr +
1749                          (used - (void __user *)vq->used),
1750                          sizeof *vhost_avail_event(vq));
1751                if (vq->log_ctx)
1752                        eventfd_signal(vq->log_ctx, 1);
1753        }
1754        return 0;
1755}
1756
1757int vhost_vq_init_access(struct vhost_virtqueue *vq)
1758{
1759        __virtio16 last_used_idx;
1760        int r;
1761        bool is_le = vq->is_le;
1762
1763        if (!vq->private_data)
1764                return 0;
1765
1766        vhost_init_is_le(vq);
1767
1768        r = vhost_update_used_flags(vq);
1769        if (r)
1770                goto err;
1771        vq->signalled_used_valid = false;
1772        if (!vq->iotlb &&
1773            !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
1774                r = -EFAULT;
1775                goto err;
1776        }
1777        r = vhost_get_used(vq, last_used_idx, &vq->used->idx);
1778        if (r) {
1779                vq_err(vq, "Can't access used idx at %p\n",
1780                       &vq->used->idx);
1781                goto err;
1782        }
1783        vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
1784        return 0;
1785
1786err:
1787        vq->is_le = is_le;
1788        return r;
1789}
1790EXPORT_SYMBOL_GPL(vhost_vq_init_access);
1791
1792static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
1793                          struct iovec iov[], int iov_size, int access)
1794{
1795        const struct vhost_umem_node *node;
1796        struct vhost_dev *dev = vq->dev;
1797        struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem;
1798        struct iovec *_iov;
1799        u64 s = 0;
1800        int ret = 0;
1801
1802        while ((u64)len > s) {
1803                u64 size;
1804                if (unlikely(ret >= iov_size)) {
1805                        ret = -ENOBUFS;
1806                        break;
1807                }
1808
1809                node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
1810                                                        addr, addr + len - 1);
1811                if (node == NULL || node->start > addr) {
1812                        if (umem != dev->iotlb) {
1813                                ret = -EFAULT;
1814                                break;
1815                        }
1816                        ret = -EAGAIN;
1817                        break;
1818                } else if (!(node->perm & access)) {
1819                        ret = -EPERM;
1820                        break;
1821                }
1822
1823                _iov = iov + ret;
1824                size = node->size - addr + node->start;
1825                _iov->iov_len = min((u64)len - s, size);
1826                _iov->iov_base = (void __user *)(unsigned long)
1827                        (node->userspace_addr + addr - node->start);
1828                s += size;
1829                addr += size;
1830                ++ret;
1831        }
1832
1833        if (ret == -EAGAIN)
1834                vhost_iotlb_miss(vq, addr, access);
1835        return ret;
1836}
1837
1838/* Each buffer in the virtqueues is actually a chain of descriptors.  This
1839 * function returns the next descriptor in the chain,
1840 * or -1U if we're at the end. */
1841static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc)
1842{
1843        unsigned int next;
1844
1845        /* If this descriptor says it doesn't chain, we're done. */
1846        if (!(desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT)))
1847                return -1U;
1848
1849        /* Check they're not leading us off end of descriptors. */
1850        next = vhost16_to_cpu(vq, READ_ONCE(desc->next));
1851        return next;
1852}
1853
1854static int get_indirect(struct vhost_virtqueue *vq,
1855                        struct iovec iov[], unsigned int iov_size,
1856                        unsigned int *out_num, unsigned int *in_num,
1857                        struct vhost_log *log, unsigned int *log_num,
1858                        struct vring_desc *indirect)
1859{
1860        struct vring_desc desc;
1861        unsigned int i = 0, count, found = 0;
1862        u32 len = vhost32_to_cpu(vq, indirect->len);
1863        struct iov_iter from;
1864        int ret, access;
1865
1866        /* Sanity check */
1867        if (unlikely(len % sizeof desc)) {
1868                vq_err(vq, "Invalid length in indirect descriptor: "
1869                       "len 0x%llx not multiple of 0x%zx\n",
1870                       (unsigned long long)len,
1871                       sizeof desc);
1872                return -EINVAL;
1873        }
1874
1875        ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
1876                             UIO_MAXIOV, VHOST_ACCESS_RO);
1877        if (unlikely(ret < 0)) {
1878                if (ret != -EAGAIN)
1879                        vq_err(vq, "Translation failure %d in indirect.\n", ret);
1880                return ret;
1881        }
1882        iov_iter_init(&from, READ, vq->indirect, ret, len);
1883
1884        /* We will use the result as an address to read from, so most
1885         * architectures only need a compiler barrier here. */
1886        read_barrier_depends();
1887
1888        count = len / sizeof desc;
1889        /* Buffers are chained via a 16 bit next field, so
1890         * we can have at most 2^16 of these. */
1891        if (unlikely(count > USHRT_MAX + 1)) {
1892                vq_err(vq, "Indirect buffer length too big: %d\n",
1893                       indirect->len);
1894                return -E2BIG;
1895        }
1896
1897        do {
1898                unsigned iov_count = *in_num + *out_num;
1899                if (unlikely(++found > count)) {
1900                        vq_err(vq, "Loop detected: last one at %u "
1901                               "indirect size %u\n",
1902                               i, count);
1903                        return -EINVAL;
1904                }
1905                if (unlikely(!copy_from_iter_full(&desc, sizeof(desc), &from))) {
1906                        vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n",
1907                               i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
1908                        return -EINVAL;
1909                }
1910                if (unlikely(desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) {
1911                        vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n",
1912                               i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
1913                        return -EINVAL;
1914                }
1915
1916                if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
1917                        access = VHOST_ACCESS_WO;
1918                else
1919                        access = VHOST_ACCESS_RO;
1920
1921                ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
1922                                     vhost32_to_cpu(vq, desc.len), iov + iov_count,
1923                                     iov_size - iov_count, access);
1924                if (unlikely(ret < 0)) {
1925                        if (ret != -EAGAIN)
1926                                vq_err(vq, "Translation failure %d indirect idx %d\n",
1927                                        ret, i);
1928                        return ret;
1929                }
1930                /* If this is an input descriptor, increment that count. */
1931                if (access == VHOST_ACCESS_WO) {
1932                        *in_num += ret;
1933                        if (unlikely(log)) {
1934                                log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
1935                                log[*log_num].len = vhost32_to_cpu(vq, desc.len);
1936                                ++*log_num;
1937                        }
1938                } else {
1939                        /* If it's an output descriptor, they're all supposed
1940                         * to come before any input descriptors. */
1941                        if (unlikely(*in_num)) {
1942                                vq_err(vq, "Indirect descriptor "
1943                                       "has out after in: idx %d\n", i);
1944                                return -EINVAL;
1945                        }
1946                        *out_num += ret;
1947                }
1948        } while ((i = next_desc(vq, &desc)) != -1);
1949        return 0;
1950}
1951
1952/* This looks in the virtqueue and for the first available buffer, and converts
1953 * it to an iovec for convenient access.  Since descriptors consist of some
1954 * number of output then some number of input descriptors, it's actually two
1955 * iovecs, but we pack them into one and note how many of each there were.
1956 *
1957 * This function returns the descriptor number found, or vq->num (which is
1958 * never a valid descriptor number) if none was found.  A negative code is
1959 * returned on error. */
1960int vhost_get_vq_desc(struct vhost_virtqueue *vq,
1961                      struct iovec iov[], unsigned int iov_size,
1962                      unsigned int *out_num, unsigned int *in_num,
1963                      struct vhost_log *log, unsigned int *log_num)
1964{
1965        struct vring_desc desc;
1966        unsigned int i, head, found = 0;
1967        u16 last_avail_idx;
1968        __virtio16 avail_idx;
1969        __virtio16 ring_head;
1970        int ret, access;
1971
1972        /* Check it isn't doing very strange things with descriptor numbers. */
1973        last_avail_idx = vq->last_avail_idx;
1974
1975        if (vq->avail_idx == vq->last_avail_idx) {
1976                if (unlikely(vhost_get_avail(vq, avail_idx, &vq->avail->idx))) {
1977                        vq_err(vq, "Failed to access avail idx at %p\n",
1978                                &vq->avail->idx);
1979                        return -EFAULT;
1980                }
1981                vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
1982
1983                if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) {
1984                        vq_err(vq, "Guest moved used index from %u to %u",
1985                                last_avail_idx, vq->avail_idx);
1986                        return -EFAULT;
1987                }
1988
1989                /* If there's nothing new since last we looked, return
1990                 * invalid.
1991                 */
1992                if (vq->avail_idx == last_avail_idx)
1993                        return vq->num;
1994
1995                /* Only get avail ring entries after they have been
1996                 * exposed by guest.
1997                 */
1998                smp_rmb();
1999        }
2000
2001        /* Grab the next descriptor number they're advertising, and increment
2002         * the index we've seen. */
2003        if (unlikely(vhost_get_avail(vq, ring_head,
2004                     &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
2005                vq_err(vq, "Failed to read head: idx %d address %p\n",
2006                       last_avail_idx,
2007                       &vq->avail->ring[last_avail_idx % vq->num]);
2008                return -EFAULT;
2009        }
2010
2011        head = vhost16_to_cpu(vq, ring_head);
2012
2013        /* If their number is silly, that's an error. */
2014        if (unlikely(head >= vq->num)) {
2015                vq_err(vq, "Guest says index %u > %u is available",
2016                       head, vq->num);
2017                return -EINVAL;
2018        }
2019
2020        /* When we start there are none of either input nor output. */
2021        *out_num = *in_num = 0;
2022        if (unlikely(log))
2023                *log_num = 0;
2024
2025        i = head;
2026        do {
2027                unsigned iov_count = *in_num + *out_num;
2028                if (unlikely(i >= vq->num)) {
2029                        vq_err(vq, "Desc index is %u > %u, head = %u",
2030                               i, vq->num, head);
2031                        return -EINVAL;
2032                }
2033                if (unlikely(++found > vq->num)) {
2034                        vq_err(vq, "Loop detected: last one at %u "
2035                               "vq size %u head %u\n",
2036                               i, vq->num, head);
2037                        return -EINVAL;
2038                }
2039                ret = vhost_copy_from_user(vq, &desc, vq->desc + i,
2040                                           sizeof desc);
2041                if (unlikely(ret)) {
2042                        vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
2043                               i, vq->desc + i);
2044                        return -EFAULT;
2045                }
2046                if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) {
2047                        ret = get_indirect(vq, iov, iov_size,
2048                                           out_num, in_num,
2049                                           log, log_num, &desc);
2050                        if (unlikely(ret < 0)) {
2051                                if (ret != -EAGAIN)
2052                                        vq_err(vq, "Failure detected "
2053                                                "in indirect descriptor at idx %d\n", i);
2054                                return ret;
2055                        }
2056                        continue;
2057                }
2058
2059                if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2060                        access = VHOST_ACCESS_WO;
2061                else
2062                        access = VHOST_ACCESS_RO;
2063                ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2064                                     vhost32_to_cpu(vq, desc.len), iov + iov_count,
2065                                     iov_size - iov_count, access);
2066                if (unlikely(ret < 0)) {
2067                        if (ret != -EAGAIN)
2068                                vq_err(vq, "Translation failure %d descriptor idx %d\n",
2069                                        ret, i);
2070                        return ret;
2071                }
2072                if (access == VHOST_ACCESS_WO) {
2073                        /* If this is an input descriptor,
2074                         * increment that count. */
2075                        *in_num += ret;
2076                        if (unlikely(log)) {
2077                                log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2078                                log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2079                                ++*log_num;
2080                        }
2081                } else {
2082                        /* If it's an output descriptor, they're all supposed
2083                         * to come before any input descriptors. */
2084                        if (unlikely(*in_num)) {
2085                                vq_err(vq, "Descriptor has out after in: "
2086                                       "idx %d\n", i);
2087                                return -EINVAL;
2088                        }
2089                        *out_num += ret;
2090                }
2091        } while ((i = next_desc(vq, &desc)) != -1);
2092
2093        /* On success, increment avail index. */
2094        vq->last_avail_idx++;
2095
2096        /* Assume notifications from guest are disabled at this point,
2097         * if they aren't we would need to update avail_event index. */
2098        BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
2099        return head;
2100}
2101EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
2102
2103/* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
2104void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
2105{
2106        vq->last_avail_idx -= n;
2107}
2108EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);
2109
2110/* After we've used one of their buffers, we tell them about it.  We'll then
2111 * want to notify the guest, using eventfd. */
2112int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
2113{
2114        struct vring_used_elem heads = {
2115                cpu_to_vhost32(vq, head),
2116                cpu_to_vhost32(vq, len)
2117        };
2118
2119        return vhost_add_used_n(vq, &heads, 1);
2120}
2121EXPORT_SYMBOL_GPL(vhost_add_used);
2122
2123static int __vhost_add_used_n(struct vhost_virtqueue *vq,
2124                            struct vring_used_elem *heads,
2125                            unsigned count)
2126{
2127        struct vring_used_elem __user *used;
2128        u16 old, new;
2129        int start;
2130
2131        start = vq->last_used_idx & (vq->num - 1);
2132        used = vq->used->ring + start;
2133        if (count == 1) {
2134                if (vhost_put_user(vq, heads[0].id, &used->id)) {
2135                        vq_err(vq, "Failed to write used id");
2136                        return -EFAULT;
2137                }
2138                if (vhost_put_user(vq, heads[0].len, &used->len)) {
2139                        vq_err(vq, "Failed to write used len");
2140                        return -EFAULT;
2141                }
2142        } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) {
2143                vq_err(vq, "Failed to write used");
2144                return -EFAULT;
2145        }
2146        if (unlikely(vq->log_used)) {
2147                /* Make sure data is seen before log. */
2148                smp_wmb();
2149                /* Log used ring entry write. */
2150                log_write(vq->log_base,
2151                          vq->log_addr +
2152                           ((void __user *)used - (void __user *)vq->used),
2153                          count * sizeof *used);
2154        }
2155        old = vq->last_used_idx;
2156        new = (vq->last_used_idx += count);
2157        /* If the driver never bothers to signal in a very long while,
2158         * used index might wrap around. If that happens, invalidate
2159         * signalled_used index we stored. TODO: make sure driver
2160         * signals at least once in 2^16 and remove this. */
2161        if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old)))
2162                vq->signalled_used_valid = false;
2163        return 0;
2164}
2165
2166/* After we've used one of their buffers, we tell them about it.  We'll then
2167 * want to notify the guest, using eventfd. */
2168int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
2169                     unsigned count)
2170{
2171        int start, n, r;
2172
2173        start = vq->last_used_idx & (vq->num - 1);
2174        n = vq->num - start;
2175        if (n < count) {
2176                r = __vhost_add_used_n(vq, heads, n);
2177                if (r < 0)
2178                        return r;
2179                heads += n;
2180                count -= n;
2181        }
2182        r = __vhost_add_used_n(vq, heads, count);
2183
2184        /* Make sure buffer is written before we update index. */
2185        smp_wmb();
2186        if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
2187                           &vq->used->idx)) {
2188                vq_err(vq, "Failed to increment used idx");
2189                return -EFAULT;
2190        }
2191        if (unlikely(vq->log_used)) {
2192                /* Log used index update. */
2193                log_write(vq->log_base,
2194                          vq->log_addr + offsetof(struct vring_used, idx),
2195                          sizeof vq->used->idx);
2196                if (vq->log_ctx)
2197                        eventfd_signal(vq->log_ctx, 1);
2198        }
2199        return r;
2200}
2201EXPORT_SYMBOL_GPL(vhost_add_used_n);
2202
2203static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2204{
2205        __u16 old, new;
2206        __virtio16 event;
2207        bool v;
2208        /* Flush out used index updates. This is paired
2209         * with the barrier that the Guest executes when enabling
2210         * interrupts. */
2211        smp_mb();
2212
2213        if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2214            unlikely(vq->avail_idx == vq->last_avail_idx))
2215                return true;
2216
2217        if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2218                __virtio16 flags;
2219                if (vhost_get_avail(vq, flags, &vq->avail->flags)) {
2220                        vq_err(vq, "Failed to get flags");
2221                        return true;
2222                }
2223                return !(flags & cpu_to_vhost16(vq, VRING_AVAIL_F_NO_INTERRUPT));
2224        }
2225        old = vq->signalled_used;
2226        v = vq->signalled_used_valid;
2227        new = vq->signalled_used = vq->last_used_idx;
2228        vq->signalled_used_valid = true;
2229
2230        if (unlikely(!v))
2231                return true;
2232
2233        if (vhost_get_avail(vq, event, vhost_used_event(vq))) {
2234                vq_err(vq, "Failed to get used event idx");
2235                return true;
2236        }
2237        return vring_need_event(vhost16_to_cpu(vq, event), new, old);
2238}
2239
2240/* This actually signals the guest, using eventfd. */
2241void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2242{
2243        /* Signal the Guest tell them we used something up. */
2244        if (vq->call_ctx && vhost_notify(dev, vq))
2245                eventfd_signal(vq->call_ctx, 1);
2246}
2247EXPORT_SYMBOL_GPL(vhost_signal);
2248
2249/* And here's the combo meal deal.  Supersize me! */
2250void vhost_add_used_and_signal(struct vhost_dev *dev,
2251                               struct vhost_virtqueue *vq,
2252                               unsigned int head, int len)
2253{
2254        vhost_add_used(vq, head, len);
2255        vhost_signal(dev, vq);
2256}
2257EXPORT_SYMBOL_GPL(vhost_add_used_and_signal);
2258
2259/* multi-buffer version of vhost_add_used_and_signal */
2260void vhost_add_used_and_signal_n(struct vhost_dev *dev,
2261                                 struct vhost_virtqueue *vq,
2262                                 struct vring_used_elem *heads, unsigned count)
2263{
2264        vhost_add_used_n(vq, heads, count);
2265        vhost_signal(dev, vq);
2266}
2267EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n);
2268
2269/* return true if we're sure that avaiable ring is empty */
2270bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2271{
2272        __virtio16 avail_idx;
2273        int r;
2274
2275        if (vq->avail_idx != vq->last_avail_idx)
2276                return false;
2277
2278        r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
2279        if (unlikely(r))
2280                return false;
2281        vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2282
2283        return vq->avail_idx == vq->last_avail_idx;
2284}
2285EXPORT_SYMBOL_GPL(vhost_vq_avail_empty);
2286
2287/* OK, now we need to know about added descriptors. */
2288bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2289{
2290        __virtio16 avail_idx;
2291        int r;
2292
2293        if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
2294                return false;
2295        vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
2296        if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2297                r = vhost_update_used_flags(vq);
2298                if (r) {
2299                        vq_err(vq, "Failed to enable notification at %p: %d\n",
2300                               &vq->used->flags, r);
2301                        return false;
2302                }
2303        } else {
2304                r = vhost_update_avail_event(vq, vq->avail_idx);
2305                if (r) {
2306                        vq_err(vq, "Failed to update avail event index at %p: %d\n",
2307                               vhost_avail_event(vq), r);
2308                        return false;
2309                }
2310        }
2311        /* They could have slipped one in as we were doing that: make
2312         * sure it's written, then check again. */
2313        smp_mb();
2314        r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
2315        if (r) {
2316                vq_err(vq, "Failed to check avail idx at %p: %d\n",
2317                       &vq->avail->idx, r);
2318                return false;
2319        }
2320
2321        return vhost16_to_cpu(vq, avail_idx) != vq->avail_idx;
2322}
2323EXPORT_SYMBOL_GPL(vhost_enable_notify);
2324
2325/* We don't need to be notified again. */
2326void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2327{
2328        int r;
2329
2330        if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
2331                return;
2332        vq->used_flags |= VRING_USED_F_NO_NOTIFY;
2333        if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2334                r = vhost_update_used_flags(vq);
2335                if (r)
2336                        vq_err(vq, "Failed to enable notification at %p: %d\n",
2337                               &vq->used->flags, r);
2338        }
2339}
2340EXPORT_SYMBOL_GPL(vhost_disable_notify);
2341
2342/* Create a new message. */
2343struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
2344{
2345        struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL);
2346        if (!node)
2347                return NULL;
2348        node->vq = vq;
2349        node->msg.type = type;
2350        return node;
2351}
2352EXPORT_SYMBOL_GPL(vhost_new_msg);
2353
2354void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
2355                       struct vhost_msg_node *node)
2356{
2357        spin_lock(&dev->iotlb_lock);
2358        list_add_tail(&node->node, head);
2359        spin_unlock(&dev->iotlb_lock);
2360
2361        wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
2362}
2363EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
2364
2365struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
2366                                         struct list_head *head)
2367{
2368        struct vhost_msg_node *node = NULL;
2369
2370        spin_lock(&dev->iotlb_lock);
2371        if (!list_empty(head)) {
2372                node = list_first_entry(head, struct vhost_msg_node,
2373                                        node);
2374                list_del(&node->node);
2375        }
2376        spin_unlock(&dev->iotlb_lock);
2377
2378        return node;
2379}
2380EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
2381
2382
2383static int __init vhost_init(void)
2384{
2385        return 0;
2386}
2387
2388static void __exit vhost_exit(void)
2389{
2390}
2391
2392module_init(vhost_init);
2393module_exit(vhost_exit);
2394
2395MODULE_VERSION("0.0.1");
2396MODULE_LICENSE("GPL v2");
2397MODULE_AUTHOR("Michael S. Tsirkin");
2398MODULE_DESCRIPTION("Host kernel accelerator for virtio");
2399