qemu/contrib/libvhost-user/libvhost-user.c
<<
>>
Prefs
   1/*
   2 * Vhost User library
   3 *
   4 * Copyright IBM, Corp. 2007
   5 * Copyright (c) 2016 Red Hat, Inc.
   6 *
   7 * Authors:
   8 *  Anthony Liguori <aliguori@us.ibm.com>
   9 *  Marc-André Lureau <mlureau@redhat.com>
  10 *  Victor Kaplansky <victork@redhat.com>
  11 *
  12 * This work is licensed under the terms of the GNU GPL, version 2 or
  13 * later.  See the COPYING file in the top-level directory.
  14 */
  15
  16/* this code avoids GLib dependency */
  17#include <stdlib.h>
  18#include <stdio.h>
  19#include <unistd.h>
  20#include <stdarg.h>
  21#include <errno.h>
  22#include <string.h>
  23#include <assert.h>
  24#include <inttypes.h>
  25#include <sys/types.h>
  26#include <sys/socket.h>
  27#include <sys/eventfd.h>
  28#include <sys/mman.h>
  29#include "qemu/compiler.h"
  30
  31#if defined(__linux__)
  32#include <sys/syscall.h>
  33#include <fcntl.h>
  34#include <sys/ioctl.h>
  35#include <linux/vhost.h>
  36
  37#ifdef __NR_userfaultfd
  38#include <linux/userfaultfd.h>
  39#endif
  40
  41#endif
  42
  43#include "qemu/atomic.h"
  44
  45#include "libvhost-user.h"
  46
  47/* usually provided by GLib */
  48#ifndef MIN
  49#define MIN(x, y) ({                            \
  50            typeof(x) _min1 = (x);              \
  51            typeof(y) _min2 = (y);              \
  52            (void) (&_min1 == &_min2);          \
  53            _min1 < _min2 ? _min1 : _min2; })
  54#endif
  55
  56#define VHOST_USER_HDR_SIZE offsetof(VhostUserMsg, payload.u64)
  57
  58/* The version of the protocol we support */
  59#define VHOST_USER_VERSION 1
  60#define LIBVHOST_USER_DEBUG 0
  61
  62#define DPRINT(...)                             \
  63    do {                                        \
  64        if (LIBVHOST_USER_DEBUG) {              \
  65            fprintf(stderr, __VA_ARGS__);        \
  66        }                                       \
  67    } while (0)
  68
  69static const char *
  70vu_request_to_string(unsigned int req)
  71{
  72#define REQ(req) [req] = #req
  73    static const char *vu_request_str[] = {
  74        REQ(VHOST_USER_NONE),
  75        REQ(VHOST_USER_GET_FEATURES),
  76        REQ(VHOST_USER_SET_FEATURES),
  77        REQ(VHOST_USER_SET_OWNER),
  78        REQ(VHOST_USER_RESET_OWNER),
  79        REQ(VHOST_USER_SET_MEM_TABLE),
  80        REQ(VHOST_USER_SET_LOG_BASE),
  81        REQ(VHOST_USER_SET_LOG_FD),
  82        REQ(VHOST_USER_SET_VRING_NUM),
  83        REQ(VHOST_USER_SET_VRING_ADDR),
  84        REQ(VHOST_USER_SET_VRING_BASE),
  85        REQ(VHOST_USER_GET_VRING_BASE),
  86        REQ(VHOST_USER_SET_VRING_KICK),
  87        REQ(VHOST_USER_SET_VRING_CALL),
  88        REQ(VHOST_USER_SET_VRING_ERR),
  89        REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
  90        REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
  91        REQ(VHOST_USER_GET_QUEUE_NUM),
  92        REQ(VHOST_USER_SET_VRING_ENABLE),
  93        REQ(VHOST_USER_SEND_RARP),
  94        REQ(VHOST_USER_NET_SET_MTU),
  95        REQ(VHOST_USER_SET_SLAVE_REQ_FD),
  96        REQ(VHOST_USER_IOTLB_MSG),
  97        REQ(VHOST_USER_SET_VRING_ENDIAN),
  98        REQ(VHOST_USER_GET_CONFIG),
  99        REQ(VHOST_USER_SET_CONFIG),
 100        REQ(VHOST_USER_POSTCOPY_ADVISE),
 101        REQ(VHOST_USER_POSTCOPY_LISTEN),
 102        REQ(VHOST_USER_POSTCOPY_END),
 103        REQ(VHOST_USER_MAX),
 104    };
 105#undef REQ
 106
 107    if (req < VHOST_USER_MAX) {
 108        return vu_request_str[req];
 109    } else {
 110        return "unknown";
 111    }
 112}
 113
 114static void
 115vu_panic(VuDev *dev, const char *msg, ...)
 116{
 117    char *buf = NULL;
 118    va_list ap;
 119
 120    va_start(ap, msg);
 121    if (vasprintf(&buf, msg, ap) < 0) {
 122        buf = NULL;
 123    }
 124    va_end(ap);
 125
 126    dev->broken = true;
 127    dev->panic(dev, buf);
 128    free(buf);
 129
 130    /* FIXME: find a way to call virtio_error? */
 131}
 132
 133/* Translate guest physical address to our virtual address.  */
 134void *
 135vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr)
 136{
 137    int i;
 138
 139    if (*plen == 0) {
 140        return NULL;
 141    }
 142
 143    /* Find matching memory region.  */
 144    for (i = 0; i < dev->nregions; i++) {
 145        VuDevRegion *r = &dev->regions[i];
 146
 147        if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) {
 148            if ((guest_addr + *plen) > (r->gpa + r->size)) {
 149                *plen = r->gpa + r->size - guest_addr;
 150            }
 151            return (void *)(uintptr_t)
 152                guest_addr - r->gpa + r->mmap_addr + r->mmap_offset;
 153        }
 154    }
 155
 156    return NULL;
 157}
 158
 159/* Translate qemu virtual address to our virtual address.  */
 160static void *
 161qva_to_va(VuDev *dev, uint64_t qemu_addr)
 162{
 163    int i;
 164
 165    /* Find matching memory region.  */
 166    for (i = 0; i < dev->nregions; i++) {
 167        VuDevRegion *r = &dev->regions[i];
 168
 169        if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
 170            return (void *)(uintptr_t)
 171                qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
 172        }
 173    }
 174
 175    return NULL;
 176}
 177
 178static void
 179vmsg_close_fds(VhostUserMsg *vmsg)
 180{
 181    int i;
 182
 183    for (i = 0; i < vmsg->fd_num; i++) {
 184        close(vmsg->fds[i]);
 185    }
 186}
 187
 188/* A test to see if we have userfault available */
 189static bool
 190have_userfault(void)
 191{
 192#if defined(__linux__) && defined(__NR_userfaultfd) &&\
 193        defined(UFFD_FEATURE_MISSING_SHMEM) &&\
 194        defined(UFFD_FEATURE_MISSING_HUGETLBFS)
 195    /* Now test the kernel we're running on really has the features */
 196    int ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
 197    struct uffdio_api api_struct;
 198    if (ufd < 0) {
 199        return false;
 200    }
 201
 202    api_struct.api = UFFD_API;
 203    api_struct.features = UFFD_FEATURE_MISSING_SHMEM |
 204                          UFFD_FEATURE_MISSING_HUGETLBFS;
 205    if (ioctl(ufd, UFFDIO_API, &api_struct)) {
 206        close(ufd);
 207        return false;
 208    }
 209    close(ufd);
 210    return true;
 211
 212#else
 213    return false;
 214#endif
 215}
 216
 217static bool
 218vu_message_read(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
 219{
 220    char control[CMSG_SPACE(VHOST_MEMORY_MAX_NREGIONS * sizeof(int))] = { };
 221    struct iovec iov = {
 222        .iov_base = (char *)vmsg,
 223        .iov_len = VHOST_USER_HDR_SIZE,
 224    };
 225    struct msghdr msg = {
 226        .msg_iov = &iov,
 227        .msg_iovlen = 1,
 228        .msg_control = control,
 229        .msg_controllen = sizeof(control),
 230    };
 231    size_t fd_size;
 232    struct cmsghdr *cmsg;
 233    int rc;
 234
 235    do {
 236        rc = recvmsg(conn_fd, &msg, 0);
 237    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 238
 239    if (rc < 0) {
 240        vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
 241        return false;
 242    }
 243
 244    vmsg->fd_num = 0;
 245    for (cmsg = CMSG_FIRSTHDR(&msg);
 246         cmsg != NULL;
 247         cmsg = CMSG_NXTHDR(&msg, cmsg))
 248    {
 249        if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
 250            fd_size = cmsg->cmsg_len - CMSG_LEN(0);
 251            vmsg->fd_num = fd_size / sizeof(int);
 252            memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
 253            break;
 254        }
 255    }
 256
 257    if (vmsg->size > sizeof(vmsg->payload)) {
 258        vu_panic(dev,
 259                 "Error: too big message request: %d, size: vmsg->size: %u, "
 260                 "while sizeof(vmsg->payload) = %zu\n",
 261                 vmsg->request, vmsg->size, sizeof(vmsg->payload));
 262        goto fail;
 263    }
 264
 265    if (vmsg->size) {
 266        do {
 267            rc = read(conn_fd, &vmsg->payload, vmsg->size);
 268        } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 269
 270        if (rc <= 0) {
 271            vu_panic(dev, "Error while reading: %s", strerror(errno));
 272            goto fail;
 273        }
 274
 275        assert(rc == vmsg->size);
 276    }
 277
 278    return true;
 279
 280fail:
 281    vmsg_close_fds(vmsg);
 282
 283    return false;
 284}
 285
 286static bool
 287vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
 288{
 289    int rc;
 290    uint8_t *p = (uint8_t *)vmsg;
 291    char control[CMSG_SPACE(VHOST_MEMORY_MAX_NREGIONS * sizeof(int))] = { };
 292    struct iovec iov = {
 293        .iov_base = (char *)vmsg,
 294        .iov_len = VHOST_USER_HDR_SIZE,
 295    };
 296    struct msghdr msg = {
 297        .msg_iov = &iov,
 298        .msg_iovlen = 1,
 299        .msg_control = control,
 300    };
 301    struct cmsghdr *cmsg;
 302
 303    memset(control, 0, sizeof(control));
 304    assert(vmsg->fd_num <= VHOST_MEMORY_MAX_NREGIONS);
 305    if (vmsg->fd_num > 0) {
 306        size_t fdsize = vmsg->fd_num * sizeof(int);
 307        msg.msg_controllen = CMSG_SPACE(fdsize);
 308        cmsg = CMSG_FIRSTHDR(&msg);
 309        cmsg->cmsg_len = CMSG_LEN(fdsize);
 310        cmsg->cmsg_level = SOL_SOCKET;
 311        cmsg->cmsg_type = SCM_RIGHTS;
 312        memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize);
 313    } else {
 314        msg.msg_controllen = 0;
 315    }
 316
 317    /* Set the version in the flags when sending the reply */
 318    vmsg->flags &= ~VHOST_USER_VERSION_MASK;
 319    vmsg->flags |= VHOST_USER_VERSION;
 320    vmsg->flags |= VHOST_USER_REPLY_MASK;
 321
 322    do {
 323        rc = sendmsg(conn_fd, &msg, 0);
 324    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 325
 326    do {
 327        if (vmsg->data) {
 328            rc = write(conn_fd, vmsg->data, vmsg->size);
 329        } else {
 330            rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
 331        }
 332    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 333
 334    if (rc <= 0) {
 335        vu_panic(dev, "Error while writing: %s", strerror(errno));
 336        return false;
 337    }
 338
 339    return true;
 340}
 341
 342/* Kick the log_call_fd if required. */
 343static void
 344vu_log_kick(VuDev *dev)
 345{
 346    if (dev->log_call_fd != -1) {
 347        DPRINT("Kicking the QEMU's log...\n");
 348        if (eventfd_write(dev->log_call_fd, 1) < 0) {
 349            vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
 350        }
 351    }
 352}
 353
 354static void
 355vu_log_page(uint8_t *log_table, uint64_t page)
 356{
 357    DPRINT("Logged dirty guest page: %"PRId64"\n", page);
 358    atomic_or(&log_table[page / 8], 1 << (page % 8));
 359}
 360
 361static void
 362vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
 363{
 364    uint64_t page;
 365
 366    if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
 367        !dev->log_table || !length) {
 368        return;
 369    }
 370
 371    assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
 372
 373    page = address / VHOST_LOG_PAGE;
 374    while (page * VHOST_LOG_PAGE < address + length) {
 375        vu_log_page(dev->log_table, page);
 376        page += VHOST_LOG_PAGE;
 377    }
 378
 379    vu_log_kick(dev);
 380}
 381
 382static void
 383vu_kick_cb(VuDev *dev, int condition, void *data)
 384{
 385    int index = (intptr_t)data;
 386    VuVirtq *vq = &dev->vq[index];
 387    int sock = vq->kick_fd;
 388    eventfd_t kick_data;
 389    ssize_t rc;
 390
 391    rc = eventfd_read(sock, &kick_data);
 392    if (rc == -1) {
 393        vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
 394        dev->remove_watch(dev, dev->vq[index].kick_fd);
 395    } else {
 396        DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
 397               kick_data, vq->handler, index);
 398        if (vq->handler) {
 399            vq->handler(dev, index);
 400        }
 401    }
 402}
 403
 404static bool
 405vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 406{
 407    vmsg->payload.u64 =
 408        1ULL << VHOST_F_LOG_ALL |
 409        1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
 410
 411    if (dev->iface->get_features) {
 412        vmsg->payload.u64 |= dev->iface->get_features(dev);
 413    }
 414
 415    vmsg->size = sizeof(vmsg->payload.u64);
 416    vmsg->fd_num = 0;
 417
 418    DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 419
 420    return true;
 421}
 422
 423static void
 424vu_set_enable_all_rings(VuDev *dev, bool enabled)
 425{
 426    int i;
 427
 428    for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
 429        dev->vq[i].enable = enabled;
 430    }
 431}
 432
 433static bool
 434vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 435{
 436    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 437
 438    dev->features = vmsg->payload.u64;
 439
 440    if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
 441        vu_set_enable_all_rings(dev, true);
 442    }
 443
 444    if (dev->iface->set_features) {
 445        dev->iface->set_features(dev, dev->features);
 446    }
 447
 448    return false;
 449}
 450
 451static bool
 452vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
 453{
 454    return false;
 455}
 456
 457static void
 458vu_close_log(VuDev *dev)
 459{
 460    if (dev->log_table) {
 461        if (munmap(dev->log_table, dev->log_size) != 0) {
 462            perror("close log munmap() error");
 463        }
 464
 465        dev->log_table = NULL;
 466    }
 467    if (dev->log_call_fd != -1) {
 468        close(dev->log_call_fd);
 469        dev->log_call_fd = -1;
 470    }
 471}
 472
 473static bool
 474vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
 475{
 476    vu_set_enable_all_rings(dev, false);
 477
 478    return false;
 479}
 480
 481static bool
 482vu_set_mem_table_exec_postcopy(VuDev *dev, VhostUserMsg *vmsg)
 483{
 484    int i;
 485    VhostUserMemory *memory = &vmsg->payload.memory;
 486    dev->nregions = memory->nregions;
 487
 488    DPRINT("Nregions: %d\n", memory->nregions);
 489    for (i = 0; i < dev->nregions; i++) {
 490        void *mmap_addr;
 491        VhostUserMemoryRegion *msg_region = &memory->regions[i];
 492        VuDevRegion *dev_region = &dev->regions[i];
 493
 494        DPRINT("Region %d\n", i);
 495        DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
 496               msg_region->guest_phys_addr);
 497        DPRINT("    memory_size:     0x%016"PRIx64"\n",
 498               msg_region->memory_size);
 499        DPRINT("    userspace_addr   0x%016"PRIx64"\n",
 500               msg_region->userspace_addr);
 501        DPRINT("    mmap_offset      0x%016"PRIx64"\n",
 502               msg_region->mmap_offset);
 503
 504        dev_region->gpa = msg_region->guest_phys_addr;
 505        dev_region->size = msg_region->memory_size;
 506        dev_region->qva = msg_region->userspace_addr;
 507        dev_region->mmap_offset = msg_region->mmap_offset;
 508
 509        /* We don't use offset argument of mmap() since the
 510         * mapped address has to be page aligned, and we use huge
 511         * pages.
 512         * In postcopy we're using PROT_NONE here to catch anyone
 513         * accessing it before we userfault
 514         */
 515        mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
 516                         PROT_NONE, MAP_SHARED,
 517                         vmsg->fds[i], 0);
 518
 519        if (mmap_addr == MAP_FAILED) {
 520            vu_panic(dev, "region mmap error: %s", strerror(errno));
 521        } else {
 522            dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
 523            DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
 524                   dev_region->mmap_addr);
 525        }
 526
 527        /* Return the address to QEMU so that it can translate the ufd
 528         * fault addresses back.
 529         */
 530        msg_region->userspace_addr = (uintptr_t)(mmap_addr +
 531                                                 dev_region->mmap_offset);
 532        close(vmsg->fds[i]);
 533    }
 534
 535    /* Send the message back to qemu with the addresses filled in */
 536    vmsg->fd_num = 0;
 537    if (!vu_message_write(dev, dev->sock, vmsg)) {
 538        vu_panic(dev, "failed to respond to set-mem-table for postcopy");
 539        return false;
 540    }
 541
 542    /* Wait for QEMU to confirm that it's registered the handler for the
 543     * faults.
 544     */
 545    if (!vu_message_read(dev, dev->sock, vmsg) ||
 546        vmsg->size != sizeof(vmsg->payload.u64) ||
 547        vmsg->payload.u64 != 0) {
 548        vu_panic(dev, "failed to receive valid ack for postcopy set-mem-table");
 549        return false;
 550    }
 551
 552    /* OK, now we can go and register the memory and generate faults */
 553    for (i = 0; i < dev->nregions; i++) {
 554        VuDevRegion *dev_region = &dev->regions[i];
 555        int ret;
 556#ifdef UFFDIO_REGISTER
 557        /* We should already have an open ufd. Mark each memory
 558         * range as ufd.
 559         * Discard any mapping we have here; note I can't use MADV_REMOVE
 560         * or fallocate to make the hole since I don't want to lose
 561         * data that's already arrived in the shared process.
 562         * TODO: How to do hugepage
 563         */
 564        ret = madvise((void *)dev_region->mmap_addr,
 565                      dev_region->size + dev_region->mmap_offset,
 566                      MADV_DONTNEED);
 567        if (ret) {
 568            fprintf(stderr,
 569                    "%s: Failed to madvise(DONTNEED) region %d: %s\n",
 570                    __func__, i, strerror(errno));
 571        }
 572        /* Turn off transparent hugepages so we dont get lose wakeups
 573         * in neighbouring pages.
 574         * TODO: Turn this backon later.
 575         */
 576        ret = madvise((void *)dev_region->mmap_addr,
 577                      dev_region->size + dev_region->mmap_offset,
 578                      MADV_NOHUGEPAGE);
 579        if (ret) {
 580            /* Note: This can happen legally on kernels that are configured
 581             * without madvise'able hugepages
 582             */
 583            fprintf(stderr,
 584                    "%s: Failed to madvise(NOHUGEPAGE) region %d: %s\n",
 585                    __func__, i, strerror(errno));
 586        }
 587        struct uffdio_register reg_struct;
 588        reg_struct.range.start = (uintptr_t)dev_region->mmap_addr;
 589        reg_struct.range.len = dev_region->size + dev_region->mmap_offset;
 590        reg_struct.mode = UFFDIO_REGISTER_MODE_MISSING;
 591
 592        if (ioctl(dev->postcopy_ufd, UFFDIO_REGISTER, &reg_struct)) {
 593            vu_panic(dev, "%s: Failed to userfault region %d "
 594                          "@%p + size:%zx offset: %zx: (ufd=%d)%s\n",
 595                     __func__, i,
 596                     dev_region->mmap_addr,
 597                     dev_region->size, dev_region->mmap_offset,
 598                     dev->postcopy_ufd, strerror(errno));
 599            return false;
 600        }
 601        if (!(reg_struct.ioctls & ((__u64)1 << _UFFDIO_COPY))) {
 602            vu_panic(dev, "%s Region (%d) doesn't support COPY",
 603                     __func__, i);
 604            return false;
 605        }
 606        DPRINT("%s: region %d: Registered userfault for %llx + %llx\n",
 607                __func__, i, reg_struct.range.start, reg_struct.range.len);
 608        /* Now it's registered we can let the client at it */
 609        if (mprotect((void *)dev_region->mmap_addr,
 610                     dev_region->size + dev_region->mmap_offset,
 611                     PROT_READ | PROT_WRITE)) {
 612            vu_panic(dev, "failed to mprotect region %d for postcopy (%s)",
 613                     i, strerror(errno));
 614            return false;
 615        }
 616        /* TODO: Stash 'zero' support flags somewhere */
 617#endif
 618    }
 619
 620    return false;
 621}
 622
 623static bool
 624vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
 625{
 626    int i;
 627    VhostUserMemory *memory = &vmsg->payload.memory;
 628
 629    for (i = 0; i < dev->nregions; i++) {
 630        VuDevRegion *r = &dev->regions[i];
 631        void *m = (void *) (uintptr_t) r->mmap_addr;
 632
 633        if (m) {
 634            munmap(m, r->size + r->mmap_offset);
 635        }
 636    }
 637    dev->nregions = memory->nregions;
 638
 639    if (dev->postcopy_listening) {
 640        return vu_set_mem_table_exec_postcopy(dev, vmsg);
 641    }
 642
 643    DPRINT("Nregions: %d\n", memory->nregions);
 644    for (i = 0; i < dev->nregions; i++) {
 645        void *mmap_addr;
 646        VhostUserMemoryRegion *msg_region = &memory->regions[i];
 647        VuDevRegion *dev_region = &dev->regions[i];
 648
 649        DPRINT("Region %d\n", i);
 650        DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
 651               msg_region->guest_phys_addr);
 652        DPRINT("    memory_size:     0x%016"PRIx64"\n",
 653               msg_region->memory_size);
 654        DPRINT("    userspace_addr   0x%016"PRIx64"\n",
 655               msg_region->userspace_addr);
 656        DPRINT("    mmap_offset      0x%016"PRIx64"\n",
 657               msg_region->mmap_offset);
 658
 659        dev_region->gpa = msg_region->guest_phys_addr;
 660        dev_region->size = msg_region->memory_size;
 661        dev_region->qva = msg_region->userspace_addr;
 662        dev_region->mmap_offset = msg_region->mmap_offset;
 663
 664        /* We don't use offset argument of mmap() since the
 665         * mapped address has to be page aligned, and we use huge
 666         * pages.  */
 667        mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
 668                         PROT_READ | PROT_WRITE, MAP_SHARED,
 669                         vmsg->fds[i], 0);
 670
 671        if (mmap_addr == MAP_FAILED) {
 672            vu_panic(dev, "region mmap error: %s", strerror(errno));
 673        } else {
 674            dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
 675            DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
 676                   dev_region->mmap_addr);
 677        }
 678
 679        close(vmsg->fds[i]);
 680    }
 681
 682    return false;
 683}
 684
 685static bool
 686vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
 687{
 688    int fd;
 689    uint64_t log_mmap_size, log_mmap_offset;
 690    void *rc;
 691
 692    if (vmsg->fd_num != 1 ||
 693        vmsg->size != sizeof(vmsg->payload.log)) {
 694        vu_panic(dev, "Invalid log_base message");
 695        return true;
 696    }
 697
 698    fd = vmsg->fds[0];
 699    log_mmap_offset = vmsg->payload.log.mmap_offset;
 700    log_mmap_size = vmsg->payload.log.mmap_size;
 701    DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
 702    DPRINT("Log mmap_size:   %"PRId64"\n", log_mmap_size);
 703
 704    rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
 705              log_mmap_offset);
 706    close(fd);
 707    if (rc == MAP_FAILED) {
 708        perror("log mmap error");
 709    }
 710
 711    if (dev->log_table) {
 712        munmap(dev->log_table, dev->log_size);
 713    }
 714    dev->log_table = rc;
 715    dev->log_size = log_mmap_size;
 716
 717    vmsg->size = sizeof(vmsg->payload.u64);
 718    vmsg->fd_num = 0;
 719
 720    return true;
 721}
 722
 723static bool
 724vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
 725{
 726    if (vmsg->fd_num != 1) {
 727        vu_panic(dev, "Invalid log_fd message");
 728        return false;
 729    }
 730
 731    if (dev->log_call_fd != -1) {
 732        close(dev->log_call_fd);
 733    }
 734    dev->log_call_fd = vmsg->fds[0];
 735    DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
 736
 737    return false;
 738}
 739
 740static bool
 741vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
 742{
 743    unsigned int index = vmsg->payload.state.index;
 744    unsigned int num = vmsg->payload.state.num;
 745
 746    DPRINT("State.index: %d\n", index);
 747    DPRINT("State.num:   %d\n", num);
 748    dev->vq[index].vring.num = num;
 749
 750    return false;
 751}
 752
 753static bool
 754vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
 755{
 756    struct vhost_vring_addr *vra = &vmsg->payload.addr;
 757    unsigned int index = vra->index;
 758    VuVirtq *vq = &dev->vq[index];
 759
 760    DPRINT("vhost_vring_addr:\n");
 761    DPRINT("    index:  %d\n", vra->index);
 762    DPRINT("    flags:  %d\n", vra->flags);
 763    DPRINT("    desc_user_addr:   0x%016llx\n", vra->desc_user_addr);
 764    DPRINT("    used_user_addr:   0x%016llx\n", vra->used_user_addr);
 765    DPRINT("    avail_user_addr:  0x%016llx\n", vra->avail_user_addr);
 766    DPRINT("    log_guest_addr:   0x%016llx\n", vra->log_guest_addr);
 767
 768    vq->vring.flags = vra->flags;
 769    vq->vring.desc = qva_to_va(dev, vra->desc_user_addr);
 770    vq->vring.used = qva_to_va(dev, vra->used_user_addr);
 771    vq->vring.avail = qva_to_va(dev, vra->avail_user_addr);
 772    vq->vring.log_guest_addr = vra->log_guest_addr;
 773
 774    DPRINT("Setting virtq addresses:\n");
 775    DPRINT("    vring_desc  at %p\n", vq->vring.desc);
 776    DPRINT("    vring_used  at %p\n", vq->vring.used);
 777    DPRINT("    vring_avail at %p\n", vq->vring.avail);
 778
 779    if (!(vq->vring.desc && vq->vring.used && vq->vring.avail)) {
 780        vu_panic(dev, "Invalid vring_addr message");
 781        return false;
 782    }
 783
 784    vq->used_idx = vq->vring.used->idx;
 785
 786    if (vq->last_avail_idx != vq->used_idx) {
 787        bool resume = dev->iface->queue_is_processed_in_order &&
 788            dev->iface->queue_is_processed_in_order(dev, index);
 789
 790        DPRINT("Last avail index != used index: %u != %u%s\n",
 791               vq->last_avail_idx, vq->used_idx,
 792               resume ? ", resuming" : "");
 793
 794        if (resume) {
 795            vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
 796        }
 797    }
 798
 799    return false;
 800}
 801
 802static bool
 803vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
 804{
 805    unsigned int index = vmsg->payload.state.index;
 806    unsigned int num = vmsg->payload.state.num;
 807
 808    DPRINT("State.index: %d\n", index);
 809    DPRINT("State.num:   %d\n", num);
 810    dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
 811
 812    return false;
 813}
 814
 815static bool
 816vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
 817{
 818    unsigned int index = vmsg->payload.state.index;
 819
 820    DPRINT("State.index: %d\n", index);
 821    vmsg->payload.state.num = dev->vq[index].last_avail_idx;
 822    vmsg->size = sizeof(vmsg->payload.state);
 823
 824    dev->vq[index].started = false;
 825    if (dev->iface->queue_set_started) {
 826        dev->iface->queue_set_started(dev, index, false);
 827    }
 828
 829    if (dev->vq[index].call_fd != -1) {
 830        close(dev->vq[index].call_fd);
 831        dev->vq[index].call_fd = -1;
 832    }
 833    if (dev->vq[index].kick_fd != -1) {
 834        dev->remove_watch(dev, dev->vq[index].kick_fd);
 835        close(dev->vq[index].kick_fd);
 836        dev->vq[index].kick_fd = -1;
 837    }
 838
 839    return true;
 840}
 841
 842static bool
 843vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
 844{
 845    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
 846
 847    if (index >= VHOST_MAX_NR_VIRTQUEUE) {
 848        vmsg_close_fds(vmsg);
 849        vu_panic(dev, "Invalid queue index: %u", index);
 850        return false;
 851    }
 852
 853    if (vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK ||
 854        vmsg->fd_num != 1) {
 855        vmsg_close_fds(vmsg);
 856        vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
 857        return false;
 858    }
 859
 860    return true;
 861}
 862
 863static bool
 864vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
 865{
 866    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
 867
 868    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 869
 870    if (!vu_check_queue_msg_file(dev, vmsg)) {
 871        return false;
 872    }
 873
 874    if (dev->vq[index].kick_fd != -1) {
 875        dev->remove_watch(dev, dev->vq[index].kick_fd);
 876        close(dev->vq[index].kick_fd);
 877        dev->vq[index].kick_fd = -1;
 878    }
 879
 880    if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
 881        dev->vq[index].kick_fd = vmsg->fds[0];
 882        DPRINT("Got kick_fd: %d for vq: %d\n", vmsg->fds[0], index);
 883    }
 884
 885    dev->vq[index].started = true;
 886    if (dev->iface->queue_set_started) {
 887        dev->iface->queue_set_started(dev, index, true);
 888    }
 889
 890    if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
 891        dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
 892                       vu_kick_cb, (void *)(long)index);
 893
 894        DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
 895               dev->vq[index].kick_fd, index);
 896    }
 897
 898    return false;
 899}
 900
 901void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
 902                          vu_queue_handler_cb handler)
 903{
 904    int qidx = vq - dev->vq;
 905
 906    vq->handler = handler;
 907    if (vq->kick_fd >= 0) {
 908        if (handler) {
 909            dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
 910                           vu_kick_cb, (void *)(long)qidx);
 911        } else {
 912            dev->remove_watch(dev, vq->kick_fd);
 913        }
 914    }
 915}
 916
 917static bool
 918vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
 919{
 920    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
 921
 922    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 923
 924    if (!vu_check_queue_msg_file(dev, vmsg)) {
 925        return false;
 926    }
 927
 928    if (dev->vq[index].call_fd != -1) {
 929        close(dev->vq[index].call_fd);
 930        dev->vq[index].call_fd = -1;
 931    }
 932
 933    if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
 934        dev->vq[index].call_fd = vmsg->fds[0];
 935    }
 936
 937    DPRINT("Got call_fd: %d for vq: %d\n", vmsg->fds[0], index);
 938
 939    return false;
 940}
 941
 942static bool
 943vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
 944{
 945    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
 946
 947    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 948
 949    if (!vu_check_queue_msg_file(dev, vmsg)) {
 950        return false;
 951    }
 952
 953    if (dev->vq[index].err_fd != -1) {
 954        close(dev->vq[index].err_fd);
 955        dev->vq[index].err_fd = -1;
 956    }
 957
 958    if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
 959        dev->vq[index].err_fd = vmsg->fds[0];
 960    }
 961
 962    return false;
 963}
 964
 965static bool
 966vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 967{
 968    uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
 969                        1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ;
 970
 971    if (have_userfault()) {
 972        features |= 1ULL << VHOST_USER_PROTOCOL_F_PAGEFAULT;
 973    }
 974
 975    if (dev->iface->get_protocol_features) {
 976        features |= dev->iface->get_protocol_features(dev);
 977    }
 978
 979    vmsg->payload.u64 = features;
 980    vmsg->size = sizeof(vmsg->payload.u64);
 981    vmsg->fd_num = 0;
 982
 983    return true;
 984}
 985
 986static bool
 987vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 988{
 989    uint64_t features = vmsg->payload.u64;
 990
 991    DPRINT("u64: 0x%016"PRIx64"\n", features);
 992
 993    dev->protocol_features = vmsg->payload.u64;
 994
 995    if (dev->iface->set_protocol_features) {
 996        dev->iface->set_protocol_features(dev, features);
 997    }
 998
 999    return false;
1000}
1001
1002static bool
1003vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1004{
1005    DPRINT("Function %s() not implemented yet.\n", __func__);
1006    return false;
1007}
1008
1009static bool
1010vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
1011{
1012    unsigned int index = vmsg->payload.state.index;
1013    unsigned int enable = vmsg->payload.state.num;
1014
1015    DPRINT("State.index: %d\n", index);
1016    DPRINT("State.enable:   %d\n", enable);
1017
1018    if (index >= VHOST_MAX_NR_VIRTQUEUE) {
1019        vu_panic(dev, "Invalid vring_enable index: %u", index);
1020        return false;
1021    }
1022
1023    dev->vq[index].enable = enable;
1024    return false;
1025}
1026
1027static bool
1028vu_set_slave_req_fd(VuDev *dev, VhostUserMsg *vmsg)
1029{
1030    if (vmsg->fd_num != 1) {
1031        vu_panic(dev, "Invalid slave_req_fd message (%d fd's)", vmsg->fd_num);
1032        return false;
1033    }
1034
1035    if (dev->slave_fd != -1) {
1036        close(dev->slave_fd);
1037    }
1038    dev->slave_fd = vmsg->fds[0];
1039    DPRINT("Got slave_fd: %d\n", vmsg->fds[0]);
1040
1041    return false;
1042}
1043
1044static bool
1045vu_get_config(VuDev *dev, VhostUserMsg *vmsg)
1046{
1047    int ret = -1;
1048
1049    if (dev->iface->get_config) {
1050        ret = dev->iface->get_config(dev, vmsg->payload.config.region,
1051                                     vmsg->payload.config.size);
1052    }
1053
1054    if (ret) {
1055        /* resize to zero to indicate an error to master */
1056        vmsg->size = 0;
1057    }
1058
1059    return true;
1060}
1061
1062static bool
1063vu_set_config(VuDev *dev, VhostUserMsg *vmsg)
1064{
1065    int ret = -1;
1066
1067    if (dev->iface->set_config) {
1068        ret = dev->iface->set_config(dev, vmsg->payload.config.region,
1069                                     vmsg->payload.config.offset,
1070                                     vmsg->payload.config.size,
1071                                     vmsg->payload.config.flags);
1072        if (ret) {
1073            vu_panic(dev, "Set virtio configuration space failed");
1074        }
1075    }
1076
1077    return false;
1078}
1079
1080static bool
1081vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg)
1082{
1083    dev->postcopy_ufd = -1;
1084#ifdef UFFDIO_API
1085    struct uffdio_api api_struct;
1086
1087    dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
1088    vmsg->size = 0;
1089#endif
1090
1091    if (dev->postcopy_ufd == -1) {
1092        vu_panic(dev, "Userfaultfd not available: %s", strerror(errno));
1093        goto out;
1094    }
1095
1096#ifdef UFFDIO_API
1097    api_struct.api = UFFD_API;
1098    api_struct.features = 0;
1099    if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) {
1100        vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno));
1101        close(dev->postcopy_ufd);
1102        dev->postcopy_ufd = -1;
1103        goto out;
1104    }
1105    /* TODO: Stash feature flags somewhere */
1106#endif
1107
1108out:
1109    /* Return a ufd to the QEMU */
1110    vmsg->fd_num = 1;
1111    vmsg->fds[0] = dev->postcopy_ufd;
1112    return true; /* = send a reply */
1113}
1114
1115static bool
1116vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg)
1117{
1118    vmsg->payload.u64 = -1;
1119    vmsg->size = sizeof(vmsg->payload.u64);
1120
1121    if (dev->nregions) {
1122        vu_panic(dev, "Regions already registered at postcopy-listen");
1123        return true;
1124    }
1125    dev->postcopy_listening = true;
1126
1127    vmsg->flags = VHOST_USER_VERSION |  VHOST_USER_REPLY_MASK;
1128    vmsg->payload.u64 = 0; /* Success */
1129    return true;
1130}
1131
1132static bool
1133vu_set_postcopy_end(VuDev *dev, VhostUserMsg *vmsg)
1134{
1135    DPRINT("%s: Entry\n", __func__);
1136    dev->postcopy_listening = false;
1137    if (dev->postcopy_ufd > 0) {
1138        close(dev->postcopy_ufd);
1139        dev->postcopy_ufd = -1;
1140        DPRINT("%s: Done close\n", __func__);
1141    }
1142
1143    vmsg->fd_num = 0;
1144    vmsg->payload.u64 = 0;
1145    vmsg->size = sizeof(vmsg->payload.u64);
1146    vmsg->flags = VHOST_USER_VERSION |  VHOST_USER_REPLY_MASK;
1147    DPRINT("%s: exit\n", __func__);
1148    return true;
1149}
1150
1151static bool
1152vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
1153{
1154    int do_reply = 0;
1155
1156    /* Print out generic part of the request. */
1157    DPRINT("================ Vhost user message ================\n");
1158    DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
1159           vmsg->request);
1160    DPRINT("Flags:   0x%x\n", vmsg->flags);
1161    DPRINT("Size:    %d\n", vmsg->size);
1162
1163    if (vmsg->fd_num) {
1164        int i;
1165        DPRINT("Fds:");
1166        for (i = 0; i < vmsg->fd_num; i++) {
1167            DPRINT(" %d", vmsg->fds[i]);
1168        }
1169        DPRINT("\n");
1170    }
1171
1172    if (dev->iface->process_msg &&
1173        dev->iface->process_msg(dev, vmsg, &do_reply)) {
1174        return do_reply;
1175    }
1176
1177    switch (vmsg->request) {
1178    case VHOST_USER_GET_FEATURES:
1179        return vu_get_features_exec(dev, vmsg);
1180    case VHOST_USER_SET_FEATURES:
1181        return vu_set_features_exec(dev, vmsg);
1182    case VHOST_USER_GET_PROTOCOL_FEATURES:
1183        return vu_get_protocol_features_exec(dev, vmsg);
1184    case VHOST_USER_SET_PROTOCOL_FEATURES:
1185        return vu_set_protocol_features_exec(dev, vmsg);
1186    case VHOST_USER_SET_OWNER:
1187        return vu_set_owner_exec(dev, vmsg);
1188    case VHOST_USER_RESET_OWNER:
1189        return vu_reset_device_exec(dev, vmsg);
1190    case VHOST_USER_SET_MEM_TABLE:
1191        return vu_set_mem_table_exec(dev, vmsg);
1192    case VHOST_USER_SET_LOG_BASE:
1193        return vu_set_log_base_exec(dev, vmsg);
1194    case VHOST_USER_SET_LOG_FD:
1195        return vu_set_log_fd_exec(dev, vmsg);
1196    case VHOST_USER_SET_VRING_NUM:
1197        return vu_set_vring_num_exec(dev, vmsg);
1198    case VHOST_USER_SET_VRING_ADDR:
1199        return vu_set_vring_addr_exec(dev, vmsg);
1200    case VHOST_USER_SET_VRING_BASE:
1201        return vu_set_vring_base_exec(dev, vmsg);
1202    case VHOST_USER_GET_VRING_BASE:
1203        return vu_get_vring_base_exec(dev, vmsg);
1204    case VHOST_USER_SET_VRING_KICK:
1205        return vu_set_vring_kick_exec(dev, vmsg);
1206    case VHOST_USER_SET_VRING_CALL:
1207        return vu_set_vring_call_exec(dev, vmsg);
1208    case VHOST_USER_SET_VRING_ERR:
1209        return vu_set_vring_err_exec(dev, vmsg);
1210    case VHOST_USER_GET_QUEUE_NUM:
1211        return vu_get_queue_num_exec(dev, vmsg);
1212    case VHOST_USER_SET_VRING_ENABLE:
1213        return vu_set_vring_enable_exec(dev, vmsg);
1214    case VHOST_USER_SET_SLAVE_REQ_FD:
1215        return vu_set_slave_req_fd(dev, vmsg);
1216    case VHOST_USER_GET_CONFIG:
1217        return vu_get_config(dev, vmsg);
1218    case VHOST_USER_SET_CONFIG:
1219        return vu_set_config(dev, vmsg);
1220    case VHOST_USER_NONE:
1221        break;
1222    case VHOST_USER_POSTCOPY_ADVISE:
1223        return vu_set_postcopy_advise(dev, vmsg);
1224    case VHOST_USER_POSTCOPY_LISTEN:
1225        return vu_set_postcopy_listen(dev, vmsg);
1226    case VHOST_USER_POSTCOPY_END:
1227        return vu_set_postcopy_end(dev, vmsg);
1228    default:
1229        vmsg_close_fds(vmsg);
1230        vu_panic(dev, "Unhandled request: %d", vmsg->request);
1231    }
1232
1233    return false;
1234}
1235
1236bool
1237vu_dispatch(VuDev *dev)
1238{
1239    VhostUserMsg vmsg = { 0, };
1240    int reply_requested;
1241    bool success = false;
1242
1243    if (!vu_message_read(dev, dev->sock, &vmsg)) {
1244        goto end;
1245    }
1246
1247    reply_requested = vu_process_message(dev, &vmsg);
1248    if (!reply_requested) {
1249        success = true;
1250        goto end;
1251    }
1252
1253    if (!vu_message_write(dev, dev->sock, &vmsg)) {
1254        goto end;
1255    }
1256
1257    success = true;
1258
1259end:
1260    free(vmsg.data);
1261    return success;
1262}
1263
1264void
1265vu_deinit(VuDev *dev)
1266{
1267    int i;
1268
1269    for (i = 0; i < dev->nregions; i++) {
1270        VuDevRegion *r = &dev->regions[i];
1271        void *m = (void *) (uintptr_t) r->mmap_addr;
1272        if (m != MAP_FAILED) {
1273            munmap(m, r->size + r->mmap_offset);
1274        }
1275    }
1276    dev->nregions = 0;
1277
1278    for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
1279        VuVirtq *vq = &dev->vq[i];
1280
1281        if (vq->call_fd != -1) {
1282            close(vq->call_fd);
1283            vq->call_fd = -1;
1284        }
1285
1286        if (vq->kick_fd != -1) {
1287            close(vq->kick_fd);
1288            vq->kick_fd = -1;
1289        }
1290
1291        if (vq->err_fd != -1) {
1292            close(vq->err_fd);
1293            vq->err_fd = -1;
1294        }
1295    }
1296
1297
1298    vu_close_log(dev);
1299    if (dev->slave_fd != -1) {
1300        close(dev->slave_fd);
1301        dev->slave_fd = -1;
1302    }
1303
1304    if (dev->sock != -1) {
1305        close(dev->sock);
1306    }
1307}
1308
1309void
1310vu_init(VuDev *dev,
1311        int socket,
1312        vu_panic_cb panic,
1313        vu_set_watch_cb set_watch,
1314        vu_remove_watch_cb remove_watch,
1315        const VuDevIface *iface)
1316{
1317    int i;
1318
1319    assert(socket >= 0);
1320    assert(set_watch);
1321    assert(remove_watch);
1322    assert(iface);
1323    assert(panic);
1324
1325    memset(dev, 0, sizeof(*dev));
1326
1327    dev->sock = socket;
1328    dev->panic = panic;
1329    dev->set_watch = set_watch;
1330    dev->remove_watch = remove_watch;
1331    dev->iface = iface;
1332    dev->log_call_fd = -1;
1333    dev->slave_fd = -1;
1334    for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
1335        dev->vq[i] = (VuVirtq) {
1336            .call_fd = -1, .kick_fd = -1, .err_fd = -1,
1337            .notification = true,
1338        };
1339    }
1340}
1341
1342VuVirtq *
1343vu_get_queue(VuDev *dev, int qidx)
1344{
1345    assert(qidx < VHOST_MAX_NR_VIRTQUEUE);
1346    return &dev->vq[qidx];
1347}
1348
1349bool
1350vu_queue_enabled(VuDev *dev, VuVirtq *vq)
1351{
1352    return vq->enable;
1353}
1354
1355bool
1356vu_queue_started(const VuDev *dev, const VuVirtq *vq)
1357{
1358    return vq->started;
1359}
1360
1361static inline uint16_t
1362vring_avail_flags(VuVirtq *vq)
1363{
1364    return vq->vring.avail->flags;
1365}
1366
1367static inline uint16_t
1368vring_avail_idx(VuVirtq *vq)
1369{
1370    vq->shadow_avail_idx = vq->vring.avail->idx;
1371
1372    return vq->shadow_avail_idx;
1373}
1374
1375static inline uint16_t
1376vring_avail_ring(VuVirtq *vq, int i)
1377{
1378    return vq->vring.avail->ring[i];
1379}
1380
1381static inline uint16_t
1382vring_get_used_event(VuVirtq *vq)
1383{
1384    return vring_avail_ring(vq, vq->vring.num);
1385}
1386
1387static int
1388virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
1389{
1390    uint16_t num_heads = vring_avail_idx(vq) - idx;
1391
1392    /* Check it isn't doing very strange things with descriptor numbers. */
1393    if (num_heads > vq->vring.num) {
1394        vu_panic(dev, "Guest moved used index from %u to %u",
1395                 idx, vq->shadow_avail_idx);
1396        return -1;
1397    }
1398    if (num_heads) {
1399        /* On success, callers read a descriptor at vq->last_avail_idx.
1400         * Make sure descriptor read does not bypass avail index read. */
1401        smp_rmb();
1402    }
1403
1404    return num_heads;
1405}
1406
1407static bool
1408virtqueue_get_head(VuDev *dev, VuVirtq *vq,
1409                   unsigned int idx, unsigned int *head)
1410{
1411    /* Grab the next descriptor number they're advertising, and increment
1412     * the index we've seen. */
1413    *head = vring_avail_ring(vq, idx % vq->vring.num);
1414
1415    /* If their number is silly, that's a fatal mistake. */
1416    if (*head >= vq->vring.num) {
1417        vu_panic(dev, "Guest says index %u is available", head);
1418        return false;
1419    }
1420
1421    return true;
1422}
1423
1424static int
1425virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc,
1426                             uint64_t addr, size_t len)
1427{
1428    struct vring_desc *ori_desc;
1429    uint64_t read_len;
1430
1431    if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
1432        return -1;
1433    }
1434
1435    if (len == 0) {
1436        return -1;
1437    }
1438
1439    while (len) {
1440        read_len = len;
1441        ori_desc = vu_gpa_to_va(dev, &read_len, addr);
1442        if (!ori_desc) {
1443            return -1;
1444        }
1445
1446        memcpy(desc, ori_desc, read_len);
1447        len -= read_len;
1448        addr += read_len;
1449        desc += read_len;
1450    }
1451
1452    return 0;
1453}
1454
1455enum {
1456    VIRTQUEUE_READ_DESC_ERROR = -1,
1457    VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
1458    VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
1459};
1460
1461static int
1462virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
1463                         int i, unsigned int max, unsigned int *next)
1464{
1465    /* If this descriptor says it doesn't chain, we're done. */
1466    if (!(desc[i].flags & VRING_DESC_F_NEXT)) {
1467        return VIRTQUEUE_READ_DESC_DONE;
1468    }
1469
1470    /* Check they're not leading us off end of descriptors. */
1471    *next = desc[i].next;
1472    /* Make sure compiler knows to grab that: we don't want it changing! */
1473    smp_wmb();
1474
1475    if (*next >= max) {
1476        vu_panic(dev, "Desc next is %u", next);
1477        return VIRTQUEUE_READ_DESC_ERROR;
1478    }
1479
1480    return VIRTQUEUE_READ_DESC_MORE;
1481}
1482
1483void
1484vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
1485                         unsigned int *out_bytes,
1486                         unsigned max_in_bytes, unsigned max_out_bytes)
1487{
1488    unsigned int idx;
1489    unsigned int total_bufs, in_total, out_total;
1490    int rc;
1491
1492    idx = vq->last_avail_idx;
1493
1494    total_bufs = in_total = out_total = 0;
1495    if (unlikely(dev->broken) ||
1496        unlikely(!vq->vring.avail)) {
1497        goto done;
1498    }
1499
1500    while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
1501        unsigned int max, desc_len, num_bufs, indirect = 0;
1502        uint64_t desc_addr, read_len;
1503        struct vring_desc *desc;
1504        struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
1505        unsigned int i;
1506
1507        max = vq->vring.num;
1508        num_bufs = total_bufs;
1509        if (!virtqueue_get_head(dev, vq, idx++, &i)) {
1510            goto err;
1511        }
1512        desc = vq->vring.desc;
1513
1514        if (desc[i].flags & VRING_DESC_F_INDIRECT) {
1515            if (desc[i].len % sizeof(struct vring_desc)) {
1516                vu_panic(dev, "Invalid size for indirect buffer table");
1517                goto err;
1518            }
1519
1520            /* If we've got too many, that implies a descriptor loop. */
1521            if (num_bufs >= max) {
1522                vu_panic(dev, "Looped descriptor");
1523                goto err;
1524            }
1525
1526            /* loop over the indirect descriptor table */
1527            indirect = 1;
1528            desc_addr = desc[i].addr;
1529            desc_len = desc[i].len;
1530            max = desc_len / sizeof(struct vring_desc);
1531            read_len = desc_len;
1532            desc = vu_gpa_to_va(dev, &read_len, desc_addr);
1533            if (unlikely(desc && read_len != desc_len)) {
1534                /* Failed to use zero copy */
1535                desc = NULL;
1536                if (!virtqueue_read_indirect_desc(dev, desc_buf,
1537                                                  desc_addr,
1538                                                  desc_len)) {
1539                    desc = desc_buf;
1540                }
1541            }
1542            if (!desc) {
1543                vu_panic(dev, "Invalid indirect buffer table");
1544                goto err;
1545            }
1546            num_bufs = i = 0;
1547        }
1548
1549        do {
1550            /* If we've got too many, that implies a descriptor loop. */
1551            if (++num_bufs > max) {
1552                vu_panic(dev, "Looped descriptor");
1553                goto err;
1554            }
1555
1556            if (desc[i].flags & VRING_DESC_F_WRITE) {
1557                in_total += desc[i].len;
1558            } else {
1559                out_total += desc[i].len;
1560            }
1561            if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
1562                goto done;
1563            }
1564            rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
1565        } while (rc == VIRTQUEUE_READ_DESC_MORE);
1566
1567        if (rc == VIRTQUEUE_READ_DESC_ERROR) {
1568            goto err;
1569        }
1570
1571        if (!indirect) {
1572            total_bufs = num_bufs;
1573        } else {
1574            total_bufs++;
1575        }
1576    }
1577    if (rc < 0) {
1578        goto err;
1579    }
1580done:
1581    if (in_bytes) {
1582        *in_bytes = in_total;
1583    }
1584    if (out_bytes) {
1585        *out_bytes = out_total;
1586    }
1587    return;
1588
1589err:
1590    in_total = out_total = 0;
1591    goto done;
1592}
1593
1594bool
1595vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
1596                     unsigned int out_bytes)
1597{
1598    unsigned int in_total, out_total;
1599
1600    vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
1601                             in_bytes, out_bytes);
1602
1603    return in_bytes <= in_total && out_bytes <= out_total;
1604}
1605
1606/* Fetch avail_idx from VQ memory only when we really need to know if
1607 * guest has added some buffers. */
1608bool
1609vu_queue_empty(VuDev *dev, VuVirtq *vq)
1610{
1611    if (unlikely(dev->broken) ||
1612        unlikely(!vq->vring.avail)) {
1613        return true;
1614    }
1615
1616    if (vq->shadow_avail_idx != vq->last_avail_idx) {
1617        return false;
1618    }
1619
1620    return vring_avail_idx(vq) == vq->last_avail_idx;
1621}
1622
1623static inline
1624bool has_feature(uint64_t features, unsigned int fbit)
1625{
1626    assert(fbit < 64);
1627    return !!(features & (1ULL << fbit));
1628}
1629
1630static inline
1631bool vu_has_feature(VuDev *dev,
1632                    unsigned int fbit)
1633{
1634    return has_feature(dev->features, fbit);
1635}
1636
1637static bool
1638vring_notify(VuDev *dev, VuVirtq *vq)
1639{
1640    uint16_t old, new;
1641    bool v;
1642
1643    /* We need to expose used array entries before checking used event. */
1644    smp_mb();
1645
1646    /* Always notify when queue is empty (when feature acknowledge) */
1647    if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
1648        !vq->inuse && vu_queue_empty(dev, vq)) {
1649        return true;
1650    }
1651
1652    if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
1653        return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
1654    }
1655
1656    v = vq->signalled_used_valid;
1657    vq->signalled_used_valid = true;
1658    old = vq->signalled_used;
1659    new = vq->signalled_used = vq->used_idx;
1660    return !v || vring_need_event(vring_get_used_event(vq), new, old);
1661}
1662
1663void
1664vu_queue_notify(VuDev *dev, VuVirtq *vq)
1665{
1666    if (unlikely(dev->broken) ||
1667        unlikely(!vq->vring.avail)) {
1668        return;
1669    }
1670
1671    if (!vring_notify(dev, vq)) {
1672        DPRINT("skipped notify...\n");
1673        return;
1674    }
1675
1676    if (eventfd_write(vq->call_fd, 1) < 0) {
1677        vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
1678    }
1679}
1680
1681static inline void
1682vring_used_flags_set_bit(VuVirtq *vq, int mask)
1683{
1684    uint16_t *flags;
1685
1686    flags = (uint16_t *)((char*)vq->vring.used +
1687                         offsetof(struct vring_used, flags));
1688    *flags |= mask;
1689}
1690
1691static inline void
1692vring_used_flags_unset_bit(VuVirtq *vq, int mask)
1693{
1694    uint16_t *flags;
1695
1696    flags = (uint16_t *)((char*)vq->vring.used +
1697                         offsetof(struct vring_used, flags));
1698    *flags &= ~mask;
1699}
1700
1701static inline void
1702vring_set_avail_event(VuVirtq *vq, uint16_t val)
1703{
1704    if (!vq->notification) {
1705        return;
1706    }
1707
1708    *((uint16_t *) &vq->vring.used->ring[vq->vring.num]) = val;
1709}
1710
1711void
1712vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
1713{
1714    vq->notification = enable;
1715    if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
1716        vring_set_avail_event(vq, vring_avail_idx(vq));
1717    } else if (enable) {
1718        vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
1719    } else {
1720        vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
1721    }
1722    if (enable) {
1723        /* Expose avail event/used flags before caller checks the avail idx. */
1724        smp_mb();
1725    }
1726}
1727
1728static void
1729virtqueue_map_desc(VuDev *dev,
1730                   unsigned int *p_num_sg, struct iovec *iov,
1731                   unsigned int max_num_sg, bool is_write,
1732                   uint64_t pa, size_t sz)
1733{
1734    unsigned num_sg = *p_num_sg;
1735
1736    assert(num_sg <= max_num_sg);
1737
1738    if (!sz) {
1739        vu_panic(dev, "virtio: zero sized buffers are not allowed");
1740        return;
1741    }
1742
1743    while (sz) {
1744        uint64_t len = sz;
1745
1746        if (num_sg == max_num_sg) {
1747            vu_panic(dev, "virtio: too many descriptors in indirect table");
1748            return;
1749        }
1750
1751        iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa);
1752        if (iov[num_sg].iov_base == NULL) {
1753            vu_panic(dev, "virtio: invalid address for buffers");
1754            return;
1755        }
1756        iov[num_sg].iov_len = len;
1757        num_sg++;
1758        sz -= len;
1759        pa += len;
1760    }
1761
1762    *p_num_sg = num_sg;
1763}
1764
1765/* Round number down to multiple */
1766#define ALIGN_DOWN(n, m) ((n) / (m) * (m))
1767
1768/* Round number up to multiple */
1769#define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
1770
1771static void *
1772virtqueue_alloc_element(size_t sz,
1773                                     unsigned out_num, unsigned in_num)
1774{
1775    VuVirtqElement *elem;
1776    size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
1777    size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
1778    size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
1779
1780    assert(sz >= sizeof(VuVirtqElement));
1781    elem = malloc(out_sg_end);
1782    elem->out_num = out_num;
1783    elem->in_num = in_num;
1784    elem->in_sg = (void *)elem + in_sg_ofs;
1785    elem->out_sg = (void *)elem + out_sg_ofs;
1786    return elem;
1787}
1788
1789void *
1790vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
1791{
1792    unsigned int i, head, max, desc_len;
1793    uint64_t desc_addr, read_len;
1794    VuVirtqElement *elem;
1795    unsigned out_num, in_num;
1796    struct iovec iov[VIRTQUEUE_MAX_SIZE];
1797    struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
1798    struct vring_desc *desc;
1799    int rc;
1800
1801    if (unlikely(dev->broken) ||
1802        unlikely(!vq->vring.avail)) {
1803        return NULL;
1804    }
1805
1806    if (vu_queue_empty(dev, vq)) {
1807        return NULL;
1808    }
1809    /* Needed after virtio_queue_empty(), see comment in
1810     * virtqueue_num_heads(). */
1811    smp_rmb();
1812
1813    /* When we start there are none of either input nor output. */
1814    out_num = in_num = 0;
1815
1816    max = vq->vring.num;
1817    if (vq->inuse >= vq->vring.num) {
1818        vu_panic(dev, "Virtqueue size exceeded");
1819        return NULL;
1820    }
1821
1822    if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
1823        return NULL;
1824    }
1825
1826    if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
1827        vring_set_avail_event(vq, vq->last_avail_idx);
1828    }
1829
1830    i = head;
1831    desc = vq->vring.desc;
1832    if (desc[i].flags & VRING_DESC_F_INDIRECT) {
1833        if (desc[i].len % sizeof(struct vring_desc)) {
1834            vu_panic(dev, "Invalid size for indirect buffer table");
1835        }
1836
1837        /* loop over the indirect descriptor table */
1838        desc_addr = desc[i].addr;
1839        desc_len = desc[i].len;
1840        max = desc_len / sizeof(struct vring_desc);
1841        read_len = desc_len;
1842        desc = vu_gpa_to_va(dev, &read_len, desc_addr);
1843        if (unlikely(desc && read_len != desc_len)) {
1844            /* Failed to use zero copy */
1845            desc = NULL;
1846            if (!virtqueue_read_indirect_desc(dev, desc_buf,
1847                                              desc_addr,
1848                                              desc_len)) {
1849                desc = desc_buf;
1850            }
1851        }
1852        if (!desc) {
1853            vu_panic(dev, "Invalid indirect buffer table");
1854            return NULL;
1855        }
1856        i = 0;
1857    }
1858
1859    /* Collect all the descriptors */
1860    do {
1861        if (desc[i].flags & VRING_DESC_F_WRITE) {
1862            virtqueue_map_desc(dev, &in_num, iov + out_num,
1863                               VIRTQUEUE_MAX_SIZE - out_num, true,
1864                               desc[i].addr, desc[i].len);
1865        } else {
1866            if (in_num) {
1867                vu_panic(dev, "Incorrect order for descriptors");
1868                return NULL;
1869            }
1870            virtqueue_map_desc(dev, &out_num, iov,
1871                               VIRTQUEUE_MAX_SIZE, false,
1872                               desc[i].addr, desc[i].len);
1873        }
1874
1875        /* If we've got too many, that implies a descriptor loop. */
1876        if ((in_num + out_num) > max) {
1877            vu_panic(dev, "Looped descriptor");
1878        }
1879        rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
1880    } while (rc == VIRTQUEUE_READ_DESC_MORE);
1881
1882    if (rc == VIRTQUEUE_READ_DESC_ERROR) {
1883        return NULL;
1884    }
1885
1886    /* Now copy what we have collected and mapped */
1887    elem = virtqueue_alloc_element(sz, out_num, in_num);
1888    elem->index = head;
1889    for (i = 0; i < out_num; i++) {
1890        elem->out_sg[i] = iov[i];
1891    }
1892    for (i = 0; i < in_num; i++) {
1893        elem->in_sg[i] = iov[out_num + i];
1894    }
1895
1896    vq->inuse++;
1897
1898    return elem;
1899}
1900
1901bool
1902vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
1903{
1904    if (num > vq->inuse) {
1905        return false;
1906    }
1907    vq->last_avail_idx -= num;
1908    vq->inuse -= num;
1909    return true;
1910}
1911
1912static inline
1913void vring_used_write(VuDev *dev, VuVirtq *vq,
1914                      struct vring_used_elem *uelem, int i)
1915{
1916    struct vring_used *used = vq->vring.used;
1917
1918    used->ring[i] = *uelem;
1919    vu_log_write(dev, vq->vring.log_guest_addr +
1920                 offsetof(struct vring_used, ring[i]),
1921                 sizeof(used->ring[i]));
1922}
1923
1924
1925static void
1926vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
1927                  const VuVirtqElement *elem,
1928                  unsigned int len)
1929{
1930    struct vring_desc *desc = vq->vring.desc;
1931    unsigned int i, max, min, desc_len;
1932    uint64_t desc_addr, read_len;
1933    struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
1934    unsigned num_bufs = 0;
1935
1936    max = vq->vring.num;
1937    i = elem->index;
1938
1939    if (desc[i].flags & VRING_DESC_F_INDIRECT) {
1940        if (desc[i].len % sizeof(struct vring_desc)) {
1941            vu_panic(dev, "Invalid size for indirect buffer table");
1942        }
1943
1944        /* loop over the indirect descriptor table */
1945        desc_addr = desc[i].addr;
1946        desc_len = desc[i].len;
1947        max = desc_len / sizeof(struct vring_desc);
1948        read_len = desc_len;
1949        desc = vu_gpa_to_va(dev, &read_len, desc_addr);
1950        if (unlikely(desc && read_len != desc_len)) {
1951            /* Failed to use zero copy */
1952            desc = NULL;
1953            if (!virtqueue_read_indirect_desc(dev, desc_buf,
1954                                              desc_addr,
1955                                              desc_len)) {
1956                desc = desc_buf;
1957            }
1958        }
1959        if (!desc) {
1960            vu_panic(dev, "Invalid indirect buffer table");
1961            return;
1962        }
1963        i = 0;
1964    }
1965
1966    do {
1967        if (++num_bufs > max) {
1968            vu_panic(dev, "Looped descriptor");
1969            return;
1970        }
1971
1972        if (desc[i].flags & VRING_DESC_F_WRITE) {
1973            min = MIN(desc[i].len, len);
1974            vu_log_write(dev, desc[i].addr, min);
1975            len -= min;
1976        }
1977
1978    } while (len > 0 &&
1979             (virtqueue_read_next_desc(dev, desc, i, max, &i)
1980              == VIRTQUEUE_READ_DESC_MORE));
1981}
1982
1983void
1984vu_queue_fill(VuDev *dev, VuVirtq *vq,
1985              const VuVirtqElement *elem,
1986              unsigned int len, unsigned int idx)
1987{
1988    struct vring_used_elem uelem;
1989
1990    if (unlikely(dev->broken) ||
1991        unlikely(!vq->vring.avail)) {
1992        return;
1993    }
1994
1995    vu_log_queue_fill(dev, vq, elem, len);
1996
1997    idx = (idx + vq->used_idx) % vq->vring.num;
1998
1999    uelem.id = elem->index;
2000    uelem.len = len;
2001    vring_used_write(dev, vq, &uelem, idx);
2002}
2003
2004static inline
2005void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
2006{
2007    vq->vring.used->idx = val;
2008    vu_log_write(dev,
2009                 vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
2010                 sizeof(vq->vring.used->idx));
2011
2012    vq->used_idx = val;
2013}
2014
2015void
2016vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
2017{
2018    uint16_t old, new;
2019
2020    if (unlikely(dev->broken) ||
2021        unlikely(!vq->vring.avail)) {
2022        return;
2023    }
2024
2025    /* Make sure buffer is written before we update index. */
2026    smp_wmb();
2027
2028    old = vq->used_idx;
2029    new = old + count;
2030    vring_used_idx_set(dev, vq, new);
2031    vq->inuse -= count;
2032    if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
2033        vq->signalled_used_valid = false;
2034    }
2035}
2036
2037void
2038vu_queue_push(VuDev *dev, VuVirtq *vq,
2039              const VuVirtqElement *elem, unsigned int len)
2040{
2041    vu_queue_fill(dev, vq, elem, len, 0);
2042    vu_queue_flush(dev, vq, 1);
2043}
2044