qemu/nbd.c
<<
>>
Prefs
   1/*
   2 *  Copyright (C) 2005  Anthony Liguori <anthony@codemonkey.ws>
   3 *
   4 *  Network Block Device
   5 *
   6 *  This program is free software; you can redistribute it and/or modify
   7 *  it under the terms of the GNU General Public License as published by
   8 *  the Free Software Foundation; under version 2 of the License.
   9 *
  10 *  This program is distributed in the hope that it will be useful,
  11 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
  12 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  13 *  GNU General Public License for more details.
  14 *
  15 *  You should have received a copy of the GNU General Public License
  16 *  along with this program; if not, see <http://www.gnu.org/licenses/>.
  17 */
  18
  19#include "block/nbd.h"
  20#include "block/block.h"
  21
  22#include "block/coroutine.h"
  23
  24#include <errno.h>
  25#include <string.h>
  26#ifndef _WIN32
  27#include <sys/ioctl.h>
  28#endif
  29#if defined(__sun__) || defined(__HAIKU__)
  30#include <sys/ioccom.h>
  31#endif
  32#include <ctype.h>
  33#include <inttypes.h>
  34
  35#ifdef __linux__
  36#include <linux/fs.h>
  37#endif
  38
  39#include "qemu/sockets.h"
  40#include "qemu/queue.h"
  41
  42//#define DEBUG_NBD
  43
  44#ifdef DEBUG_NBD
  45#define TRACE(msg, ...) do { \
  46    LOG(msg, ## __VA_ARGS__); \
  47} while(0)
  48#else
  49#define TRACE(msg, ...) \
  50    do { } while (0)
  51#endif
  52
  53#define LOG(msg, ...) do { \
  54    fprintf(stderr, "%s:%s():L%d: " msg "\n", \
  55            __FILE__, __FUNCTION__, __LINE__, ## __VA_ARGS__); \
  56} while(0)
  57
  58/* This is all part of the "official" NBD API */
  59
  60#define NBD_REQUEST_SIZE        (4 + 4 + 8 + 8 + 4)
  61#define NBD_REPLY_SIZE          (4 + 4 + 8)
  62#define NBD_REQUEST_MAGIC       0x25609513
  63#define NBD_REPLY_MAGIC         0x67446698
  64#define NBD_OPTS_MAGIC          0x49484156454F5054LL
  65#define NBD_CLIENT_MAGIC        0x0000420281861253LL
  66
  67#define NBD_SET_SOCK            _IO(0xab, 0)
  68#define NBD_SET_BLKSIZE         _IO(0xab, 1)
  69#define NBD_SET_SIZE            _IO(0xab, 2)
  70#define NBD_DO_IT               _IO(0xab, 3)
  71#define NBD_CLEAR_SOCK          _IO(0xab, 4)
  72#define NBD_CLEAR_QUE           _IO(0xab, 5)
  73#define NBD_PRINT_DEBUG         _IO(0xab, 6)
  74#define NBD_SET_SIZE_BLOCKS     _IO(0xab, 7)
  75#define NBD_DISCONNECT          _IO(0xab, 8)
  76#define NBD_SET_TIMEOUT         _IO(0xab, 9)
  77#define NBD_SET_FLAGS           _IO(0xab, 10)
  78
  79#define NBD_OPT_EXPORT_NAME     (1 << 0)
  80
  81/* Definitions for opaque data types */
  82
  83typedef struct NBDRequest NBDRequest;
  84
  85struct NBDRequest {
  86    QSIMPLEQ_ENTRY(NBDRequest) entry;
  87    NBDClient *client;
  88    uint8_t *data;
  89};
  90
  91struct NBDExport {
  92    int refcount;
  93    void (*close)(NBDExport *exp);
  94
  95    BlockDriverState *bs;
  96    char *name;
  97    off_t dev_offset;
  98    off_t size;
  99    uint32_t nbdflags;
 100    QTAILQ_HEAD(, NBDClient) clients;
 101    QSIMPLEQ_HEAD(, NBDRequest) requests;
 102    QTAILQ_ENTRY(NBDExport) next;
 103};
 104
 105static QTAILQ_HEAD(, NBDExport) exports = QTAILQ_HEAD_INITIALIZER(exports);
 106
 107struct NBDClient {
 108    int refcount;
 109    void (*close)(NBDClient *client);
 110
 111    NBDExport *exp;
 112    int sock;
 113
 114    Coroutine *recv_coroutine;
 115
 116    CoMutex send_lock;
 117    Coroutine *send_coroutine;
 118
 119    QTAILQ_ENTRY(NBDClient) next;
 120    int nb_requests;
 121    bool closing;
 122};
 123
 124/* That's all folks */
 125
 126ssize_t nbd_wr_sync(int fd, void *buffer, size_t size, bool do_read)
 127{
 128    size_t offset = 0;
 129    int err;
 130
 131    if (qemu_in_coroutine()) {
 132        if (do_read) {
 133            return qemu_co_recv(fd, buffer, size);
 134        } else {
 135            return qemu_co_send(fd, buffer, size);
 136        }
 137    }
 138
 139    while (offset < size) {
 140        ssize_t len;
 141
 142        if (do_read) {
 143            len = qemu_recv(fd, buffer + offset, size - offset, 0);
 144        } else {
 145            len = send(fd, buffer + offset, size - offset, 0);
 146        }
 147
 148        if (len < 0) {
 149            err = socket_error();
 150
 151            /* recoverable error */
 152            if (err == EINTR || (offset > 0 && err == EAGAIN)) {
 153                continue;
 154            }
 155
 156            /* unrecoverable error */
 157            return -err;
 158        }
 159
 160        /* eof */
 161        if (len == 0) {
 162            break;
 163        }
 164
 165        offset += len;
 166    }
 167
 168    return offset;
 169}
 170
 171static ssize_t read_sync(int fd, void *buffer, size_t size)
 172{
 173    /* Sockets are kept in blocking mode in the negotiation phase.  After
 174     * that, a non-readable socket simply means that another thread stole
 175     * our request/reply.  Synchronization is done with recv_coroutine, so
 176     * that this is coroutine-safe.
 177     */
 178    return nbd_wr_sync(fd, buffer, size, true);
 179}
 180
 181static ssize_t write_sync(int fd, void *buffer, size_t size)
 182{
 183    int ret;
 184    do {
 185        /* For writes, we do expect the socket to be writable.  */
 186        ret = nbd_wr_sync(fd, buffer, size, false);
 187    } while (ret == -EAGAIN);
 188    return ret;
 189}
 190
 191static void combine_addr(char *buf, size_t len, const char* address,
 192                         uint16_t port)
 193{
 194    /* If the address-part contains a colon, it's an IPv6 IP so needs [] */
 195    if (strstr(address, ":")) {
 196        snprintf(buf, len, "[%s]:%u", address, port);
 197    } else {
 198        snprintf(buf, len, "%s:%u", address, port);
 199    }
 200}
 201
 202int tcp_socket_outgoing(const char *address, uint16_t port)
 203{
 204    char address_and_port[128];
 205    combine_addr(address_and_port, 128, address, port);
 206    return tcp_socket_outgoing_spec(address_and_port);
 207}
 208
 209int tcp_socket_outgoing_spec(const char *address_and_port)
 210{
 211    Error *local_err = NULL;
 212    int fd = inet_connect(address_and_port, &local_err);
 213
 214    if (local_err != NULL) {
 215        qerror_report_err(local_err);
 216        error_free(local_err);
 217    }
 218    return fd;
 219}
 220
 221int tcp_socket_incoming(const char *address, uint16_t port)
 222{
 223    char address_and_port[128];
 224    combine_addr(address_and_port, 128, address, port);
 225    return tcp_socket_incoming_spec(address_and_port);
 226}
 227
 228int tcp_socket_incoming_spec(const char *address_and_port)
 229{
 230    Error *local_err = NULL;
 231    int fd = inet_listen(address_and_port, NULL, 0, SOCK_STREAM, 0, &local_err);
 232
 233    if (local_err != NULL) {
 234        qerror_report_err(local_err);
 235        error_free(local_err);
 236    }
 237    return fd;
 238}
 239
 240int unix_socket_incoming(const char *path)
 241{
 242    Error *local_err = NULL;
 243    int fd = unix_listen(path, NULL, 0, &local_err);
 244
 245    if (local_err != NULL) {
 246        qerror_report_err(local_err);
 247        error_free(local_err);
 248    }
 249    return fd;
 250}
 251
 252int unix_socket_outgoing(const char *path)
 253{
 254    Error *local_err = NULL;
 255    int fd = unix_connect(path, &local_err);
 256
 257    if (local_err != NULL) {
 258        qerror_report_err(local_err);
 259        error_free(local_err);
 260    }
 261    return fd;
 262}
 263
 264/* Basic flow for negotiation
 265
 266   Server         Client
 267   Negotiate
 268
 269   or
 270
 271   Server         Client
 272   Negotiate #1
 273                  Option
 274   Negotiate #2
 275
 276   ----
 277
 278   followed by
 279
 280   Server         Client
 281                  Request
 282   Response
 283                  Request
 284   Response
 285                  ...
 286   ...
 287                  Request (type == 2)
 288
 289*/
 290
 291static int nbd_receive_options(NBDClient *client)
 292{
 293    int csock = client->sock;
 294    char name[256];
 295    uint32_t tmp, length;
 296    uint64_t magic;
 297    int rc;
 298
 299    /* Client sends:
 300        [ 0 ..   3]   reserved (0)
 301        [ 4 ..  11]   NBD_OPTS_MAGIC
 302        [12 ..  15]   NBD_OPT_EXPORT_NAME
 303        [16 ..  19]   length
 304        [20 ..  xx]   export name (length bytes)
 305     */
 306
 307    rc = -EINVAL;
 308    if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
 309        LOG("read failed");
 310        goto fail;
 311    }
 312    TRACE("Checking reserved");
 313    if (tmp != 0) {
 314        LOG("Bad reserved received");
 315        goto fail;
 316    }
 317
 318    if (read_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
 319        LOG("read failed");
 320        goto fail;
 321    }
 322    TRACE("Checking reserved");
 323    if (magic != be64_to_cpu(NBD_OPTS_MAGIC)) {
 324        LOG("Bad magic received");
 325        goto fail;
 326    }
 327
 328    if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
 329        LOG("read failed");
 330        goto fail;
 331    }
 332    TRACE("Checking option");
 333    if (tmp != be32_to_cpu(NBD_OPT_EXPORT_NAME)) {
 334        LOG("Bad option received");
 335        goto fail;
 336    }
 337
 338    if (read_sync(csock, &length, sizeof(length)) != sizeof(length)) {
 339        LOG("read failed");
 340        goto fail;
 341    }
 342    TRACE("Checking length");
 343    length = be32_to_cpu(length);
 344    if (length > 255) {
 345        LOG("Bad length received");
 346        goto fail;
 347    }
 348    if (read_sync(csock, name, length) != length) {
 349        LOG("read failed");
 350        goto fail;
 351    }
 352    name[length] = '\0';
 353
 354    client->exp = nbd_export_find(name);
 355    if (!client->exp) {
 356        LOG("export not found");
 357        goto fail;
 358    }
 359
 360    QTAILQ_INSERT_TAIL(&client->exp->clients, client, next);
 361    nbd_export_get(client->exp);
 362
 363    TRACE("Option negotiation succeeded.");
 364    rc = 0;
 365fail:
 366    return rc;
 367}
 368
 369static int nbd_send_negotiate(NBDClient *client)
 370{
 371    int csock = client->sock;
 372    char buf[8 + 8 + 8 + 128];
 373    int rc;
 374    const int myflags = (NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_TRIM |
 375                         NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA);
 376
 377    /* Negotiation header without options:
 378        [ 0 ..   7]   passwd       ("NBDMAGIC")
 379        [ 8 ..  15]   magic        (NBD_CLIENT_MAGIC)
 380        [16 ..  23]   size
 381        [24 ..  25]   server flags (0)
 382        [24 ..  27]   export flags
 383        [28 .. 151]   reserved     (0)
 384
 385       Negotiation header with options, part 1:
 386        [ 0 ..   7]   passwd       ("NBDMAGIC")
 387        [ 8 ..  15]   magic        (NBD_OPTS_MAGIC)
 388        [16 ..  17]   server flags (0)
 389
 390       part 2 (after options are sent):
 391        [18 ..  25]   size
 392        [26 ..  27]   export flags
 393        [28 .. 151]   reserved     (0)
 394     */
 395
 396    socket_set_block(csock);
 397    rc = -EINVAL;
 398
 399    TRACE("Beginning negotiation.");
 400    memset(buf, 0, sizeof(buf));
 401    memcpy(buf, "NBDMAGIC", 8);
 402    if (client->exp) {
 403        assert ((client->exp->nbdflags & ~65535) == 0);
 404        cpu_to_be64w((uint64_t*)(buf + 8), NBD_CLIENT_MAGIC);
 405        cpu_to_be64w((uint64_t*)(buf + 16), client->exp->size);
 406        cpu_to_be16w((uint16_t*)(buf + 26), client->exp->nbdflags | myflags);
 407    } else {
 408        cpu_to_be64w((uint64_t*)(buf + 8), NBD_OPTS_MAGIC);
 409    }
 410
 411    if (client->exp) {
 412        if (write_sync(csock, buf, sizeof(buf)) != sizeof(buf)) {
 413            LOG("write failed");
 414            goto fail;
 415        }
 416    } else {
 417        if (write_sync(csock, buf, 18) != 18) {
 418            LOG("write failed");
 419            goto fail;
 420        }
 421        rc = nbd_receive_options(client);
 422        if (rc < 0) {
 423            LOG("option negotiation failed");
 424            goto fail;
 425        }
 426
 427        assert ((client->exp->nbdflags & ~65535) == 0);
 428        cpu_to_be64w((uint64_t*)(buf + 18), client->exp->size);
 429        cpu_to_be16w((uint16_t*)(buf + 26), client->exp->nbdflags | myflags);
 430        if (write_sync(csock, buf + 18, sizeof(buf) - 18) != sizeof(buf) - 18) {
 431            LOG("write failed");
 432            goto fail;
 433        }
 434    }
 435
 436    TRACE("Negotiation succeeded.");
 437    rc = 0;
 438fail:
 439    socket_set_nonblock(csock);
 440    return rc;
 441}
 442
 443int nbd_receive_negotiate(int csock, const char *name, uint32_t *flags,
 444                          off_t *size, size_t *blocksize)
 445{
 446    char buf[256];
 447    uint64_t magic, s;
 448    uint16_t tmp;
 449    int rc;
 450
 451    TRACE("Receiving negotiation.");
 452
 453    socket_set_block(csock);
 454    rc = -EINVAL;
 455
 456    if (read_sync(csock, buf, 8) != 8) {
 457        LOG("read failed");
 458        goto fail;
 459    }
 460
 461    buf[8] = '\0';
 462    if (strlen(buf) == 0) {
 463        LOG("server connection closed");
 464        goto fail;
 465    }
 466
 467    TRACE("Magic is %c%c%c%c%c%c%c%c",
 468          qemu_isprint(buf[0]) ? buf[0] : '.',
 469          qemu_isprint(buf[1]) ? buf[1] : '.',
 470          qemu_isprint(buf[2]) ? buf[2] : '.',
 471          qemu_isprint(buf[3]) ? buf[3] : '.',
 472          qemu_isprint(buf[4]) ? buf[4] : '.',
 473          qemu_isprint(buf[5]) ? buf[5] : '.',
 474          qemu_isprint(buf[6]) ? buf[6] : '.',
 475          qemu_isprint(buf[7]) ? buf[7] : '.');
 476
 477    if (memcmp(buf, "NBDMAGIC", 8) != 0) {
 478        LOG("Invalid magic received");
 479        goto fail;
 480    }
 481
 482    if (read_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
 483        LOG("read failed");
 484        goto fail;
 485    }
 486    magic = be64_to_cpu(magic);
 487    TRACE("Magic is 0x%" PRIx64, magic);
 488
 489    if (name) {
 490        uint32_t reserved = 0;
 491        uint32_t opt;
 492        uint32_t namesize;
 493
 494        TRACE("Checking magic (opts_magic)");
 495        if (magic != NBD_OPTS_MAGIC) {
 496            LOG("Bad magic received");
 497            goto fail;
 498        }
 499        if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
 500            LOG("flags read failed");
 501            goto fail;
 502        }
 503        *flags = be16_to_cpu(tmp) << 16;
 504        /* reserved for future use */
 505        if (write_sync(csock, &reserved, sizeof(reserved)) !=
 506            sizeof(reserved)) {
 507            LOG("write failed (reserved)");
 508            goto fail;
 509        }
 510        /* write the export name */
 511        magic = cpu_to_be64(magic);
 512        if (write_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
 513            LOG("write failed (magic)");
 514            goto fail;
 515        }
 516        opt = cpu_to_be32(NBD_OPT_EXPORT_NAME);
 517        if (write_sync(csock, &opt, sizeof(opt)) != sizeof(opt)) {
 518            LOG("write failed (opt)");
 519            goto fail;
 520        }
 521        namesize = cpu_to_be32(strlen(name));
 522        if (write_sync(csock, &namesize, sizeof(namesize)) !=
 523            sizeof(namesize)) {
 524            LOG("write failed (namesize)");
 525            goto fail;
 526        }
 527        if (write_sync(csock, (char*)name, strlen(name)) != strlen(name)) {
 528            LOG("write failed (name)");
 529            goto fail;
 530        }
 531    } else {
 532        TRACE("Checking magic (cli_magic)");
 533
 534        if (magic != NBD_CLIENT_MAGIC) {
 535            LOG("Bad magic received");
 536            goto fail;
 537        }
 538    }
 539
 540    if (read_sync(csock, &s, sizeof(s)) != sizeof(s)) {
 541        LOG("read failed");
 542        goto fail;
 543    }
 544    *size = be64_to_cpu(s);
 545    *blocksize = 1024;
 546    TRACE("Size is %" PRIu64, *size);
 547
 548    if (!name) {
 549        if (read_sync(csock, flags, sizeof(*flags)) != sizeof(*flags)) {
 550            LOG("read failed (flags)");
 551            goto fail;
 552        }
 553        *flags = be32_to_cpup(flags);
 554    } else {
 555        if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
 556            LOG("read failed (tmp)");
 557            goto fail;
 558        }
 559        *flags |= be32_to_cpu(tmp);
 560    }
 561    if (read_sync(csock, &buf, 124) != 124) {
 562        LOG("read failed (buf)");
 563        goto fail;
 564    }
 565    rc = 0;
 566
 567fail:
 568    socket_set_nonblock(csock);
 569    return rc;
 570}
 571
 572#ifdef __linux__
 573int nbd_init(int fd, int csock, uint32_t flags, off_t size, size_t blocksize)
 574{
 575    TRACE("Setting NBD socket");
 576
 577    if (ioctl(fd, NBD_SET_SOCK, csock) < 0) {
 578        int serrno = errno;
 579        LOG("Failed to set NBD socket");
 580        return -serrno;
 581    }
 582
 583    TRACE("Setting block size to %lu", (unsigned long)blocksize);
 584
 585    if (ioctl(fd, NBD_SET_BLKSIZE, blocksize) < 0) {
 586        int serrno = errno;
 587        LOG("Failed setting NBD block size");
 588        return -serrno;
 589    }
 590
 591        TRACE("Setting size to %zd block(s)", (size_t)(size / blocksize));
 592
 593    if (ioctl(fd, NBD_SET_SIZE_BLOCKS, size / blocksize) < 0) {
 594        int serrno = errno;
 595        LOG("Failed setting size (in blocks)");
 596        return -serrno;
 597    }
 598
 599    if (ioctl(fd, NBD_SET_FLAGS, flags) < 0) {
 600        if (errno == ENOTTY) {
 601            int read_only = (flags & NBD_FLAG_READ_ONLY) != 0;
 602            TRACE("Setting readonly attribute");
 603
 604            if (ioctl(fd, BLKROSET, (unsigned long) &read_only) < 0) {
 605                int serrno = errno;
 606                LOG("Failed setting read-only attribute");
 607                return -serrno;
 608            }
 609        } else {
 610            int serrno = errno;
 611            LOG("Failed setting flags");
 612            return -serrno;
 613        }
 614    }
 615
 616    TRACE("Negotiation ended");
 617
 618    return 0;
 619}
 620
 621int nbd_disconnect(int fd)
 622{
 623    ioctl(fd, NBD_CLEAR_QUE);
 624    ioctl(fd, NBD_DISCONNECT);
 625    ioctl(fd, NBD_CLEAR_SOCK);
 626    return 0;
 627}
 628
 629int nbd_client(int fd)
 630{
 631    int ret;
 632    int serrno;
 633
 634    TRACE("Doing NBD loop");
 635
 636    ret = ioctl(fd, NBD_DO_IT);
 637    if (ret < 0 && errno == EPIPE) {
 638        /* NBD_DO_IT normally returns EPIPE when someone has disconnected
 639         * the socket via NBD_DISCONNECT.  We do not want to return 1 in
 640         * that case.
 641         */
 642        ret = 0;
 643    }
 644    serrno = errno;
 645
 646    TRACE("NBD loop returned %d: %s", ret, strerror(serrno));
 647
 648    TRACE("Clearing NBD queue");
 649    ioctl(fd, NBD_CLEAR_QUE);
 650
 651    TRACE("Clearing NBD socket");
 652    ioctl(fd, NBD_CLEAR_SOCK);
 653
 654    errno = serrno;
 655    return ret;
 656}
 657#else
 658int nbd_init(int fd, int csock, uint32_t flags, off_t size, size_t blocksize)
 659{
 660    return -ENOTSUP;
 661}
 662
 663int nbd_disconnect(int fd)
 664{
 665    return -ENOTSUP;
 666}
 667
 668int nbd_client(int fd)
 669{
 670    return -ENOTSUP;
 671}
 672#endif
 673
 674ssize_t nbd_send_request(int csock, struct nbd_request *request)
 675{
 676    uint8_t buf[NBD_REQUEST_SIZE];
 677    ssize_t ret;
 678
 679    cpu_to_be32w((uint32_t*)buf, NBD_REQUEST_MAGIC);
 680    cpu_to_be32w((uint32_t*)(buf + 4), request->type);
 681    cpu_to_be64w((uint64_t*)(buf + 8), request->handle);
 682    cpu_to_be64w((uint64_t*)(buf + 16), request->from);
 683    cpu_to_be32w((uint32_t*)(buf + 24), request->len);
 684
 685    TRACE("Sending request to client: "
 686          "{ .from = %" PRIu64", .len = %u, .handle = %" PRIu64", .type=%i}",
 687          request->from, request->len, request->handle, request->type);
 688
 689    ret = write_sync(csock, buf, sizeof(buf));
 690    if (ret < 0) {
 691        return ret;
 692    }
 693
 694    if (ret != sizeof(buf)) {
 695        LOG("writing to socket failed");
 696        return -EINVAL;
 697    }
 698    return 0;
 699}
 700
 701static ssize_t nbd_receive_request(int csock, struct nbd_request *request)
 702{
 703    uint8_t buf[NBD_REQUEST_SIZE];
 704    uint32_t magic;
 705    ssize_t ret;
 706
 707    ret = read_sync(csock, buf, sizeof(buf));
 708    if (ret < 0) {
 709        return ret;
 710    }
 711
 712    if (ret != sizeof(buf)) {
 713        LOG("read failed");
 714        return -EINVAL;
 715    }
 716
 717    /* Request
 718       [ 0 ..  3]   magic   (NBD_REQUEST_MAGIC)
 719       [ 4 ..  7]   type    (0 == READ, 1 == WRITE)
 720       [ 8 .. 15]   handle
 721       [16 .. 23]   from
 722       [24 .. 27]   len
 723     */
 724
 725    magic = be32_to_cpup((uint32_t*)buf);
 726    request->type  = be32_to_cpup((uint32_t*)(buf + 4));
 727    request->handle = be64_to_cpup((uint64_t*)(buf + 8));
 728    request->from  = be64_to_cpup((uint64_t*)(buf + 16));
 729    request->len   = be32_to_cpup((uint32_t*)(buf + 24));
 730
 731    TRACE("Got request: "
 732          "{ magic = 0x%x, .type = %d, from = %" PRIu64" , len = %u }",
 733          magic, request->type, request->from, request->len);
 734
 735    if (magic != NBD_REQUEST_MAGIC) {
 736        LOG("invalid magic (got 0x%x)", magic);
 737        return -EINVAL;
 738    }
 739    return 0;
 740}
 741
 742ssize_t nbd_receive_reply(int csock, struct nbd_reply *reply)
 743{
 744    uint8_t buf[NBD_REPLY_SIZE];
 745    uint32_t magic;
 746    ssize_t ret;
 747
 748    ret = read_sync(csock, buf, sizeof(buf));
 749    if (ret < 0) {
 750        return ret;
 751    }
 752
 753    if (ret != sizeof(buf)) {
 754        LOG("read failed");
 755        return -EINVAL;
 756    }
 757
 758    /* Reply
 759       [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
 760       [ 4 ..  7]    error   (0 == no error)
 761       [ 7 .. 15]    handle
 762     */
 763
 764    magic = be32_to_cpup((uint32_t*)buf);
 765    reply->error  = be32_to_cpup((uint32_t*)(buf + 4));
 766    reply->handle = be64_to_cpup((uint64_t*)(buf + 8));
 767
 768    TRACE("Got reply: "
 769          "{ magic = 0x%x, .error = %d, handle = %" PRIu64" }",
 770          magic, reply->error, reply->handle);
 771
 772    if (magic != NBD_REPLY_MAGIC) {
 773        LOG("invalid magic (got 0x%x)", magic);
 774        return -EINVAL;
 775    }
 776    return 0;
 777}
 778
 779static ssize_t nbd_send_reply(int csock, struct nbd_reply *reply)
 780{
 781    uint8_t buf[NBD_REPLY_SIZE];
 782    ssize_t ret;
 783
 784    /* Reply
 785       [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
 786       [ 4 ..  7]    error   (0 == no error)
 787       [ 7 .. 15]    handle
 788     */
 789    cpu_to_be32w((uint32_t*)buf, NBD_REPLY_MAGIC);
 790    cpu_to_be32w((uint32_t*)(buf + 4), reply->error);
 791    cpu_to_be64w((uint64_t*)(buf + 8), reply->handle);
 792
 793    TRACE("Sending response to client");
 794
 795    ret = write_sync(csock, buf, sizeof(buf));
 796    if (ret < 0) {
 797        return ret;
 798    }
 799
 800    if (ret != sizeof(buf)) {
 801        LOG("writing to socket failed");
 802        return -EINVAL;
 803    }
 804    return 0;
 805}
 806
 807#define MAX_NBD_REQUESTS 16
 808
 809void nbd_client_get(NBDClient *client)
 810{
 811    client->refcount++;
 812}
 813
 814void nbd_client_put(NBDClient *client)
 815{
 816    if (--client->refcount == 0) {
 817        /* The last reference should be dropped by client->close,
 818         * which is called by nbd_client_close.
 819         */
 820        assert(client->closing);
 821
 822        qemu_set_fd_handler2(client->sock, NULL, NULL, NULL, NULL);
 823        close(client->sock);
 824        client->sock = -1;
 825        if (client->exp) {
 826            QTAILQ_REMOVE(&client->exp->clients, client, next);
 827            nbd_export_put(client->exp);
 828        }
 829        g_free(client);
 830    }
 831}
 832
 833void nbd_client_close(NBDClient *client)
 834{
 835    if (client->closing) {
 836        return;
 837    }
 838
 839    client->closing = true;
 840
 841    /* Force requests to finish.  They will drop their own references,
 842     * then we'll close the socket and free the NBDClient.
 843     */
 844    shutdown(client->sock, 2);
 845
 846    /* Also tell the client, so that they release their reference.  */
 847    if (client->close) {
 848        client->close(client);
 849    }
 850}
 851
 852static NBDRequest *nbd_request_get(NBDClient *client)
 853{
 854    NBDRequest *req;
 855    NBDExport *exp = client->exp;
 856
 857    assert(client->nb_requests <= MAX_NBD_REQUESTS - 1);
 858    client->nb_requests++;
 859
 860    if (QSIMPLEQ_EMPTY(&exp->requests)) {
 861        req = g_malloc0(sizeof(NBDRequest));
 862        req->data = qemu_blockalign(exp->bs, NBD_BUFFER_SIZE);
 863    } else {
 864        req = QSIMPLEQ_FIRST(&exp->requests);
 865        QSIMPLEQ_REMOVE_HEAD(&exp->requests, entry);
 866    }
 867    nbd_client_get(client);
 868    req->client = client;
 869    return req;
 870}
 871
 872static void nbd_request_put(NBDRequest *req)
 873{
 874    NBDClient *client = req->client;
 875    QSIMPLEQ_INSERT_HEAD(&client->exp->requests, req, entry);
 876    if (client->nb_requests-- == MAX_NBD_REQUESTS) {
 877        qemu_notify_event();
 878    }
 879    nbd_client_put(client);
 880}
 881
 882NBDExport *nbd_export_new(BlockDriverState *bs, off_t dev_offset,
 883                          off_t size, uint32_t nbdflags,
 884                          void (*close)(NBDExport *))
 885{
 886    NBDExport *exp = g_malloc0(sizeof(NBDExport));
 887    QSIMPLEQ_INIT(&exp->requests);
 888    exp->refcount = 1;
 889    QTAILQ_INIT(&exp->clients);
 890    exp->bs = bs;
 891    exp->dev_offset = dev_offset;
 892    exp->nbdflags = nbdflags;
 893    exp->size = size == -1 ? bdrv_getlength(bs) : size;
 894    exp->close = close;
 895    return exp;
 896}
 897
 898NBDExport *nbd_export_find(const char *name)
 899{
 900    NBDExport *exp;
 901    QTAILQ_FOREACH(exp, &exports, next) {
 902        if (strcmp(name, exp->name) == 0) {
 903            return exp;
 904        }
 905    }
 906
 907    return NULL;
 908}
 909
 910void nbd_export_set_name(NBDExport *exp, const char *name)
 911{
 912    if (exp->name == name) {
 913        return;
 914    }
 915
 916    nbd_export_get(exp);
 917    if (exp->name != NULL) {
 918        g_free(exp->name);
 919        exp->name = NULL;
 920        QTAILQ_REMOVE(&exports, exp, next);
 921        nbd_export_put(exp);
 922    }
 923    if (name != NULL) {
 924        nbd_export_get(exp);
 925        exp->name = g_strdup(name);
 926        QTAILQ_INSERT_TAIL(&exports, exp, next);
 927    }
 928    nbd_export_put(exp);
 929}
 930
 931void nbd_export_close(NBDExport *exp)
 932{
 933    NBDClient *client, *next;
 934
 935    nbd_export_get(exp);
 936    QTAILQ_FOREACH_SAFE(client, &exp->clients, next, next) {
 937        nbd_client_close(client);
 938    }
 939    nbd_export_set_name(exp, NULL);
 940    nbd_export_put(exp);
 941}
 942
 943void nbd_export_get(NBDExport *exp)
 944{
 945    assert(exp->refcount > 0);
 946    exp->refcount++;
 947}
 948
 949void nbd_export_put(NBDExport *exp)
 950{
 951    assert(exp->refcount > 0);
 952    if (exp->refcount == 1) {
 953        nbd_export_close(exp);
 954    }
 955
 956    if (--exp->refcount == 0) {
 957        assert(exp->name == NULL);
 958
 959        if (exp->close) {
 960            exp->close(exp);
 961        }
 962
 963        while (!QSIMPLEQ_EMPTY(&exp->requests)) {
 964            NBDRequest *first = QSIMPLEQ_FIRST(&exp->requests);
 965            QSIMPLEQ_REMOVE_HEAD(&exp->requests, entry);
 966            qemu_vfree(first->data);
 967            g_free(first);
 968        }
 969
 970        g_free(exp);
 971    }
 972}
 973
 974BlockDriverState *nbd_export_get_blockdev(NBDExport *exp)
 975{
 976    return exp->bs;
 977}
 978
 979void nbd_export_close_all(void)
 980{
 981    NBDExport *exp, *next;
 982
 983    QTAILQ_FOREACH_SAFE(exp, &exports, next, next) {
 984        nbd_export_close(exp);
 985    }
 986}
 987
 988static int nbd_can_read(void *opaque);
 989static void nbd_read(void *opaque);
 990static void nbd_restart_write(void *opaque);
 991
 992static ssize_t nbd_co_send_reply(NBDRequest *req, struct nbd_reply *reply,
 993                                 int len)
 994{
 995    NBDClient *client = req->client;
 996    int csock = client->sock;
 997    ssize_t rc, ret;
 998
 999    qemu_co_mutex_lock(&client->send_lock);
1000    qemu_set_fd_handler2(csock, nbd_can_read, nbd_read,
1001                         nbd_restart_write, client);
1002    client->send_coroutine = qemu_coroutine_self();
1003
1004    if (!len) {
1005        rc = nbd_send_reply(csock, reply);
1006    } else {
1007        socket_set_cork(csock, 1);
1008        rc = nbd_send_reply(csock, reply);
1009        if (rc >= 0) {
1010            ret = qemu_co_send(csock, req->data, len);
1011            if (ret != len) {
1012                rc = -EIO;
1013            }
1014        }
1015        socket_set_cork(csock, 0);
1016    }
1017
1018    client->send_coroutine = NULL;
1019    qemu_set_fd_handler2(csock, nbd_can_read, nbd_read, NULL, client);
1020    qemu_co_mutex_unlock(&client->send_lock);
1021    return rc;
1022}
1023
1024static ssize_t nbd_co_receive_request(NBDRequest *req, struct nbd_request *request)
1025{
1026    NBDClient *client = req->client;
1027    int csock = client->sock;
1028    ssize_t rc;
1029
1030    client->recv_coroutine = qemu_coroutine_self();
1031    rc = nbd_receive_request(csock, request);
1032    if (rc < 0) {
1033        if (rc != -EAGAIN) {
1034            rc = -EIO;
1035        }
1036        goto out;
1037    }
1038
1039    if (request->len > NBD_BUFFER_SIZE) {
1040        LOG("len (%u) is larger than max len (%u)",
1041            request->len, NBD_BUFFER_SIZE);
1042        rc = -EINVAL;
1043        goto out;
1044    }
1045
1046    if ((request->from + request->len) < request->from) {
1047        LOG("integer overflow detected! "
1048            "you're probably being attacked");
1049        rc = -EINVAL;
1050        goto out;
1051    }
1052
1053    TRACE("Decoding type");
1054
1055    if ((request->type & NBD_CMD_MASK_COMMAND) == NBD_CMD_WRITE) {
1056        TRACE("Reading %u byte(s)", request->len);
1057
1058        if (qemu_co_recv(csock, req->data, request->len) != request->len) {
1059            LOG("reading from socket failed");
1060            rc = -EIO;
1061            goto out;
1062        }
1063    }
1064    rc = 0;
1065
1066out:
1067    client->recv_coroutine = NULL;
1068    return rc;
1069}
1070
1071static void nbd_trip(void *opaque)
1072{
1073    NBDClient *client = opaque;
1074    NBDExport *exp = client->exp;
1075    NBDRequest *req;
1076    struct nbd_request request;
1077    struct nbd_reply reply;
1078    ssize_t ret;
1079
1080    TRACE("Reading request.");
1081    if (client->closing) {
1082        return;
1083    }
1084
1085    req = nbd_request_get(client);
1086    ret = nbd_co_receive_request(req, &request);
1087    if (ret == -EAGAIN) {
1088        goto done;
1089    }
1090    if (ret == -EIO) {
1091        goto out;
1092    }
1093
1094    reply.handle = request.handle;
1095    reply.error = 0;
1096
1097    if (ret < 0) {
1098        reply.error = -ret;
1099        goto error_reply;
1100    }
1101
1102    if ((request.from + request.len) > exp->size) {
1103            LOG("From: %" PRIu64 ", Len: %u, Size: %" PRIu64
1104            ", Offset: %" PRIu64 "\n",
1105                    request.from, request.len,
1106                    (uint64_t)exp->size, (uint64_t)exp->dev_offset);
1107        LOG("requested operation past EOF--bad client?");
1108        goto invalid_request;
1109    }
1110
1111    switch (request.type & NBD_CMD_MASK_COMMAND) {
1112    case NBD_CMD_READ:
1113        TRACE("Request type is READ");
1114
1115        if (request.type & NBD_CMD_FLAG_FUA) {
1116            ret = bdrv_co_flush(exp->bs);
1117            if (ret < 0) {
1118                LOG("flush failed");
1119                reply.error = -ret;
1120                goto error_reply;
1121            }
1122        }
1123
1124        ret = bdrv_read(exp->bs, (request.from + exp->dev_offset) / 512,
1125                        req->data, request.len / 512);
1126        if (ret < 0) {
1127            LOG("reading from file failed");
1128            reply.error = -ret;
1129            goto error_reply;
1130        }
1131
1132        TRACE("Read %u byte(s)", request.len);
1133        if (nbd_co_send_reply(req, &reply, request.len) < 0)
1134            goto out;
1135        break;
1136    case NBD_CMD_WRITE:
1137        TRACE("Request type is WRITE");
1138
1139        if (exp->nbdflags & NBD_FLAG_READ_ONLY) {
1140            TRACE("Server is read-only, return error");
1141            reply.error = EROFS;
1142            goto error_reply;
1143        }
1144
1145        TRACE("Writing to device");
1146
1147        ret = bdrv_write(exp->bs, (request.from + exp->dev_offset) / 512,
1148                         req->data, request.len / 512);
1149        if (ret < 0) {
1150            LOG("writing to file failed");
1151            reply.error = -ret;
1152            goto error_reply;
1153        }
1154
1155        if (request.type & NBD_CMD_FLAG_FUA) {
1156            ret = bdrv_co_flush(exp->bs);
1157            if (ret < 0) {
1158                LOG("flush failed");
1159                reply.error = -ret;
1160                goto error_reply;
1161            }
1162        }
1163
1164        if (nbd_co_send_reply(req, &reply, 0) < 0) {
1165            goto out;
1166        }
1167        break;
1168    case NBD_CMD_DISC:
1169        TRACE("Request type is DISCONNECT");
1170        errno = 0;
1171        goto out;
1172    case NBD_CMD_FLUSH:
1173        TRACE("Request type is FLUSH");
1174
1175        ret = bdrv_co_flush(exp->bs);
1176        if (ret < 0) {
1177            LOG("flush failed");
1178            reply.error = -ret;
1179        }
1180        if (nbd_co_send_reply(req, &reply, 0) < 0) {
1181            goto out;
1182        }
1183        break;
1184    case NBD_CMD_TRIM:
1185        TRACE("Request type is TRIM");
1186        ret = bdrv_co_discard(exp->bs, (request.from + exp->dev_offset) / 512,
1187                              request.len / 512);
1188        if (ret < 0) {
1189            LOG("discard failed");
1190            reply.error = -ret;
1191        }
1192        if (nbd_co_send_reply(req, &reply, 0) < 0) {
1193            goto out;
1194        }
1195        break;
1196    default:
1197        LOG("invalid request type (%u) received", request.type);
1198    invalid_request:
1199        reply.error = -EINVAL;
1200    error_reply:
1201        if (nbd_co_send_reply(req, &reply, 0) < 0) {
1202            goto out;
1203        }
1204        break;
1205    }
1206
1207    TRACE("Request/Reply complete");
1208
1209done:
1210    nbd_request_put(req);
1211    return;
1212
1213out:
1214    nbd_request_put(req);
1215    nbd_client_close(client);
1216}
1217
1218static int nbd_can_read(void *opaque)
1219{
1220    NBDClient *client = opaque;
1221
1222    return client->recv_coroutine || client->nb_requests < MAX_NBD_REQUESTS;
1223}
1224
1225static void nbd_read(void *opaque)
1226{
1227    NBDClient *client = opaque;
1228
1229    if (client->recv_coroutine) {
1230        qemu_coroutine_enter(client->recv_coroutine, NULL);
1231    } else {
1232        qemu_coroutine_enter(qemu_coroutine_create(nbd_trip), client);
1233    }
1234}
1235
1236static void nbd_restart_write(void *opaque)
1237{
1238    NBDClient *client = opaque;
1239
1240    qemu_coroutine_enter(client->send_coroutine, NULL);
1241}
1242
1243NBDClient *nbd_client_new(NBDExport *exp, int csock,
1244                          void (*close)(NBDClient *))
1245{
1246    NBDClient *client;
1247    client = g_malloc0(sizeof(NBDClient));
1248    client->refcount = 1;
1249    client->exp = exp;
1250    client->sock = csock;
1251    if (nbd_send_negotiate(client) < 0) {
1252        g_free(client);
1253        return NULL;
1254    }
1255    client->close = close;
1256    qemu_co_mutex_init(&client->send_lock);
1257    qemu_set_fd_handler2(csock, nbd_can_read, nbd_read, NULL, client);
1258
1259    if (exp) {
1260        QTAILQ_INSERT_TAIL(&exp->clients, client, next);
1261        nbd_export_get(exp);
1262    }
1263    return client;
1264}
1265