qemu/contrib/libvhost-user/libvhost-user.c
<<
>>
Prefs
   1/*
   2 * Vhost User library
   3 *
   4 * Copyright IBM, Corp. 2007
   5 * Copyright (c) 2016 Red Hat, Inc.
   6 *
   7 * Authors:
   8 *  Anthony Liguori <aliguori@us.ibm.com>
   9 *  Marc-André Lureau <mlureau@redhat.com>
  10 *  Victor Kaplansky <victork@redhat.com>
  11 *
  12 * This work is licensed under the terms of the GNU GPL, version 2 or
  13 * later.  See the COPYING file in the top-level directory.
  14 */
  15
  16/* this code avoids GLib dependency */
  17#include <stdlib.h>
  18#include <stdio.h>
  19#include <unistd.h>
  20#include <stdarg.h>
  21#include <errno.h>
  22#include <string.h>
  23#include <assert.h>
  24#include <inttypes.h>
  25#include <sys/types.h>
  26#include <sys/socket.h>
  27#include <sys/eventfd.h>
  28#include <sys/mman.h>
  29#include "qemu/compiler.h"
  30
  31#if defined(__linux__)
  32#include <sys/syscall.h>
  33#include <fcntl.h>
  34#include <sys/ioctl.h>
  35#include <linux/vhost.h>
  36
  37#ifdef __NR_userfaultfd
  38#include <linux/userfaultfd.h>
  39#endif
  40
  41#endif
  42
  43#include "qemu/atomic.h"
  44#include "qemu/osdep.h"
  45#include "qemu/memfd.h"
  46
  47#include "libvhost-user.h"
  48
  49/* usually provided by GLib */
  50#ifndef MIN
  51#define MIN(x, y) ({                            \
  52            typeof(x) _min1 = (x);              \
  53            typeof(y) _min2 = (y);              \
  54            (void) (&_min1 == &_min2);          \
  55            _min1 < _min2 ? _min1 : _min2; })
  56#endif
  57
  58/* Round number down to multiple */
  59#define ALIGN_DOWN(n, m) ((n) / (m) * (m))
  60
  61/* Round number up to multiple */
  62#define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
  63
  64/* Align each region to cache line size in inflight buffer */
  65#define INFLIGHT_ALIGNMENT 64
  66
  67/* The version of inflight buffer */
  68#define INFLIGHT_VERSION 1
  69
  70#define VHOST_USER_HDR_SIZE offsetof(VhostUserMsg, payload.u64)
  71
  72/* The version of the protocol we support */
  73#define VHOST_USER_VERSION 1
  74#define LIBVHOST_USER_DEBUG 0
  75
  76#define DPRINT(...)                             \
  77    do {                                        \
  78        if (LIBVHOST_USER_DEBUG) {              \
  79            fprintf(stderr, __VA_ARGS__);        \
  80        }                                       \
  81    } while (0)
  82
  83static inline
  84bool has_feature(uint64_t features, unsigned int fbit)
  85{
  86    assert(fbit < 64);
  87    return !!(features & (1ULL << fbit));
  88}
  89
  90static inline
  91bool vu_has_feature(VuDev *dev,
  92                    unsigned int fbit)
  93{
  94    return has_feature(dev->features, fbit);
  95}
  96
  97static inline bool vu_has_protocol_feature(VuDev *dev, unsigned int fbit)
  98{
  99    return has_feature(dev->protocol_features, fbit);
 100}
 101
 102static const char *
 103vu_request_to_string(unsigned int req)
 104{
 105#define REQ(req) [req] = #req
 106    static const char *vu_request_str[] = {
 107        REQ(VHOST_USER_NONE),
 108        REQ(VHOST_USER_GET_FEATURES),
 109        REQ(VHOST_USER_SET_FEATURES),
 110        REQ(VHOST_USER_SET_OWNER),
 111        REQ(VHOST_USER_RESET_OWNER),
 112        REQ(VHOST_USER_SET_MEM_TABLE),
 113        REQ(VHOST_USER_SET_LOG_BASE),
 114        REQ(VHOST_USER_SET_LOG_FD),
 115        REQ(VHOST_USER_SET_VRING_NUM),
 116        REQ(VHOST_USER_SET_VRING_ADDR),
 117        REQ(VHOST_USER_SET_VRING_BASE),
 118        REQ(VHOST_USER_GET_VRING_BASE),
 119        REQ(VHOST_USER_SET_VRING_KICK),
 120        REQ(VHOST_USER_SET_VRING_CALL),
 121        REQ(VHOST_USER_SET_VRING_ERR),
 122        REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
 123        REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
 124        REQ(VHOST_USER_GET_QUEUE_NUM),
 125        REQ(VHOST_USER_SET_VRING_ENABLE),
 126        REQ(VHOST_USER_SEND_RARP),
 127        REQ(VHOST_USER_NET_SET_MTU),
 128        REQ(VHOST_USER_SET_SLAVE_REQ_FD),
 129        REQ(VHOST_USER_IOTLB_MSG),
 130        REQ(VHOST_USER_SET_VRING_ENDIAN),
 131        REQ(VHOST_USER_GET_CONFIG),
 132        REQ(VHOST_USER_SET_CONFIG),
 133        REQ(VHOST_USER_POSTCOPY_ADVISE),
 134        REQ(VHOST_USER_POSTCOPY_LISTEN),
 135        REQ(VHOST_USER_POSTCOPY_END),
 136        REQ(VHOST_USER_GET_INFLIGHT_FD),
 137        REQ(VHOST_USER_SET_INFLIGHT_FD),
 138        REQ(VHOST_USER_GPU_SET_SOCKET),
 139        REQ(VHOST_USER_VRING_KICK),
 140        REQ(VHOST_USER_GET_MAX_MEM_SLOTS),
 141        REQ(VHOST_USER_ADD_MEM_REG),
 142        REQ(VHOST_USER_REM_MEM_REG),
 143        REQ(VHOST_USER_MAX),
 144    };
 145#undef REQ
 146
 147    if (req < VHOST_USER_MAX) {
 148        return vu_request_str[req];
 149    } else {
 150        return "unknown";
 151    }
 152}
 153
 154static void
 155vu_panic(VuDev *dev, const char *msg, ...)
 156{
 157    char *buf = NULL;
 158    va_list ap;
 159
 160    va_start(ap, msg);
 161    if (vasprintf(&buf, msg, ap) < 0) {
 162        buf = NULL;
 163    }
 164    va_end(ap);
 165
 166    dev->broken = true;
 167    dev->panic(dev, buf);
 168    free(buf);
 169
 170    /*
 171     * FIXME:
 172     * find a way to call virtio_error, or perhaps close the connection?
 173     */
 174}
 175
 176/* Translate guest physical address to our virtual address.  */
 177void *
 178vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr)
 179{
 180    int i;
 181
 182    if (*plen == 0) {
 183        return NULL;
 184    }
 185
 186    /* Find matching memory region.  */
 187    for (i = 0; i < dev->nregions; i++) {
 188        VuDevRegion *r = &dev->regions[i];
 189
 190        if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) {
 191            if ((guest_addr + *plen) > (r->gpa + r->size)) {
 192                *plen = r->gpa + r->size - guest_addr;
 193            }
 194            return (void *)(uintptr_t)
 195                guest_addr - r->gpa + r->mmap_addr + r->mmap_offset;
 196        }
 197    }
 198
 199    return NULL;
 200}
 201
 202/* Translate qemu virtual address to our virtual address.  */
 203static void *
 204qva_to_va(VuDev *dev, uint64_t qemu_addr)
 205{
 206    int i;
 207
 208    /* Find matching memory region.  */
 209    for (i = 0; i < dev->nregions; i++) {
 210        VuDevRegion *r = &dev->regions[i];
 211
 212        if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
 213            return (void *)(uintptr_t)
 214                qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
 215        }
 216    }
 217
 218    return NULL;
 219}
 220
 221static void
 222vmsg_close_fds(VhostUserMsg *vmsg)
 223{
 224    int i;
 225
 226    for (i = 0; i < vmsg->fd_num; i++) {
 227        close(vmsg->fds[i]);
 228    }
 229}
 230
 231/* Set reply payload.u64 and clear request flags and fd_num */
 232static void vmsg_set_reply_u64(VhostUserMsg *vmsg, uint64_t val)
 233{
 234    vmsg->flags = 0; /* defaults will be set by vu_send_reply() */
 235    vmsg->size = sizeof(vmsg->payload.u64);
 236    vmsg->payload.u64 = val;
 237    vmsg->fd_num = 0;
 238}
 239
 240/* A test to see if we have userfault available */
 241static bool
 242have_userfault(void)
 243{
 244#if defined(__linux__) && defined(__NR_userfaultfd) &&\
 245        defined(UFFD_FEATURE_MISSING_SHMEM) &&\
 246        defined(UFFD_FEATURE_MISSING_HUGETLBFS)
 247    /* Now test the kernel we're running on really has the features */
 248    int ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
 249    struct uffdio_api api_struct;
 250    if (ufd < 0) {
 251        return false;
 252    }
 253
 254    api_struct.api = UFFD_API;
 255    api_struct.features = UFFD_FEATURE_MISSING_SHMEM |
 256                          UFFD_FEATURE_MISSING_HUGETLBFS;
 257    if (ioctl(ufd, UFFDIO_API, &api_struct)) {
 258        close(ufd);
 259        return false;
 260    }
 261    close(ufd);
 262    return true;
 263
 264#else
 265    return false;
 266#endif
 267}
 268
 269static bool
 270vu_message_read(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
 271{
 272    char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
 273    struct iovec iov = {
 274        .iov_base = (char *)vmsg,
 275        .iov_len = VHOST_USER_HDR_SIZE,
 276    };
 277    struct msghdr msg = {
 278        .msg_iov = &iov,
 279        .msg_iovlen = 1,
 280        .msg_control = control,
 281        .msg_controllen = sizeof(control),
 282    };
 283    size_t fd_size;
 284    struct cmsghdr *cmsg;
 285    int rc;
 286
 287    do {
 288        rc = recvmsg(conn_fd, &msg, 0);
 289    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 290
 291    if (rc < 0) {
 292        vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
 293        return false;
 294    }
 295
 296    vmsg->fd_num = 0;
 297    for (cmsg = CMSG_FIRSTHDR(&msg);
 298         cmsg != NULL;
 299         cmsg = CMSG_NXTHDR(&msg, cmsg))
 300    {
 301        if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
 302            fd_size = cmsg->cmsg_len - CMSG_LEN(0);
 303            vmsg->fd_num = fd_size / sizeof(int);
 304            memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
 305            break;
 306        }
 307    }
 308
 309    if (vmsg->size > sizeof(vmsg->payload)) {
 310        vu_panic(dev,
 311                 "Error: too big message request: %d, size: vmsg->size: %u, "
 312                 "while sizeof(vmsg->payload) = %zu\n",
 313                 vmsg->request, vmsg->size, sizeof(vmsg->payload));
 314        goto fail;
 315    }
 316
 317    if (vmsg->size) {
 318        do {
 319            rc = read(conn_fd, &vmsg->payload, vmsg->size);
 320        } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 321
 322        if (rc <= 0) {
 323            vu_panic(dev, "Error while reading: %s", strerror(errno));
 324            goto fail;
 325        }
 326
 327        assert(rc == vmsg->size);
 328    }
 329
 330    return true;
 331
 332fail:
 333    vmsg_close_fds(vmsg);
 334
 335    return false;
 336}
 337
 338static bool
 339vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
 340{
 341    int rc;
 342    uint8_t *p = (uint8_t *)vmsg;
 343    char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
 344    struct iovec iov = {
 345        .iov_base = (char *)vmsg,
 346        .iov_len = VHOST_USER_HDR_SIZE,
 347    };
 348    struct msghdr msg = {
 349        .msg_iov = &iov,
 350        .msg_iovlen = 1,
 351        .msg_control = control,
 352    };
 353    struct cmsghdr *cmsg;
 354
 355    memset(control, 0, sizeof(control));
 356    assert(vmsg->fd_num <= VHOST_MEMORY_BASELINE_NREGIONS);
 357    if (vmsg->fd_num > 0) {
 358        size_t fdsize = vmsg->fd_num * sizeof(int);
 359        msg.msg_controllen = CMSG_SPACE(fdsize);
 360        cmsg = CMSG_FIRSTHDR(&msg);
 361        cmsg->cmsg_len = CMSG_LEN(fdsize);
 362        cmsg->cmsg_level = SOL_SOCKET;
 363        cmsg->cmsg_type = SCM_RIGHTS;
 364        memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize);
 365    } else {
 366        msg.msg_controllen = 0;
 367    }
 368
 369    do {
 370        rc = sendmsg(conn_fd, &msg, 0);
 371    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 372
 373    if (vmsg->size) {
 374        do {
 375            if (vmsg->data) {
 376                rc = write(conn_fd, vmsg->data, vmsg->size);
 377            } else {
 378                rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
 379            }
 380        } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 381    }
 382
 383    if (rc <= 0) {
 384        vu_panic(dev, "Error while writing: %s", strerror(errno));
 385        return false;
 386    }
 387
 388    return true;
 389}
 390
 391static bool
 392vu_send_reply(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
 393{
 394    /* Set the version in the flags when sending the reply */
 395    vmsg->flags &= ~VHOST_USER_VERSION_MASK;
 396    vmsg->flags |= VHOST_USER_VERSION;
 397    vmsg->flags |= VHOST_USER_REPLY_MASK;
 398
 399    return vu_message_write(dev, conn_fd, vmsg);
 400}
 401
 402/*
 403 * Processes a reply on the slave channel.
 404 * Entered with slave_mutex held and releases it before exit.
 405 * Returns true on success.
 406 */
 407static bool
 408vu_process_message_reply(VuDev *dev, const VhostUserMsg *vmsg)
 409{
 410    VhostUserMsg msg_reply;
 411    bool result = false;
 412
 413    if ((vmsg->flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
 414        result = true;
 415        goto out;
 416    }
 417
 418    if (!vu_message_read(dev, dev->slave_fd, &msg_reply)) {
 419        goto out;
 420    }
 421
 422    if (msg_reply.request != vmsg->request) {
 423        DPRINT("Received unexpected msg type. Expected %d received %d",
 424               vmsg->request, msg_reply.request);
 425        goto out;
 426    }
 427
 428    result = msg_reply.payload.u64 == 0;
 429
 430out:
 431    pthread_mutex_unlock(&dev->slave_mutex);
 432    return result;
 433}
 434
 435/* Kick the log_call_fd if required. */
 436static void
 437vu_log_kick(VuDev *dev)
 438{
 439    if (dev->log_call_fd != -1) {
 440        DPRINT("Kicking the QEMU's log...\n");
 441        if (eventfd_write(dev->log_call_fd, 1) < 0) {
 442            vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
 443        }
 444    }
 445}
 446
 447static void
 448vu_log_page(uint8_t *log_table, uint64_t page)
 449{
 450    DPRINT("Logged dirty guest page: %"PRId64"\n", page);
 451    atomic_or(&log_table[page / 8], 1 << (page % 8));
 452}
 453
 454static void
 455vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
 456{
 457    uint64_t page;
 458
 459    if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
 460        !dev->log_table || !length) {
 461        return;
 462    }
 463
 464    assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
 465
 466    page = address / VHOST_LOG_PAGE;
 467    while (page * VHOST_LOG_PAGE < address + length) {
 468        vu_log_page(dev->log_table, page);
 469        page += 1;
 470    }
 471
 472    vu_log_kick(dev);
 473}
 474
 475static void
 476vu_kick_cb(VuDev *dev, int condition, void *data)
 477{
 478    int index = (intptr_t)data;
 479    VuVirtq *vq = &dev->vq[index];
 480    int sock = vq->kick_fd;
 481    eventfd_t kick_data;
 482    ssize_t rc;
 483
 484    rc = eventfd_read(sock, &kick_data);
 485    if (rc == -1) {
 486        vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
 487        dev->remove_watch(dev, dev->vq[index].kick_fd);
 488    } else {
 489        DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
 490               kick_data, vq->handler, index);
 491        if (vq->handler) {
 492            vq->handler(dev, index);
 493        }
 494    }
 495}
 496
 497static bool
 498vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 499{
 500    vmsg->payload.u64 =
 501        /*
 502         * The following VIRTIO feature bits are supported by our virtqueue
 503         * implementation:
 504         */
 505        1ULL << VIRTIO_F_NOTIFY_ON_EMPTY |
 506        1ULL << VIRTIO_RING_F_INDIRECT_DESC |
 507        1ULL << VIRTIO_RING_F_EVENT_IDX |
 508        1ULL << VIRTIO_F_VERSION_1 |
 509
 510        /* vhost-user feature bits */
 511        1ULL << VHOST_F_LOG_ALL |
 512        1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
 513
 514    if (dev->iface->get_features) {
 515        vmsg->payload.u64 |= dev->iface->get_features(dev);
 516    }
 517
 518    vmsg->size = sizeof(vmsg->payload.u64);
 519    vmsg->fd_num = 0;
 520
 521    DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 522
 523    return true;
 524}
 525
 526static void
 527vu_set_enable_all_rings(VuDev *dev, bool enabled)
 528{
 529    uint16_t i;
 530
 531    for (i = 0; i < dev->max_queues; i++) {
 532        dev->vq[i].enable = enabled;
 533    }
 534}
 535
 536static bool
 537vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 538{
 539    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 540
 541    dev->features = vmsg->payload.u64;
 542
 543    if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
 544        vu_set_enable_all_rings(dev, true);
 545    }
 546
 547    if (dev->iface->set_features) {
 548        dev->iface->set_features(dev, dev->features);
 549    }
 550
 551    return false;
 552}
 553
 554static bool
 555vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
 556{
 557    return false;
 558}
 559
 560static void
 561vu_close_log(VuDev *dev)
 562{
 563    if (dev->log_table) {
 564        if (munmap(dev->log_table, dev->log_size) != 0) {
 565            perror("close log munmap() error");
 566        }
 567
 568        dev->log_table = NULL;
 569    }
 570    if (dev->log_call_fd != -1) {
 571        close(dev->log_call_fd);
 572        dev->log_call_fd = -1;
 573    }
 574}
 575
 576static bool
 577vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
 578{
 579    vu_set_enable_all_rings(dev, false);
 580
 581    return false;
 582}
 583
 584static bool
 585map_ring(VuDev *dev, VuVirtq *vq)
 586{
 587    vq->vring.desc = qva_to_va(dev, vq->vra.desc_user_addr);
 588    vq->vring.used = qva_to_va(dev, vq->vra.used_user_addr);
 589    vq->vring.avail = qva_to_va(dev, vq->vra.avail_user_addr);
 590
 591    DPRINT("Setting virtq addresses:\n");
 592    DPRINT("    vring_desc  at %p\n", vq->vring.desc);
 593    DPRINT("    vring_used  at %p\n", vq->vring.used);
 594    DPRINT("    vring_avail at %p\n", vq->vring.avail);
 595
 596    return !(vq->vring.desc && vq->vring.used && vq->vring.avail);
 597}
 598
 599static bool
 600generate_faults(VuDev *dev) {
 601    int i;
 602    for (i = 0; i < dev->nregions; i++) {
 603        VuDevRegion *dev_region = &dev->regions[i];
 604        int ret;
 605#ifdef UFFDIO_REGISTER
 606        /*
 607         * We should already have an open ufd. Mark each memory
 608         * range as ufd.
 609         * Discard any mapping we have here; note I can't use MADV_REMOVE
 610         * or fallocate to make the hole since I don't want to lose
 611         * data that's already arrived in the shared process.
 612         * TODO: How to do hugepage
 613         */
 614        ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
 615                      dev_region->size + dev_region->mmap_offset,
 616                      MADV_DONTNEED);
 617        if (ret) {
 618            fprintf(stderr,
 619                    "%s: Failed to madvise(DONTNEED) region %d: %s\n",
 620                    __func__, i, strerror(errno));
 621        }
 622        /*
 623         * Turn off transparent hugepages so we dont get lose wakeups
 624         * in neighbouring pages.
 625         * TODO: Turn this backon later.
 626         */
 627        ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
 628                      dev_region->size + dev_region->mmap_offset,
 629                      MADV_NOHUGEPAGE);
 630        if (ret) {
 631            /*
 632             * Note: This can happen legally on kernels that are configured
 633             * without madvise'able hugepages
 634             */
 635            fprintf(stderr,
 636                    "%s: Failed to madvise(NOHUGEPAGE) region %d: %s\n",
 637                    __func__, i, strerror(errno));
 638        }
 639        struct uffdio_register reg_struct;
 640        reg_struct.range.start = (uintptr_t)dev_region->mmap_addr;
 641        reg_struct.range.len = dev_region->size + dev_region->mmap_offset;
 642        reg_struct.mode = UFFDIO_REGISTER_MODE_MISSING;
 643
 644        if (ioctl(dev->postcopy_ufd, UFFDIO_REGISTER, &reg_struct)) {
 645            vu_panic(dev, "%s: Failed to userfault region %d "
 646                          "@%p + size:%zx offset: %zx: (ufd=%d)%s\n",
 647                     __func__, i,
 648                     dev_region->mmap_addr,
 649                     dev_region->size, dev_region->mmap_offset,
 650                     dev->postcopy_ufd, strerror(errno));
 651            return false;
 652        }
 653        if (!(reg_struct.ioctls & ((__u64)1 << _UFFDIO_COPY))) {
 654            vu_panic(dev, "%s Region (%d) doesn't support COPY",
 655                     __func__, i);
 656            return false;
 657        }
 658        DPRINT("%s: region %d: Registered userfault for %"
 659               PRIx64 " + %" PRIx64 "\n", __func__, i,
 660               (uint64_t)reg_struct.range.start,
 661               (uint64_t)reg_struct.range.len);
 662        /* Now it's registered we can let the client at it */
 663        if (mprotect((void *)(uintptr_t)dev_region->mmap_addr,
 664                     dev_region->size + dev_region->mmap_offset,
 665                     PROT_READ | PROT_WRITE)) {
 666            vu_panic(dev, "failed to mprotect region %d for postcopy (%s)",
 667                     i, strerror(errno));
 668            return false;
 669        }
 670        /* TODO: Stash 'zero' support flags somewhere */
 671#endif
 672    }
 673
 674    return true;
 675}
 676
 677static bool
 678vu_add_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
 679    int i;
 680    bool track_ramblocks = dev->postcopy_listening;
 681    VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
 682    VuDevRegion *dev_region = &dev->regions[dev->nregions];
 683    void *mmap_addr;
 684
 685    /*
 686     * If we are in postcopy mode and we receive a u64 payload with a 0 value
 687     * we know all the postcopy client bases have been recieved, and we
 688     * should start generating faults.
 689     */
 690    if (track_ramblocks &&
 691        vmsg->size == sizeof(vmsg->payload.u64) &&
 692        vmsg->payload.u64 == 0) {
 693        (void)generate_faults(dev);
 694        return false;
 695    }
 696
 697    DPRINT("Adding region: %d\n", dev->nregions);
 698    DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
 699           msg_region->guest_phys_addr);
 700    DPRINT("    memory_size:     0x%016"PRIx64"\n",
 701           msg_region->memory_size);
 702    DPRINT("    userspace_addr   0x%016"PRIx64"\n",
 703           msg_region->userspace_addr);
 704    DPRINT("    mmap_offset      0x%016"PRIx64"\n",
 705           msg_region->mmap_offset);
 706
 707    dev_region->gpa = msg_region->guest_phys_addr;
 708    dev_region->size = msg_region->memory_size;
 709    dev_region->qva = msg_region->userspace_addr;
 710    dev_region->mmap_offset = msg_region->mmap_offset;
 711
 712    /*
 713     * We don't use offset argument of mmap() since the
 714     * mapped address has to be page aligned, and we use huge
 715     * pages.
 716     */
 717    if (track_ramblocks) {
 718        /*
 719         * In postcopy we're using PROT_NONE here to catch anyone
 720         * accessing it before we userfault.
 721         */
 722        mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
 723                         PROT_NONE, MAP_SHARED,
 724                         vmsg->fds[0], 0);
 725    } else {
 726        mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
 727                         PROT_READ | PROT_WRITE, MAP_SHARED, vmsg->fds[0],
 728                         0);
 729    }
 730
 731    if (mmap_addr == MAP_FAILED) {
 732        vu_panic(dev, "region mmap error: %s", strerror(errno));
 733    } else {
 734        dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
 735        DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
 736               dev_region->mmap_addr);
 737    }
 738
 739    close(vmsg->fds[0]);
 740
 741    if (track_ramblocks) {
 742        /*
 743         * Return the address to QEMU so that it can translate the ufd
 744         * fault addresses back.
 745         */
 746        msg_region->userspace_addr = (uintptr_t)(mmap_addr +
 747                                                 dev_region->mmap_offset);
 748
 749        /* Send the message back to qemu with the addresses filled in. */
 750        vmsg->fd_num = 0;
 751        if (!vu_send_reply(dev, dev->sock, vmsg)) {
 752            vu_panic(dev, "failed to respond to add-mem-region for postcopy");
 753            return false;
 754        }
 755
 756        DPRINT("Successfully added new region in postcopy\n");
 757        dev->nregions++;
 758        return false;
 759
 760    } else {
 761        for (i = 0; i < dev->max_queues; i++) {
 762            if (dev->vq[i].vring.desc) {
 763                if (map_ring(dev, &dev->vq[i])) {
 764                    vu_panic(dev, "remapping queue %d for new memory region",
 765                             i);
 766                }
 767            }
 768        }
 769
 770        DPRINT("Successfully added new region\n");
 771        dev->nregions++;
 772        vmsg_set_reply_u64(vmsg, 0);
 773        return true;
 774    }
 775}
 776
 777static inline bool reg_equal(VuDevRegion *vudev_reg,
 778                             VhostUserMemoryRegion *msg_reg)
 779{
 780    if (vudev_reg->gpa == msg_reg->guest_phys_addr &&
 781        vudev_reg->qva == msg_reg->userspace_addr &&
 782        vudev_reg->size == msg_reg->memory_size) {
 783        return true;
 784    }
 785
 786    return false;
 787}
 788
 789static bool
 790vu_rem_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
 791    int i, j;
 792    bool found = false;
 793    VuDevRegion shadow_regions[VHOST_USER_MAX_RAM_SLOTS] = {};
 794    VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
 795
 796    DPRINT("Removing region:\n");
 797    DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
 798           msg_region->guest_phys_addr);
 799    DPRINT("    memory_size:     0x%016"PRIx64"\n",
 800           msg_region->memory_size);
 801    DPRINT("    userspace_addr   0x%016"PRIx64"\n",
 802           msg_region->userspace_addr);
 803    DPRINT("    mmap_offset      0x%016"PRIx64"\n",
 804           msg_region->mmap_offset);
 805
 806    for (i = 0, j = 0; i < dev->nregions; i++) {
 807        if (!reg_equal(&dev->regions[i], msg_region)) {
 808            shadow_regions[j].gpa = dev->regions[i].gpa;
 809            shadow_regions[j].size = dev->regions[i].size;
 810            shadow_regions[j].qva = dev->regions[i].qva;
 811            shadow_regions[j].mmap_offset = dev->regions[i].mmap_offset;
 812            j++;
 813        } else {
 814            found = true;
 815            VuDevRegion *r = &dev->regions[i];
 816            void *m = (void *) (uintptr_t) r->mmap_addr;
 817
 818            if (m) {
 819                munmap(m, r->size + r->mmap_offset);
 820            }
 821        }
 822    }
 823
 824    if (found) {
 825        memcpy(dev->regions, shadow_regions,
 826               sizeof(VuDevRegion) * VHOST_USER_MAX_RAM_SLOTS);
 827        DPRINT("Successfully removed a region\n");
 828        dev->nregions--;
 829        vmsg_set_reply_u64(vmsg, 0);
 830    } else {
 831        vu_panic(dev, "Specified region not found\n");
 832    }
 833
 834    return true;
 835}
 836
 837static bool
 838vu_set_mem_table_exec_postcopy(VuDev *dev, VhostUserMsg *vmsg)
 839{
 840    int i;
 841    VhostUserMemory m = vmsg->payload.memory, *memory = &m;
 842    dev->nregions = memory->nregions;
 843
 844    DPRINT("Nregions: %d\n", memory->nregions);
 845    for (i = 0; i < dev->nregions; i++) {
 846        void *mmap_addr;
 847        VhostUserMemoryRegion *msg_region = &memory->regions[i];
 848        VuDevRegion *dev_region = &dev->regions[i];
 849
 850        DPRINT("Region %d\n", i);
 851        DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
 852               msg_region->guest_phys_addr);
 853        DPRINT("    memory_size:     0x%016"PRIx64"\n",
 854               msg_region->memory_size);
 855        DPRINT("    userspace_addr   0x%016"PRIx64"\n",
 856               msg_region->userspace_addr);
 857        DPRINT("    mmap_offset      0x%016"PRIx64"\n",
 858               msg_region->mmap_offset);
 859
 860        dev_region->gpa = msg_region->guest_phys_addr;
 861        dev_region->size = msg_region->memory_size;
 862        dev_region->qva = msg_region->userspace_addr;
 863        dev_region->mmap_offset = msg_region->mmap_offset;
 864
 865        /* We don't use offset argument of mmap() since the
 866         * mapped address has to be page aligned, and we use huge
 867         * pages.
 868         * In postcopy we're using PROT_NONE here to catch anyone
 869         * accessing it before we userfault
 870         */
 871        mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
 872                         PROT_NONE, MAP_SHARED,
 873                         vmsg->fds[i], 0);
 874
 875        if (mmap_addr == MAP_FAILED) {
 876            vu_panic(dev, "region mmap error: %s", strerror(errno));
 877        } else {
 878            dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
 879            DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
 880                   dev_region->mmap_addr);
 881        }
 882
 883        /* Return the address to QEMU so that it can translate the ufd
 884         * fault addresses back.
 885         */
 886        msg_region->userspace_addr = (uintptr_t)(mmap_addr +
 887                                                 dev_region->mmap_offset);
 888        close(vmsg->fds[i]);
 889    }
 890
 891    /* Send the message back to qemu with the addresses filled in */
 892    vmsg->fd_num = 0;
 893    if (!vu_send_reply(dev, dev->sock, vmsg)) {
 894        vu_panic(dev, "failed to respond to set-mem-table for postcopy");
 895        return false;
 896    }
 897
 898    /* Wait for QEMU to confirm that it's registered the handler for the
 899     * faults.
 900     */
 901    if (!vu_message_read(dev, dev->sock, vmsg) ||
 902        vmsg->size != sizeof(vmsg->payload.u64) ||
 903        vmsg->payload.u64 != 0) {
 904        vu_panic(dev, "failed to receive valid ack for postcopy set-mem-table");
 905        return false;
 906    }
 907
 908    /* OK, now we can go and register the memory and generate faults */
 909    (void)generate_faults(dev);
 910
 911    return false;
 912}
 913
 914static bool
 915vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
 916{
 917    int i;
 918    VhostUserMemory m = vmsg->payload.memory, *memory = &m;
 919
 920    for (i = 0; i < dev->nregions; i++) {
 921        VuDevRegion *r = &dev->regions[i];
 922        void *m = (void *) (uintptr_t) r->mmap_addr;
 923
 924        if (m) {
 925            munmap(m, r->size + r->mmap_offset);
 926        }
 927    }
 928    dev->nregions = memory->nregions;
 929
 930    if (dev->postcopy_listening) {
 931        return vu_set_mem_table_exec_postcopy(dev, vmsg);
 932    }
 933
 934    DPRINT("Nregions: %d\n", memory->nregions);
 935    for (i = 0; i < dev->nregions; i++) {
 936        void *mmap_addr;
 937        VhostUserMemoryRegion *msg_region = &memory->regions[i];
 938        VuDevRegion *dev_region = &dev->regions[i];
 939
 940        DPRINT("Region %d\n", i);
 941        DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
 942               msg_region->guest_phys_addr);
 943        DPRINT("    memory_size:     0x%016"PRIx64"\n",
 944               msg_region->memory_size);
 945        DPRINT("    userspace_addr   0x%016"PRIx64"\n",
 946               msg_region->userspace_addr);
 947        DPRINT("    mmap_offset      0x%016"PRIx64"\n",
 948               msg_region->mmap_offset);
 949
 950        dev_region->gpa = msg_region->guest_phys_addr;
 951        dev_region->size = msg_region->memory_size;
 952        dev_region->qva = msg_region->userspace_addr;
 953        dev_region->mmap_offset = msg_region->mmap_offset;
 954
 955        /* We don't use offset argument of mmap() since the
 956         * mapped address has to be page aligned, and we use huge
 957         * pages.  */
 958        mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
 959                         PROT_READ | PROT_WRITE, MAP_SHARED,
 960                         vmsg->fds[i], 0);
 961
 962        if (mmap_addr == MAP_FAILED) {
 963            vu_panic(dev, "region mmap error: %s", strerror(errno));
 964        } else {
 965            dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
 966            DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
 967                   dev_region->mmap_addr);
 968        }
 969
 970        close(vmsg->fds[i]);
 971    }
 972
 973    for (i = 0; i < dev->max_queues; i++) {
 974        if (dev->vq[i].vring.desc) {
 975            if (map_ring(dev, &dev->vq[i])) {
 976                vu_panic(dev, "remaping queue %d during setmemtable", i);
 977            }
 978        }
 979    }
 980
 981    return false;
 982}
 983
 984static bool
 985vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
 986{
 987    int fd;
 988    uint64_t log_mmap_size, log_mmap_offset;
 989    void *rc;
 990
 991    if (vmsg->fd_num != 1 ||
 992        vmsg->size != sizeof(vmsg->payload.log)) {
 993        vu_panic(dev, "Invalid log_base message");
 994        return true;
 995    }
 996
 997    fd = vmsg->fds[0];
 998    log_mmap_offset = vmsg->payload.log.mmap_offset;
 999    log_mmap_size = vmsg->payload.log.mmap_size;
1000    DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
1001    DPRINT("Log mmap_size:   %"PRId64"\n", log_mmap_size);
1002
1003    rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
1004              log_mmap_offset);
1005    close(fd);
1006    if (rc == MAP_FAILED) {
1007        perror("log mmap error");
1008    }
1009
1010    if (dev->log_table) {
1011        munmap(dev->log_table, dev->log_size);
1012    }
1013    dev->log_table = rc;
1014    dev->log_size = log_mmap_size;
1015
1016    vmsg->size = sizeof(vmsg->payload.u64);
1017    vmsg->fd_num = 0;
1018
1019    return true;
1020}
1021
1022static bool
1023vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
1024{
1025    if (vmsg->fd_num != 1) {
1026        vu_panic(dev, "Invalid log_fd message");
1027        return false;
1028    }
1029
1030    if (dev->log_call_fd != -1) {
1031        close(dev->log_call_fd);
1032    }
1033    dev->log_call_fd = vmsg->fds[0];
1034    DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
1035
1036    return false;
1037}
1038
1039static bool
1040vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1041{
1042    unsigned int index = vmsg->payload.state.index;
1043    unsigned int num = vmsg->payload.state.num;
1044
1045    DPRINT("State.index: %d\n", index);
1046    DPRINT("State.num:   %d\n", num);
1047    dev->vq[index].vring.num = num;
1048
1049    return false;
1050}
1051
1052static bool
1053vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
1054{
1055    struct vhost_vring_addr addr = vmsg->payload.addr, *vra = &addr;
1056    unsigned int index = vra->index;
1057    VuVirtq *vq = &dev->vq[index];
1058
1059    DPRINT("vhost_vring_addr:\n");
1060    DPRINT("    index:  %d\n", vra->index);
1061    DPRINT("    flags:  %d\n", vra->flags);
1062    DPRINT("    desc_user_addr:   0x%016" PRIx64 "\n", vra->desc_user_addr);
1063    DPRINT("    used_user_addr:   0x%016" PRIx64 "\n", vra->used_user_addr);
1064    DPRINT("    avail_user_addr:  0x%016" PRIx64 "\n", vra->avail_user_addr);
1065    DPRINT("    log_guest_addr:   0x%016" PRIx64 "\n", vra->log_guest_addr);
1066
1067    vq->vra = *vra;
1068    vq->vring.flags = vra->flags;
1069    vq->vring.log_guest_addr = vra->log_guest_addr;
1070
1071
1072    if (map_ring(dev, vq)) {
1073        vu_panic(dev, "Invalid vring_addr message");
1074        return false;
1075    }
1076
1077    vq->used_idx = vq->vring.used->idx;
1078
1079    if (vq->last_avail_idx != vq->used_idx) {
1080        bool resume = dev->iface->queue_is_processed_in_order &&
1081            dev->iface->queue_is_processed_in_order(dev, index);
1082
1083        DPRINT("Last avail index != used index: %u != %u%s\n",
1084               vq->last_avail_idx, vq->used_idx,
1085               resume ? ", resuming" : "");
1086
1087        if (resume) {
1088            vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
1089        }
1090    }
1091
1092    return false;
1093}
1094
1095static bool
1096vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1097{
1098    unsigned int index = vmsg->payload.state.index;
1099    unsigned int num = vmsg->payload.state.num;
1100
1101    DPRINT("State.index: %d\n", index);
1102    DPRINT("State.num:   %d\n", num);
1103    dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
1104
1105    return false;
1106}
1107
1108static bool
1109vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1110{
1111    unsigned int index = vmsg->payload.state.index;
1112
1113    DPRINT("State.index: %d\n", index);
1114    vmsg->payload.state.num = dev->vq[index].last_avail_idx;
1115    vmsg->size = sizeof(vmsg->payload.state);
1116
1117    dev->vq[index].started = false;
1118    if (dev->iface->queue_set_started) {
1119        dev->iface->queue_set_started(dev, index, false);
1120    }
1121
1122    if (dev->vq[index].call_fd != -1) {
1123        close(dev->vq[index].call_fd);
1124        dev->vq[index].call_fd = -1;
1125    }
1126    if (dev->vq[index].kick_fd != -1) {
1127        dev->remove_watch(dev, dev->vq[index].kick_fd);
1128        close(dev->vq[index].kick_fd);
1129        dev->vq[index].kick_fd = -1;
1130    }
1131
1132    return true;
1133}
1134
1135static bool
1136vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
1137{
1138    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1139    bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1140
1141    if (index >= dev->max_queues) {
1142        vmsg_close_fds(vmsg);
1143        vu_panic(dev, "Invalid queue index: %u", index);
1144        return false;
1145    }
1146
1147    if (nofd) {
1148        vmsg_close_fds(vmsg);
1149        return true;
1150    }
1151
1152    if (vmsg->fd_num != 1) {
1153        vmsg_close_fds(vmsg);
1154        vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
1155        return false;
1156    }
1157
1158    return true;
1159}
1160
1161static int
1162inflight_desc_compare(const void *a, const void *b)
1163{
1164    VuVirtqInflightDesc *desc0 = (VuVirtqInflightDesc *)a,
1165                        *desc1 = (VuVirtqInflightDesc *)b;
1166
1167    if (desc1->counter > desc0->counter &&
1168        (desc1->counter - desc0->counter) < VIRTQUEUE_MAX_SIZE * 2) {
1169        return 1;
1170    }
1171
1172    return -1;
1173}
1174
1175static int
1176vu_check_queue_inflights(VuDev *dev, VuVirtq *vq)
1177{
1178    int i = 0;
1179
1180    if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
1181        return 0;
1182    }
1183
1184    if (unlikely(!vq->inflight)) {
1185        return -1;
1186    }
1187
1188    if (unlikely(!vq->inflight->version)) {
1189        /* initialize the buffer */
1190        vq->inflight->version = INFLIGHT_VERSION;
1191        return 0;
1192    }
1193
1194    vq->used_idx = vq->vring.used->idx;
1195    vq->resubmit_num = 0;
1196    vq->resubmit_list = NULL;
1197    vq->counter = 0;
1198
1199    if (unlikely(vq->inflight->used_idx != vq->used_idx)) {
1200        vq->inflight->desc[vq->inflight->last_batch_head].inflight = 0;
1201
1202        barrier();
1203
1204        vq->inflight->used_idx = vq->used_idx;
1205    }
1206
1207    for (i = 0; i < vq->inflight->desc_num; i++) {
1208        if (vq->inflight->desc[i].inflight == 1) {
1209            vq->inuse++;
1210        }
1211    }
1212
1213    vq->shadow_avail_idx = vq->last_avail_idx = vq->inuse + vq->used_idx;
1214
1215    if (vq->inuse) {
1216        vq->resubmit_list = calloc(vq->inuse, sizeof(VuVirtqInflightDesc));
1217        if (!vq->resubmit_list) {
1218            return -1;
1219        }
1220
1221        for (i = 0; i < vq->inflight->desc_num; i++) {
1222            if (vq->inflight->desc[i].inflight) {
1223                vq->resubmit_list[vq->resubmit_num].index = i;
1224                vq->resubmit_list[vq->resubmit_num].counter =
1225                                        vq->inflight->desc[i].counter;
1226                vq->resubmit_num++;
1227            }
1228        }
1229
1230        if (vq->resubmit_num > 1) {
1231            qsort(vq->resubmit_list, vq->resubmit_num,
1232                  sizeof(VuVirtqInflightDesc), inflight_desc_compare);
1233        }
1234        vq->counter = vq->resubmit_list[0].counter + 1;
1235    }
1236
1237    /* in case of I/O hang after reconnecting */
1238    if (eventfd_write(vq->kick_fd, 1)) {
1239        return -1;
1240    }
1241
1242    return 0;
1243}
1244
1245static bool
1246vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
1247{
1248    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1249    bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1250
1251    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1252
1253    if (!vu_check_queue_msg_file(dev, vmsg)) {
1254        return false;
1255    }
1256
1257    if (dev->vq[index].kick_fd != -1) {
1258        dev->remove_watch(dev, dev->vq[index].kick_fd);
1259        close(dev->vq[index].kick_fd);
1260        dev->vq[index].kick_fd = -1;
1261    }
1262
1263    dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0];
1264    DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index);
1265
1266    dev->vq[index].started = true;
1267    if (dev->iface->queue_set_started) {
1268        dev->iface->queue_set_started(dev, index, true);
1269    }
1270
1271    if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
1272        dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
1273                       vu_kick_cb, (void *)(long)index);
1274
1275        DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
1276               dev->vq[index].kick_fd, index);
1277    }
1278
1279    if (vu_check_queue_inflights(dev, &dev->vq[index])) {
1280        vu_panic(dev, "Failed to check inflights for vq: %d\n", index);
1281    }
1282
1283    return false;
1284}
1285
1286void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
1287                          vu_queue_handler_cb handler)
1288{
1289    int qidx = vq - dev->vq;
1290
1291    vq->handler = handler;
1292    if (vq->kick_fd >= 0) {
1293        if (handler) {
1294            dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
1295                           vu_kick_cb, (void *)(long)qidx);
1296        } else {
1297            dev->remove_watch(dev, vq->kick_fd);
1298        }
1299    }
1300}
1301
1302bool vu_set_queue_host_notifier(VuDev *dev, VuVirtq *vq, int fd,
1303                                int size, int offset)
1304{
1305    int qidx = vq - dev->vq;
1306    int fd_num = 0;
1307    VhostUserMsg vmsg = {
1308        .request = VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG,
1309        .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1310        .size = sizeof(vmsg.payload.area),
1311        .payload.area = {
1312            .u64 = qidx & VHOST_USER_VRING_IDX_MASK,
1313            .size = size,
1314            .offset = offset,
1315        },
1316    };
1317
1318    if (fd == -1) {
1319        vmsg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1320    } else {
1321        vmsg.fds[fd_num++] = fd;
1322    }
1323
1324    vmsg.fd_num = fd_num;
1325
1326    if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD)) {
1327        return false;
1328    }
1329
1330    pthread_mutex_lock(&dev->slave_mutex);
1331    if (!vu_message_write(dev, dev->slave_fd, &vmsg)) {
1332        pthread_mutex_unlock(&dev->slave_mutex);
1333        return false;
1334    }
1335
1336    /* Also unlocks the slave_mutex */
1337    return vu_process_message_reply(dev, &vmsg);
1338}
1339
1340static bool
1341vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
1342{
1343    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1344    bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1345
1346    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1347
1348    if (!vu_check_queue_msg_file(dev, vmsg)) {
1349        return false;
1350    }
1351
1352    if (dev->vq[index].call_fd != -1) {
1353        close(dev->vq[index].call_fd);
1354        dev->vq[index].call_fd = -1;
1355    }
1356
1357    dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0];
1358
1359    /* in case of I/O hang after reconnecting */
1360    if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) {
1361        return -1;
1362    }
1363
1364    DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index);
1365
1366    return false;
1367}
1368
1369static bool
1370vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
1371{
1372    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1373    bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1374
1375    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1376
1377    if (!vu_check_queue_msg_file(dev, vmsg)) {
1378        return false;
1379    }
1380
1381    if (dev->vq[index].err_fd != -1) {
1382        close(dev->vq[index].err_fd);
1383        dev->vq[index].err_fd = -1;
1384    }
1385
1386    dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0];
1387
1388    return false;
1389}
1390
1391static bool
1392vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1393{
1394    /*
1395     * Note that we support, but intentionally do not set,
1396     * VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS. This means that
1397     * a device implementation can return it in its callback
1398     * (get_protocol_features) if it wants to use this for
1399     * simulation, but it is otherwise not desirable (if even
1400     * implemented by the master.)
1401     */
1402    uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_MQ |
1403                        1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
1404                        1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ |
1405                        1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER |
1406                        1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD |
1407                        1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK |
1408                        1ULL << VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS;
1409
1410    if (have_userfault()) {
1411        features |= 1ULL << VHOST_USER_PROTOCOL_F_PAGEFAULT;
1412    }
1413
1414    if (dev->iface->get_config && dev->iface->set_config) {
1415        features |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
1416    }
1417
1418    if (dev->iface->get_protocol_features) {
1419        features |= dev->iface->get_protocol_features(dev);
1420    }
1421
1422    vmsg_set_reply_u64(vmsg, features);
1423    return true;
1424}
1425
1426static bool
1427vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1428{
1429    uint64_t features = vmsg->payload.u64;
1430
1431    DPRINT("u64: 0x%016"PRIx64"\n", features);
1432
1433    dev->protocol_features = vmsg->payload.u64;
1434
1435    if (vu_has_protocol_feature(dev,
1436                                VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
1437        (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ) ||
1438         !vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
1439        /*
1440         * The use case for using messages for kick/call is simulation, to make
1441         * the kick and call synchronous. To actually get that behaviour, both
1442         * of the other features are required.
1443         * Theoretically, one could use only kick messages, or do them without
1444         * having F_REPLY_ACK, but too many (possibly pending) messages on the
1445         * socket will eventually cause the master to hang, to avoid this in
1446         * scenarios where not desired enforce that the settings are in a way
1447         * that actually enables the simulation case.
1448         */
1449        vu_panic(dev,
1450                 "F_IN_BAND_NOTIFICATIONS requires F_SLAVE_REQ && F_REPLY_ACK");
1451        return false;
1452    }
1453
1454    if (dev->iface->set_protocol_features) {
1455        dev->iface->set_protocol_features(dev, features);
1456    }
1457
1458    return false;
1459}
1460
1461static bool
1462vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1463{
1464    vmsg_set_reply_u64(vmsg, dev->max_queues);
1465    return true;
1466}
1467
1468static bool
1469vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
1470{
1471    unsigned int index = vmsg->payload.state.index;
1472    unsigned int enable = vmsg->payload.state.num;
1473
1474    DPRINT("State.index: %d\n", index);
1475    DPRINT("State.enable:   %d\n", enable);
1476
1477    if (index >= dev->max_queues) {
1478        vu_panic(dev, "Invalid vring_enable index: %u", index);
1479        return false;
1480    }
1481
1482    dev->vq[index].enable = enable;
1483    return false;
1484}
1485
1486static bool
1487vu_set_slave_req_fd(VuDev *dev, VhostUserMsg *vmsg)
1488{
1489    if (vmsg->fd_num != 1) {
1490        vu_panic(dev, "Invalid slave_req_fd message (%d fd's)", vmsg->fd_num);
1491        return false;
1492    }
1493
1494    if (dev->slave_fd != -1) {
1495        close(dev->slave_fd);
1496    }
1497    dev->slave_fd = vmsg->fds[0];
1498    DPRINT("Got slave_fd: %d\n", vmsg->fds[0]);
1499
1500    return false;
1501}
1502
1503static bool
1504vu_get_config(VuDev *dev, VhostUserMsg *vmsg)
1505{
1506    int ret = -1;
1507
1508    if (dev->iface->get_config) {
1509        ret = dev->iface->get_config(dev, vmsg->payload.config.region,
1510                                     vmsg->payload.config.size);
1511    }
1512
1513    if (ret) {
1514        /* resize to zero to indicate an error to master */
1515        vmsg->size = 0;
1516    }
1517
1518    return true;
1519}
1520
1521static bool
1522vu_set_config(VuDev *dev, VhostUserMsg *vmsg)
1523{
1524    int ret = -1;
1525
1526    if (dev->iface->set_config) {
1527        ret = dev->iface->set_config(dev, vmsg->payload.config.region,
1528                                     vmsg->payload.config.offset,
1529                                     vmsg->payload.config.size,
1530                                     vmsg->payload.config.flags);
1531        if (ret) {
1532            vu_panic(dev, "Set virtio configuration space failed");
1533        }
1534    }
1535
1536    return false;
1537}
1538
1539static bool
1540vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg)
1541{
1542    dev->postcopy_ufd = -1;
1543#ifdef UFFDIO_API
1544    struct uffdio_api api_struct;
1545
1546    dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
1547    vmsg->size = 0;
1548#endif
1549
1550    if (dev->postcopy_ufd == -1) {
1551        vu_panic(dev, "Userfaultfd not available: %s", strerror(errno));
1552        goto out;
1553    }
1554
1555#ifdef UFFDIO_API
1556    api_struct.api = UFFD_API;
1557    api_struct.features = 0;
1558    if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) {
1559        vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno));
1560        close(dev->postcopy_ufd);
1561        dev->postcopy_ufd = -1;
1562        goto out;
1563    }
1564    /* TODO: Stash feature flags somewhere */
1565#endif
1566
1567out:
1568    /* Return a ufd to the QEMU */
1569    vmsg->fd_num = 1;
1570    vmsg->fds[0] = dev->postcopy_ufd;
1571    return true; /* = send a reply */
1572}
1573
1574static bool
1575vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg)
1576{
1577    if (dev->nregions) {
1578        vu_panic(dev, "Regions already registered at postcopy-listen");
1579        vmsg_set_reply_u64(vmsg, -1);
1580        return true;
1581    }
1582    dev->postcopy_listening = true;
1583
1584    vmsg_set_reply_u64(vmsg, 0);
1585    return true;
1586}
1587
1588static bool
1589vu_set_postcopy_end(VuDev *dev, VhostUserMsg *vmsg)
1590{
1591    DPRINT("%s: Entry\n", __func__);
1592    dev->postcopy_listening = false;
1593    if (dev->postcopy_ufd > 0) {
1594        close(dev->postcopy_ufd);
1595        dev->postcopy_ufd = -1;
1596        DPRINT("%s: Done close\n", __func__);
1597    }
1598
1599    vmsg_set_reply_u64(vmsg, 0);
1600    DPRINT("%s: exit\n", __func__);
1601    return true;
1602}
1603
1604static inline uint64_t
1605vu_inflight_queue_size(uint16_t queue_size)
1606{
1607    return ALIGN_UP(sizeof(VuDescStateSplit) * queue_size +
1608           sizeof(uint16_t), INFLIGHT_ALIGNMENT);
1609}
1610
1611static bool
1612vu_get_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1613{
1614    int fd;
1615    void *addr;
1616    uint64_t mmap_size;
1617    uint16_t num_queues, queue_size;
1618
1619    if (vmsg->size != sizeof(vmsg->payload.inflight)) {
1620        vu_panic(dev, "Invalid get_inflight_fd message:%d", vmsg->size);
1621        vmsg->payload.inflight.mmap_size = 0;
1622        return true;
1623    }
1624
1625    num_queues = vmsg->payload.inflight.num_queues;
1626    queue_size = vmsg->payload.inflight.queue_size;
1627
1628    DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1629    DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1630
1631    mmap_size = vu_inflight_queue_size(queue_size) * num_queues;
1632
1633    addr = qemu_memfd_alloc("vhost-inflight", mmap_size,
1634                            F_SEAL_GROW | F_SEAL_SHRINK | F_SEAL_SEAL,
1635                            &fd, NULL);
1636
1637    if (!addr) {
1638        vu_panic(dev, "Failed to alloc vhost inflight area");
1639        vmsg->payload.inflight.mmap_size = 0;
1640        return true;
1641    }
1642
1643    memset(addr, 0, mmap_size);
1644
1645    dev->inflight_info.addr = addr;
1646    dev->inflight_info.size = vmsg->payload.inflight.mmap_size = mmap_size;
1647    dev->inflight_info.fd = vmsg->fds[0] = fd;
1648    vmsg->fd_num = 1;
1649    vmsg->payload.inflight.mmap_offset = 0;
1650
1651    DPRINT("send inflight mmap_size: %"PRId64"\n",
1652           vmsg->payload.inflight.mmap_size);
1653    DPRINT("send inflight mmap offset: %"PRId64"\n",
1654           vmsg->payload.inflight.mmap_offset);
1655
1656    return true;
1657}
1658
1659static bool
1660vu_set_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1661{
1662    int fd, i;
1663    uint64_t mmap_size, mmap_offset;
1664    uint16_t num_queues, queue_size;
1665    void *rc;
1666
1667    if (vmsg->fd_num != 1 ||
1668        vmsg->size != sizeof(vmsg->payload.inflight)) {
1669        vu_panic(dev, "Invalid set_inflight_fd message size:%d fds:%d",
1670                 vmsg->size, vmsg->fd_num);
1671        return false;
1672    }
1673
1674    fd = vmsg->fds[0];
1675    mmap_size = vmsg->payload.inflight.mmap_size;
1676    mmap_offset = vmsg->payload.inflight.mmap_offset;
1677    num_queues = vmsg->payload.inflight.num_queues;
1678    queue_size = vmsg->payload.inflight.queue_size;
1679
1680    DPRINT("set_inflight_fd mmap_size: %"PRId64"\n", mmap_size);
1681    DPRINT("set_inflight_fd mmap_offset: %"PRId64"\n", mmap_offset);
1682    DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1683    DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1684
1685    rc = mmap(0, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED,
1686              fd, mmap_offset);
1687
1688    if (rc == MAP_FAILED) {
1689        vu_panic(dev, "set_inflight_fd mmap error: %s", strerror(errno));
1690        return false;
1691    }
1692
1693    if (dev->inflight_info.fd) {
1694        close(dev->inflight_info.fd);
1695    }
1696
1697    if (dev->inflight_info.addr) {
1698        munmap(dev->inflight_info.addr, dev->inflight_info.size);
1699    }
1700
1701    dev->inflight_info.fd = fd;
1702    dev->inflight_info.addr = rc;
1703    dev->inflight_info.size = mmap_size;
1704
1705    for (i = 0; i < num_queues; i++) {
1706        dev->vq[i].inflight = (VuVirtqInflight *)rc;
1707        dev->vq[i].inflight->desc_num = queue_size;
1708        rc = (void *)((char *)rc + vu_inflight_queue_size(queue_size));
1709    }
1710
1711    return false;
1712}
1713
1714static bool
1715vu_handle_vring_kick(VuDev *dev, VhostUserMsg *vmsg)
1716{
1717    unsigned int index = vmsg->payload.state.index;
1718
1719    if (index >= dev->max_queues) {
1720        vu_panic(dev, "Invalid queue index: %u", index);
1721        return false;
1722    }
1723
1724    DPRINT("Got kick message: handler:%p idx:%d\n",
1725           dev->vq[index].handler, index);
1726
1727    if (!dev->vq[index].started) {
1728        dev->vq[index].started = true;
1729
1730        if (dev->iface->queue_set_started) {
1731            dev->iface->queue_set_started(dev, index, true);
1732        }
1733    }
1734
1735    if (dev->vq[index].handler) {
1736        dev->vq[index].handler(dev, index);
1737    }
1738
1739    return false;
1740}
1741
1742static bool vu_handle_get_max_memslots(VuDev *dev, VhostUserMsg *vmsg)
1743{
1744    vmsg->flags = VHOST_USER_REPLY_MASK | VHOST_USER_VERSION;
1745    vmsg->size  = sizeof(vmsg->payload.u64);
1746    vmsg->payload.u64 = VHOST_USER_MAX_RAM_SLOTS;
1747    vmsg->fd_num = 0;
1748
1749    if (!vu_message_write(dev, dev->sock, vmsg)) {
1750        vu_panic(dev, "Failed to send max ram slots: %s\n", strerror(errno));
1751    }
1752
1753    DPRINT("u64: 0x%016"PRIx64"\n", (uint64_t) VHOST_USER_MAX_RAM_SLOTS);
1754
1755    return false;
1756}
1757
1758static bool
1759vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
1760{
1761    int do_reply = 0;
1762
1763    /* Print out generic part of the request. */
1764    DPRINT("================ Vhost user message ================\n");
1765    DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
1766           vmsg->request);
1767    DPRINT("Flags:   0x%x\n", vmsg->flags);
1768    DPRINT("Size:    %d\n", vmsg->size);
1769
1770    if (vmsg->fd_num) {
1771        int i;
1772        DPRINT("Fds:");
1773        for (i = 0; i < vmsg->fd_num; i++) {
1774            DPRINT(" %d", vmsg->fds[i]);
1775        }
1776        DPRINT("\n");
1777    }
1778
1779    if (dev->iface->process_msg &&
1780        dev->iface->process_msg(dev, vmsg, &do_reply)) {
1781        return do_reply;
1782    }
1783
1784    switch (vmsg->request) {
1785    case VHOST_USER_GET_FEATURES:
1786        return vu_get_features_exec(dev, vmsg);
1787    case VHOST_USER_SET_FEATURES:
1788        return vu_set_features_exec(dev, vmsg);
1789    case VHOST_USER_GET_PROTOCOL_FEATURES:
1790        return vu_get_protocol_features_exec(dev, vmsg);
1791    case VHOST_USER_SET_PROTOCOL_FEATURES:
1792        return vu_set_protocol_features_exec(dev, vmsg);
1793    case VHOST_USER_SET_OWNER:
1794        return vu_set_owner_exec(dev, vmsg);
1795    case VHOST_USER_RESET_OWNER:
1796        return vu_reset_device_exec(dev, vmsg);
1797    case VHOST_USER_SET_MEM_TABLE:
1798        return vu_set_mem_table_exec(dev, vmsg);
1799    case VHOST_USER_SET_LOG_BASE:
1800        return vu_set_log_base_exec(dev, vmsg);
1801    case VHOST_USER_SET_LOG_FD:
1802        return vu_set_log_fd_exec(dev, vmsg);
1803    case VHOST_USER_SET_VRING_NUM:
1804        return vu_set_vring_num_exec(dev, vmsg);
1805    case VHOST_USER_SET_VRING_ADDR:
1806        return vu_set_vring_addr_exec(dev, vmsg);
1807    case VHOST_USER_SET_VRING_BASE:
1808        return vu_set_vring_base_exec(dev, vmsg);
1809    case VHOST_USER_GET_VRING_BASE:
1810        return vu_get_vring_base_exec(dev, vmsg);
1811    case VHOST_USER_SET_VRING_KICK:
1812        return vu_set_vring_kick_exec(dev, vmsg);
1813    case VHOST_USER_SET_VRING_CALL:
1814        return vu_set_vring_call_exec(dev, vmsg);
1815    case VHOST_USER_SET_VRING_ERR:
1816        return vu_set_vring_err_exec(dev, vmsg);
1817    case VHOST_USER_GET_QUEUE_NUM:
1818        return vu_get_queue_num_exec(dev, vmsg);
1819    case VHOST_USER_SET_VRING_ENABLE:
1820        return vu_set_vring_enable_exec(dev, vmsg);
1821    case VHOST_USER_SET_SLAVE_REQ_FD:
1822        return vu_set_slave_req_fd(dev, vmsg);
1823    case VHOST_USER_GET_CONFIG:
1824        return vu_get_config(dev, vmsg);
1825    case VHOST_USER_SET_CONFIG:
1826        return vu_set_config(dev, vmsg);
1827    case VHOST_USER_NONE:
1828        /* if you need processing before exit, override iface->process_msg */
1829        exit(0);
1830    case VHOST_USER_POSTCOPY_ADVISE:
1831        return vu_set_postcopy_advise(dev, vmsg);
1832    case VHOST_USER_POSTCOPY_LISTEN:
1833        return vu_set_postcopy_listen(dev, vmsg);
1834    case VHOST_USER_POSTCOPY_END:
1835        return vu_set_postcopy_end(dev, vmsg);
1836    case VHOST_USER_GET_INFLIGHT_FD:
1837        return vu_get_inflight_fd(dev, vmsg);
1838    case VHOST_USER_SET_INFLIGHT_FD:
1839        return vu_set_inflight_fd(dev, vmsg);
1840    case VHOST_USER_VRING_KICK:
1841        return vu_handle_vring_kick(dev, vmsg);
1842    case VHOST_USER_GET_MAX_MEM_SLOTS:
1843        return vu_handle_get_max_memslots(dev, vmsg);
1844    case VHOST_USER_ADD_MEM_REG:
1845        return vu_add_mem_reg(dev, vmsg);
1846    case VHOST_USER_REM_MEM_REG:
1847        return vu_rem_mem_reg(dev, vmsg);
1848    default:
1849        vmsg_close_fds(vmsg);
1850        vu_panic(dev, "Unhandled request: %d", vmsg->request);
1851    }
1852
1853    return false;
1854}
1855
1856bool
1857vu_dispatch(VuDev *dev)
1858{
1859    VhostUserMsg vmsg = { 0, };
1860    int reply_requested;
1861    bool need_reply, success = false;
1862
1863    if (!vu_message_read(dev, dev->sock, &vmsg)) {
1864        goto end;
1865    }
1866
1867    need_reply = vmsg.flags & VHOST_USER_NEED_REPLY_MASK;
1868
1869    reply_requested = vu_process_message(dev, &vmsg);
1870    if (!reply_requested && need_reply) {
1871        vmsg_set_reply_u64(&vmsg, 0);
1872        reply_requested = 1;
1873    }
1874
1875    if (!reply_requested) {
1876        success = true;
1877        goto end;
1878    }
1879
1880    if (!vu_send_reply(dev, dev->sock, &vmsg)) {
1881        goto end;
1882    }
1883
1884    success = true;
1885
1886end:
1887    free(vmsg.data);
1888    return success;
1889}
1890
1891void
1892vu_deinit(VuDev *dev)
1893{
1894    int i;
1895
1896    for (i = 0; i < dev->nregions; i++) {
1897        VuDevRegion *r = &dev->regions[i];
1898        void *m = (void *) (uintptr_t) r->mmap_addr;
1899        if (m != MAP_FAILED) {
1900            munmap(m, r->size + r->mmap_offset);
1901        }
1902    }
1903    dev->nregions = 0;
1904
1905    for (i = 0; i < dev->max_queues; i++) {
1906        VuVirtq *vq = &dev->vq[i];
1907
1908        if (vq->call_fd != -1) {
1909            close(vq->call_fd);
1910            vq->call_fd = -1;
1911        }
1912
1913        if (vq->kick_fd != -1) {
1914            close(vq->kick_fd);
1915            vq->kick_fd = -1;
1916        }
1917
1918        if (vq->err_fd != -1) {
1919            close(vq->err_fd);
1920            vq->err_fd = -1;
1921        }
1922
1923        if (vq->resubmit_list) {
1924            free(vq->resubmit_list);
1925            vq->resubmit_list = NULL;
1926        }
1927
1928        vq->inflight = NULL;
1929    }
1930
1931    if (dev->inflight_info.addr) {
1932        munmap(dev->inflight_info.addr, dev->inflight_info.size);
1933        dev->inflight_info.addr = NULL;
1934    }
1935
1936    if (dev->inflight_info.fd > 0) {
1937        close(dev->inflight_info.fd);
1938        dev->inflight_info.fd = -1;
1939    }
1940
1941    vu_close_log(dev);
1942    if (dev->slave_fd != -1) {
1943        close(dev->slave_fd);
1944        dev->slave_fd = -1;
1945    }
1946    pthread_mutex_destroy(&dev->slave_mutex);
1947
1948    if (dev->sock != -1) {
1949        close(dev->sock);
1950    }
1951
1952    free(dev->vq);
1953    dev->vq = NULL;
1954}
1955
1956bool
1957vu_init(VuDev *dev,
1958        uint16_t max_queues,
1959        int socket,
1960        vu_panic_cb panic,
1961        vu_set_watch_cb set_watch,
1962        vu_remove_watch_cb remove_watch,
1963        const VuDevIface *iface)
1964{
1965    uint16_t i;
1966
1967    assert(max_queues > 0);
1968    assert(socket >= 0);
1969    assert(set_watch);
1970    assert(remove_watch);
1971    assert(iface);
1972    assert(panic);
1973
1974    memset(dev, 0, sizeof(*dev));
1975
1976    dev->sock = socket;
1977    dev->panic = panic;
1978    dev->set_watch = set_watch;
1979    dev->remove_watch = remove_watch;
1980    dev->iface = iface;
1981    dev->log_call_fd = -1;
1982    pthread_mutex_init(&dev->slave_mutex, NULL);
1983    dev->slave_fd = -1;
1984    dev->max_queues = max_queues;
1985
1986    dev->vq = malloc(max_queues * sizeof(dev->vq[0]));
1987    if (!dev->vq) {
1988        DPRINT("%s: failed to malloc virtqueues\n", __func__);
1989        return false;
1990    }
1991
1992    for (i = 0; i < max_queues; i++) {
1993        dev->vq[i] = (VuVirtq) {
1994            .call_fd = -1, .kick_fd = -1, .err_fd = -1,
1995            .notification = true,
1996        };
1997    }
1998
1999    return true;
2000}
2001
2002VuVirtq *
2003vu_get_queue(VuDev *dev, int qidx)
2004{
2005    assert(qidx < dev->max_queues);
2006    return &dev->vq[qidx];
2007}
2008
2009bool
2010vu_queue_enabled(VuDev *dev, VuVirtq *vq)
2011{
2012    return vq->enable;
2013}
2014
2015bool
2016vu_queue_started(const VuDev *dev, const VuVirtq *vq)
2017{
2018    return vq->started;
2019}
2020
2021static inline uint16_t
2022vring_avail_flags(VuVirtq *vq)
2023{
2024    return vq->vring.avail->flags;
2025}
2026
2027static inline uint16_t
2028vring_avail_idx(VuVirtq *vq)
2029{
2030    vq->shadow_avail_idx = vq->vring.avail->idx;
2031
2032    return vq->shadow_avail_idx;
2033}
2034
2035static inline uint16_t
2036vring_avail_ring(VuVirtq *vq, int i)
2037{
2038    return vq->vring.avail->ring[i];
2039}
2040
2041static inline uint16_t
2042vring_get_used_event(VuVirtq *vq)
2043{
2044    return vring_avail_ring(vq, vq->vring.num);
2045}
2046
2047static int
2048virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
2049{
2050    uint16_t num_heads = vring_avail_idx(vq) - idx;
2051
2052    /* Check it isn't doing very strange things with descriptor numbers. */
2053    if (num_heads > vq->vring.num) {
2054        vu_panic(dev, "Guest moved used index from %u to %u",
2055                 idx, vq->shadow_avail_idx);
2056        return -1;
2057    }
2058    if (num_heads) {
2059        /* On success, callers read a descriptor at vq->last_avail_idx.
2060         * Make sure descriptor read does not bypass avail index read. */
2061        smp_rmb();
2062    }
2063
2064    return num_heads;
2065}
2066
2067static bool
2068virtqueue_get_head(VuDev *dev, VuVirtq *vq,
2069                   unsigned int idx, unsigned int *head)
2070{
2071    /* Grab the next descriptor number they're advertising, and increment
2072     * the index we've seen. */
2073    *head = vring_avail_ring(vq, idx % vq->vring.num);
2074
2075    /* If their number is silly, that's a fatal mistake. */
2076    if (*head >= vq->vring.num) {
2077        vu_panic(dev, "Guest says index %u is available", *head);
2078        return false;
2079    }
2080
2081    return true;
2082}
2083
2084static int
2085virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc,
2086                             uint64_t addr, size_t len)
2087{
2088    struct vring_desc *ori_desc;
2089    uint64_t read_len;
2090
2091    if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
2092        return -1;
2093    }
2094
2095    if (len == 0) {
2096        return -1;
2097    }
2098
2099    while (len) {
2100        read_len = len;
2101        ori_desc = vu_gpa_to_va(dev, &read_len, addr);
2102        if (!ori_desc) {
2103            return -1;
2104        }
2105
2106        memcpy(desc, ori_desc, read_len);
2107        len -= read_len;
2108        addr += read_len;
2109        desc += read_len;
2110    }
2111
2112    return 0;
2113}
2114
2115enum {
2116    VIRTQUEUE_READ_DESC_ERROR = -1,
2117    VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
2118    VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
2119};
2120
2121static int
2122virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
2123                         int i, unsigned int max, unsigned int *next)
2124{
2125    /* If this descriptor says it doesn't chain, we're done. */
2126    if (!(desc[i].flags & VRING_DESC_F_NEXT)) {
2127        return VIRTQUEUE_READ_DESC_DONE;
2128    }
2129
2130    /* Check they're not leading us off end of descriptors. */
2131    *next = desc[i].next;
2132    /* Make sure compiler knows to grab that: we don't want it changing! */
2133    smp_wmb();
2134
2135    if (*next >= max) {
2136        vu_panic(dev, "Desc next is %u", *next);
2137        return VIRTQUEUE_READ_DESC_ERROR;
2138    }
2139
2140    return VIRTQUEUE_READ_DESC_MORE;
2141}
2142
2143void
2144vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
2145                         unsigned int *out_bytes,
2146                         unsigned max_in_bytes, unsigned max_out_bytes)
2147{
2148    unsigned int idx;
2149    unsigned int total_bufs, in_total, out_total;
2150    int rc;
2151
2152    idx = vq->last_avail_idx;
2153
2154    total_bufs = in_total = out_total = 0;
2155    if (unlikely(dev->broken) ||
2156        unlikely(!vq->vring.avail)) {
2157        goto done;
2158    }
2159
2160    while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
2161        unsigned int max, desc_len, num_bufs, indirect = 0;
2162        uint64_t desc_addr, read_len;
2163        struct vring_desc *desc;
2164        struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2165        unsigned int i;
2166
2167        max = vq->vring.num;
2168        num_bufs = total_bufs;
2169        if (!virtqueue_get_head(dev, vq, idx++, &i)) {
2170            goto err;
2171        }
2172        desc = vq->vring.desc;
2173
2174        if (desc[i].flags & VRING_DESC_F_INDIRECT) {
2175            if (desc[i].len % sizeof(struct vring_desc)) {
2176                vu_panic(dev, "Invalid size for indirect buffer table");
2177                goto err;
2178            }
2179
2180            /* If we've got too many, that implies a descriptor loop. */
2181            if (num_bufs >= max) {
2182                vu_panic(dev, "Looped descriptor");
2183                goto err;
2184            }
2185
2186            /* loop over the indirect descriptor table */
2187            indirect = 1;
2188            desc_addr = desc[i].addr;
2189            desc_len = desc[i].len;
2190            max = desc_len / sizeof(struct vring_desc);
2191            read_len = desc_len;
2192            desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2193            if (unlikely(desc && read_len != desc_len)) {
2194                /* Failed to use zero copy */
2195                desc = NULL;
2196                if (!virtqueue_read_indirect_desc(dev, desc_buf,
2197                                                  desc_addr,
2198                                                  desc_len)) {
2199                    desc = desc_buf;
2200                }
2201            }
2202            if (!desc) {
2203                vu_panic(dev, "Invalid indirect buffer table");
2204                goto err;
2205            }
2206            num_bufs = i = 0;
2207        }
2208
2209        do {
2210            /* If we've got too many, that implies a descriptor loop. */
2211            if (++num_bufs > max) {
2212                vu_panic(dev, "Looped descriptor");
2213                goto err;
2214            }
2215
2216            if (desc[i].flags & VRING_DESC_F_WRITE) {
2217                in_total += desc[i].len;
2218            } else {
2219                out_total += desc[i].len;
2220            }
2221            if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
2222                goto done;
2223            }
2224            rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2225        } while (rc == VIRTQUEUE_READ_DESC_MORE);
2226
2227        if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2228            goto err;
2229        }
2230
2231        if (!indirect) {
2232            total_bufs = num_bufs;
2233        } else {
2234            total_bufs++;
2235        }
2236    }
2237    if (rc < 0) {
2238        goto err;
2239    }
2240done:
2241    if (in_bytes) {
2242        *in_bytes = in_total;
2243    }
2244    if (out_bytes) {
2245        *out_bytes = out_total;
2246    }
2247    return;
2248
2249err:
2250    in_total = out_total = 0;
2251    goto done;
2252}
2253
2254bool
2255vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
2256                     unsigned int out_bytes)
2257{
2258    unsigned int in_total, out_total;
2259
2260    vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
2261                             in_bytes, out_bytes);
2262
2263    return in_bytes <= in_total && out_bytes <= out_total;
2264}
2265
2266/* Fetch avail_idx from VQ memory only when we really need to know if
2267 * guest has added some buffers. */
2268bool
2269vu_queue_empty(VuDev *dev, VuVirtq *vq)
2270{
2271    if (unlikely(dev->broken) ||
2272        unlikely(!vq->vring.avail)) {
2273        return true;
2274    }
2275
2276    if (vq->shadow_avail_idx != vq->last_avail_idx) {
2277        return false;
2278    }
2279
2280    return vring_avail_idx(vq) == vq->last_avail_idx;
2281}
2282
2283static bool
2284vring_notify(VuDev *dev, VuVirtq *vq)
2285{
2286    uint16_t old, new;
2287    bool v;
2288
2289    /* We need to expose used array entries before checking used event. */
2290    smp_mb();
2291
2292    /* Always notify when queue is empty (when feature acknowledge) */
2293    if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2294        !vq->inuse && vu_queue_empty(dev, vq)) {
2295        return true;
2296    }
2297
2298    if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2299        return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
2300    }
2301
2302    v = vq->signalled_used_valid;
2303    vq->signalled_used_valid = true;
2304    old = vq->signalled_used;
2305    new = vq->signalled_used = vq->used_idx;
2306    return !v || vring_need_event(vring_get_used_event(vq), new, old);
2307}
2308
2309static void _vu_queue_notify(VuDev *dev, VuVirtq *vq, bool sync)
2310{
2311    if (unlikely(dev->broken) ||
2312        unlikely(!vq->vring.avail)) {
2313        return;
2314    }
2315
2316    if (!vring_notify(dev, vq)) {
2317        DPRINT("skipped notify...\n");
2318        return;
2319    }
2320
2321    if (vq->call_fd < 0 &&
2322        vu_has_protocol_feature(dev,
2323                                VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
2324        vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
2325        VhostUserMsg vmsg = {
2326            .request = VHOST_USER_SLAVE_VRING_CALL,
2327            .flags = VHOST_USER_VERSION,
2328            .size = sizeof(vmsg.payload.state),
2329            .payload.state = {
2330                .index = vq - dev->vq,
2331            },
2332        };
2333        bool ack = sync &&
2334                   vu_has_protocol_feature(dev,
2335                                           VHOST_USER_PROTOCOL_F_REPLY_ACK);
2336
2337        if (ack) {
2338            vmsg.flags |= VHOST_USER_NEED_REPLY_MASK;
2339        }
2340
2341        vu_message_write(dev, dev->slave_fd, &vmsg);
2342        if (ack) {
2343            vu_message_read(dev, dev->slave_fd, &vmsg);
2344        }
2345        return;
2346    }
2347
2348    if (eventfd_write(vq->call_fd, 1) < 0) {
2349        vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
2350    }
2351}
2352
2353void vu_queue_notify(VuDev *dev, VuVirtq *vq)
2354{
2355    _vu_queue_notify(dev, vq, false);
2356}
2357
2358void vu_queue_notify_sync(VuDev *dev, VuVirtq *vq)
2359{
2360    _vu_queue_notify(dev, vq, true);
2361}
2362
2363static inline void
2364vring_used_flags_set_bit(VuVirtq *vq, int mask)
2365{
2366    uint16_t *flags;
2367
2368    flags = (uint16_t *)((char*)vq->vring.used +
2369                         offsetof(struct vring_used, flags));
2370    *flags |= mask;
2371}
2372
2373static inline void
2374vring_used_flags_unset_bit(VuVirtq *vq, int mask)
2375{
2376    uint16_t *flags;
2377
2378    flags = (uint16_t *)((char*)vq->vring.used +
2379                         offsetof(struct vring_used, flags));
2380    *flags &= ~mask;
2381}
2382
2383static inline void
2384vring_set_avail_event(VuVirtq *vq, uint16_t val)
2385{
2386    if (!vq->notification) {
2387        return;
2388    }
2389
2390    *((uint16_t *) &vq->vring.used->ring[vq->vring.num]) = val;
2391}
2392
2393void
2394vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
2395{
2396    vq->notification = enable;
2397    if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2398        vring_set_avail_event(vq, vring_avail_idx(vq));
2399    } else if (enable) {
2400        vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
2401    } else {
2402        vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
2403    }
2404    if (enable) {
2405        /* Expose avail event/used flags before caller checks the avail idx. */
2406        smp_mb();
2407    }
2408}
2409
2410static void
2411virtqueue_map_desc(VuDev *dev,
2412                   unsigned int *p_num_sg, struct iovec *iov,
2413                   unsigned int max_num_sg, bool is_write,
2414                   uint64_t pa, size_t sz)
2415{
2416    unsigned num_sg = *p_num_sg;
2417
2418    assert(num_sg <= max_num_sg);
2419
2420    if (!sz) {
2421        vu_panic(dev, "virtio: zero sized buffers are not allowed");
2422        return;
2423    }
2424
2425    while (sz) {
2426        uint64_t len = sz;
2427
2428        if (num_sg == max_num_sg) {
2429            vu_panic(dev, "virtio: too many descriptors in indirect table");
2430            return;
2431        }
2432
2433        iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa);
2434        if (iov[num_sg].iov_base == NULL) {
2435            vu_panic(dev, "virtio: invalid address for buffers");
2436            return;
2437        }
2438        iov[num_sg].iov_len = len;
2439        num_sg++;
2440        sz -= len;
2441        pa += len;
2442    }
2443
2444    *p_num_sg = num_sg;
2445}
2446
2447static void *
2448virtqueue_alloc_element(size_t sz,
2449                                     unsigned out_num, unsigned in_num)
2450{
2451    VuVirtqElement *elem;
2452    size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
2453    size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
2454    size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
2455
2456    assert(sz >= sizeof(VuVirtqElement));
2457    elem = malloc(out_sg_end);
2458    elem->out_num = out_num;
2459    elem->in_num = in_num;
2460    elem->in_sg = (void *)elem + in_sg_ofs;
2461    elem->out_sg = (void *)elem + out_sg_ofs;
2462    return elem;
2463}
2464
2465static void *
2466vu_queue_map_desc(VuDev *dev, VuVirtq *vq, unsigned int idx, size_t sz)
2467{
2468    struct vring_desc *desc = vq->vring.desc;
2469    uint64_t desc_addr, read_len;
2470    unsigned int desc_len;
2471    unsigned int max = vq->vring.num;
2472    unsigned int i = idx;
2473    VuVirtqElement *elem;
2474    unsigned int out_num = 0, in_num = 0;
2475    struct iovec iov[VIRTQUEUE_MAX_SIZE];
2476    struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2477    int rc;
2478
2479    if (desc[i].flags & VRING_DESC_F_INDIRECT) {
2480        if (desc[i].len % sizeof(struct vring_desc)) {
2481            vu_panic(dev, "Invalid size for indirect buffer table");
2482        }
2483
2484        /* loop over the indirect descriptor table */
2485        desc_addr = desc[i].addr;
2486        desc_len = desc[i].len;
2487        max = desc_len / sizeof(struct vring_desc);
2488        read_len = desc_len;
2489        desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2490        if (unlikely(desc && read_len != desc_len)) {
2491            /* Failed to use zero copy */
2492            desc = NULL;
2493            if (!virtqueue_read_indirect_desc(dev, desc_buf,
2494                                              desc_addr,
2495                                              desc_len)) {
2496                desc = desc_buf;
2497            }
2498        }
2499        if (!desc) {
2500            vu_panic(dev, "Invalid indirect buffer table");
2501            return NULL;
2502        }
2503        i = 0;
2504    }
2505
2506    /* Collect all the descriptors */
2507    do {
2508        if (desc[i].flags & VRING_DESC_F_WRITE) {
2509            virtqueue_map_desc(dev, &in_num, iov + out_num,
2510                               VIRTQUEUE_MAX_SIZE - out_num, true,
2511                               desc[i].addr, desc[i].len);
2512        } else {
2513            if (in_num) {
2514                vu_panic(dev, "Incorrect order for descriptors");
2515                return NULL;
2516            }
2517            virtqueue_map_desc(dev, &out_num, iov,
2518                               VIRTQUEUE_MAX_SIZE, false,
2519                               desc[i].addr, desc[i].len);
2520        }
2521
2522        /* If we've got too many, that implies a descriptor loop. */
2523        if ((in_num + out_num) > max) {
2524            vu_panic(dev, "Looped descriptor");
2525        }
2526        rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2527    } while (rc == VIRTQUEUE_READ_DESC_MORE);
2528
2529    if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2530        vu_panic(dev, "read descriptor error");
2531        return NULL;
2532    }
2533
2534    /* Now copy what we have collected and mapped */
2535    elem = virtqueue_alloc_element(sz, out_num, in_num);
2536    elem->index = idx;
2537    for (i = 0; i < out_num; i++) {
2538        elem->out_sg[i] = iov[i];
2539    }
2540    for (i = 0; i < in_num; i++) {
2541        elem->in_sg[i] = iov[out_num + i];
2542    }
2543
2544    return elem;
2545}
2546
2547static int
2548vu_queue_inflight_get(VuDev *dev, VuVirtq *vq, int desc_idx)
2549{
2550    if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2551        return 0;
2552    }
2553
2554    if (unlikely(!vq->inflight)) {
2555        return -1;
2556    }
2557
2558    vq->inflight->desc[desc_idx].counter = vq->counter++;
2559    vq->inflight->desc[desc_idx].inflight = 1;
2560
2561    return 0;
2562}
2563
2564static int
2565vu_queue_inflight_pre_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2566{
2567    if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2568        return 0;
2569    }
2570
2571    if (unlikely(!vq->inflight)) {
2572        return -1;
2573    }
2574
2575    vq->inflight->last_batch_head = desc_idx;
2576
2577    return 0;
2578}
2579
2580static int
2581vu_queue_inflight_post_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2582{
2583    if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2584        return 0;
2585    }
2586
2587    if (unlikely(!vq->inflight)) {
2588        return -1;
2589    }
2590
2591    barrier();
2592
2593    vq->inflight->desc[desc_idx].inflight = 0;
2594
2595    barrier();
2596
2597    vq->inflight->used_idx = vq->used_idx;
2598
2599    return 0;
2600}
2601
2602void *
2603vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
2604{
2605    int i;
2606    unsigned int head;
2607    VuVirtqElement *elem;
2608
2609    if (unlikely(dev->broken) ||
2610        unlikely(!vq->vring.avail)) {
2611        return NULL;
2612    }
2613
2614    if (unlikely(vq->resubmit_list && vq->resubmit_num > 0)) {
2615        i = (--vq->resubmit_num);
2616        elem = vu_queue_map_desc(dev, vq, vq->resubmit_list[i].index, sz);
2617
2618        if (!vq->resubmit_num) {
2619            free(vq->resubmit_list);
2620            vq->resubmit_list = NULL;
2621        }
2622
2623        return elem;
2624    }
2625
2626    if (vu_queue_empty(dev, vq)) {
2627        return NULL;
2628    }
2629    /*
2630     * Needed after virtio_queue_empty(), see comment in
2631     * virtqueue_num_heads().
2632     */
2633    smp_rmb();
2634
2635    if (vq->inuse >= vq->vring.num) {
2636        vu_panic(dev, "Virtqueue size exceeded");
2637        return NULL;
2638    }
2639
2640    if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
2641        return NULL;
2642    }
2643
2644    if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2645        vring_set_avail_event(vq, vq->last_avail_idx);
2646    }
2647
2648    elem = vu_queue_map_desc(dev, vq, head, sz);
2649
2650    if (!elem) {
2651        return NULL;
2652    }
2653
2654    vq->inuse++;
2655
2656    vu_queue_inflight_get(dev, vq, head);
2657
2658    return elem;
2659}
2660
2661static void
2662vu_queue_detach_element(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2663                        size_t len)
2664{
2665    vq->inuse--;
2666    /* unmap, when DMA support is added */
2667}
2668
2669void
2670vu_queue_unpop(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2671               size_t len)
2672{
2673    vq->last_avail_idx--;
2674    vu_queue_detach_element(dev, vq, elem, len);
2675}
2676
2677bool
2678vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
2679{
2680    if (num > vq->inuse) {
2681        return false;
2682    }
2683    vq->last_avail_idx -= num;
2684    vq->inuse -= num;
2685    return true;
2686}
2687
2688static inline
2689void vring_used_write(VuDev *dev, VuVirtq *vq,
2690                      struct vring_used_elem *uelem, int i)
2691{
2692    struct vring_used *used = vq->vring.used;
2693
2694    used->ring[i] = *uelem;
2695    vu_log_write(dev, vq->vring.log_guest_addr +
2696                 offsetof(struct vring_used, ring[i]),
2697                 sizeof(used->ring[i]));
2698}
2699
2700
2701static void
2702vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
2703                  const VuVirtqElement *elem,
2704                  unsigned int len)
2705{
2706    struct vring_desc *desc = vq->vring.desc;
2707    unsigned int i, max, min, desc_len;
2708    uint64_t desc_addr, read_len;
2709    struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2710    unsigned num_bufs = 0;
2711
2712    max = vq->vring.num;
2713    i = elem->index;
2714
2715    if (desc[i].flags & VRING_DESC_F_INDIRECT) {
2716        if (desc[i].len % sizeof(struct vring_desc)) {
2717            vu_panic(dev, "Invalid size for indirect buffer table");
2718        }
2719
2720        /* loop over the indirect descriptor table */
2721        desc_addr = desc[i].addr;
2722        desc_len = desc[i].len;
2723        max = desc_len / sizeof(struct vring_desc);
2724        read_len = desc_len;
2725        desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2726        if (unlikely(desc && read_len != desc_len)) {
2727            /* Failed to use zero copy */
2728            desc = NULL;
2729            if (!virtqueue_read_indirect_desc(dev, desc_buf,
2730                                              desc_addr,
2731                                              desc_len)) {
2732                desc = desc_buf;
2733            }
2734        }
2735        if (!desc) {
2736            vu_panic(dev, "Invalid indirect buffer table");
2737            return;
2738        }
2739        i = 0;
2740    }
2741
2742    do {
2743        if (++num_bufs > max) {
2744            vu_panic(dev, "Looped descriptor");
2745            return;
2746        }
2747
2748        if (desc[i].flags & VRING_DESC_F_WRITE) {
2749            min = MIN(desc[i].len, len);
2750            vu_log_write(dev, desc[i].addr, min);
2751            len -= min;
2752        }
2753
2754    } while (len > 0 &&
2755             (virtqueue_read_next_desc(dev, desc, i, max, &i)
2756              == VIRTQUEUE_READ_DESC_MORE));
2757}
2758
2759void
2760vu_queue_fill(VuDev *dev, VuVirtq *vq,
2761              const VuVirtqElement *elem,
2762              unsigned int len, unsigned int idx)
2763{
2764    struct vring_used_elem uelem;
2765
2766    if (unlikely(dev->broken) ||
2767        unlikely(!vq->vring.avail)) {
2768        return;
2769    }
2770
2771    vu_log_queue_fill(dev, vq, elem, len);
2772
2773    idx = (idx + vq->used_idx) % vq->vring.num;
2774
2775    uelem.id = elem->index;
2776    uelem.len = len;
2777    vring_used_write(dev, vq, &uelem, idx);
2778}
2779
2780static inline
2781void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
2782{
2783    vq->vring.used->idx = val;
2784    vu_log_write(dev,
2785                 vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
2786                 sizeof(vq->vring.used->idx));
2787
2788    vq->used_idx = val;
2789}
2790
2791void
2792vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
2793{
2794    uint16_t old, new;
2795
2796    if (unlikely(dev->broken) ||
2797        unlikely(!vq->vring.avail)) {
2798        return;
2799    }
2800
2801    /* Make sure buffer is written before we update index. */
2802    smp_wmb();
2803
2804    old = vq->used_idx;
2805    new = old + count;
2806    vring_used_idx_set(dev, vq, new);
2807    vq->inuse -= count;
2808    if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
2809        vq->signalled_used_valid = false;
2810    }
2811}
2812
2813void
2814vu_queue_push(VuDev *dev, VuVirtq *vq,
2815              const VuVirtqElement *elem, unsigned int len)
2816{
2817    vu_queue_fill(dev, vq, elem, len, 0);
2818    vu_queue_inflight_pre_put(dev, vq, elem->index);
2819    vu_queue_flush(dev, vq, 1);
2820    vu_queue_inflight_post_put(dev, vq, elem->index);
2821}
2822