qemu/subprojects/libvhost-user/libvhost-user.c
<<
>>
Prefs
   1/*
   2 * Vhost User library
   3 *
   4 * Copyright IBM, Corp. 2007
   5 * Copyright (c) 2016 Red Hat, Inc.
   6 *
   7 * Authors:
   8 *  Anthony Liguori <aliguori@us.ibm.com>
   9 *  Marc-André Lureau <mlureau@redhat.com>
  10 *  Victor Kaplansky <victork@redhat.com>
  11 *
  12 * This work is licensed under the terms of the GNU GPL, version 2 or
  13 * later.  See the COPYING file in the top-level directory.
  14 */
  15
  16/* this code avoids GLib dependency */
  17#include <stdlib.h>
  18#include <stdio.h>
  19#include <unistd.h>
  20#include <stdarg.h>
  21#include <errno.h>
  22#include <string.h>
  23#include <assert.h>
  24#include <inttypes.h>
  25#include <sys/types.h>
  26#include <sys/socket.h>
  27#include <sys/eventfd.h>
  28#include <sys/mman.h>
  29#include <endian.h>
  30
  31#if defined(__linux__)
  32#include <sys/syscall.h>
  33#include <fcntl.h>
  34#include <sys/ioctl.h>
  35#include <linux/vhost.h>
  36
  37#ifdef __NR_userfaultfd
  38#include <linux/userfaultfd.h>
  39#endif
  40
  41#endif
  42
  43#include "include/atomic.h"
  44
  45#include "libvhost-user.h"
  46
  47/* usually provided by GLib */
  48#if     __GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ > 4)
  49#if !defined(__clang__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 4)
  50#define G_GNUC_PRINTF(format_idx, arg_idx) \
  51  __attribute__((__format__(gnu_printf, format_idx, arg_idx)))
  52#else
  53#define G_GNUC_PRINTF(format_idx, arg_idx) \
  54  __attribute__((__format__(__printf__, format_idx, arg_idx)))
  55#endif
  56#else   /* !__GNUC__ */
  57#define G_GNUC_PRINTF(format_idx, arg_idx)
  58#endif  /* !__GNUC__ */
  59#ifndef MIN
  60#define MIN(x, y) ({                            \
  61            typeof(x) _min1 = (x);              \
  62            typeof(y) _min2 = (y);              \
  63            (void) (&_min1 == &_min2);          \
  64            _min1 < _min2 ? _min1 : _min2; })
  65#endif
  66
  67/* Round number down to multiple */
  68#define ALIGN_DOWN(n, m) ((n) / (m) * (m))
  69
  70/* Round number up to multiple */
  71#define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
  72
  73#ifndef unlikely
  74#define unlikely(x)   __builtin_expect(!!(x), 0)
  75#endif
  76
  77/* Align each region to cache line size in inflight buffer */
  78#define INFLIGHT_ALIGNMENT 64
  79
  80/* The version of inflight buffer */
  81#define INFLIGHT_VERSION 1
  82
  83/* The version of the protocol we support */
  84#define VHOST_USER_VERSION 1
  85#define LIBVHOST_USER_DEBUG 0
  86
  87#define DPRINT(...)                             \
  88    do {                                        \
  89        if (LIBVHOST_USER_DEBUG) {              \
  90            fprintf(stderr, __VA_ARGS__);        \
  91        }                                       \
  92    } while (0)
  93
  94static inline
  95bool has_feature(uint64_t features, unsigned int fbit)
  96{
  97    assert(fbit < 64);
  98    return !!(features & (1ULL << fbit));
  99}
 100
 101static inline
 102bool vu_has_feature(VuDev *dev,
 103                    unsigned int fbit)
 104{
 105    return has_feature(dev->features, fbit);
 106}
 107
 108static inline bool vu_has_protocol_feature(VuDev *dev, unsigned int fbit)
 109{
 110    return has_feature(dev->protocol_features, fbit);
 111}
 112
 113const char *
 114vu_request_to_string(unsigned int req)
 115{
 116#define REQ(req) [req] = #req
 117    static const char *vu_request_str[] = {
 118        REQ(VHOST_USER_NONE),
 119        REQ(VHOST_USER_GET_FEATURES),
 120        REQ(VHOST_USER_SET_FEATURES),
 121        REQ(VHOST_USER_SET_OWNER),
 122        REQ(VHOST_USER_RESET_OWNER),
 123        REQ(VHOST_USER_SET_MEM_TABLE),
 124        REQ(VHOST_USER_SET_LOG_BASE),
 125        REQ(VHOST_USER_SET_LOG_FD),
 126        REQ(VHOST_USER_SET_VRING_NUM),
 127        REQ(VHOST_USER_SET_VRING_ADDR),
 128        REQ(VHOST_USER_SET_VRING_BASE),
 129        REQ(VHOST_USER_GET_VRING_BASE),
 130        REQ(VHOST_USER_SET_VRING_KICK),
 131        REQ(VHOST_USER_SET_VRING_CALL),
 132        REQ(VHOST_USER_SET_VRING_ERR),
 133        REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
 134        REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
 135        REQ(VHOST_USER_GET_QUEUE_NUM),
 136        REQ(VHOST_USER_SET_VRING_ENABLE),
 137        REQ(VHOST_USER_SEND_RARP),
 138        REQ(VHOST_USER_NET_SET_MTU),
 139        REQ(VHOST_USER_SET_SLAVE_REQ_FD),
 140        REQ(VHOST_USER_IOTLB_MSG),
 141        REQ(VHOST_USER_SET_VRING_ENDIAN),
 142        REQ(VHOST_USER_GET_CONFIG),
 143        REQ(VHOST_USER_SET_CONFIG),
 144        REQ(VHOST_USER_POSTCOPY_ADVISE),
 145        REQ(VHOST_USER_POSTCOPY_LISTEN),
 146        REQ(VHOST_USER_POSTCOPY_END),
 147        REQ(VHOST_USER_GET_INFLIGHT_FD),
 148        REQ(VHOST_USER_SET_INFLIGHT_FD),
 149        REQ(VHOST_USER_GPU_SET_SOCKET),
 150        REQ(VHOST_USER_VRING_KICK),
 151        REQ(VHOST_USER_GET_MAX_MEM_SLOTS),
 152        REQ(VHOST_USER_ADD_MEM_REG),
 153        REQ(VHOST_USER_REM_MEM_REG),
 154        REQ(VHOST_USER_MAX),
 155    };
 156#undef REQ
 157
 158    if (req < VHOST_USER_MAX) {
 159        return vu_request_str[req];
 160    } else {
 161        return "unknown";
 162    }
 163}
 164
 165static void G_GNUC_PRINTF(2, 3)
 166vu_panic(VuDev *dev, const char *msg, ...)
 167{
 168    char *buf = NULL;
 169    va_list ap;
 170
 171    va_start(ap, msg);
 172    if (vasprintf(&buf, msg, ap) < 0) {
 173        buf = NULL;
 174    }
 175    va_end(ap);
 176
 177    dev->broken = true;
 178    dev->panic(dev, buf);
 179    free(buf);
 180
 181    /*
 182     * FIXME:
 183     * find a way to call virtio_error, or perhaps close the connection?
 184     */
 185}
 186
 187/* Translate guest physical address to our virtual address.  */
 188void *
 189vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr)
 190{
 191    int i;
 192
 193    if (*plen == 0) {
 194        return NULL;
 195    }
 196
 197    /* Find matching memory region.  */
 198    for (i = 0; i < dev->nregions; i++) {
 199        VuDevRegion *r = &dev->regions[i];
 200
 201        if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) {
 202            if ((guest_addr + *plen) > (r->gpa + r->size)) {
 203                *plen = r->gpa + r->size - guest_addr;
 204            }
 205            return (void *)(uintptr_t)
 206                guest_addr - r->gpa + r->mmap_addr + r->mmap_offset;
 207        }
 208    }
 209
 210    return NULL;
 211}
 212
 213/* Translate qemu virtual address to our virtual address.  */
 214static void *
 215qva_to_va(VuDev *dev, uint64_t qemu_addr)
 216{
 217    int i;
 218
 219    /* Find matching memory region.  */
 220    for (i = 0; i < dev->nregions; i++) {
 221        VuDevRegion *r = &dev->regions[i];
 222
 223        if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
 224            return (void *)(uintptr_t)
 225                qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
 226        }
 227    }
 228
 229    return NULL;
 230}
 231
 232static void
 233vmsg_close_fds(VhostUserMsg *vmsg)
 234{
 235    int i;
 236
 237    for (i = 0; i < vmsg->fd_num; i++) {
 238        close(vmsg->fds[i]);
 239    }
 240}
 241
 242/* Set reply payload.u64 and clear request flags and fd_num */
 243static void vmsg_set_reply_u64(VhostUserMsg *vmsg, uint64_t val)
 244{
 245    vmsg->flags = 0; /* defaults will be set by vu_send_reply() */
 246    vmsg->size = sizeof(vmsg->payload.u64);
 247    vmsg->payload.u64 = val;
 248    vmsg->fd_num = 0;
 249}
 250
 251/* A test to see if we have userfault available */
 252static bool
 253have_userfault(void)
 254{
 255#if defined(__linux__) && defined(__NR_userfaultfd) &&\
 256        defined(UFFD_FEATURE_MISSING_SHMEM) &&\
 257        defined(UFFD_FEATURE_MISSING_HUGETLBFS)
 258    /* Now test the kernel we're running on really has the features */
 259    int ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
 260    struct uffdio_api api_struct;
 261    if (ufd < 0) {
 262        return false;
 263    }
 264
 265    api_struct.api = UFFD_API;
 266    api_struct.features = UFFD_FEATURE_MISSING_SHMEM |
 267                          UFFD_FEATURE_MISSING_HUGETLBFS;
 268    if (ioctl(ufd, UFFDIO_API, &api_struct)) {
 269        close(ufd);
 270        return false;
 271    }
 272    close(ufd);
 273    return true;
 274
 275#else
 276    return false;
 277#endif
 278}
 279
 280static bool
 281vu_message_read_default(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
 282{
 283    char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
 284    struct iovec iov = {
 285        .iov_base = (char *)vmsg,
 286        .iov_len = VHOST_USER_HDR_SIZE,
 287    };
 288    struct msghdr msg = {
 289        .msg_iov = &iov,
 290        .msg_iovlen = 1,
 291        .msg_control = control,
 292        .msg_controllen = sizeof(control),
 293    };
 294    size_t fd_size;
 295    struct cmsghdr *cmsg;
 296    int rc;
 297
 298    do {
 299        rc = recvmsg(conn_fd, &msg, 0);
 300    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 301
 302    if (rc < 0) {
 303        vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
 304        return false;
 305    }
 306
 307    vmsg->fd_num = 0;
 308    for (cmsg = CMSG_FIRSTHDR(&msg);
 309         cmsg != NULL;
 310         cmsg = CMSG_NXTHDR(&msg, cmsg))
 311    {
 312        if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
 313            fd_size = cmsg->cmsg_len - CMSG_LEN(0);
 314            vmsg->fd_num = fd_size / sizeof(int);
 315            memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
 316            break;
 317        }
 318    }
 319
 320    if (vmsg->size > sizeof(vmsg->payload)) {
 321        vu_panic(dev,
 322                 "Error: too big message request: %d, size: vmsg->size: %u, "
 323                 "while sizeof(vmsg->payload) = %zu\n",
 324                 vmsg->request, vmsg->size, sizeof(vmsg->payload));
 325        goto fail;
 326    }
 327
 328    if (vmsg->size) {
 329        do {
 330            rc = read(conn_fd, &vmsg->payload, vmsg->size);
 331        } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 332
 333        if (rc <= 0) {
 334            vu_panic(dev, "Error while reading: %s", strerror(errno));
 335            goto fail;
 336        }
 337
 338        assert(rc == vmsg->size);
 339    }
 340
 341    return true;
 342
 343fail:
 344    vmsg_close_fds(vmsg);
 345
 346    return false;
 347}
 348
 349static bool
 350vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
 351{
 352    int rc;
 353    uint8_t *p = (uint8_t *)vmsg;
 354    char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
 355    struct iovec iov = {
 356        .iov_base = (char *)vmsg,
 357        .iov_len = VHOST_USER_HDR_SIZE,
 358    };
 359    struct msghdr msg = {
 360        .msg_iov = &iov,
 361        .msg_iovlen = 1,
 362        .msg_control = control,
 363    };
 364    struct cmsghdr *cmsg;
 365
 366    memset(control, 0, sizeof(control));
 367    assert(vmsg->fd_num <= VHOST_MEMORY_BASELINE_NREGIONS);
 368    if (vmsg->fd_num > 0) {
 369        size_t fdsize = vmsg->fd_num * sizeof(int);
 370        msg.msg_controllen = CMSG_SPACE(fdsize);
 371        cmsg = CMSG_FIRSTHDR(&msg);
 372        cmsg->cmsg_len = CMSG_LEN(fdsize);
 373        cmsg->cmsg_level = SOL_SOCKET;
 374        cmsg->cmsg_type = SCM_RIGHTS;
 375        memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize);
 376    } else {
 377        msg.msg_controllen = 0;
 378    }
 379
 380    do {
 381        rc = sendmsg(conn_fd, &msg, 0);
 382    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 383
 384    if (vmsg->size) {
 385        do {
 386            if (vmsg->data) {
 387                rc = write(conn_fd, vmsg->data, vmsg->size);
 388            } else {
 389                rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
 390            }
 391        } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 392    }
 393
 394    if (rc <= 0) {
 395        vu_panic(dev, "Error while writing: %s", strerror(errno));
 396        return false;
 397    }
 398
 399    return true;
 400}
 401
 402static bool
 403vu_send_reply(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
 404{
 405    /* Set the version in the flags when sending the reply */
 406    vmsg->flags &= ~VHOST_USER_VERSION_MASK;
 407    vmsg->flags |= VHOST_USER_VERSION;
 408    vmsg->flags |= VHOST_USER_REPLY_MASK;
 409
 410    return vu_message_write(dev, conn_fd, vmsg);
 411}
 412
 413/*
 414 * Processes a reply on the slave channel.
 415 * Entered with slave_mutex held and releases it before exit.
 416 * Returns true on success.
 417 */
 418static bool
 419vu_process_message_reply(VuDev *dev, const VhostUserMsg *vmsg)
 420{
 421    VhostUserMsg msg_reply;
 422    bool result = false;
 423
 424    if ((vmsg->flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
 425        result = true;
 426        goto out;
 427    }
 428
 429    if (!vu_message_read_default(dev, dev->slave_fd, &msg_reply)) {
 430        goto out;
 431    }
 432
 433    if (msg_reply.request != vmsg->request) {
 434        DPRINT("Received unexpected msg type. Expected %d received %d",
 435               vmsg->request, msg_reply.request);
 436        goto out;
 437    }
 438
 439    result = msg_reply.payload.u64 == 0;
 440
 441out:
 442    pthread_mutex_unlock(&dev->slave_mutex);
 443    return result;
 444}
 445
 446/* Kick the log_call_fd if required. */
 447static void
 448vu_log_kick(VuDev *dev)
 449{
 450    if (dev->log_call_fd != -1) {
 451        DPRINT("Kicking the QEMU's log...\n");
 452        if (eventfd_write(dev->log_call_fd, 1) < 0) {
 453            vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
 454        }
 455    }
 456}
 457
 458static void
 459vu_log_page(uint8_t *log_table, uint64_t page)
 460{
 461    DPRINT("Logged dirty guest page: %"PRId64"\n", page);
 462    qatomic_or(&log_table[page / 8], 1 << (page % 8));
 463}
 464
 465static void
 466vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
 467{
 468    uint64_t page;
 469
 470    if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
 471        !dev->log_table || !length) {
 472        return;
 473    }
 474
 475    assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
 476
 477    page = address / VHOST_LOG_PAGE;
 478    while (page * VHOST_LOG_PAGE < address + length) {
 479        vu_log_page(dev->log_table, page);
 480        page += 1;
 481    }
 482
 483    vu_log_kick(dev);
 484}
 485
 486static void
 487vu_kick_cb(VuDev *dev, int condition, void *data)
 488{
 489    int index = (intptr_t)data;
 490    VuVirtq *vq = &dev->vq[index];
 491    int sock = vq->kick_fd;
 492    eventfd_t kick_data;
 493    ssize_t rc;
 494
 495    rc = eventfd_read(sock, &kick_data);
 496    if (rc == -1) {
 497        vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
 498        dev->remove_watch(dev, dev->vq[index].kick_fd);
 499    } else {
 500        DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
 501               kick_data, vq->handler, index);
 502        if (vq->handler) {
 503            vq->handler(dev, index);
 504        }
 505    }
 506}
 507
 508static bool
 509vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 510{
 511    vmsg->payload.u64 =
 512        /*
 513         * The following VIRTIO feature bits are supported by our virtqueue
 514         * implementation:
 515         */
 516        1ULL << VIRTIO_F_NOTIFY_ON_EMPTY |
 517        1ULL << VIRTIO_RING_F_INDIRECT_DESC |
 518        1ULL << VIRTIO_RING_F_EVENT_IDX |
 519        1ULL << VIRTIO_F_VERSION_1 |
 520
 521        /* vhost-user feature bits */
 522        1ULL << VHOST_F_LOG_ALL |
 523        1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
 524
 525    if (dev->iface->get_features) {
 526        vmsg->payload.u64 |= dev->iface->get_features(dev);
 527    }
 528
 529    vmsg->size = sizeof(vmsg->payload.u64);
 530    vmsg->fd_num = 0;
 531
 532    DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 533
 534    return true;
 535}
 536
 537static void
 538vu_set_enable_all_rings(VuDev *dev, bool enabled)
 539{
 540    uint16_t i;
 541
 542    for (i = 0; i < dev->max_queues; i++) {
 543        dev->vq[i].enable = enabled;
 544    }
 545}
 546
 547static bool
 548vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 549{
 550    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 551
 552    dev->features = vmsg->payload.u64;
 553    if (!vu_has_feature(dev, VIRTIO_F_VERSION_1)) {
 554        /*
 555         * We only support devices conforming to VIRTIO 1.0 or
 556         * later
 557         */
 558        vu_panic(dev, "virtio legacy devices aren't supported by libvhost-user");
 559        return false;
 560    }
 561
 562    if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
 563        vu_set_enable_all_rings(dev, true);
 564    }
 565
 566    if (dev->iface->set_features) {
 567        dev->iface->set_features(dev, dev->features);
 568    }
 569
 570    return false;
 571}
 572
 573static bool
 574vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
 575{
 576    return false;
 577}
 578
 579static void
 580vu_close_log(VuDev *dev)
 581{
 582    if (dev->log_table) {
 583        if (munmap(dev->log_table, dev->log_size) != 0) {
 584            perror("close log munmap() error");
 585        }
 586
 587        dev->log_table = NULL;
 588    }
 589    if (dev->log_call_fd != -1) {
 590        close(dev->log_call_fd);
 591        dev->log_call_fd = -1;
 592    }
 593}
 594
 595static bool
 596vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
 597{
 598    vu_set_enable_all_rings(dev, false);
 599
 600    return false;
 601}
 602
 603static bool
 604map_ring(VuDev *dev, VuVirtq *vq)
 605{
 606    vq->vring.desc = qva_to_va(dev, vq->vra.desc_user_addr);
 607    vq->vring.used = qva_to_va(dev, vq->vra.used_user_addr);
 608    vq->vring.avail = qva_to_va(dev, vq->vra.avail_user_addr);
 609
 610    DPRINT("Setting virtq addresses:\n");
 611    DPRINT("    vring_desc  at %p\n", vq->vring.desc);
 612    DPRINT("    vring_used  at %p\n", vq->vring.used);
 613    DPRINT("    vring_avail at %p\n", vq->vring.avail);
 614
 615    return !(vq->vring.desc && vq->vring.used && vq->vring.avail);
 616}
 617
 618static bool
 619generate_faults(VuDev *dev) {
 620    int i;
 621    for (i = 0; i < dev->nregions; i++) {
 622        VuDevRegion *dev_region = &dev->regions[i];
 623        int ret;
 624#ifdef UFFDIO_REGISTER
 625        /*
 626         * We should already have an open ufd. Mark each memory
 627         * range as ufd.
 628         * Discard any mapping we have here; note I can't use MADV_REMOVE
 629         * or fallocate to make the hole since I don't want to lose
 630         * data that's already arrived in the shared process.
 631         * TODO: How to do hugepage
 632         */
 633        ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
 634                      dev_region->size + dev_region->mmap_offset,
 635                      MADV_DONTNEED);
 636        if (ret) {
 637            fprintf(stderr,
 638                    "%s: Failed to madvise(DONTNEED) region %d: %s\n",
 639                    __func__, i, strerror(errno));
 640        }
 641        /*
 642         * Turn off transparent hugepages so we dont get lose wakeups
 643         * in neighbouring pages.
 644         * TODO: Turn this backon later.
 645         */
 646        ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
 647                      dev_region->size + dev_region->mmap_offset,
 648                      MADV_NOHUGEPAGE);
 649        if (ret) {
 650            /*
 651             * Note: This can happen legally on kernels that are configured
 652             * without madvise'able hugepages
 653             */
 654            fprintf(stderr,
 655                    "%s: Failed to madvise(NOHUGEPAGE) region %d: %s\n",
 656                    __func__, i, strerror(errno));
 657        }
 658        struct uffdio_register reg_struct;
 659        reg_struct.range.start = (uintptr_t)dev_region->mmap_addr;
 660        reg_struct.range.len = dev_region->size + dev_region->mmap_offset;
 661        reg_struct.mode = UFFDIO_REGISTER_MODE_MISSING;
 662
 663        if (ioctl(dev->postcopy_ufd, UFFDIO_REGISTER, &reg_struct)) {
 664            vu_panic(dev, "%s: Failed to userfault region %d "
 665                          "@%" PRIx64 " + size:%" PRIx64 " offset: %" PRIx64
 666                          ": (ufd=%d)%s\n",
 667                     __func__, i,
 668                     dev_region->mmap_addr,
 669                     dev_region->size, dev_region->mmap_offset,
 670                     dev->postcopy_ufd, strerror(errno));
 671            return false;
 672        }
 673        if (!(reg_struct.ioctls & ((__u64)1 << _UFFDIO_COPY))) {
 674            vu_panic(dev, "%s Region (%d) doesn't support COPY",
 675                     __func__, i);
 676            return false;
 677        }
 678        DPRINT("%s: region %d: Registered userfault for %"
 679               PRIx64 " + %" PRIx64 "\n", __func__, i,
 680               (uint64_t)reg_struct.range.start,
 681               (uint64_t)reg_struct.range.len);
 682        /* Now it's registered we can let the client at it */
 683        if (mprotect((void *)(uintptr_t)dev_region->mmap_addr,
 684                     dev_region->size + dev_region->mmap_offset,
 685                     PROT_READ | PROT_WRITE)) {
 686            vu_panic(dev, "failed to mprotect region %d for postcopy (%s)",
 687                     i, strerror(errno));
 688            return false;
 689        }
 690        /* TODO: Stash 'zero' support flags somewhere */
 691#endif
 692    }
 693
 694    return true;
 695}
 696
 697static bool
 698vu_add_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
 699    int i;
 700    bool track_ramblocks = dev->postcopy_listening;
 701    VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
 702    VuDevRegion *dev_region = &dev->regions[dev->nregions];
 703    void *mmap_addr;
 704
 705    if (vmsg->fd_num != 1) {
 706        vmsg_close_fds(vmsg);
 707        vu_panic(dev, "VHOST_USER_ADD_MEM_REG received %d fds - only 1 fd "
 708                      "should be sent for this message type", vmsg->fd_num);
 709        return false;
 710    }
 711
 712    if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
 713        close(vmsg->fds[0]);
 714        vu_panic(dev, "VHOST_USER_ADD_MEM_REG requires a message size of at "
 715                      "least %zu bytes and only %d bytes were received",
 716                      VHOST_USER_MEM_REG_SIZE, vmsg->size);
 717        return false;
 718    }
 719
 720    if (dev->nregions == VHOST_USER_MAX_RAM_SLOTS) {
 721        close(vmsg->fds[0]);
 722        vu_panic(dev, "failing attempt to hot add memory via "
 723                      "VHOST_USER_ADD_MEM_REG message because the backend has "
 724                      "no free ram slots available");
 725        return false;
 726    }
 727
 728    /*
 729     * If we are in postcopy mode and we receive a u64 payload with a 0 value
 730     * we know all the postcopy client bases have been received, and we
 731     * should start generating faults.
 732     */
 733    if (track_ramblocks &&
 734        vmsg->size == sizeof(vmsg->payload.u64) &&
 735        vmsg->payload.u64 == 0) {
 736        (void)generate_faults(dev);
 737        return false;
 738    }
 739
 740    DPRINT("Adding region: %u\n", dev->nregions);
 741    DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
 742           msg_region->guest_phys_addr);
 743    DPRINT("    memory_size:     0x%016"PRIx64"\n",
 744           msg_region->memory_size);
 745    DPRINT("    userspace_addr   0x%016"PRIx64"\n",
 746           msg_region->userspace_addr);
 747    DPRINT("    mmap_offset      0x%016"PRIx64"\n",
 748           msg_region->mmap_offset);
 749
 750    dev_region->gpa = msg_region->guest_phys_addr;
 751    dev_region->size = msg_region->memory_size;
 752    dev_region->qva = msg_region->userspace_addr;
 753    dev_region->mmap_offset = msg_region->mmap_offset;
 754
 755    /*
 756     * We don't use offset argument of mmap() since the
 757     * mapped address has to be page aligned, and we use huge
 758     * pages.
 759     */
 760    if (track_ramblocks) {
 761        /*
 762         * In postcopy we're using PROT_NONE here to catch anyone
 763         * accessing it before we userfault.
 764         */
 765        mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
 766                         PROT_NONE, MAP_SHARED | MAP_NORESERVE,
 767                         vmsg->fds[0], 0);
 768    } else {
 769        mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
 770                         PROT_READ | PROT_WRITE, MAP_SHARED | MAP_NORESERVE,
 771                         vmsg->fds[0], 0);
 772    }
 773
 774    if (mmap_addr == MAP_FAILED) {
 775        vu_panic(dev, "region mmap error: %s", strerror(errno));
 776    } else {
 777        dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
 778        DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
 779               dev_region->mmap_addr);
 780    }
 781
 782    close(vmsg->fds[0]);
 783
 784    if (track_ramblocks) {
 785        /*
 786         * Return the address to QEMU so that it can translate the ufd
 787         * fault addresses back.
 788         */
 789        msg_region->userspace_addr = (uintptr_t)(mmap_addr +
 790                                                 dev_region->mmap_offset);
 791
 792        /* Send the message back to qemu with the addresses filled in. */
 793        vmsg->fd_num = 0;
 794        DPRINT("Successfully added new region in postcopy\n");
 795        dev->nregions++;
 796        return true;
 797    } else {
 798        for (i = 0; i < dev->max_queues; i++) {
 799            if (dev->vq[i].vring.desc) {
 800                if (map_ring(dev, &dev->vq[i])) {
 801                    vu_panic(dev, "remapping queue %d for new memory region",
 802                             i);
 803                }
 804            }
 805        }
 806
 807        DPRINT("Successfully added new region\n");
 808        dev->nregions++;
 809        return false;
 810    }
 811}
 812
 813static inline bool reg_equal(VuDevRegion *vudev_reg,
 814                             VhostUserMemoryRegion *msg_reg)
 815{
 816    if (vudev_reg->gpa == msg_reg->guest_phys_addr &&
 817        vudev_reg->qva == msg_reg->userspace_addr &&
 818        vudev_reg->size == msg_reg->memory_size) {
 819        return true;
 820    }
 821
 822    return false;
 823}
 824
 825static bool
 826vu_rem_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
 827    VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
 828    int i;
 829    bool found = false;
 830
 831    if (vmsg->fd_num > 1) {
 832        vmsg_close_fds(vmsg);
 833        vu_panic(dev, "VHOST_USER_REM_MEM_REG received %d fds - at most 1 fd "
 834                      "should be sent for this message type", vmsg->fd_num);
 835        return false;
 836    }
 837
 838    if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
 839        vmsg_close_fds(vmsg);
 840        vu_panic(dev, "VHOST_USER_REM_MEM_REG requires a message size of at "
 841                      "least %zu bytes and only %d bytes were received",
 842                      VHOST_USER_MEM_REG_SIZE, vmsg->size);
 843        return false;
 844    }
 845
 846    DPRINT("Removing region:\n");
 847    DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
 848           msg_region->guest_phys_addr);
 849    DPRINT("    memory_size:     0x%016"PRIx64"\n",
 850           msg_region->memory_size);
 851    DPRINT("    userspace_addr   0x%016"PRIx64"\n",
 852           msg_region->userspace_addr);
 853    DPRINT("    mmap_offset      0x%016"PRIx64"\n",
 854           msg_region->mmap_offset);
 855
 856    for (i = 0; i < dev->nregions; i++) {
 857        if (reg_equal(&dev->regions[i], msg_region)) {
 858            VuDevRegion *r = &dev->regions[i];
 859            void *m = (void *) (uintptr_t) r->mmap_addr;
 860
 861            if (m) {
 862                munmap(m, r->size + r->mmap_offset);
 863            }
 864
 865            /*
 866             * Shift all affected entries by 1 to close the hole at index i and
 867             * zero out the last entry.
 868             */
 869            memmove(dev->regions + i, dev->regions + i + 1,
 870                    sizeof(VuDevRegion) * (dev->nregions - i - 1));
 871            memset(dev->regions + dev->nregions - 1, 0, sizeof(VuDevRegion));
 872            DPRINT("Successfully removed a region\n");
 873            dev->nregions--;
 874            i--;
 875
 876            found = true;
 877
 878            /* Continue the search for eventual duplicates. */
 879        }
 880    }
 881
 882    if (!found) {
 883        vu_panic(dev, "Specified region not found\n");
 884    }
 885
 886    vmsg_close_fds(vmsg);
 887
 888    return false;
 889}
 890
 891static bool
 892vu_set_mem_table_exec_postcopy(VuDev *dev, VhostUserMsg *vmsg)
 893{
 894    int i;
 895    VhostUserMemory m = vmsg->payload.memory, *memory = &m;
 896    dev->nregions = memory->nregions;
 897
 898    DPRINT("Nregions: %u\n", memory->nregions);
 899    for (i = 0; i < dev->nregions; i++) {
 900        void *mmap_addr;
 901        VhostUserMemoryRegion *msg_region = &memory->regions[i];
 902        VuDevRegion *dev_region = &dev->regions[i];
 903
 904        DPRINT("Region %d\n", i);
 905        DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
 906               msg_region->guest_phys_addr);
 907        DPRINT("    memory_size:     0x%016"PRIx64"\n",
 908               msg_region->memory_size);
 909        DPRINT("    userspace_addr   0x%016"PRIx64"\n",
 910               msg_region->userspace_addr);
 911        DPRINT("    mmap_offset      0x%016"PRIx64"\n",
 912               msg_region->mmap_offset);
 913
 914        dev_region->gpa = msg_region->guest_phys_addr;
 915        dev_region->size = msg_region->memory_size;
 916        dev_region->qva = msg_region->userspace_addr;
 917        dev_region->mmap_offset = msg_region->mmap_offset;
 918
 919        /* We don't use offset argument of mmap() since the
 920         * mapped address has to be page aligned, and we use huge
 921         * pages.
 922         * In postcopy we're using PROT_NONE here to catch anyone
 923         * accessing it before we userfault
 924         */
 925        mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
 926                         PROT_NONE, MAP_SHARED | MAP_NORESERVE,
 927                         vmsg->fds[i], 0);
 928
 929        if (mmap_addr == MAP_FAILED) {
 930            vu_panic(dev, "region mmap error: %s", strerror(errno));
 931        } else {
 932            dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
 933            DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
 934                   dev_region->mmap_addr);
 935        }
 936
 937        /* Return the address to QEMU so that it can translate the ufd
 938         * fault addresses back.
 939         */
 940        msg_region->userspace_addr = (uintptr_t)(mmap_addr +
 941                                                 dev_region->mmap_offset);
 942        close(vmsg->fds[i]);
 943    }
 944
 945    /* Send the message back to qemu with the addresses filled in */
 946    vmsg->fd_num = 0;
 947    if (!vu_send_reply(dev, dev->sock, vmsg)) {
 948        vu_panic(dev, "failed to respond to set-mem-table for postcopy");
 949        return false;
 950    }
 951
 952    /* Wait for QEMU to confirm that it's registered the handler for the
 953     * faults.
 954     */
 955    if (!dev->read_msg(dev, dev->sock, vmsg) ||
 956        vmsg->size != sizeof(vmsg->payload.u64) ||
 957        vmsg->payload.u64 != 0) {
 958        vu_panic(dev, "failed to receive valid ack for postcopy set-mem-table");
 959        return false;
 960    }
 961
 962    /* OK, now we can go and register the memory and generate faults */
 963    (void)generate_faults(dev);
 964
 965    return false;
 966}
 967
 968static bool
 969vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
 970{
 971    int i;
 972    VhostUserMemory m = vmsg->payload.memory, *memory = &m;
 973
 974    for (i = 0; i < dev->nregions; i++) {
 975        VuDevRegion *r = &dev->regions[i];
 976        void *m = (void *) (uintptr_t) r->mmap_addr;
 977
 978        if (m) {
 979            munmap(m, r->size + r->mmap_offset);
 980        }
 981    }
 982    dev->nregions = memory->nregions;
 983
 984    if (dev->postcopy_listening) {
 985        return vu_set_mem_table_exec_postcopy(dev, vmsg);
 986    }
 987
 988    DPRINT("Nregions: %u\n", memory->nregions);
 989    for (i = 0; i < dev->nregions; i++) {
 990        void *mmap_addr;
 991        VhostUserMemoryRegion *msg_region = &memory->regions[i];
 992        VuDevRegion *dev_region = &dev->regions[i];
 993
 994        DPRINT("Region %d\n", i);
 995        DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
 996               msg_region->guest_phys_addr);
 997        DPRINT("    memory_size:     0x%016"PRIx64"\n",
 998               msg_region->memory_size);
 999        DPRINT("    userspace_addr   0x%016"PRIx64"\n",
1000               msg_region->userspace_addr);
1001        DPRINT("    mmap_offset      0x%016"PRIx64"\n",
1002               msg_region->mmap_offset);
1003
1004        dev_region->gpa = msg_region->guest_phys_addr;
1005        dev_region->size = msg_region->memory_size;
1006        dev_region->qva = msg_region->userspace_addr;
1007        dev_region->mmap_offset = msg_region->mmap_offset;
1008
1009        /* We don't use offset argument of mmap() since the
1010         * mapped address has to be page aligned, and we use huge
1011         * pages.  */
1012        mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
1013                         PROT_READ | PROT_WRITE, MAP_SHARED | MAP_NORESERVE,
1014                         vmsg->fds[i], 0);
1015
1016        if (mmap_addr == MAP_FAILED) {
1017            vu_panic(dev, "region mmap error: %s", strerror(errno));
1018        } else {
1019            dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
1020            DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
1021                   dev_region->mmap_addr);
1022        }
1023
1024        close(vmsg->fds[i]);
1025    }
1026
1027    for (i = 0; i < dev->max_queues; i++) {
1028        if (dev->vq[i].vring.desc) {
1029            if (map_ring(dev, &dev->vq[i])) {
1030                vu_panic(dev, "remapping queue %d during setmemtable", i);
1031            }
1032        }
1033    }
1034
1035    return false;
1036}
1037
1038static bool
1039vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1040{
1041    int fd;
1042    uint64_t log_mmap_size, log_mmap_offset;
1043    void *rc;
1044
1045    if (vmsg->fd_num != 1 ||
1046        vmsg->size != sizeof(vmsg->payload.log)) {
1047        vu_panic(dev, "Invalid log_base message");
1048        return true;
1049    }
1050
1051    fd = vmsg->fds[0];
1052    log_mmap_offset = vmsg->payload.log.mmap_offset;
1053    log_mmap_size = vmsg->payload.log.mmap_size;
1054    DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
1055    DPRINT("Log mmap_size:   %"PRId64"\n", log_mmap_size);
1056
1057    rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
1058              log_mmap_offset);
1059    close(fd);
1060    if (rc == MAP_FAILED) {
1061        perror("log mmap error");
1062    }
1063
1064    if (dev->log_table) {
1065        munmap(dev->log_table, dev->log_size);
1066    }
1067    dev->log_table = rc;
1068    dev->log_size = log_mmap_size;
1069
1070    vmsg->size = sizeof(vmsg->payload.u64);
1071    vmsg->fd_num = 0;
1072
1073    return true;
1074}
1075
1076static bool
1077vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
1078{
1079    if (vmsg->fd_num != 1) {
1080        vu_panic(dev, "Invalid log_fd message");
1081        return false;
1082    }
1083
1084    if (dev->log_call_fd != -1) {
1085        close(dev->log_call_fd);
1086    }
1087    dev->log_call_fd = vmsg->fds[0];
1088    DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
1089
1090    return false;
1091}
1092
1093static bool
1094vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1095{
1096    unsigned int index = vmsg->payload.state.index;
1097    unsigned int num = vmsg->payload.state.num;
1098
1099    DPRINT("State.index: %u\n", index);
1100    DPRINT("State.num:   %u\n", num);
1101    dev->vq[index].vring.num = num;
1102
1103    return false;
1104}
1105
1106static bool
1107vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
1108{
1109    struct vhost_vring_addr addr = vmsg->payload.addr, *vra = &addr;
1110    unsigned int index = vra->index;
1111    VuVirtq *vq = &dev->vq[index];
1112
1113    DPRINT("vhost_vring_addr:\n");
1114    DPRINT("    index:  %d\n", vra->index);
1115    DPRINT("    flags:  %d\n", vra->flags);
1116    DPRINT("    desc_user_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->desc_user_addr);
1117    DPRINT("    used_user_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->used_user_addr);
1118    DPRINT("    avail_user_addr:  0x%016" PRIx64 "\n", (uint64_t)vra->avail_user_addr);
1119    DPRINT("    log_guest_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->log_guest_addr);
1120
1121    vq->vra = *vra;
1122    vq->vring.flags = vra->flags;
1123    vq->vring.log_guest_addr = vra->log_guest_addr;
1124
1125
1126    if (map_ring(dev, vq)) {
1127        vu_panic(dev, "Invalid vring_addr message");
1128        return false;
1129    }
1130
1131    vq->used_idx = le16toh(vq->vring.used->idx);
1132
1133    if (vq->last_avail_idx != vq->used_idx) {
1134        bool resume = dev->iface->queue_is_processed_in_order &&
1135            dev->iface->queue_is_processed_in_order(dev, index);
1136
1137        DPRINT("Last avail index != used index: %u != %u%s\n",
1138               vq->last_avail_idx, vq->used_idx,
1139               resume ? ", resuming" : "");
1140
1141        if (resume) {
1142            vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
1143        }
1144    }
1145
1146    return false;
1147}
1148
1149static bool
1150vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1151{
1152    unsigned int index = vmsg->payload.state.index;
1153    unsigned int num = vmsg->payload.state.num;
1154
1155    DPRINT("State.index: %u\n", index);
1156    DPRINT("State.num:   %u\n", num);
1157    dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
1158
1159    return false;
1160}
1161
1162static bool
1163vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1164{
1165    unsigned int index = vmsg->payload.state.index;
1166
1167    DPRINT("State.index: %u\n", index);
1168    vmsg->payload.state.num = dev->vq[index].last_avail_idx;
1169    vmsg->size = sizeof(vmsg->payload.state);
1170
1171    dev->vq[index].started = false;
1172    if (dev->iface->queue_set_started) {
1173        dev->iface->queue_set_started(dev, index, false);
1174    }
1175
1176    if (dev->vq[index].call_fd != -1) {
1177        close(dev->vq[index].call_fd);
1178        dev->vq[index].call_fd = -1;
1179    }
1180    if (dev->vq[index].kick_fd != -1) {
1181        dev->remove_watch(dev, dev->vq[index].kick_fd);
1182        close(dev->vq[index].kick_fd);
1183        dev->vq[index].kick_fd = -1;
1184    }
1185
1186    return true;
1187}
1188
1189static bool
1190vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
1191{
1192    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1193    bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1194
1195    if (index >= dev->max_queues) {
1196        vmsg_close_fds(vmsg);
1197        vu_panic(dev, "Invalid queue index: %u", index);
1198        return false;
1199    }
1200
1201    if (nofd) {
1202        vmsg_close_fds(vmsg);
1203        return true;
1204    }
1205
1206    if (vmsg->fd_num != 1) {
1207        vmsg_close_fds(vmsg);
1208        vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
1209        return false;
1210    }
1211
1212    return true;
1213}
1214
1215static int
1216inflight_desc_compare(const void *a, const void *b)
1217{
1218    VuVirtqInflightDesc *desc0 = (VuVirtqInflightDesc *)a,
1219                        *desc1 = (VuVirtqInflightDesc *)b;
1220
1221    if (desc1->counter > desc0->counter &&
1222        (desc1->counter - desc0->counter) < VIRTQUEUE_MAX_SIZE * 2) {
1223        return 1;
1224    }
1225
1226    return -1;
1227}
1228
1229static int
1230vu_check_queue_inflights(VuDev *dev, VuVirtq *vq)
1231{
1232    int i = 0;
1233
1234    if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
1235        return 0;
1236    }
1237
1238    if (unlikely(!vq->inflight)) {
1239        return -1;
1240    }
1241
1242    if (unlikely(!vq->inflight->version)) {
1243        /* initialize the buffer */
1244        vq->inflight->version = INFLIGHT_VERSION;
1245        return 0;
1246    }
1247
1248    vq->used_idx = le16toh(vq->vring.used->idx);
1249    vq->resubmit_num = 0;
1250    vq->resubmit_list = NULL;
1251    vq->counter = 0;
1252
1253    if (unlikely(vq->inflight->used_idx != vq->used_idx)) {
1254        vq->inflight->desc[vq->inflight->last_batch_head].inflight = 0;
1255
1256        barrier();
1257
1258        vq->inflight->used_idx = vq->used_idx;
1259    }
1260
1261    for (i = 0; i < vq->inflight->desc_num; i++) {
1262        if (vq->inflight->desc[i].inflight == 1) {
1263            vq->inuse++;
1264        }
1265    }
1266
1267    vq->shadow_avail_idx = vq->last_avail_idx = vq->inuse + vq->used_idx;
1268
1269    if (vq->inuse) {
1270        vq->resubmit_list = calloc(vq->inuse, sizeof(VuVirtqInflightDesc));
1271        if (!vq->resubmit_list) {
1272            return -1;
1273        }
1274
1275        for (i = 0; i < vq->inflight->desc_num; i++) {
1276            if (vq->inflight->desc[i].inflight) {
1277                vq->resubmit_list[vq->resubmit_num].index = i;
1278                vq->resubmit_list[vq->resubmit_num].counter =
1279                                        vq->inflight->desc[i].counter;
1280                vq->resubmit_num++;
1281            }
1282        }
1283
1284        if (vq->resubmit_num > 1) {
1285            qsort(vq->resubmit_list, vq->resubmit_num,
1286                  sizeof(VuVirtqInflightDesc), inflight_desc_compare);
1287        }
1288        vq->counter = vq->resubmit_list[0].counter + 1;
1289    }
1290
1291    /* in case of I/O hang after reconnecting */
1292    if (eventfd_write(vq->kick_fd, 1)) {
1293        return -1;
1294    }
1295
1296    return 0;
1297}
1298
1299static bool
1300vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
1301{
1302    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1303    bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1304
1305    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1306
1307    if (!vu_check_queue_msg_file(dev, vmsg)) {
1308        return false;
1309    }
1310
1311    if (dev->vq[index].kick_fd != -1) {
1312        dev->remove_watch(dev, dev->vq[index].kick_fd);
1313        close(dev->vq[index].kick_fd);
1314        dev->vq[index].kick_fd = -1;
1315    }
1316
1317    dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0];
1318    DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index);
1319
1320    dev->vq[index].started = true;
1321    if (dev->iface->queue_set_started) {
1322        dev->iface->queue_set_started(dev, index, true);
1323    }
1324
1325    if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
1326        dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
1327                       vu_kick_cb, (void *)(long)index);
1328
1329        DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
1330               dev->vq[index].kick_fd, index);
1331    }
1332
1333    if (vu_check_queue_inflights(dev, &dev->vq[index])) {
1334        vu_panic(dev, "Failed to check inflights for vq: %d\n", index);
1335    }
1336
1337    return false;
1338}
1339
1340void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
1341                          vu_queue_handler_cb handler)
1342{
1343    int qidx = vq - dev->vq;
1344
1345    vq->handler = handler;
1346    if (vq->kick_fd >= 0) {
1347        if (handler) {
1348            dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
1349                           vu_kick_cb, (void *)(long)qidx);
1350        } else {
1351            dev->remove_watch(dev, vq->kick_fd);
1352        }
1353    }
1354}
1355
1356bool vu_set_queue_host_notifier(VuDev *dev, VuVirtq *vq, int fd,
1357                                int size, int offset)
1358{
1359    int qidx = vq - dev->vq;
1360    int fd_num = 0;
1361    VhostUserMsg vmsg = {
1362        .request = VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG,
1363        .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1364        .size = sizeof(vmsg.payload.area),
1365        .payload.area = {
1366            .u64 = qidx & VHOST_USER_VRING_IDX_MASK,
1367            .size = size,
1368            .offset = offset,
1369        },
1370    };
1371
1372    if (fd == -1) {
1373        vmsg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1374    } else {
1375        vmsg.fds[fd_num++] = fd;
1376    }
1377
1378    vmsg.fd_num = fd_num;
1379
1380    if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD)) {
1381        return false;
1382    }
1383
1384    pthread_mutex_lock(&dev->slave_mutex);
1385    if (!vu_message_write(dev, dev->slave_fd, &vmsg)) {
1386        pthread_mutex_unlock(&dev->slave_mutex);
1387        return false;
1388    }
1389
1390    /* Also unlocks the slave_mutex */
1391    return vu_process_message_reply(dev, &vmsg);
1392}
1393
1394static bool
1395vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
1396{
1397    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1398    bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1399
1400    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1401
1402    if (!vu_check_queue_msg_file(dev, vmsg)) {
1403        return false;
1404    }
1405
1406    if (dev->vq[index].call_fd != -1) {
1407        close(dev->vq[index].call_fd);
1408        dev->vq[index].call_fd = -1;
1409    }
1410
1411    dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0];
1412
1413    /* in case of I/O hang after reconnecting */
1414    if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) {
1415        return -1;
1416    }
1417
1418    DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index);
1419
1420    return false;
1421}
1422
1423static bool
1424vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
1425{
1426    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1427    bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1428
1429    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1430
1431    if (!vu_check_queue_msg_file(dev, vmsg)) {
1432        return false;
1433    }
1434
1435    if (dev->vq[index].err_fd != -1) {
1436        close(dev->vq[index].err_fd);
1437        dev->vq[index].err_fd = -1;
1438    }
1439
1440    dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0];
1441
1442    return false;
1443}
1444
1445static bool
1446vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1447{
1448    /*
1449     * Note that we support, but intentionally do not set,
1450     * VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS. This means that
1451     * a device implementation can return it in its callback
1452     * (get_protocol_features) if it wants to use this for
1453     * simulation, but it is otherwise not desirable (if even
1454     * implemented by the master.)
1455     */
1456    uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_MQ |
1457                        1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
1458                        1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ |
1459                        1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER |
1460                        1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD |
1461                        1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK |
1462                        1ULL << VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS;
1463
1464    if (have_userfault()) {
1465        features |= 1ULL << VHOST_USER_PROTOCOL_F_PAGEFAULT;
1466    }
1467
1468    if (dev->iface->get_config && dev->iface->set_config) {
1469        features |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
1470    }
1471
1472    if (dev->iface->get_protocol_features) {
1473        features |= dev->iface->get_protocol_features(dev);
1474    }
1475
1476    vmsg_set_reply_u64(vmsg, features);
1477    return true;
1478}
1479
1480static bool
1481vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1482{
1483    uint64_t features = vmsg->payload.u64;
1484
1485    DPRINT("u64: 0x%016"PRIx64"\n", features);
1486
1487    dev->protocol_features = vmsg->payload.u64;
1488
1489    if (vu_has_protocol_feature(dev,
1490                                VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
1491        (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ) ||
1492         !vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
1493        /*
1494         * The use case for using messages for kick/call is simulation, to make
1495         * the kick and call synchronous. To actually get that behaviour, both
1496         * of the other features are required.
1497         * Theoretically, one could use only kick messages, or do them without
1498         * having F_REPLY_ACK, but too many (possibly pending) messages on the
1499         * socket will eventually cause the master to hang, to avoid this in
1500         * scenarios where not desired enforce that the settings are in a way
1501         * that actually enables the simulation case.
1502         */
1503        vu_panic(dev,
1504                 "F_IN_BAND_NOTIFICATIONS requires F_SLAVE_REQ && F_REPLY_ACK");
1505        return false;
1506    }
1507
1508    if (dev->iface->set_protocol_features) {
1509        dev->iface->set_protocol_features(dev, features);
1510    }
1511
1512    return false;
1513}
1514
1515static bool
1516vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1517{
1518    vmsg_set_reply_u64(vmsg, dev->max_queues);
1519    return true;
1520}
1521
1522static bool
1523vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
1524{
1525    unsigned int index = vmsg->payload.state.index;
1526    unsigned int enable = vmsg->payload.state.num;
1527
1528    DPRINT("State.index: %u\n", index);
1529    DPRINT("State.enable:   %u\n", enable);
1530
1531    if (index >= dev->max_queues) {
1532        vu_panic(dev, "Invalid vring_enable index: %u", index);
1533        return false;
1534    }
1535
1536    dev->vq[index].enable = enable;
1537    return false;
1538}
1539
1540static bool
1541vu_set_slave_req_fd(VuDev *dev, VhostUserMsg *vmsg)
1542{
1543    if (vmsg->fd_num != 1) {
1544        vu_panic(dev, "Invalid slave_req_fd message (%d fd's)", vmsg->fd_num);
1545        return false;
1546    }
1547
1548    if (dev->slave_fd != -1) {
1549        close(dev->slave_fd);
1550    }
1551    dev->slave_fd = vmsg->fds[0];
1552    DPRINT("Got slave_fd: %d\n", vmsg->fds[0]);
1553
1554    return false;
1555}
1556
1557static bool
1558vu_get_config(VuDev *dev, VhostUserMsg *vmsg)
1559{
1560    int ret = -1;
1561
1562    if (dev->iface->get_config) {
1563        ret = dev->iface->get_config(dev, vmsg->payload.config.region,
1564                                     vmsg->payload.config.size);
1565    }
1566
1567    if (ret) {
1568        /* resize to zero to indicate an error to master */
1569        vmsg->size = 0;
1570    }
1571
1572    return true;
1573}
1574
1575static bool
1576vu_set_config(VuDev *dev, VhostUserMsg *vmsg)
1577{
1578    int ret = -1;
1579
1580    if (dev->iface->set_config) {
1581        ret = dev->iface->set_config(dev, vmsg->payload.config.region,
1582                                     vmsg->payload.config.offset,
1583                                     vmsg->payload.config.size,
1584                                     vmsg->payload.config.flags);
1585        if (ret) {
1586            vu_panic(dev, "Set virtio configuration space failed");
1587        }
1588    }
1589
1590    return false;
1591}
1592
1593static bool
1594vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg)
1595{
1596    dev->postcopy_ufd = -1;
1597#ifdef UFFDIO_API
1598    struct uffdio_api api_struct;
1599
1600    dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
1601    vmsg->size = 0;
1602#endif
1603
1604    if (dev->postcopy_ufd == -1) {
1605        vu_panic(dev, "Userfaultfd not available: %s", strerror(errno));
1606        goto out;
1607    }
1608
1609#ifdef UFFDIO_API
1610    api_struct.api = UFFD_API;
1611    api_struct.features = 0;
1612    if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) {
1613        vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno));
1614        close(dev->postcopy_ufd);
1615        dev->postcopy_ufd = -1;
1616        goto out;
1617    }
1618    /* TODO: Stash feature flags somewhere */
1619#endif
1620
1621out:
1622    /* Return a ufd to the QEMU */
1623    vmsg->fd_num = 1;
1624    vmsg->fds[0] = dev->postcopy_ufd;
1625    return true; /* = send a reply */
1626}
1627
1628static bool
1629vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg)
1630{
1631    if (dev->nregions) {
1632        vu_panic(dev, "Regions already registered at postcopy-listen");
1633        vmsg_set_reply_u64(vmsg, -1);
1634        return true;
1635    }
1636    dev->postcopy_listening = true;
1637
1638    vmsg_set_reply_u64(vmsg, 0);
1639    return true;
1640}
1641
1642static bool
1643vu_set_postcopy_end(VuDev *dev, VhostUserMsg *vmsg)
1644{
1645    DPRINT("%s: Entry\n", __func__);
1646    dev->postcopy_listening = false;
1647    if (dev->postcopy_ufd > 0) {
1648        close(dev->postcopy_ufd);
1649        dev->postcopy_ufd = -1;
1650        DPRINT("%s: Done close\n", __func__);
1651    }
1652
1653    vmsg_set_reply_u64(vmsg, 0);
1654    DPRINT("%s: exit\n", __func__);
1655    return true;
1656}
1657
1658static inline uint64_t
1659vu_inflight_queue_size(uint16_t queue_size)
1660{
1661    return ALIGN_UP(sizeof(VuDescStateSplit) * queue_size +
1662           sizeof(uint16_t), INFLIGHT_ALIGNMENT);
1663}
1664
1665#ifdef MFD_ALLOW_SEALING
1666static void *
1667memfd_alloc(const char *name, size_t size, unsigned int flags, int *fd)
1668{
1669    void *ptr;
1670    int ret;
1671
1672    *fd = memfd_create(name, MFD_ALLOW_SEALING);
1673    if (*fd < 0) {
1674        return NULL;
1675    }
1676
1677    ret = ftruncate(*fd, size);
1678    if (ret < 0) {
1679        close(*fd);
1680        return NULL;
1681    }
1682
1683    ret = fcntl(*fd, F_ADD_SEALS, flags);
1684    if (ret < 0) {
1685        close(*fd);
1686        return NULL;
1687    }
1688
1689    ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, *fd, 0);
1690    if (ptr == MAP_FAILED) {
1691        close(*fd);
1692        return NULL;
1693    }
1694
1695    return ptr;
1696}
1697#endif
1698
1699static bool
1700vu_get_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1701{
1702    int fd = -1;
1703    void *addr = NULL;
1704    uint64_t mmap_size;
1705    uint16_t num_queues, queue_size;
1706
1707    if (vmsg->size != sizeof(vmsg->payload.inflight)) {
1708        vu_panic(dev, "Invalid get_inflight_fd message:%d", vmsg->size);
1709        vmsg->payload.inflight.mmap_size = 0;
1710        return true;
1711    }
1712
1713    num_queues = vmsg->payload.inflight.num_queues;
1714    queue_size = vmsg->payload.inflight.queue_size;
1715
1716    DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1717    DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1718
1719    mmap_size = vu_inflight_queue_size(queue_size) * num_queues;
1720
1721#ifdef MFD_ALLOW_SEALING
1722    addr = memfd_alloc("vhost-inflight", mmap_size,
1723                       F_SEAL_GROW | F_SEAL_SHRINK | F_SEAL_SEAL,
1724                       &fd);
1725#else
1726    vu_panic(dev, "Not implemented: memfd support is missing");
1727#endif
1728
1729    if (!addr) {
1730        vu_panic(dev, "Failed to alloc vhost inflight area");
1731        vmsg->payload.inflight.mmap_size = 0;
1732        return true;
1733    }
1734
1735    memset(addr, 0, mmap_size);
1736
1737    dev->inflight_info.addr = addr;
1738    dev->inflight_info.size = vmsg->payload.inflight.mmap_size = mmap_size;
1739    dev->inflight_info.fd = vmsg->fds[0] = fd;
1740    vmsg->fd_num = 1;
1741    vmsg->payload.inflight.mmap_offset = 0;
1742
1743    DPRINT("send inflight mmap_size: %"PRId64"\n",
1744           vmsg->payload.inflight.mmap_size);
1745    DPRINT("send inflight mmap offset: %"PRId64"\n",
1746           vmsg->payload.inflight.mmap_offset);
1747
1748    return true;
1749}
1750
1751static bool
1752vu_set_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1753{
1754    int fd, i;
1755    uint64_t mmap_size, mmap_offset;
1756    uint16_t num_queues, queue_size;
1757    void *rc;
1758
1759    if (vmsg->fd_num != 1 ||
1760        vmsg->size != sizeof(vmsg->payload.inflight)) {
1761        vu_panic(dev, "Invalid set_inflight_fd message size:%d fds:%d",
1762                 vmsg->size, vmsg->fd_num);
1763        return false;
1764    }
1765
1766    fd = vmsg->fds[0];
1767    mmap_size = vmsg->payload.inflight.mmap_size;
1768    mmap_offset = vmsg->payload.inflight.mmap_offset;
1769    num_queues = vmsg->payload.inflight.num_queues;
1770    queue_size = vmsg->payload.inflight.queue_size;
1771
1772    DPRINT("set_inflight_fd mmap_size: %"PRId64"\n", mmap_size);
1773    DPRINT("set_inflight_fd mmap_offset: %"PRId64"\n", mmap_offset);
1774    DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1775    DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1776
1777    rc = mmap(0, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED,
1778              fd, mmap_offset);
1779
1780    if (rc == MAP_FAILED) {
1781        vu_panic(dev, "set_inflight_fd mmap error: %s", strerror(errno));
1782        return false;
1783    }
1784
1785    if (dev->inflight_info.fd) {
1786        close(dev->inflight_info.fd);
1787    }
1788
1789    if (dev->inflight_info.addr) {
1790        munmap(dev->inflight_info.addr, dev->inflight_info.size);
1791    }
1792
1793    dev->inflight_info.fd = fd;
1794    dev->inflight_info.addr = rc;
1795    dev->inflight_info.size = mmap_size;
1796
1797    for (i = 0; i < num_queues; i++) {
1798        dev->vq[i].inflight = (VuVirtqInflight *)rc;
1799        dev->vq[i].inflight->desc_num = queue_size;
1800        rc = (void *)((char *)rc + vu_inflight_queue_size(queue_size));
1801    }
1802
1803    return false;
1804}
1805
1806static bool
1807vu_handle_vring_kick(VuDev *dev, VhostUserMsg *vmsg)
1808{
1809    unsigned int index = vmsg->payload.state.index;
1810
1811    if (index >= dev->max_queues) {
1812        vu_panic(dev, "Invalid queue index: %u", index);
1813        return false;
1814    }
1815
1816    DPRINT("Got kick message: handler:%p idx:%u\n",
1817           dev->vq[index].handler, index);
1818
1819    if (!dev->vq[index].started) {
1820        dev->vq[index].started = true;
1821
1822        if (dev->iface->queue_set_started) {
1823            dev->iface->queue_set_started(dev, index, true);
1824        }
1825    }
1826
1827    if (dev->vq[index].handler) {
1828        dev->vq[index].handler(dev, index);
1829    }
1830
1831    return false;
1832}
1833
1834static bool vu_handle_get_max_memslots(VuDev *dev, VhostUserMsg *vmsg)
1835{
1836    vmsg_set_reply_u64(vmsg, VHOST_USER_MAX_RAM_SLOTS);
1837
1838    DPRINT("u64: 0x%016"PRIx64"\n", (uint64_t) VHOST_USER_MAX_RAM_SLOTS);
1839
1840    return true;
1841}
1842
1843static bool
1844vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
1845{
1846    int do_reply = 0;
1847
1848    /* Print out generic part of the request. */
1849    DPRINT("================ Vhost user message ================\n");
1850    DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
1851           vmsg->request);
1852    DPRINT("Flags:   0x%x\n", vmsg->flags);
1853    DPRINT("Size:    %u\n", vmsg->size);
1854
1855    if (vmsg->fd_num) {
1856        int i;
1857        DPRINT("Fds:");
1858        for (i = 0; i < vmsg->fd_num; i++) {
1859            DPRINT(" %d", vmsg->fds[i]);
1860        }
1861        DPRINT("\n");
1862    }
1863
1864    if (dev->iface->process_msg &&
1865        dev->iface->process_msg(dev, vmsg, &do_reply)) {
1866        return do_reply;
1867    }
1868
1869    switch (vmsg->request) {
1870    case VHOST_USER_GET_FEATURES:
1871        return vu_get_features_exec(dev, vmsg);
1872    case VHOST_USER_SET_FEATURES:
1873        return vu_set_features_exec(dev, vmsg);
1874    case VHOST_USER_GET_PROTOCOL_FEATURES:
1875        return vu_get_protocol_features_exec(dev, vmsg);
1876    case VHOST_USER_SET_PROTOCOL_FEATURES:
1877        return vu_set_protocol_features_exec(dev, vmsg);
1878    case VHOST_USER_SET_OWNER:
1879        return vu_set_owner_exec(dev, vmsg);
1880    case VHOST_USER_RESET_OWNER:
1881        return vu_reset_device_exec(dev, vmsg);
1882    case VHOST_USER_SET_MEM_TABLE:
1883        return vu_set_mem_table_exec(dev, vmsg);
1884    case VHOST_USER_SET_LOG_BASE:
1885        return vu_set_log_base_exec(dev, vmsg);
1886    case VHOST_USER_SET_LOG_FD:
1887        return vu_set_log_fd_exec(dev, vmsg);
1888    case VHOST_USER_SET_VRING_NUM:
1889        return vu_set_vring_num_exec(dev, vmsg);
1890    case VHOST_USER_SET_VRING_ADDR:
1891        return vu_set_vring_addr_exec(dev, vmsg);
1892    case VHOST_USER_SET_VRING_BASE:
1893        return vu_set_vring_base_exec(dev, vmsg);
1894    case VHOST_USER_GET_VRING_BASE:
1895        return vu_get_vring_base_exec(dev, vmsg);
1896    case VHOST_USER_SET_VRING_KICK:
1897        return vu_set_vring_kick_exec(dev, vmsg);
1898    case VHOST_USER_SET_VRING_CALL:
1899        return vu_set_vring_call_exec(dev, vmsg);
1900    case VHOST_USER_SET_VRING_ERR:
1901        return vu_set_vring_err_exec(dev, vmsg);
1902    case VHOST_USER_GET_QUEUE_NUM:
1903        return vu_get_queue_num_exec(dev, vmsg);
1904    case VHOST_USER_SET_VRING_ENABLE:
1905        return vu_set_vring_enable_exec(dev, vmsg);
1906    case VHOST_USER_SET_SLAVE_REQ_FD:
1907        return vu_set_slave_req_fd(dev, vmsg);
1908    case VHOST_USER_GET_CONFIG:
1909        return vu_get_config(dev, vmsg);
1910    case VHOST_USER_SET_CONFIG:
1911        return vu_set_config(dev, vmsg);
1912    case VHOST_USER_NONE:
1913        /* if you need processing before exit, override iface->process_msg */
1914        exit(0);
1915    case VHOST_USER_POSTCOPY_ADVISE:
1916        return vu_set_postcopy_advise(dev, vmsg);
1917    case VHOST_USER_POSTCOPY_LISTEN:
1918        return vu_set_postcopy_listen(dev, vmsg);
1919    case VHOST_USER_POSTCOPY_END:
1920        return vu_set_postcopy_end(dev, vmsg);
1921    case VHOST_USER_GET_INFLIGHT_FD:
1922        return vu_get_inflight_fd(dev, vmsg);
1923    case VHOST_USER_SET_INFLIGHT_FD:
1924        return vu_set_inflight_fd(dev, vmsg);
1925    case VHOST_USER_VRING_KICK:
1926        return vu_handle_vring_kick(dev, vmsg);
1927    case VHOST_USER_GET_MAX_MEM_SLOTS:
1928        return vu_handle_get_max_memslots(dev, vmsg);
1929    case VHOST_USER_ADD_MEM_REG:
1930        return vu_add_mem_reg(dev, vmsg);
1931    case VHOST_USER_REM_MEM_REG:
1932        return vu_rem_mem_reg(dev, vmsg);
1933    default:
1934        vmsg_close_fds(vmsg);
1935        vu_panic(dev, "Unhandled request: %d", vmsg->request);
1936    }
1937
1938    return false;
1939}
1940
1941bool
1942vu_dispatch(VuDev *dev)
1943{
1944    VhostUserMsg vmsg = { 0, };
1945    int reply_requested;
1946    bool need_reply, success = false;
1947
1948    if (!dev->read_msg(dev, dev->sock, &vmsg)) {
1949        goto end;
1950    }
1951
1952    need_reply = vmsg.flags & VHOST_USER_NEED_REPLY_MASK;
1953
1954    reply_requested = vu_process_message(dev, &vmsg);
1955    if (!reply_requested && need_reply) {
1956        vmsg_set_reply_u64(&vmsg, 0);
1957        reply_requested = 1;
1958    }
1959
1960    if (!reply_requested) {
1961        success = true;
1962        goto end;
1963    }
1964
1965    if (!vu_send_reply(dev, dev->sock, &vmsg)) {
1966        goto end;
1967    }
1968
1969    success = true;
1970
1971end:
1972    free(vmsg.data);
1973    return success;
1974}
1975
1976void
1977vu_deinit(VuDev *dev)
1978{
1979    int i;
1980
1981    for (i = 0; i < dev->nregions; i++) {
1982        VuDevRegion *r = &dev->regions[i];
1983        void *m = (void *) (uintptr_t) r->mmap_addr;
1984        if (m != MAP_FAILED) {
1985            munmap(m, r->size + r->mmap_offset);
1986        }
1987    }
1988    dev->nregions = 0;
1989
1990    for (i = 0; i < dev->max_queues; i++) {
1991        VuVirtq *vq = &dev->vq[i];
1992
1993        if (vq->call_fd != -1) {
1994            close(vq->call_fd);
1995            vq->call_fd = -1;
1996        }
1997
1998        if (vq->kick_fd != -1) {
1999            dev->remove_watch(dev, vq->kick_fd);
2000            close(vq->kick_fd);
2001            vq->kick_fd = -1;
2002        }
2003
2004        if (vq->err_fd != -1) {
2005            close(vq->err_fd);
2006            vq->err_fd = -1;
2007        }
2008
2009        if (vq->resubmit_list) {
2010            free(vq->resubmit_list);
2011            vq->resubmit_list = NULL;
2012        }
2013
2014        vq->inflight = NULL;
2015    }
2016
2017    if (dev->inflight_info.addr) {
2018        munmap(dev->inflight_info.addr, dev->inflight_info.size);
2019        dev->inflight_info.addr = NULL;
2020    }
2021
2022    if (dev->inflight_info.fd > 0) {
2023        close(dev->inflight_info.fd);
2024        dev->inflight_info.fd = -1;
2025    }
2026
2027    vu_close_log(dev);
2028    if (dev->slave_fd != -1) {
2029        close(dev->slave_fd);
2030        dev->slave_fd = -1;
2031    }
2032    pthread_mutex_destroy(&dev->slave_mutex);
2033
2034    if (dev->sock != -1) {
2035        close(dev->sock);
2036    }
2037
2038    free(dev->vq);
2039    dev->vq = NULL;
2040}
2041
2042bool
2043vu_init(VuDev *dev,
2044        uint16_t max_queues,
2045        int socket,
2046        vu_panic_cb panic,
2047        vu_read_msg_cb read_msg,
2048        vu_set_watch_cb set_watch,
2049        vu_remove_watch_cb remove_watch,
2050        const VuDevIface *iface)
2051{
2052    uint16_t i;
2053
2054    assert(max_queues > 0);
2055    assert(socket >= 0);
2056    assert(set_watch);
2057    assert(remove_watch);
2058    assert(iface);
2059    assert(panic);
2060
2061    memset(dev, 0, sizeof(*dev));
2062
2063    dev->sock = socket;
2064    dev->panic = panic;
2065    dev->read_msg = read_msg ? read_msg : vu_message_read_default;
2066    dev->set_watch = set_watch;
2067    dev->remove_watch = remove_watch;
2068    dev->iface = iface;
2069    dev->log_call_fd = -1;
2070    pthread_mutex_init(&dev->slave_mutex, NULL);
2071    dev->slave_fd = -1;
2072    dev->max_queues = max_queues;
2073
2074    dev->vq = malloc(max_queues * sizeof(dev->vq[0]));
2075    if (!dev->vq) {
2076        DPRINT("%s: failed to malloc virtqueues\n", __func__);
2077        return false;
2078    }
2079
2080    for (i = 0; i < max_queues; i++) {
2081        dev->vq[i] = (VuVirtq) {
2082            .call_fd = -1, .kick_fd = -1, .err_fd = -1,
2083            .notification = true,
2084        };
2085    }
2086
2087    return true;
2088}
2089
2090VuVirtq *
2091vu_get_queue(VuDev *dev, int qidx)
2092{
2093    assert(qidx < dev->max_queues);
2094    return &dev->vq[qidx];
2095}
2096
2097bool
2098vu_queue_enabled(VuDev *dev, VuVirtq *vq)
2099{
2100    return vq->enable;
2101}
2102
2103bool
2104vu_queue_started(const VuDev *dev, const VuVirtq *vq)
2105{
2106    return vq->started;
2107}
2108
2109static inline uint16_t
2110vring_avail_flags(VuVirtq *vq)
2111{
2112    return le16toh(vq->vring.avail->flags);
2113}
2114
2115static inline uint16_t
2116vring_avail_idx(VuVirtq *vq)
2117{
2118    vq->shadow_avail_idx = le16toh(vq->vring.avail->idx);
2119
2120    return vq->shadow_avail_idx;
2121}
2122
2123static inline uint16_t
2124vring_avail_ring(VuVirtq *vq, int i)
2125{
2126    return le16toh(vq->vring.avail->ring[i]);
2127}
2128
2129static inline uint16_t
2130vring_get_used_event(VuVirtq *vq)
2131{
2132    return vring_avail_ring(vq, vq->vring.num);
2133}
2134
2135static int
2136virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
2137{
2138    uint16_t num_heads = vring_avail_idx(vq) - idx;
2139
2140    /* Check it isn't doing very strange things with descriptor numbers. */
2141    if (num_heads > vq->vring.num) {
2142        vu_panic(dev, "Guest moved used index from %u to %u",
2143                 idx, vq->shadow_avail_idx);
2144        return -1;
2145    }
2146    if (num_heads) {
2147        /* On success, callers read a descriptor at vq->last_avail_idx.
2148         * Make sure descriptor read does not bypass avail index read. */
2149        smp_rmb();
2150    }
2151
2152    return num_heads;
2153}
2154
2155static bool
2156virtqueue_get_head(VuDev *dev, VuVirtq *vq,
2157                   unsigned int idx, unsigned int *head)
2158{
2159    /* Grab the next descriptor number they're advertising, and increment
2160     * the index we've seen. */
2161    *head = vring_avail_ring(vq, idx % vq->vring.num);
2162
2163    /* If their number is silly, that's a fatal mistake. */
2164    if (*head >= vq->vring.num) {
2165        vu_panic(dev, "Guest says index %u is available", *head);
2166        return false;
2167    }
2168
2169    return true;
2170}
2171
2172static int
2173virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc,
2174                             uint64_t addr, size_t len)
2175{
2176    struct vring_desc *ori_desc;
2177    uint64_t read_len;
2178
2179    if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
2180        return -1;
2181    }
2182
2183    if (len == 0) {
2184        return -1;
2185    }
2186
2187    while (len) {
2188        read_len = len;
2189        ori_desc = vu_gpa_to_va(dev, &read_len, addr);
2190        if (!ori_desc) {
2191            return -1;
2192        }
2193
2194        memcpy(desc, ori_desc, read_len);
2195        len -= read_len;
2196        addr += read_len;
2197        desc += read_len;
2198    }
2199
2200    return 0;
2201}
2202
2203enum {
2204    VIRTQUEUE_READ_DESC_ERROR = -1,
2205    VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
2206    VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
2207};
2208
2209static int
2210virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
2211                         int i, unsigned int max, unsigned int *next)
2212{
2213    /* If this descriptor says it doesn't chain, we're done. */
2214    if (!(le16toh(desc[i].flags) & VRING_DESC_F_NEXT)) {
2215        return VIRTQUEUE_READ_DESC_DONE;
2216    }
2217
2218    /* Check they're not leading us off end of descriptors. */
2219    *next = le16toh(desc[i].next);
2220    /* Make sure compiler knows to grab that: we don't want it changing! */
2221    smp_wmb();
2222
2223    if (*next >= max) {
2224        vu_panic(dev, "Desc next is %u", *next);
2225        return VIRTQUEUE_READ_DESC_ERROR;
2226    }
2227
2228    return VIRTQUEUE_READ_DESC_MORE;
2229}
2230
2231void
2232vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
2233                         unsigned int *out_bytes,
2234                         unsigned max_in_bytes, unsigned max_out_bytes)
2235{
2236    unsigned int idx;
2237    unsigned int total_bufs, in_total, out_total;
2238    int rc;
2239
2240    idx = vq->last_avail_idx;
2241
2242    total_bufs = in_total = out_total = 0;
2243    if (unlikely(dev->broken) ||
2244        unlikely(!vq->vring.avail)) {
2245        goto done;
2246    }
2247
2248    while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
2249        unsigned int max, desc_len, num_bufs, indirect = 0;
2250        uint64_t desc_addr, read_len;
2251        struct vring_desc *desc;
2252        struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2253        unsigned int i;
2254
2255        max = vq->vring.num;
2256        num_bufs = total_bufs;
2257        if (!virtqueue_get_head(dev, vq, idx++, &i)) {
2258            goto err;
2259        }
2260        desc = vq->vring.desc;
2261
2262        if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2263            if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2264                vu_panic(dev, "Invalid size for indirect buffer table");
2265                goto err;
2266            }
2267
2268            /* If we've got too many, that implies a descriptor loop. */
2269            if (num_bufs >= max) {
2270                vu_panic(dev, "Looped descriptor");
2271                goto err;
2272            }
2273
2274            /* loop over the indirect descriptor table */
2275            indirect = 1;
2276            desc_addr = le64toh(desc[i].addr);
2277            desc_len = le32toh(desc[i].len);
2278            max = desc_len / sizeof(struct vring_desc);
2279            read_len = desc_len;
2280            desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2281            if (unlikely(desc && read_len != desc_len)) {
2282                /* Failed to use zero copy */
2283                desc = NULL;
2284                if (!virtqueue_read_indirect_desc(dev, desc_buf,
2285                                                  desc_addr,
2286                                                  desc_len)) {
2287                    desc = desc_buf;
2288                }
2289            }
2290            if (!desc) {
2291                vu_panic(dev, "Invalid indirect buffer table");
2292                goto err;
2293            }
2294            num_bufs = i = 0;
2295        }
2296
2297        do {
2298            /* If we've got too many, that implies a descriptor loop. */
2299            if (++num_bufs > max) {
2300                vu_panic(dev, "Looped descriptor");
2301                goto err;
2302            }
2303
2304            if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2305                in_total += le32toh(desc[i].len);
2306            } else {
2307                out_total += le32toh(desc[i].len);
2308            }
2309            if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
2310                goto done;
2311            }
2312            rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2313        } while (rc == VIRTQUEUE_READ_DESC_MORE);
2314
2315        if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2316            goto err;
2317        }
2318
2319        if (!indirect) {
2320            total_bufs = num_bufs;
2321        } else {
2322            total_bufs++;
2323        }
2324    }
2325    if (rc < 0) {
2326        goto err;
2327    }
2328done:
2329    if (in_bytes) {
2330        *in_bytes = in_total;
2331    }
2332    if (out_bytes) {
2333        *out_bytes = out_total;
2334    }
2335    return;
2336
2337err:
2338    in_total = out_total = 0;
2339    goto done;
2340}
2341
2342bool
2343vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
2344                     unsigned int out_bytes)
2345{
2346    unsigned int in_total, out_total;
2347
2348    vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
2349                             in_bytes, out_bytes);
2350
2351    return in_bytes <= in_total && out_bytes <= out_total;
2352}
2353
2354/* Fetch avail_idx from VQ memory only when we really need to know if
2355 * guest has added some buffers. */
2356bool
2357vu_queue_empty(VuDev *dev, VuVirtq *vq)
2358{
2359    if (unlikely(dev->broken) ||
2360        unlikely(!vq->vring.avail)) {
2361        return true;
2362    }
2363
2364    if (vq->shadow_avail_idx != vq->last_avail_idx) {
2365        return false;
2366    }
2367
2368    return vring_avail_idx(vq) == vq->last_avail_idx;
2369}
2370
2371static bool
2372vring_notify(VuDev *dev, VuVirtq *vq)
2373{
2374    uint16_t old, new;
2375    bool v;
2376
2377    /* We need to expose used array entries before checking used event. */
2378    smp_mb();
2379
2380    /* Always notify when queue is empty (when feature acknowledge) */
2381    if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2382        !vq->inuse && vu_queue_empty(dev, vq)) {
2383        return true;
2384    }
2385
2386    if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2387        return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
2388    }
2389
2390    v = vq->signalled_used_valid;
2391    vq->signalled_used_valid = true;
2392    old = vq->signalled_used;
2393    new = vq->signalled_used = vq->used_idx;
2394    return !v || vring_need_event(vring_get_used_event(vq), new, old);
2395}
2396
2397static void _vu_queue_notify(VuDev *dev, VuVirtq *vq, bool sync)
2398{
2399    if (unlikely(dev->broken) ||
2400        unlikely(!vq->vring.avail)) {
2401        return;
2402    }
2403
2404    if (!vring_notify(dev, vq)) {
2405        DPRINT("skipped notify...\n");
2406        return;
2407    }
2408
2409    if (vq->call_fd < 0 &&
2410        vu_has_protocol_feature(dev,
2411                                VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
2412        vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
2413        VhostUserMsg vmsg = {
2414            .request = VHOST_USER_SLAVE_VRING_CALL,
2415            .flags = VHOST_USER_VERSION,
2416            .size = sizeof(vmsg.payload.state),
2417            .payload.state = {
2418                .index = vq - dev->vq,
2419            },
2420        };
2421        bool ack = sync &&
2422                   vu_has_protocol_feature(dev,
2423                                           VHOST_USER_PROTOCOL_F_REPLY_ACK);
2424
2425        if (ack) {
2426            vmsg.flags |= VHOST_USER_NEED_REPLY_MASK;
2427        }
2428
2429        vu_message_write(dev, dev->slave_fd, &vmsg);
2430        if (ack) {
2431            vu_message_read_default(dev, dev->slave_fd, &vmsg);
2432        }
2433        return;
2434    }
2435
2436    if (eventfd_write(vq->call_fd, 1) < 0) {
2437        vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
2438    }
2439}
2440
2441void vu_queue_notify(VuDev *dev, VuVirtq *vq)
2442{
2443    _vu_queue_notify(dev, vq, false);
2444}
2445
2446void vu_queue_notify_sync(VuDev *dev, VuVirtq *vq)
2447{
2448    _vu_queue_notify(dev, vq, true);
2449}
2450
2451static inline void
2452vring_used_flags_set_bit(VuVirtq *vq, int mask)
2453{
2454    uint16_t *flags;
2455
2456    flags = (uint16_t *)((char*)vq->vring.used +
2457                         offsetof(struct vring_used, flags));
2458    *flags = htole16(le16toh(*flags) | mask);
2459}
2460
2461static inline void
2462vring_used_flags_unset_bit(VuVirtq *vq, int mask)
2463{
2464    uint16_t *flags;
2465
2466    flags = (uint16_t *)((char*)vq->vring.used +
2467                         offsetof(struct vring_used, flags));
2468    *flags = htole16(le16toh(*flags) & ~mask);
2469}
2470
2471static inline void
2472vring_set_avail_event(VuVirtq *vq, uint16_t val)
2473{
2474    uint16_t *avail;
2475
2476    if (!vq->notification) {
2477        return;
2478    }
2479
2480    avail = (uint16_t *)&vq->vring.used->ring[vq->vring.num];
2481    *avail = htole16(val);
2482}
2483
2484void
2485vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
2486{
2487    vq->notification = enable;
2488    if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2489        vring_set_avail_event(vq, vring_avail_idx(vq));
2490    } else if (enable) {
2491        vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
2492    } else {
2493        vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
2494    }
2495    if (enable) {
2496        /* Expose avail event/used flags before caller checks the avail idx. */
2497        smp_mb();
2498    }
2499}
2500
2501static bool
2502virtqueue_map_desc(VuDev *dev,
2503                   unsigned int *p_num_sg, struct iovec *iov,
2504                   unsigned int max_num_sg, bool is_write,
2505                   uint64_t pa, size_t sz)
2506{
2507    unsigned num_sg = *p_num_sg;
2508
2509    assert(num_sg <= max_num_sg);
2510
2511    if (!sz) {
2512        vu_panic(dev, "virtio: zero sized buffers are not allowed");
2513        return false;
2514    }
2515
2516    while (sz) {
2517        uint64_t len = sz;
2518
2519        if (num_sg == max_num_sg) {
2520            vu_panic(dev, "virtio: too many descriptors in indirect table");
2521            return false;
2522        }
2523
2524        iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa);
2525        if (iov[num_sg].iov_base == NULL) {
2526            vu_panic(dev, "virtio: invalid address for buffers");
2527            return false;
2528        }
2529        iov[num_sg].iov_len = len;
2530        num_sg++;
2531        sz -= len;
2532        pa += len;
2533    }
2534
2535    *p_num_sg = num_sg;
2536    return true;
2537}
2538
2539static void *
2540virtqueue_alloc_element(size_t sz,
2541                                     unsigned out_num, unsigned in_num)
2542{
2543    VuVirtqElement *elem;
2544    size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
2545    size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
2546    size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
2547
2548    assert(sz >= sizeof(VuVirtqElement));
2549    elem = malloc(out_sg_end);
2550    if (!elem) {
2551        DPRINT("%s: failed to malloc virtqueue element\n", __func__);
2552        return NULL;
2553    }
2554    elem->out_num = out_num;
2555    elem->in_num = in_num;
2556    elem->in_sg = (void *)elem + in_sg_ofs;
2557    elem->out_sg = (void *)elem + out_sg_ofs;
2558    return elem;
2559}
2560
2561static void *
2562vu_queue_map_desc(VuDev *dev, VuVirtq *vq, unsigned int idx, size_t sz)
2563{
2564    struct vring_desc *desc = vq->vring.desc;
2565    uint64_t desc_addr, read_len;
2566    unsigned int desc_len;
2567    unsigned int max = vq->vring.num;
2568    unsigned int i = idx;
2569    VuVirtqElement *elem;
2570    unsigned int out_num = 0, in_num = 0;
2571    struct iovec iov[VIRTQUEUE_MAX_SIZE];
2572    struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2573    int rc;
2574
2575    if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2576        if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2577            vu_panic(dev, "Invalid size for indirect buffer table");
2578            return NULL;
2579        }
2580
2581        /* loop over the indirect descriptor table */
2582        desc_addr = le64toh(desc[i].addr);
2583        desc_len = le32toh(desc[i].len);
2584        max = desc_len / sizeof(struct vring_desc);
2585        read_len = desc_len;
2586        desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2587        if (unlikely(desc && read_len != desc_len)) {
2588            /* Failed to use zero copy */
2589            desc = NULL;
2590            if (!virtqueue_read_indirect_desc(dev, desc_buf,
2591                                              desc_addr,
2592                                              desc_len)) {
2593                desc = desc_buf;
2594            }
2595        }
2596        if (!desc) {
2597            vu_panic(dev, "Invalid indirect buffer table");
2598            return NULL;
2599        }
2600        i = 0;
2601    }
2602
2603    /* Collect all the descriptors */
2604    do {
2605        if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2606            if (!virtqueue_map_desc(dev, &in_num, iov + out_num,
2607                               VIRTQUEUE_MAX_SIZE - out_num, true,
2608                               le64toh(desc[i].addr),
2609                               le32toh(desc[i].len))) {
2610                return NULL;
2611            }
2612        } else {
2613            if (in_num) {
2614                vu_panic(dev, "Incorrect order for descriptors");
2615                return NULL;
2616            }
2617            if (!virtqueue_map_desc(dev, &out_num, iov,
2618                               VIRTQUEUE_MAX_SIZE, false,
2619                               le64toh(desc[i].addr),
2620                               le32toh(desc[i].len))) {
2621                return NULL;
2622            }
2623        }
2624
2625        /* If we've got too many, that implies a descriptor loop. */
2626        if ((in_num + out_num) > max) {
2627            vu_panic(dev, "Looped descriptor");
2628            return NULL;
2629        }
2630        rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2631    } while (rc == VIRTQUEUE_READ_DESC_MORE);
2632
2633    if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2634        vu_panic(dev, "read descriptor error");
2635        return NULL;
2636    }
2637
2638    /* Now copy what we have collected and mapped */
2639    elem = virtqueue_alloc_element(sz, out_num, in_num);
2640    if (!elem) {
2641        return NULL;
2642    }
2643    elem->index = idx;
2644    for (i = 0; i < out_num; i++) {
2645        elem->out_sg[i] = iov[i];
2646    }
2647    for (i = 0; i < in_num; i++) {
2648        elem->in_sg[i] = iov[out_num + i];
2649    }
2650
2651    return elem;
2652}
2653
2654static int
2655vu_queue_inflight_get(VuDev *dev, VuVirtq *vq, int desc_idx)
2656{
2657    if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2658        return 0;
2659    }
2660
2661    if (unlikely(!vq->inflight)) {
2662        return -1;
2663    }
2664
2665    vq->inflight->desc[desc_idx].counter = vq->counter++;
2666    vq->inflight->desc[desc_idx].inflight = 1;
2667
2668    return 0;
2669}
2670
2671static int
2672vu_queue_inflight_pre_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2673{
2674    if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2675        return 0;
2676    }
2677
2678    if (unlikely(!vq->inflight)) {
2679        return -1;
2680    }
2681
2682    vq->inflight->last_batch_head = desc_idx;
2683
2684    return 0;
2685}
2686
2687static int
2688vu_queue_inflight_post_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2689{
2690    if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2691        return 0;
2692    }
2693
2694    if (unlikely(!vq->inflight)) {
2695        return -1;
2696    }
2697
2698    barrier();
2699
2700    vq->inflight->desc[desc_idx].inflight = 0;
2701
2702    barrier();
2703
2704    vq->inflight->used_idx = vq->used_idx;
2705
2706    return 0;
2707}
2708
2709void *
2710vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
2711{
2712    int i;
2713    unsigned int head;
2714    VuVirtqElement *elem;
2715
2716    if (unlikely(dev->broken) ||
2717        unlikely(!vq->vring.avail)) {
2718        return NULL;
2719    }
2720
2721    if (unlikely(vq->resubmit_list && vq->resubmit_num > 0)) {
2722        i = (--vq->resubmit_num);
2723        elem = vu_queue_map_desc(dev, vq, vq->resubmit_list[i].index, sz);
2724
2725        if (!vq->resubmit_num) {
2726            free(vq->resubmit_list);
2727            vq->resubmit_list = NULL;
2728        }
2729
2730        return elem;
2731    }
2732
2733    if (vu_queue_empty(dev, vq)) {
2734        return NULL;
2735    }
2736    /*
2737     * Needed after virtio_queue_empty(), see comment in
2738     * virtqueue_num_heads().
2739     */
2740    smp_rmb();
2741
2742    if (vq->inuse >= vq->vring.num) {
2743        vu_panic(dev, "Virtqueue size exceeded");
2744        return NULL;
2745    }
2746
2747    if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
2748        return NULL;
2749    }
2750
2751    if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2752        vring_set_avail_event(vq, vq->last_avail_idx);
2753    }
2754
2755    elem = vu_queue_map_desc(dev, vq, head, sz);
2756
2757    if (!elem) {
2758        return NULL;
2759    }
2760
2761    vq->inuse++;
2762
2763    vu_queue_inflight_get(dev, vq, head);
2764
2765    return elem;
2766}
2767
2768static void
2769vu_queue_detach_element(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2770                        size_t len)
2771{
2772    vq->inuse--;
2773    /* unmap, when DMA support is added */
2774}
2775
2776void
2777vu_queue_unpop(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2778               size_t len)
2779{
2780    vq->last_avail_idx--;
2781    vu_queue_detach_element(dev, vq, elem, len);
2782}
2783
2784bool
2785vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
2786{
2787    if (num > vq->inuse) {
2788        return false;
2789    }
2790    vq->last_avail_idx -= num;
2791    vq->inuse -= num;
2792    return true;
2793}
2794
2795static inline
2796void vring_used_write(VuDev *dev, VuVirtq *vq,
2797                      struct vring_used_elem *uelem, int i)
2798{
2799    struct vring_used *used = vq->vring.used;
2800
2801    used->ring[i] = *uelem;
2802    vu_log_write(dev, vq->vring.log_guest_addr +
2803                 offsetof(struct vring_used, ring[i]),
2804                 sizeof(used->ring[i]));
2805}
2806
2807
2808static void
2809vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
2810                  const VuVirtqElement *elem,
2811                  unsigned int len)
2812{
2813    struct vring_desc *desc = vq->vring.desc;
2814    unsigned int i, max, min, desc_len;
2815    uint64_t desc_addr, read_len;
2816    struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2817    unsigned num_bufs = 0;
2818
2819    max = vq->vring.num;
2820    i = elem->index;
2821
2822    if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2823        if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2824            vu_panic(dev, "Invalid size for indirect buffer table");
2825            return;
2826        }
2827
2828        /* loop over the indirect descriptor table */
2829        desc_addr = le64toh(desc[i].addr);
2830        desc_len = le32toh(desc[i].len);
2831        max = desc_len / sizeof(struct vring_desc);
2832        read_len = desc_len;
2833        desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2834        if (unlikely(desc && read_len != desc_len)) {
2835            /* Failed to use zero copy */
2836            desc = NULL;
2837            if (!virtqueue_read_indirect_desc(dev, desc_buf,
2838                                              desc_addr,
2839                                              desc_len)) {
2840                desc = desc_buf;
2841            }
2842        }
2843        if (!desc) {
2844            vu_panic(dev, "Invalid indirect buffer table");
2845            return;
2846        }
2847        i = 0;
2848    }
2849
2850    do {
2851        if (++num_bufs > max) {
2852            vu_panic(dev, "Looped descriptor");
2853            return;
2854        }
2855
2856        if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2857            min = MIN(le32toh(desc[i].len), len);
2858            vu_log_write(dev, le64toh(desc[i].addr), min);
2859            len -= min;
2860        }
2861
2862    } while (len > 0 &&
2863             (virtqueue_read_next_desc(dev, desc, i, max, &i)
2864              == VIRTQUEUE_READ_DESC_MORE));
2865}
2866
2867void
2868vu_queue_fill(VuDev *dev, VuVirtq *vq,
2869              const VuVirtqElement *elem,
2870              unsigned int len, unsigned int idx)
2871{
2872    struct vring_used_elem uelem;
2873
2874    if (unlikely(dev->broken) ||
2875        unlikely(!vq->vring.avail)) {
2876        return;
2877    }
2878
2879    vu_log_queue_fill(dev, vq, elem, len);
2880
2881    idx = (idx + vq->used_idx) % vq->vring.num;
2882
2883    uelem.id = htole32(elem->index);
2884    uelem.len = htole32(len);
2885    vring_used_write(dev, vq, &uelem, idx);
2886}
2887
2888static inline
2889void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
2890{
2891    vq->vring.used->idx = htole16(val);
2892    vu_log_write(dev,
2893                 vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
2894                 sizeof(vq->vring.used->idx));
2895
2896    vq->used_idx = val;
2897}
2898
2899void
2900vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
2901{
2902    uint16_t old, new;
2903
2904    if (unlikely(dev->broken) ||
2905        unlikely(!vq->vring.avail)) {
2906        return;
2907    }
2908
2909    /* Make sure buffer is written before we update index. */
2910    smp_wmb();
2911
2912    old = vq->used_idx;
2913    new = old + count;
2914    vring_used_idx_set(dev, vq, new);
2915    vq->inuse -= count;
2916    if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
2917        vq->signalled_used_valid = false;
2918    }
2919}
2920
2921void
2922vu_queue_push(VuDev *dev, VuVirtq *vq,
2923              const VuVirtqElement *elem, unsigned int len)
2924{
2925    vu_queue_fill(dev, vq, elem, len, 0);
2926    vu_queue_inflight_pre_put(dev, vq, elem->index);
2927    vu_queue_flush(dev, vq, 1);
2928    vu_queue_inflight_post_put(dev, vq, elem->index);
2929}
2930