qemu/contrib/libvhost-user/libvhost-user.c
<<
>>
Prefs
   1/*
   2 * Vhost User library
   3 *
   4 * Copyright IBM, Corp. 2007
   5 * Copyright (c) 2016 Red Hat, Inc.
   6 *
   7 * Authors:
   8 *  Anthony Liguori <aliguori@us.ibm.com>
   9 *  Marc-André Lureau <mlureau@redhat.com>
  10 *  Victor Kaplansky <victork@redhat.com>
  11 *
  12 * This work is licensed under the terms of the GNU GPL, version 2 or
  13 * later.  See the COPYING file in the top-level directory.
  14 */
  15
  16/* this code avoids GLib dependency */
  17#include <stdlib.h>
  18#include <stdio.h>
  19#include <unistd.h>
  20#include <stdarg.h>
  21#include <errno.h>
  22#include <string.h>
  23#include <assert.h>
  24#include <inttypes.h>
  25#include <sys/types.h>
  26#include <sys/socket.h>
  27#include <sys/eventfd.h>
  28#include <sys/mman.h>
  29#include <linux/vhost.h>
  30
  31#include "qemu/compiler.h"
  32#include "qemu/atomic.h"
  33
  34#include "libvhost-user.h"
  35
  36/* usually provided by GLib */
  37#ifndef MIN
  38#define MIN(x, y) ({                            \
  39            typeof(x) _min1 = (x);              \
  40            typeof(y) _min2 = (y);              \
  41            (void) (&_min1 == &_min2);          \
  42            _min1 < _min2 ? _min1 : _min2; })
  43#endif
  44
  45#define VHOST_USER_HDR_SIZE offsetof(VhostUserMsg, payload.u64)
  46
  47/* The version of the protocol we support */
  48#define VHOST_USER_VERSION 1
  49#define LIBVHOST_USER_DEBUG 0
  50
  51#define DPRINT(...)                             \
  52    do {                                        \
  53        if (LIBVHOST_USER_DEBUG) {              \
  54            fprintf(stderr, __VA_ARGS__);        \
  55        }                                       \
  56    } while (0)
  57
  58static const char *
  59vu_request_to_string(unsigned int req)
  60{
  61#define REQ(req) [req] = #req
  62    static const char *vu_request_str[] = {
  63        REQ(VHOST_USER_NONE),
  64        REQ(VHOST_USER_GET_FEATURES),
  65        REQ(VHOST_USER_SET_FEATURES),
  66        REQ(VHOST_USER_SET_OWNER),
  67        REQ(VHOST_USER_RESET_OWNER),
  68        REQ(VHOST_USER_SET_MEM_TABLE),
  69        REQ(VHOST_USER_SET_LOG_BASE),
  70        REQ(VHOST_USER_SET_LOG_FD),
  71        REQ(VHOST_USER_SET_VRING_NUM),
  72        REQ(VHOST_USER_SET_VRING_ADDR),
  73        REQ(VHOST_USER_SET_VRING_BASE),
  74        REQ(VHOST_USER_GET_VRING_BASE),
  75        REQ(VHOST_USER_SET_VRING_KICK),
  76        REQ(VHOST_USER_SET_VRING_CALL),
  77        REQ(VHOST_USER_SET_VRING_ERR),
  78        REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
  79        REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
  80        REQ(VHOST_USER_GET_QUEUE_NUM),
  81        REQ(VHOST_USER_SET_VRING_ENABLE),
  82        REQ(VHOST_USER_SEND_RARP),
  83        REQ(VHOST_USER_NET_SET_MTU),
  84        REQ(VHOST_USER_SET_SLAVE_REQ_FD),
  85        REQ(VHOST_USER_IOTLB_MSG),
  86        REQ(VHOST_USER_SET_VRING_ENDIAN),
  87        REQ(VHOST_USER_MAX),
  88    };
  89#undef REQ
  90
  91    if (req < VHOST_USER_MAX) {
  92        return vu_request_str[req];
  93    } else {
  94        return "unknown";
  95    }
  96}
  97
  98static void
  99vu_panic(VuDev *dev, const char *msg, ...)
 100{
 101    char *buf = NULL;
 102    va_list ap;
 103
 104    va_start(ap, msg);
 105    if (vasprintf(&buf, msg, ap) < 0) {
 106        buf = NULL;
 107    }
 108    va_end(ap);
 109
 110    dev->broken = true;
 111    dev->panic(dev, buf);
 112    free(buf);
 113
 114    /* FIXME: find a way to call virtio_error? */
 115}
 116
 117/* Translate guest physical address to our virtual address.  */
 118void *
 119vu_gpa_to_va(VuDev *dev, uint64_t guest_addr)
 120{
 121    int i;
 122
 123    /* Find matching memory region.  */
 124    for (i = 0; i < dev->nregions; i++) {
 125        VuDevRegion *r = &dev->regions[i];
 126
 127        if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) {
 128            return (void *)(uintptr_t)
 129                guest_addr - r->gpa + r->mmap_addr + r->mmap_offset;
 130        }
 131    }
 132
 133    return NULL;
 134}
 135
 136/* Translate qemu virtual address to our virtual address.  */
 137static void *
 138qva_to_va(VuDev *dev, uint64_t qemu_addr)
 139{
 140    int i;
 141
 142    /* Find matching memory region.  */
 143    for (i = 0; i < dev->nregions; i++) {
 144        VuDevRegion *r = &dev->regions[i];
 145
 146        if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
 147            return (void *)(uintptr_t)
 148                qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
 149        }
 150    }
 151
 152    return NULL;
 153}
 154
 155static void
 156vmsg_close_fds(VhostUserMsg *vmsg)
 157{
 158    int i;
 159
 160    for (i = 0; i < vmsg->fd_num; i++) {
 161        close(vmsg->fds[i]);
 162    }
 163}
 164
 165static bool
 166vu_message_read(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
 167{
 168    char control[CMSG_SPACE(VHOST_MEMORY_MAX_NREGIONS * sizeof(int))] = { };
 169    struct iovec iov = {
 170        .iov_base = (char *)vmsg,
 171        .iov_len = VHOST_USER_HDR_SIZE,
 172    };
 173    struct msghdr msg = {
 174        .msg_iov = &iov,
 175        .msg_iovlen = 1,
 176        .msg_control = control,
 177        .msg_controllen = sizeof(control),
 178    };
 179    size_t fd_size;
 180    struct cmsghdr *cmsg;
 181    int rc;
 182
 183    do {
 184        rc = recvmsg(conn_fd, &msg, 0);
 185    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 186
 187    if (rc < 0) {
 188        vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
 189        return false;
 190    }
 191
 192    vmsg->fd_num = 0;
 193    for (cmsg = CMSG_FIRSTHDR(&msg);
 194         cmsg != NULL;
 195         cmsg = CMSG_NXTHDR(&msg, cmsg))
 196    {
 197        if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
 198            fd_size = cmsg->cmsg_len - CMSG_LEN(0);
 199            vmsg->fd_num = fd_size / sizeof(int);
 200            memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
 201            break;
 202        }
 203    }
 204
 205    if (vmsg->size > sizeof(vmsg->payload)) {
 206        vu_panic(dev,
 207                 "Error: too big message request: %d, size: vmsg->size: %u, "
 208                 "while sizeof(vmsg->payload) = %zu\n",
 209                 vmsg->request, vmsg->size, sizeof(vmsg->payload));
 210        goto fail;
 211    }
 212
 213    if (vmsg->size) {
 214        do {
 215            rc = read(conn_fd, &vmsg->payload, vmsg->size);
 216        } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 217
 218        if (rc <= 0) {
 219            vu_panic(dev, "Error while reading: %s", strerror(errno));
 220            goto fail;
 221        }
 222
 223        assert(rc == vmsg->size);
 224    }
 225
 226    return true;
 227
 228fail:
 229    vmsg_close_fds(vmsg);
 230
 231    return false;
 232}
 233
 234static bool
 235vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
 236{
 237    int rc;
 238    uint8_t *p = (uint8_t *)vmsg;
 239
 240    /* Set the version in the flags when sending the reply */
 241    vmsg->flags &= ~VHOST_USER_VERSION_MASK;
 242    vmsg->flags |= VHOST_USER_VERSION;
 243    vmsg->flags |= VHOST_USER_REPLY_MASK;
 244
 245    do {
 246        rc = write(conn_fd, p, VHOST_USER_HDR_SIZE);
 247    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 248
 249    do {
 250        if (vmsg->data) {
 251            rc = write(conn_fd, vmsg->data, vmsg->size);
 252        } else {
 253            rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
 254        }
 255    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
 256
 257    if (rc <= 0) {
 258        vu_panic(dev, "Error while writing: %s", strerror(errno));
 259        return false;
 260    }
 261
 262    return true;
 263}
 264
 265/* Kick the log_call_fd if required. */
 266static void
 267vu_log_kick(VuDev *dev)
 268{
 269    if (dev->log_call_fd != -1) {
 270        DPRINT("Kicking the QEMU's log...\n");
 271        if (eventfd_write(dev->log_call_fd, 1) < 0) {
 272            vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
 273        }
 274    }
 275}
 276
 277static void
 278vu_log_page(uint8_t *log_table, uint64_t page)
 279{
 280    DPRINT("Logged dirty guest page: %"PRId64"\n", page);
 281    atomic_or(&log_table[page / 8], 1 << (page % 8));
 282}
 283
 284static void
 285vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
 286{
 287    uint64_t page;
 288
 289    if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
 290        !dev->log_table || !length) {
 291        return;
 292    }
 293
 294    assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
 295
 296    page = address / VHOST_LOG_PAGE;
 297    while (page * VHOST_LOG_PAGE < address + length) {
 298        vu_log_page(dev->log_table, page);
 299        page += VHOST_LOG_PAGE;
 300    }
 301
 302    vu_log_kick(dev);
 303}
 304
 305static void
 306vu_kick_cb(VuDev *dev, int condition, void *data)
 307{
 308    int index = (intptr_t)data;
 309    VuVirtq *vq = &dev->vq[index];
 310    int sock = vq->kick_fd;
 311    eventfd_t kick_data;
 312    ssize_t rc;
 313
 314    rc = eventfd_read(sock, &kick_data);
 315    if (rc == -1) {
 316        vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
 317        dev->remove_watch(dev, dev->vq[index].kick_fd);
 318    } else {
 319        DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
 320               kick_data, vq->handler, index);
 321        if (vq->handler) {
 322            vq->handler(dev, index);
 323        }
 324    }
 325}
 326
 327static bool
 328vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 329{
 330    vmsg->payload.u64 =
 331        1ULL << VHOST_F_LOG_ALL |
 332        1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
 333
 334    if (dev->iface->get_features) {
 335        vmsg->payload.u64 |= dev->iface->get_features(dev);
 336    }
 337
 338    vmsg->size = sizeof(vmsg->payload.u64);
 339
 340    DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 341
 342    return true;
 343}
 344
 345static void
 346vu_set_enable_all_rings(VuDev *dev, bool enabled)
 347{
 348    int i;
 349
 350    for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
 351        dev->vq[i].enable = enabled;
 352    }
 353}
 354
 355static bool
 356vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 357{
 358    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 359
 360    dev->features = vmsg->payload.u64;
 361
 362    if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
 363        vu_set_enable_all_rings(dev, true);
 364    }
 365
 366    if (dev->iface->set_features) {
 367        dev->iface->set_features(dev, dev->features);
 368    }
 369
 370    return false;
 371}
 372
 373static bool
 374vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
 375{
 376    return false;
 377}
 378
 379static void
 380vu_close_log(VuDev *dev)
 381{
 382    if (dev->log_table) {
 383        if (munmap(dev->log_table, dev->log_size) != 0) {
 384            perror("close log munmap() error");
 385        }
 386
 387        dev->log_table = NULL;
 388    }
 389    if (dev->log_call_fd != -1) {
 390        close(dev->log_call_fd);
 391        dev->log_call_fd = -1;
 392    }
 393}
 394
 395static bool
 396vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
 397{
 398    vu_set_enable_all_rings(dev, false);
 399
 400    return false;
 401}
 402
 403static bool
 404vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
 405{
 406    int i;
 407    VhostUserMemory *memory = &vmsg->payload.memory;
 408    dev->nregions = memory->nregions;
 409
 410    DPRINT("Nregions: %d\n", memory->nregions);
 411    for (i = 0; i < dev->nregions; i++) {
 412        void *mmap_addr;
 413        VhostUserMemoryRegion *msg_region = &memory->regions[i];
 414        VuDevRegion *dev_region = &dev->regions[i];
 415
 416        DPRINT("Region %d\n", i);
 417        DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
 418               msg_region->guest_phys_addr);
 419        DPRINT("    memory_size:     0x%016"PRIx64"\n",
 420               msg_region->memory_size);
 421        DPRINT("    userspace_addr   0x%016"PRIx64"\n",
 422               msg_region->userspace_addr);
 423        DPRINT("    mmap_offset      0x%016"PRIx64"\n",
 424               msg_region->mmap_offset);
 425
 426        dev_region->gpa = msg_region->guest_phys_addr;
 427        dev_region->size = msg_region->memory_size;
 428        dev_region->qva = msg_region->userspace_addr;
 429        dev_region->mmap_offset = msg_region->mmap_offset;
 430
 431        /* We don't use offset argument of mmap() since the
 432         * mapped address has to be page aligned, and we use huge
 433         * pages.  */
 434        mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
 435                         PROT_READ | PROT_WRITE, MAP_SHARED,
 436                         vmsg->fds[i], 0);
 437
 438        if (mmap_addr == MAP_FAILED) {
 439            vu_panic(dev, "region mmap error: %s", strerror(errno));
 440        } else {
 441            dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
 442            DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
 443                   dev_region->mmap_addr);
 444        }
 445
 446        close(vmsg->fds[i]);
 447    }
 448
 449    return false;
 450}
 451
 452static bool
 453vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
 454{
 455    int fd;
 456    uint64_t log_mmap_size, log_mmap_offset;
 457    void *rc;
 458
 459    if (vmsg->fd_num != 1 ||
 460        vmsg->size != sizeof(vmsg->payload.log)) {
 461        vu_panic(dev, "Invalid log_base message");
 462        return true;
 463    }
 464
 465    fd = vmsg->fds[0];
 466    log_mmap_offset = vmsg->payload.log.mmap_offset;
 467    log_mmap_size = vmsg->payload.log.mmap_size;
 468    DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
 469    DPRINT("Log mmap_size:   %"PRId64"\n", log_mmap_size);
 470
 471    rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
 472              log_mmap_offset);
 473    if (rc == MAP_FAILED) {
 474        perror("log mmap error");
 475    }
 476    dev->log_table = rc;
 477    dev->log_size = log_mmap_size;
 478
 479    vmsg->size = sizeof(vmsg->payload.u64);
 480
 481    return true;
 482}
 483
 484static bool
 485vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
 486{
 487    if (vmsg->fd_num != 1) {
 488        vu_panic(dev, "Invalid log_fd message");
 489        return false;
 490    }
 491
 492    if (dev->log_call_fd != -1) {
 493        close(dev->log_call_fd);
 494    }
 495    dev->log_call_fd = vmsg->fds[0];
 496    DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
 497
 498    return false;
 499}
 500
 501static bool
 502vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
 503{
 504    unsigned int index = vmsg->payload.state.index;
 505    unsigned int num = vmsg->payload.state.num;
 506
 507    DPRINT("State.index: %d\n", index);
 508    DPRINT("State.num:   %d\n", num);
 509    dev->vq[index].vring.num = num;
 510
 511    return false;
 512}
 513
 514static bool
 515vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
 516{
 517    struct vhost_vring_addr *vra = &vmsg->payload.addr;
 518    unsigned int index = vra->index;
 519    VuVirtq *vq = &dev->vq[index];
 520
 521    DPRINT("vhost_vring_addr:\n");
 522    DPRINT("    index:  %d\n", vra->index);
 523    DPRINT("    flags:  %d\n", vra->flags);
 524    DPRINT("    desc_user_addr:   0x%016llx\n", vra->desc_user_addr);
 525    DPRINT("    used_user_addr:   0x%016llx\n", vra->used_user_addr);
 526    DPRINT("    avail_user_addr:  0x%016llx\n", vra->avail_user_addr);
 527    DPRINT("    log_guest_addr:   0x%016llx\n", vra->log_guest_addr);
 528
 529    vq->vring.flags = vra->flags;
 530    vq->vring.desc = qva_to_va(dev, vra->desc_user_addr);
 531    vq->vring.used = qva_to_va(dev, vra->used_user_addr);
 532    vq->vring.avail = qva_to_va(dev, vra->avail_user_addr);
 533    vq->vring.log_guest_addr = vra->log_guest_addr;
 534
 535    DPRINT("Setting virtq addresses:\n");
 536    DPRINT("    vring_desc  at %p\n", vq->vring.desc);
 537    DPRINT("    vring_used  at %p\n", vq->vring.used);
 538    DPRINT("    vring_avail at %p\n", vq->vring.avail);
 539
 540    if (!(vq->vring.desc && vq->vring.used && vq->vring.avail)) {
 541        vu_panic(dev, "Invalid vring_addr message");
 542        return false;
 543    }
 544
 545    vq->used_idx = vq->vring.used->idx;
 546
 547    if (vq->last_avail_idx != vq->used_idx) {
 548        bool resume = dev->iface->queue_is_processed_in_order &&
 549            dev->iface->queue_is_processed_in_order(dev, index);
 550
 551        DPRINT("Last avail index != used index: %u != %u%s\n",
 552               vq->last_avail_idx, vq->used_idx,
 553               resume ? ", resuming" : "");
 554
 555        if (resume) {
 556            vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
 557        }
 558    }
 559
 560    return false;
 561}
 562
 563static bool
 564vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
 565{
 566    unsigned int index = vmsg->payload.state.index;
 567    unsigned int num = vmsg->payload.state.num;
 568
 569    DPRINT("State.index: %d\n", index);
 570    DPRINT("State.num:   %d\n", num);
 571    dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
 572
 573    return false;
 574}
 575
 576static bool
 577vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
 578{
 579    unsigned int index = vmsg->payload.state.index;
 580
 581    DPRINT("State.index: %d\n", index);
 582    vmsg->payload.state.num = dev->vq[index].last_avail_idx;
 583    vmsg->size = sizeof(vmsg->payload.state);
 584
 585    dev->vq[index].started = false;
 586    if (dev->iface->queue_set_started) {
 587        dev->iface->queue_set_started(dev, index, false);
 588    }
 589
 590    if (dev->vq[index].call_fd != -1) {
 591        close(dev->vq[index].call_fd);
 592        dev->vq[index].call_fd = -1;
 593    }
 594    if (dev->vq[index].kick_fd != -1) {
 595        dev->remove_watch(dev, dev->vq[index].kick_fd);
 596        close(dev->vq[index].kick_fd);
 597        dev->vq[index].kick_fd = -1;
 598    }
 599
 600    return true;
 601}
 602
 603static bool
 604vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
 605{
 606    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
 607
 608    if (index >= VHOST_MAX_NR_VIRTQUEUE) {
 609        vmsg_close_fds(vmsg);
 610        vu_panic(dev, "Invalid queue index: %u", index);
 611        return false;
 612    }
 613
 614    if (vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK ||
 615        vmsg->fd_num != 1) {
 616        vmsg_close_fds(vmsg);
 617        vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
 618        return false;
 619    }
 620
 621    return true;
 622}
 623
 624static bool
 625vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
 626{
 627    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
 628
 629    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 630
 631    if (!vu_check_queue_msg_file(dev, vmsg)) {
 632        return false;
 633    }
 634
 635    if (dev->vq[index].kick_fd != -1) {
 636        dev->remove_watch(dev, dev->vq[index].kick_fd);
 637        close(dev->vq[index].kick_fd);
 638        dev->vq[index].kick_fd = -1;
 639    }
 640
 641    if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
 642        dev->vq[index].kick_fd = vmsg->fds[0];
 643        DPRINT("Got kick_fd: %d for vq: %d\n", vmsg->fds[0], index);
 644    }
 645
 646    dev->vq[index].started = true;
 647    if (dev->iface->queue_set_started) {
 648        dev->iface->queue_set_started(dev, index, true);
 649    }
 650
 651    if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
 652        dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
 653                       vu_kick_cb, (void *)(long)index);
 654
 655        DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
 656               dev->vq[index].kick_fd, index);
 657    }
 658
 659    return false;
 660}
 661
 662void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
 663                          vu_queue_handler_cb handler)
 664{
 665    int qidx = vq - dev->vq;
 666
 667    vq->handler = handler;
 668    if (vq->kick_fd >= 0) {
 669        if (handler) {
 670            dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
 671                           vu_kick_cb, (void *)(long)qidx);
 672        } else {
 673            dev->remove_watch(dev, vq->kick_fd);
 674        }
 675    }
 676}
 677
 678static bool
 679vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
 680{
 681    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
 682
 683    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 684
 685    if (!vu_check_queue_msg_file(dev, vmsg)) {
 686        return false;
 687    }
 688
 689    if (dev->vq[index].call_fd != -1) {
 690        close(dev->vq[index].call_fd);
 691        dev->vq[index].call_fd = -1;
 692    }
 693
 694    if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
 695        dev->vq[index].call_fd = vmsg->fds[0];
 696    }
 697
 698    DPRINT("Got call_fd: %d for vq: %d\n", vmsg->fds[0], index);
 699
 700    return false;
 701}
 702
 703static bool
 704vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
 705{
 706    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
 707
 708    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
 709
 710    if (!vu_check_queue_msg_file(dev, vmsg)) {
 711        return false;
 712    }
 713
 714    if (dev->vq[index].err_fd != -1) {
 715        close(dev->vq[index].err_fd);
 716        dev->vq[index].err_fd = -1;
 717    }
 718
 719    if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
 720        dev->vq[index].err_fd = vmsg->fds[0];
 721    }
 722
 723    return false;
 724}
 725
 726static bool
 727vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 728{
 729    uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
 730                        1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ;
 731
 732    if (dev->iface->get_protocol_features) {
 733        features |= dev->iface->get_protocol_features(dev);
 734    }
 735
 736    vmsg->payload.u64 = features;
 737    vmsg->size = sizeof(vmsg->payload.u64);
 738
 739    return true;
 740}
 741
 742static bool
 743vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
 744{
 745    uint64_t features = vmsg->payload.u64;
 746
 747    DPRINT("u64: 0x%016"PRIx64"\n", features);
 748
 749    dev->protocol_features = vmsg->payload.u64;
 750
 751    if (dev->iface->set_protocol_features) {
 752        dev->iface->set_protocol_features(dev, features);
 753    }
 754
 755    return false;
 756}
 757
 758static bool
 759vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
 760{
 761    DPRINT("Function %s() not implemented yet.\n", __func__);
 762    return false;
 763}
 764
 765static bool
 766vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
 767{
 768    unsigned int index = vmsg->payload.state.index;
 769    unsigned int enable = vmsg->payload.state.num;
 770
 771    DPRINT("State.index: %d\n", index);
 772    DPRINT("State.enable:   %d\n", enable);
 773
 774    if (index >= VHOST_MAX_NR_VIRTQUEUE) {
 775        vu_panic(dev, "Invalid vring_enable index: %u", index);
 776        return false;
 777    }
 778
 779    dev->vq[index].enable = enable;
 780    return false;
 781}
 782
 783static bool
 784vu_set_slave_req_fd(VuDev *dev, VhostUserMsg *vmsg)
 785{
 786    if (vmsg->fd_num != 1) {
 787        vu_panic(dev, "Invalid slave_req_fd message (%d fd's)", vmsg->fd_num);
 788        return false;
 789    }
 790
 791    if (dev->slave_fd != -1) {
 792        close(dev->slave_fd);
 793    }
 794    dev->slave_fd = vmsg->fds[0];
 795    DPRINT("Got slave_fd: %d\n", vmsg->fds[0]);
 796
 797    return false;
 798}
 799
 800static bool
 801vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
 802{
 803    int do_reply = 0;
 804
 805    /* Print out generic part of the request. */
 806    DPRINT("================ Vhost user message ================\n");
 807    DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
 808           vmsg->request);
 809    DPRINT("Flags:   0x%x\n", vmsg->flags);
 810    DPRINT("Size:    %d\n", vmsg->size);
 811
 812    if (vmsg->fd_num) {
 813        int i;
 814        DPRINT("Fds:");
 815        for (i = 0; i < vmsg->fd_num; i++) {
 816            DPRINT(" %d", vmsg->fds[i]);
 817        }
 818        DPRINT("\n");
 819    }
 820
 821    if (dev->iface->process_msg &&
 822        dev->iface->process_msg(dev, vmsg, &do_reply)) {
 823        return do_reply;
 824    }
 825
 826    switch (vmsg->request) {
 827    case VHOST_USER_GET_FEATURES:
 828        return vu_get_features_exec(dev, vmsg);
 829    case VHOST_USER_SET_FEATURES:
 830        return vu_set_features_exec(dev, vmsg);
 831    case VHOST_USER_GET_PROTOCOL_FEATURES:
 832        return vu_get_protocol_features_exec(dev, vmsg);
 833    case VHOST_USER_SET_PROTOCOL_FEATURES:
 834        return vu_set_protocol_features_exec(dev, vmsg);
 835    case VHOST_USER_SET_OWNER:
 836        return vu_set_owner_exec(dev, vmsg);
 837    case VHOST_USER_RESET_OWNER:
 838        return vu_reset_device_exec(dev, vmsg);
 839    case VHOST_USER_SET_MEM_TABLE:
 840        return vu_set_mem_table_exec(dev, vmsg);
 841    case VHOST_USER_SET_LOG_BASE:
 842        return vu_set_log_base_exec(dev, vmsg);
 843    case VHOST_USER_SET_LOG_FD:
 844        return vu_set_log_fd_exec(dev, vmsg);
 845    case VHOST_USER_SET_VRING_NUM:
 846        return vu_set_vring_num_exec(dev, vmsg);
 847    case VHOST_USER_SET_VRING_ADDR:
 848        return vu_set_vring_addr_exec(dev, vmsg);
 849    case VHOST_USER_SET_VRING_BASE:
 850        return vu_set_vring_base_exec(dev, vmsg);
 851    case VHOST_USER_GET_VRING_BASE:
 852        return vu_get_vring_base_exec(dev, vmsg);
 853    case VHOST_USER_SET_VRING_KICK:
 854        return vu_set_vring_kick_exec(dev, vmsg);
 855    case VHOST_USER_SET_VRING_CALL:
 856        return vu_set_vring_call_exec(dev, vmsg);
 857    case VHOST_USER_SET_VRING_ERR:
 858        return vu_set_vring_err_exec(dev, vmsg);
 859    case VHOST_USER_GET_QUEUE_NUM:
 860        return vu_get_queue_num_exec(dev, vmsg);
 861    case VHOST_USER_SET_VRING_ENABLE:
 862        return vu_set_vring_enable_exec(dev, vmsg);
 863    case VHOST_USER_SET_SLAVE_REQ_FD:
 864        return vu_set_slave_req_fd(dev, vmsg);
 865    case VHOST_USER_NONE:
 866        break;
 867    default:
 868        vmsg_close_fds(vmsg);
 869        vu_panic(dev, "Unhandled request: %d", vmsg->request);
 870    }
 871
 872    return false;
 873}
 874
 875bool
 876vu_dispatch(VuDev *dev)
 877{
 878    VhostUserMsg vmsg = { 0, };
 879    int reply_requested;
 880    bool success = false;
 881
 882    if (!vu_message_read(dev, dev->sock, &vmsg)) {
 883        goto end;
 884    }
 885
 886    reply_requested = vu_process_message(dev, &vmsg);
 887    if (!reply_requested) {
 888        success = true;
 889        goto end;
 890    }
 891
 892    if (!vu_message_write(dev, dev->sock, &vmsg)) {
 893        goto end;
 894    }
 895
 896    success = true;
 897
 898end:
 899    free(vmsg.data);
 900    return success;
 901}
 902
 903void
 904vu_deinit(VuDev *dev)
 905{
 906    int i;
 907
 908    for (i = 0; i < dev->nregions; i++) {
 909        VuDevRegion *r = &dev->regions[i];
 910        void *m = (void *) (uintptr_t) r->mmap_addr;
 911        if (m != MAP_FAILED) {
 912            munmap(m, r->size + r->mmap_offset);
 913        }
 914    }
 915    dev->nregions = 0;
 916
 917    for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
 918        VuVirtq *vq = &dev->vq[i];
 919
 920        if (vq->call_fd != -1) {
 921            close(vq->call_fd);
 922            vq->call_fd = -1;
 923        }
 924
 925        if (vq->kick_fd != -1) {
 926            close(vq->kick_fd);
 927            vq->kick_fd = -1;
 928        }
 929
 930        if (vq->err_fd != -1) {
 931            close(vq->err_fd);
 932            vq->err_fd = -1;
 933        }
 934    }
 935
 936
 937    vu_close_log(dev);
 938    if (dev->slave_fd != -1) {
 939        close(dev->slave_fd);
 940        dev->slave_fd = -1;
 941    }
 942
 943    if (dev->sock != -1) {
 944        close(dev->sock);
 945    }
 946}
 947
 948void
 949vu_init(VuDev *dev,
 950        int socket,
 951        vu_panic_cb panic,
 952        vu_set_watch_cb set_watch,
 953        vu_remove_watch_cb remove_watch,
 954        const VuDevIface *iface)
 955{
 956    int i;
 957
 958    assert(socket >= 0);
 959    assert(set_watch);
 960    assert(remove_watch);
 961    assert(iface);
 962    assert(panic);
 963
 964    memset(dev, 0, sizeof(*dev));
 965
 966    dev->sock = socket;
 967    dev->panic = panic;
 968    dev->set_watch = set_watch;
 969    dev->remove_watch = remove_watch;
 970    dev->iface = iface;
 971    dev->log_call_fd = -1;
 972    dev->slave_fd = -1;
 973    for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
 974        dev->vq[i] = (VuVirtq) {
 975            .call_fd = -1, .kick_fd = -1, .err_fd = -1,
 976            .notification = true,
 977        };
 978    }
 979}
 980
 981VuVirtq *
 982vu_get_queue(VuDev *dev, int qidx)
 983{
 984    assert(qidx < VHOST_MAX_NR_VIRTQUEUE);
 985    return &dev->vq[qidx];
 986}
 987
 988bool
 989vu_queue_enabled(VuDev *dev, VuVirtq *vq)
 990{
 991    return vq->enable;
 992}
 993
 994bool
 995vu_queue_started(const VuDev *dev, const VuVirtq *vq)
 996{
 997    return vq->started;
 998}
 999
1000static inline uint16_t
1001vring_avail_flags(VuVirtq *vq)
1002{
1003    return vq->vring.avail->flags;
1004}
1005
1006static inline uint16_t
1007vring_avail_idx(VuVirtq *vq)
1008{
1009    vq->shadow_avail_idx = vq->vring.avail->idx;
1010
1011    return vq->shadow_avail_idx;
1012}
1013
1014static inline uint16_t
1015vring_avail_ring(VuVirtq *vq, int i)
1016{
1017    return vq->vring.avail->ring[i];
1018}
1019
1020static inline uint16_t
1021vring_get_used_event(VuVirtq *vq)
1022{
1023    return vring_avail_ring(vq, vq->vring.num);
1024}
1025
1026static int
1027virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
1028{
1029    uint16_t num_heads = vring_avail_idx(vq) - idx;
1030
1031    /* Check it isn't doing very strange things with descriptor numbers. */
1032    if (num_heads > vq->vring.num) {
1033        vu_panic(dev, "Guest moved used index from %u to %u",
1034                 idx, vq->shadow_avail_idx);
1035        return -1;
1036    }
1037    if (num_heads) {
1038        /* On success, callers read a descriptor at vq->last_avail_idx.
1039         * Make sure descriptor read does not bypass avail index read. */
1040        smp_rmb();
1041    }
1042
1043    return num_heads;
1044}
1045
1046static bool
1047virtqueue_get_head(VuDev *dev, VuVirtq *vq,
1048                   unsigned int idx, unsigned int *head)
1049{
1050    /* Grab the next descriptor number they're advertising, and increment
1051     * the index we've seen. */
1052    *head = vring_avail_ring(vq, idx % vq->vring.num);
1053
1054    /* If their number is silly, that's a fatal mistake. */
1055    if (*head >= vq->vring.num) {
1056        vu_panic(dev, "Guest says index %u is available", head);
1057        return false;
1058    }
1059
1060    return true;
1061}
1062
1063enum {
1064    VIRTQUEUE_READ_DESC_ERROR = -1,
1065    VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
1066    VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
1067};
1068
1069static int
1070virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
1071                         int i, unsigned int max, unsigned int *next)
1072{
1073    /* If this descriptor says it doesn't chain, we're done. */
1074    if (!(desc[i].flags & VRING_DESC_F_NEXT)) {
1075        return VIRTQUEUE_READ_DESC_DONE;
1076    }
1077
1078    /* Check they're not leading us off end of descriptors. */
1079    *next = desc[i].next;
1080    /* Make sure compiler knows to grab that: we don't want it changing! */
1081    smp_wmb();
1082
1083    if (*next >= max) {
1084        vu_panic(dev, "Desc next is %u", next);
1085        return VIRTQUEUE_READ_DESC_ERROR;
1086    }
1087
1088    return VIRTQUEUE_READ_DESC_MORE;
1089}
1090
1091void
1092vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
1093                         unsigned int *out_bytes,
1094                         unsigned max_in_bytes, unsigned max_out_bytes)
1095{
1096    unsigned int idx;
1097    unsigned int total_bufs, in_total, out_total;
1098    int rc;
1099
1100    idx = vq->last_avail_idx;
1101
1102    total_bufs = in_total = out_total = 0;
1103    if (unlikely(dev->broken) ||
1104        unlikely(!vq->vring.avail)) {
1105        goto done;
1106    }
1107
1108    while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
1109        unsigned int max, num_bufs, indirect = 0;
1110        struct vring_desc *desc;
1111        unsigned int i;
1112
1113        max = vq->vring.num;
1114        num_bufs = total_bufs;
1115        if (!virtqueue_get_head(dev, vq, idx++, &i)) {
1116            goto err;
1117        }
1118        desc = vq->vring.desc;
1119
1120        if (desc[i].flags & VRING_DESC_F_INDIRECT) {
1121            if (desc[i].len % sizeof(struct vring_desc)) {
1122                vu_panic(dev, "Invalid size for indirect buffer table");
1123                goto err;
1124            }
1125
1126            /* If we've got too many, that implies a descriptor loop. */
1127            if (num_bufs >= max) {
1128                vu_panic(dev, "Looped descriptor");
1129                goto err;
1130            }
1131
1132            /* loop over the indirect descriptor table */
1133            indirect = 1;
1134            max = desc[i].len / sizeof(struct vring_desc);
1135            desc = vu_gpa_to_va(dev, desc[i].addr);
1136            num_bufs = i = 0;
1137        }
1138
1139        do {
1140            /* If we've got too many, that implies a descriptor loop. */
1141            if (++num_bufs > max) {
1142                vu_panic(dev, "Looped descriptor");
1143                goto err;
1144            }
1145
1146            if (desc[i].flags & VRING_DESC_F_WRITE) {
1147                in_total += desc[i].len;
1148            } else {
1149                out_total += desc[i].len;
1150            }
1151            if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
1152                goto done;
1153            }
1154            rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
1155        } while (rc == VIRTQUEUE_READ_DESC_MORE);
1156
1157        if (rc == VIRTQUEUE_READ_DESC_ERROR) {
1158            goto err;
1159        }
1160
1161        if (!indirect) {
1162            total_bufs = num_bufs;
1163        } else {
1164            total_bufs++;
1165        }
1166    }
1167    if (rc < 0) {
1168        goto err;
1169    }
1170done:
1171    if (in_bytes) {
1172        *in_bytes = in_total;
1173    }
1174    if (out_bytes) {
1175        *out_bytes = out_total;
1176    }
1177    return;
1178
1179err:
1180    in_total = out_total = 0;
1181    goto done;
1182}
1183
1184bool
1185vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
1186                     unsigned int out_bytes)
1187{
1188    unsigned int in_total, out_total;
1189
1190    vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
1191                             in_bytes, out_bytes);
1192
1193    return in_bytes <= in_total && out_bytes <= out_total;
1194}
1195
1196/* Fetch avail_idx from VQ memory only when we really need to know if
1197 * guest has added some buffers. */
1198bool
1199vu_queue_empty(VuDev *dev, VuVirtq *vq)
1200{
1201    if (unlikely(dev->broken) ||
1202        unlikely(!vq->vring.avail)) {
1203        return true;
1204    }
1205
1206    if (vq->shadow_avail_idx != vq->last_avail_idx) {
1207        return false;
1208    }
1209
1210    return vring_avail_idx(vq) == vq->last_avail_idx;
1211}
1212
1213static inline
1214bool has_feature(uint64_t features, unsigned int fbit)
1215{
1216    assert(fbit < 64);
1217    return !!(features & (1ULL << fbit));
1218}
1219
1220static inline
1221bool vu_has_feature(VuDev *dev,
1222                    unsigned int fbit)
1223{
1224    return has_feature(dev->features, fbit);
1225}
1226
1227static bool
1228vring_notify(VuDev *dev, VuVirtq *vq)
1229{
1230    uint16_t old, new;
1231    bool v;
1232
1233    /* We need to expose used array entries before checking used event. */
1234    smp_mb();
1235
1236    /* Always notify when queue is empty (when feature acknowledge) */
1237    if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
1238        !vq->inuse && vu_queue_empty(dev, vq)) {
1239        return true;
1240    }
1241
1242    if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
1243        return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
1244    }
1245
1246    v = vq->signalled_used_valid;
1247    vq->signalled_used_valid = true;
1248    old = vq->signalled_used;
1249    new = vq->signalled_used = vq->used_idx;
1250    return !v || vring_need_event(vring_get_used_event(vq), new, old);
1251}
1252
1253void
1254vu_queue_notify(VuDev *dev, VuVirtq *vq)
1255{
1256    if (unlikely(dev->broken) ||
1257        unlikely(!vq->vring.avail)) {
1258        return;
1259    }
1260
1261    if (!vring_notify(dev, vq)) {
1262        DPRINT("skipped notify...\n");
1263        return;
1264    }
1265
1266    if (eventfd_write(vq->call_fd, 1) < 0) {
1267        vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
1268    }
1269}
1270
1271static inline void
1272vring_used_flags_set_bit(VuVirtq *vq, int mask)
1273{
1274    uint16_t *flags;
1275
1276    flags = (uint16_t *)((char*)vq->vring.used +
1277                         offsetof(struct vring_used, flags));
1278    *flags |= mask;
1279}
1280
1281static inline void
1282vring_used_flags_unset_bit(VuVirtq *vq, int mask)
1283{
1284    uint16_t *flags;
1285
1286    flags = (uint16_t *)((char*)vq->vring.used +
1287                         offsetof(struct vring_used, flags));
1288    *flags &= ~mask;
1289}
1290
1291static inline void
1292vring_set_avail_event(VuVirtq *vq, uint16_t val)
1293{
1294    if (!vq->notification) {
1295        return;
1296    }
1297
1298    *((uint16_t *) &vq->vring.used->ring[vq->vring.num]) = val;
1299}
1300
1301void
1302vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
1303{
1304    vq->notification = enable;
1305    if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
1306        vring_set_avail_event(vq, vring_avail_idx(vq));
1307    } else if (enable) {
1308        vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
1309    } else {
1310        vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
1311    }
1312    if (enable) {
1313        /* Expose avail event/used flags before caller checks the avail idx. */
1314        smp_mb();
1315    }
1316}
1317
1318static void
1319virtqueue_map_desc(VuDev *dev,
1320                   unsigned int *p_num_sg, struct iovec *iov,
1321                   unsigned int max_num_sg, bool is_write,
1322                   uint64_t pa, size_t sz)
1323{
1324    unsigned num_sg = *p_num_sg;
1325
1326    assert(num_sg <= max_num_sg);
1327
1328    if (!sz) {
1329        vu_panic(dev, "virtio: zero sized buffers are not allowed");
1330        return;
1331    }
1332
1333    iov[num_sg].iov_base = vu_gpa_to_va(dev, pa);
1334    iov[num_sg].iov_len = sz;
1335    num_sg++;
1336
1337    *p_num_sg = num_sg;
1338}
1339
1340/* Round number down to multiple */
1341#define ALIGN_DOWN(n, m) ((n) / (m) * (m))
1342
1343/* Round number up to multiple */
1344#define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
1345
1346static void *
1347virtqueue_alloc_element(size_t sz,
1348                                     unsigned out_num, unsigned in_num)
1349{
1350    VuVirtqElement *elem;
1351    size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
1352    size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
1353    size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
1354
1355    assert(sz >= sizeof(VuVirtqElement));
1356    elem = malloc(out_sg_end);
1357    elem->out_num = out_num;
1358    elem->in_num = in_num;
1359    elem->in_sg = (void *)elem + in_sg_ofs;
1360    elem->out_sg = (void *)elem + out_sg_ofs;
1361    return elem;
1362}
1363
1364void *
1365vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
1366{
1367    unsigned int i, head, max;
1368    VuVirtqElement *elem;
1369    unsigned out_num, in_num;
1370    struct iovec iov[VIRTQUEUE_MAX_SIZE];
1371    struct vring_desc *desc;
1372    int rc;
1373
1374    if (unlikely(dev->broken) ||
1375        unlikely(!vq->vring.avail)) {
1376        return NULL;
1377    }
1378
1379    if (vu_queue_empty(dev, vq)) {
1380        return NULL;
1381    }
1382    /* Needed after virtio_queue_empty(), see comment in
1383     * virtqueue_num_heads(). */
1384    smp_rmb();
1385
1386    /* When we start there are none of either input nor output. */
1387    out_num = in_num = 0;
1388
1389    max = vq->vring.num;
1390    if (vq->inuse >= vq->vring.num) {
1391        vu_panic(dev, "Virtqueue size exceeded");
1392        return NULL;
1393    }
1394
1395    if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
1396        return NULL;
1397    }
1398
1399    if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
1400        vring_set_avail_event(vq, vq->last_avail_idx);
1401    }
1402
1403    i = head;
1404    desc = vq->vring.desc;
1405    if (desc[i].flags & VRING_DESC_F_INDIRECT) {
1406        if (desc[i].len % sizeof(struct vring_desc)) {
1407            vu_panic(dev, "Invalid size for indirect buffer table");
1408        }
1409
1410        /* loop over the indirect descriptor table */
1411        max = desc[i].len / sizeof(struct vring_desc);
1412        desc = vu_gpa_to_va(dev, desc[i].addr);
1413        i = 0;
1414    }
1415
1416    /* Collect all the descriptors */
1417    do {
1418        if (desc[i].flags & VRING_DESC_F_WRITE) {
1419            virtqueue_map_desc(dev, &in_num, iov + out_num,
1420                               VIRTQUEUE_MAX_SIZE - out_num, true,
1421                               desc[i].addr, desc[i].len);
1422        } else {
1423            if (in_num) {
1424                vu_panic(dev, "Incorrect order for descriptors");
1425                return NULL;
1426            }
1427            virtqueue_map_desc(dev, &out_num, iov,
1428                               VIRTQUEUE_MAX_SIZE, false,
1429                               desc[i].addr, desc[i].len);
1430        }
1431
1432        /* If we've got too many, that implies a descriptor loop. */
1433        if ((in_num + out_num) > max) {
1434            vu_panic(dev, "Looped descriptor");
1435        }
1436        rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
1437    } while (rc == VIRTQUEUE_READ_DESC_MORE);
1438
1439    if (rc == VIRTQUEUE_READ_DESC_ERROR) {
1440        return NULL;
1441    }
1442
1443    /* Now copy what we have collected and mapped */
1444    elem = virtqueue_alloc_element(sz, out_num, in_num);
1445    elem->index = head;
1446    for (i = 0; i < out_num; i++) {
1447        elem->out_sg[i] = iov[i];
1448    }
1449    for (i = 0; i < in_num; i++) {
1450        elem->in_sg[i] = iov[out_num + i];
1451    }
1452
1453    vq->inuse++;
1454
1455    return elem;
1456}
1457
1458bool
1459vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
1460{
1461    if (num > vq->inuse) {
1462        return false;
1463    }
1464    vq->last_avail_idx -= num;
1465    vq->inuse -= num;
1466    return true;
1467}
1468
1469static inline
1470void vring_used_write(VuDev *dev, VuVirtq *vq,
1471                      struct vring_used_elem *uelem, int i)
1472{
1473    struct vring_used *used = vq->vring.used;
1474
1475    used->ring[i] = *uelem;
1476    vu_log_write(dev, vq->vring.log_guest_addr +
1477                 offsetof(struct vring_used, ring[i]),
1478                 sizeof(used->ring[i]));
1479}
1480
1481
1482static void
1483vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
1484                  const VuVirtqElement *elem,
1485                  unsigned int len)
1486{
1487    struct vring_desc *desc = vq->vring.desc;
1488    unsigned int i, max, min;
1489    unsigned num_bufs = 0;
1490
1491    max = vq->vring.num;
1492    i = elem->index;
1493
1494    if (desc[i].flags & VRING_DESC_F_INDIRECT) {
1495        if (desc[i].len % sizeof(struct vring_desc)) {
1496            vu_panic(dev, "Invalid size for indirect buffer table");
1497        }
1498
1499        /* loop over the indirect descriptor table */
1500        max = desc[i].len / sizeof(struct vring_desc);
1501        desc = vu_gpa_to_va(dev, desc[i].addr);
1502        i = 0;
1503    }
1504
1505    do {
1506        if (++num_bufs > max) {
1507            vu_panic(dev, "Looped descriptor");
1508            return;
1509        }
1510
1511        if (desc[i].flags & VRING_DESC_F_WRITE) {
1512            min = MIN(desc[i].len, len);
1513            vu_log_write(dev, desc[i].addr, min);
1514            len -= min;
1515        }
1516
1517    } while (len > 0 &&
1518             (virtqueue_read_next_desc(dev, desc, i, max, &i)
1519              == VIRTQUEUE_READ_DESC_MORE));
1520}
1521
1522void
1523vu_queue_fill(VuDev *dev, VuVirtq *vq,
1524              const VuVirtqElement *elem,
1525              unsigned int len, unsigned int idx)
1526{
1527    struct vring_used_elem uelem;
1528
1529    if (unlikely(dev->broken) ||
1530        unlikely(!vq->vring.avail)) {
1531        return;
1532    }
1533
1534    vu_log_queue_fill(dev, vq, elem, len);
1535
1536    idx = (idx + vq->used_idx) % vq->vring.num;
1537
1538    uelem.id = elem->index;
1539    uelem.len = len;
1540    vring_used_write(dev, vq, &uelem, idx);
1541}
1542
1543static inline
1544void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
1545{
1546    vq->vring.used->idx = val;
1547    vu_log_write(dev,
1548                 vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
1549                 sizeof(vq->vring.used->idx));
1550
1551    vq->used_idx = val;
1552}
1553
1554void
1555vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
1556{
1557    uint16_t old, new;
1558
1559    if (unlikely(dev->broken) ||
1560        unlikely(!vq->vring.avail)) {
1561        return;
1562    }
1563
1564    /* Make sure buffer is written before we update index. */
1565    smp_wmb();
1566
1567    old = vq->used_idx;
1568    new = old + count;
1569    vring_used_idx_set(dev, vq, new);
1570    vq->inuse -= count;
1571    if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
1572        vq->signalled_used_valid = false;
1573    }
1574}
1575
1576void
1577vu_queue_push(VuDev *dev, VuVirtq *vq,
1578              const VuVirtqElement *elem, unsigned int len)
1579{
1580    vu_queue_fill(dev, vq, elem, len, 0);
1581    vu_queue_flush(dev, vq, 1);
1582}
1583