qemu/nbd.c
<<
>>
Prefs
   1/*
   2 *  Copyright (C) 2005  Anthony Liguori <anthony@codemonkey.ws>
   3 *
   4 *  Network Block Device
   5 *
   6 *  This program is free software; you can redistribute it and/or modify
   7 *  it under the terms of the GNU General Public License as published by
   8 *  the Free Software Foundation; under version 2 of the License.
   9 *
  10 *  This program is distributed in the hope that it will be useful,
  11 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
  12 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  13 *  GNU General Public License for more details.
  14 *
  15 *  You should have received a copy of the GNU General Public License
  16 *  along with this program; if not, see <http://www.gnu.org/licenses/>.
  17 */
  18
  19#include "block/nbd.h"
  20#include "block/block.h"
  21
  22#include "block/coroutine.h"
  23
  24#include <errno.h>
  25#include <string.h>
  26#ifndef _WIN32
  27#include <sys/ioctl.h>
  28#endif
  29#if defined(__sun__) || defined(__HAIKU__)
  30#include <sys/ioccom.h>
  31#endif
  32#include <ctype.h>
  33#include <inttypes.h>
  34
  35#ifdef __linux__
  36#include <linux/fs.h>
  37#endif
  38
  39#include "qemu/sockets.h"
  40#include "qemu/queue.h"
  41
  42//#define DEBUG_NBD
  43
  44#ifdef DEBUG_NBD
  45#define TRACE(msg, ...) do { \
  46    LOG(msg, ## __VA_ARGS__); \
  47} while(0)
  48#else
  49#define TRACE(msg, ...) \
  50    do { } while (0)
  51#endif
  52
  53#define LOG(msg, ...) do { \
  54    fprintf(stderr, "%s:%s():L%d: " msg "\n", \
  55            __FILE__, __FUNCTION__, __LINE__, ## __VA_ARGS__); \
  56} while(0)
  57
  58/* This is all part of the "official" NBD API */
  59
  60#define NBD_REQUEST_SIZE        (4 + 4 + 8 + 8 + 4)
  61#define NBD_REPLY_SIZE          (4 + 4 + 8)
  62#define NBD_REQUEST_MAGIC       0x25609513
  63#define NBD_REPLY_MAGIC         0x67446698
  64#define NBD_OPTS_MAGIC          0x49484156454F5054LL
  65#define NBD_CLIENT_MAGIC        0x0000420281861253LL
  66
  67#define NBD_SET_SOCK            _IO(0xab, 0)
  68#define NBD_SET_BLKSIZE         _IO(0xab, 1)
  69#define NBD_SET_SIZE            _IO(0xab, 2)
  70#define NBD_DO_IT               _IO(0xab, 3)
  71#define NBD_CLEAR_SOCK          _IO(0xab, 4)
  72#define NBD_CLEAR_QUE           _IO(0xab, 5)
  73#define NBD_PRINT_DEBUG         _IO(0xab, 6)
  74#define NBD_SET_SIZE_BLOCKS     _IO(0xab, 7)
  75#define NBD_DISCONNECT          _IO(0xab, 8)
  76#define NBD_SET_TIMEOUT         _IO(0xab, 9)
  77#define NBD_SET_FLAGS           _IO(0xab, 10)
  78
  79#define NBD_OPT_EXPORT_NAME     (1 << 0)
  80
  81/* Definitions for opaque data types */
  82
  83typedef struct NBDRequest NBDRequest;
  84
  85struct NBDRequest {
  86    QSIMPLEQ_ENTRY(NBDRequest) entry;
  87    NBDClient *client;
  88    uint8_t *data;
  89};
  90
  91struct NBDExport {
  92    int refcount;
  93    void (*close)(NBDExport *exp);
  94
  95    BlockDriverState *bs;
  96    char *name;
  97    off_t dev_offset;
  98    off_t size;
  99    uint32_t nbdflags;
 100    QTAILQ_HEAD(, NBDClient) clients;
 101    QTAILQ_ENTRY(NBDExport) next;
 102};
 103
 104static QTAILQ_HEAD(, NBDExport) exports = QTAILQ_HEAD_INITIALIZER(exports);
 105
 106struct NBDClient {
 107    int refcount;
 108    void (*close)(NBDClient *client);
 109
 110    NBDExport *exp;
 111    int sock;
 112
 113    Coroutine *recv_coroutine;
 114
 115    CoMutex send_lock;
 116    Coroutine *send_coroutine;
 117
 118    QTAILQ_ENTRY(NBDClient) next;
 119    int nb_requests;
 120    bool closing;
 121};
 122
 123/* That's all folks */
 124
 125ssize_t nbd_wr_sync(int fd, void *buffer, size_t size, bool do_read)
 126{
 127    size_t offset = 0;
 128    int err;
 129
 130    if (qemu_in_coroutine()) {
 131        if (do_read) {
 132            return qemu_co_recv(fd, buffer, size);
 133        } else {
 134            return qemu_co_send(fd, buffer, size);
 135        }
 136    }
 137
 138    while (offset < size) {
 139        ssize_t len;
 140
 141        if (do_read) {
 142            len = qemu_recv(fd, buffer + offset, size - offset, 0);
 143        } else {
 144            len = send(fd, buffer + offset, size - offset, 0);
 145        }
 146
 147        if (len < 0) {
 148            err = socket_error();
 149
 150            /* recoverable error */
 151            if (err == EINTR || (offset > 0 && err == EAGAIN)) {
 152                continue;
 153            }
 154
 155            /* unrecoverable error */
 156            return -err;
 157        }
 158
 159        /* eof */
 160        if (len == 0) {
 161            break;
 162        }
 163
 164        offset += len;
 165    }
 166
 167    return offset;
 168}
 169
 170static ssize_t read_sync(int fd, void *buffer, size_t size)
 171{
 172    /* Sockets are kept in blocking mode in the negotiation phase.  After
 173     * that, a non-readable socket simply means that another thread stole
 174     * our request/reply.  Synchronization is done with recv_coroutine, so
 175     * that this is coroutine-safe.
 176     */
 177    return nbd_wr_sync(fd, buffer, size, true);
 178}
 179
 180static ssize_t write_sync(int fd, void *buffer, size_t size)
 181{
 182    int ret;
 183    do {
 184        /* For writes, we do expect the socket to be writable.  */
 185        ret = nbd_wr_sync(fd, buffer, size, false);
 186    } while (ret == -EAGAIN);
 187    return ret;
 188}
 189
 190static void combine_addr(char *buf, size_t len, const char* address,
 191                         uint16_t port)
 192{
 193    /* If the address-part contains a colon, it's an IPv6 IP so needs [] */
 194    if (strstr(address, ":")) {
 195        snprintf(buf, len, "[%s]:%u", address, port);
 196    } else {
 197        snprintf(buf, len, "%s:%u", address, port);
 198    }
 199}
 200
 201int tcp_socket_outgoing_opts(QemuOpts *opts)
 202{
 203    Error *local_err = NULL;
 204    int fd = inet_connect_opts(opts, &local_err, NULL, NULL);
 205    if (local_err != NULL) {
 206        qerror_report_err(local_err);
 207        error_free(local_err);
 208    }
 209
 210    return fd;
 211}
 212
 213int tcp_socket_incoming(const char *address, uint16_t port)
 214{
 215    char address_and_port[128];
 216    combine_addr(address_and_port, 128, address, port);
 217    return tcp_socket_incoming_spec(address_and_port);
 218}
 219
 220int tcp_socket_incoming_spec(const char *address_and_port)
 221{
 222    Error *local_err = NULL;
 223    int fd = inet_listen(address_and_port, NULL, 0, SOCK_STREAM, 0, &local_err);
 224
 225    if (local_err != NULL) {
 226        qerror_report_err(local_err);
 227        error_free(local_err);
 228    }
 229    return fd;
 230}
 231
 232int unix_socket_incoming(const char *path)
 233{
 234    Error *local_err = NULL;
 235    int fd = unix_listen(path, NULL, 0, &local_err);
 236
 237    if (local_err != NULL) {
 238        qerror_report_err(local_err);
 239        error_free(local_err);
 240    }
 241    return fd;
 242}
 243
 244int unix_socket_outgoing(const char *path)
 245{
 246    Error *local_err = NULL;
 247    int fd = unix_connect(path, &local_err);
 248
 249    if (local_err != NULL) {
 250        qerror_report_err(local_err);
 251        error_free(local_err);
 252    }
 253    return fd;
 254}
 255
 256/* Basic flow for negotiation
 257
 258   Server         Client
 259   Negotiate
 260
 261   or
 262
 263   Server         Client
 264   Negotiate #1
 265                  Option
 266   Negotiate #2
 267
 268   ----
 269
 270   followed by
 271
 272   Server         Client
 273                  Request
 274   Response
 275                  Request
 276   Response
 277                  ...
 278   ...
 279                  Request (type == 2)
 280
 281*/
 282
 283static int nbd_receive_options(NBDClient *client)
 284{
 285    int csock = client->sock;
 286    char name[256];
 287    uint32_t tmp, length;
 288    uint64_t magic;
 289    int rc;
 290
 291    /* Client sends:
 292        [ 0 ..   3]   reserved (0)
 293        [ 4 ..  11]   NBD_OPTS_MAGIC
 294        [12 ..  15]   NBD_OPT_EXPORT_NAME
 295        [16 ..  19]   length
 296        [20 ..  xx]   export name (length bytes)
 297     */
 298
 299    rc = -EINVAL;
 300    if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
 301        LOG("read failed");
 302        goto fail;
 303    }
 304    TRACE("Checking reserved");
 305    if (tmp != 0) {
 306        LOG("Bad reserved received");
 307        goto fail;
 308    }
 309
 310    if (read_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
 311        LOG("read failed");
 312        goto fail;
 313    }
 314    TRACE("Checking reserved");
 315    if (magic != be64_to_cpu(NBD_OPTS_MAGIC)) {
 316        LOG("Bad magic received");
 317        goto fail;
 318    }
 319
 320    if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
 321        LOG("read failed");
 322        goto fail;
 323    }
 324    TRACE("Checking option");
 325    if (tmp != be32_to_cpu(NBD_OPT_EXPORT_NAME)) {
 326        LOG("Bad option received");
 327        goto fail;
 328    }
 329
 330    if (read_sync(csock, &length, sizeof(length)) != sizeof(length)) {
 331        LOG("read failed");
 332        goto fail;
 333    }
 334    TRACE("Checking length");
 335    length = be32_to_cpu(length);
 336    if (length > 255) {
 337        LOG("Bad length received");
 338        goto fail;
 339    }
 340    if (read_sync(csock, name, length) != length) {
 341        LOG("read failed");
 342        goto fail;
 343    }
 344    name[length] = '\0';
 345
 346    client->exp = nbd_export_find(name);
 347    if (!client->exp) {
 348        LOG("export not found");
 349        goto fail;
 350    }
 351
 352    QTAILQ_INSERT_TAIL(&client->exp->clients, client, next);
 353    nbd_export_get(client->exp);
 354
 355    TRACE("Option negotiation succeeded.");
 356    rc = 0;
 357fail:
 358    return rc;
 359}
 360
 361static int nbd_send_negotiate(NBDClient *client)
 362{
 363    int csock = client->sock;
 364    char buf[8 + 8 + 8 + 128];
 365    int rc;
 366    const int myflags = (NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_TRIM |
 367                         NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA);
 368
 369    /* Negotiation header without options:
 370        [ 0 ..   7]   passwd       ("NBDMAGIC")
 371        [ 8 ..  15]   magic        (NBD_CLIENT_MAGIC)
 372        [16 ..  23]   size
 373        [24 ..  25]   server flags (0)
 374        [24 ..  27]   export flags
 375        [28 .. 151]   reserved     (0)
 376
 377       Negotiation header with options, part 1:
 378        [ 0 ..   7]   passwd       ("NBDMAGIC")
 379        [ 8 ..  15]   magic        (NBD_OPTS_MAGIC)
 380        [16 ..  17]   server flags (0)
 381
 382       part 2 (after options are sent):
 383        [18 ..  25]   size
 384        [26 ..  27]   export flags
 385        [28 .. 151]   reserved     (0)
 386     */
 387
 388    qemu_set_block(csock);
 389    rc = -EINVAL;
 390
 391    TRACE("Beginning negotiation.");
 392    memset(buf, 0, sizeof(buf));
 393    memcpy(buf, "NBDMAGIC", 8);
 394    if (client->exp) {
 395        assert ((client->exp->nbdflags & ~65535) == 0);
 396        cpu_to_be64w((uint64_t*)(buf + 8), NBD_CLIENT_MAGIC);
 397        cpu_to_be64w((uint64_t*)(buf + 16), client->exp->size);
 398        cpu_to_be16w((uint16_t*)(buf + 26), client->exp->nbdflags | myflags);
 399    } else {
 400        cpu_to_be64w((uint64_t*)(buf + 8), NBD_OPTS_MAGIC);
 401    }
 402
 403    if (client->exp) {
 404        if (write_sync(csock, buf, sizeof(buf)) != sizeof(buf)) {
 405            LOG("write failed");
 406            goto fail;
 407        }
 408    } else {
 409        if (write_sync(csock, buf, 18) != 18) {
 410            LOG("write failed");
 411            goto fail;
 412        }
 413        rc = nbd_receive_options(client);
 414        if (rc < 0) {
 415            LOG("option negotiation failed");
 416            goto fail;
 417        }
 418
 419        assert ((client->exp->nbdflags & ~65535) == 0);
 420        cpu_to_be64w((uint64_t*)(buf + 18), client->exp->size);
 421        cpu_to_be16w((uint16_t*)(buf + 26), client->exp->nbdflags | myflags);
 422        if (write_sync(csock, buf + 18, sizeof(buf) - 18) != sizeof(buf) - 18) {
 423            LOG("write failed");
 424            goto fail;
 425        }
 426    }
 427
 428    TRACE("Negotiation succeeded.");
 429    rc = 0;
 430fail:
 431    qemu_set_nonblock(csock);
 432    return rc;
 433}
 434
 435int nbd_receive_negotiate(int csock, const char *name, uint32_t *flags,
 436                          off_t *size, size_t *blocksize)
 437{
 438    char buf[256];
 439    uint64_t magic, s;
 440    uint16_t tmp;
 441    int rc;
 442
 443    TRACE("Receiving negotiation.");
 444
 445    qemu_set_block(csock);
 446    rc = -EINVAL;
 447
 448    if (read_sync(csock, buf, 8) != 8) {
 449        LOG("read failed");
 450        goto fail;
 451    }
 452
 453    buf[8] = '\0';
 454    if (strlen(buf) == 0) {
 455        LOG("server connection closed");
 456        goto fail;
 457    }
 458
 459    TRACE("Magic is %c%c%c%c%c%c%c%c",
 460          qemu_isprint(buf[0]) ? buf[0] : '.',
 461          qemu_isprint(buf[1]) ? buf[1] : '.',
 462          qemu_isprint(buf[2]) ? buf[2] : '.',
 463          qemu_isprint(buf[3]) ? buf[3] : '.',
 464          qemu_isprint(buf[4]) ? buf[4] : '.',
 465          qemu_isprint(buf[5]) ? buf[5] : '.',
 466          qemu_isprint(buf[6]) ? buf[6] : '.',
 467          qemu_isprint(buf[7]) ? buf[7] : '.');
 468
 469    if (memcmp(buf, "NBDMAGIC", 8) != 0) {
 470        LOG("Invalid magic received");
 471        goto fail;
 472    }
 473
 474    if (read_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
 475        LOG("read failed");
 476        goto fail;
 477    }
 478    magic = be64_to_cpu(magic);
 479    TRACE("Magic is 0x%" PRIx64, magic);
 480
 481    if (name) {
 482        uint32_t reserved = 0;
 483        uint32_t opt;
 484        uint32_t namesize;
 485
 486        TRACE("Checking magic (opts_magic)");
 487        if (magic != NBD_OPTS_MAGIC) {
 488            LOG("Bad magic received");
 489            goto fail;
 490        }
 491        if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
 492            LOG("flags read failed");
 493            goto fail;
 494        }
 495        *flags = be16_to_cpu(tmp) << 16;
 496        /* reserved for future use */
 497        if (write_sync(csock, &reserved, sizeof(reserved)) !=
 498            sizeof(reserved)) {
 499            LOG("write failed (reserved)");
 500            goto fail;
 501        }
 502        /* write the export name */
 503        magic = cpu_to_be64(magic);
 504        if (write_sync(csock, &magic, sizeof(magic)) != sizeof(magic)) {
 505            LOG("write failed (magic)");
 506            goto fail;
 507        }
 508        opt = cpu_to_be32(NBD_OPT_EXPORT_NAME);
 509        if (write_sync(csock, &opt, sizeof(opt)) != sizeof(opt)) {
 510            LOG("write failed (opt)");
 511            goto fail;
 512        }
 513        namesize = cpu_to_be32(strlen(name));
 514        if (write_sync(csock, &namesize, sizeof(namesize)) !=
 515            sizeof(namesize)) {
 516            LOG("write failed (namesize)");
 517            goto fail;
 518        }
 519        if (write_sync(csock, (char*)name, strlen(name)) != strlen(name)) {
 520            LOG("write failed (name)");
 521            goto fail;
 522        }
 523    } else {
 524        TRACE("Checking magic (cli_magic)");
 525
 526        if (magic != NBD_CLIENT_MAGIC) {
 527            LOG("Bad magic received");
 528            goto fail;
 529        }
 530    }
 531
 532    if (read_sync(csock, &s, sizeof(s)) != sizeof(s)) {
 533        LOG("read failed");
 534        goto fail;
 535    }
 536    *size = be64_to_cpu(s);
 537    *blocksize = 1024;
 538    TRACE("Size is %" PRIu64, *size);
 539
 540    if (!name) {
 541        if (read_sync(csock, flags, sizeof(*flags)) != sizeof(*flags)) {
 542            LOG("read failed (flags)");
 543            goto fail;
 544        }
 545        *flags = be32_to_cpup(flags);
 546    } else {
 547        if (read_sync(csock, &tmp, sizeof(tmp)) != sizeof(tmp)) {
 548            LOG("read failed (tmp)");
 549            goto fail;
 550        }
 551        *flags |= be32_to_cpu(tmp);
 552    }
 553    if (read_sync(csock, &buf, 124) != 124) {
 554        LOG("read failed (buf)");
 555        goto fail;
 556    }
 557    rc = 0;
 558
 559fail:
 560    qemu_set_nonblock(csock);
 561    return rc;
 562}
 563
 564#ifdef __linux__
 565int nbd_init(int fd, int csock, uint32_t flags, off_t size, size_t blocksize)
 566{
 567    TRACE("Setting NBD socket");
 568
 569    if (ioctl(fd, NBD_SET_SOCK, csock) < 0) {
 570        int serrno = errno;
 571        LOG("Failed to set NBD socket");
 572        return -serrno;
 573    }
 574
 575    TRACE("Setting block size to %lu", (unsigned long)blocksize);
 576
 577    if (ioctl(fd, NBD_SET_BLKSIZE, blocksize) < 0) {
 578        int serrno = errno;
 579        LOG("Failed setting NBD block size");
 580        return -serrno;
 581    }
 582
 583        TRACE("Setting size to %zd block(s)", (size_t)(size / blocksize));
 584
 585    if (ioctl(fd, NBD_SET_SIZE_BLOCKS, size / blocksize) < 0) {
 586        int serrno = errno;
 587        LOG("Failed setting size (in blocks)");
 588        return -serrno;
 589    }
 590
 591    if (ioctl(fd, NBD_SET_FLAGS, flags) < 0) {
 592        if (errno == ENOTTY) {
 593            int read_only = (flags & NBD_FLAG_READ_ONLY) != 0;
 594            TRACE("Setting readonly attribute");
 595
 596            if (ioctl(fd, BLKROSET, (unsigned long) &read_only) < 0) {
 597                int serrno = errno;
 598                LOG("Failed setting read-only attribute");
 599                return -serrno;
 600            }
 601        } else {
 602            int serrno = errno;
 603            LOG("Failed setting flags");
 604            return -serrno;
 605        }
 606    }
 607
 608    TRACE("Negotiation ended");
 609
 610    return 0;
 611}
 612
 613int nbd_disconnect(int fd)
 614{
 615    ioctl(fd, NBD_CLEAR_QUE);
 616    ioctl(fd, NBD_DISCONNECT);
 617    ioctl(fd, NBD_CLEAR_SOCK);
 618    return 0;
 619}
 620
 621int nbd_client(int fd)
 622{
 623    int ret;
 624    int serrno;
 625
 626    TRACE("Doing NBD loop");
 627
 628    ret = ioctl(fd, NBD_DO_IT);
 629    if (ret < 0 && errno == EPIPE) {
 630        /* NBD_DO_IT normally returns EPIPE when someone has disconnected
 631         * the socket via NBD_DISCONNECT.  We do not want to return 1 in
 632         * that case.
 633         */
 634        ret = 0;
 635    }
 636    serrno = errno;
 637
 638    TRACE("NBD loop returned %d: %s", ret, strerror(serrno));
 639
 640    TRACE("Clearing NBD queue");
 641    ioctl(fd, NBD_CLEAR_QUE);
 642
 643    TRACE("Clearing NBD socket");
 644    ioctl(fd, NBD_CLEAR_SOCK);
 645
 646    errno = serrno;
 647    return ret;
 648}
 649#else
 650int nbd_init(int fd, int csock, uint32_t flags, off_t size, size_t blocksize)
 651{
 652    return -ENOTSUP;
 653}
 654
 655int nbd_disconnect(int fd)
 656{
 657    return -ENOTSUP;
 658}
 659
 660int nbd_client(int fd)
 661{
 662    return -ENOTSUP;
 663}
 664#endif
 665
 666ssize_t nbd_send_request(int csock, struct nbd_request *request)
 667{
 668    uint8_t buf[NBD_REQUEST_SIZE];
 669    ssize_t ret;
 670
 671    cpu_to_be32w((uint32_t*)buf, NBD_REQUEST_MAGIC);
 672    cpu_to_be32w((uint32_t*)(buf + 4), request->type);
 673    cpu_to_be64w((uint64_t*)(buf + 8), request->handle);
 674    cpu_to_be64w((uint64_t*)(buf + 16), request->from);
 675    cpu_to_be32w((uint32_t*)(buf + 24), request->len);
 676
 677    TRACE("Sending request to client: "
 678          "{ .from = %" PRIu64", .len = %u, .handle = %" PRIu64", .type=%i}",
 679          request->from, request->len, request->handle, request->type);
 680
 681    ret = write_sync(csock, buf, sizeof(buf));
 682    if (ret < 0) {
 683        return ret;
 684    }
 685
 686    if (ret != sizeof(buf)) {
 687        LOG("writing to socket failed");
 688        return -EINVAL;
 689    }
 690    return 0;
 691}
 692
 693static ssize_t nbd_receive_request(int csock, struct nbd_request *request)
 694{
 695    uint8_t buf[NBD_REQUEST_SIZE];
 696    uint32_t magic;
 697    ssize_t ret;
 698
 699    ret = read_sync(csock, buf, sizeof(buf));
 700    if (ret < 0) {
 701        return ret;
 702    }
 703
 704    if (ret != sizeof(buf)) {
 705        LOG("read failed");
 706        return -EINVAL;
 707    }
 708
 709    /* Request
 710       [ 0 ..  3]   magic   (NBD_REQUEST_MAGIC)
 711       [ 4 ..  7]   type    (0 == READ, 1 == WRITE)
 712       [ 8 .. 15]   handle
 713       [16 .. 23]   from
 714       [24 .. 27]   len
 715     */
 716
 717    magic = be32_to_cpup((uint32_t*)buf);
 718    request->type  = be32_to_cpup((uint32_t*)(buf + 4));
 719    request->handle = be64_to_cpup((uint64_t*)(buf + 8));
 720    request->from  = be64_to_cpup((uint64_t*)(buf + 16));
 721    request->len   = be32_to_cpup((uint32_t*)(buf + 24));
 722
 723    TRACE("Got request: "
 724          "{ magic = 0x%x, .type = %d, from = %" PRIu64" , len = %u }",
 725          magic, request->type, request->from, request->len);
 726
 727    if (magic != NBD_REQUEST_MAGIC) {
 728        LOG("invalid magic (got 0x%x)", magic);
 729        return -EINVAL;
 730    }
 731    return 0;
 732}
 733
 734ssize_t nbd_receive_reply(int csock, struct nbd_reply *reply)
 735{
 736    uint8_t buf[NBD_REPLY_SIZE];
 737    uint32_t magic;
 738    ssize_t ret;
 739
 740    ret = read_sync(csock, buf, sizeof(buf));
 741    if (ret < 0) {
 742        return ret;
 743    }
 744
 745    if (ret != sizeof(buf)) {
 746        LOG("read failed");
 747        return -EINVAL;
 748    }
 749
 750    /* Reply
 751       [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
 752       [ 4 ..  7]    error   (0 == no error)
 753       [ 7 .. 15]    handle
 754     */
 755
 756    magic = be32_to_cpup((uint32_t*)buf);
 757    reply->error  = be32_to_cpup((uint32_t*)(buf + 4));
 758    reply->handle = be64_to_cpup((uint64_t*)(buf + 8));
 759
 760    TRACE("Got reply: "
 761          "{ magic = 0x%x, .error = %d, handle = %" PRIu64" }",
 762          magic, reply->error, reply->handle);
 763
 764    if (magic != NBD_REPLY_MAGIC) {
 765        LOG("invalid magic (got 0x%x)", magic);
 766        return -EINVAL;
 767    }
 768    return 0;
 769}
 770
 771static ssize_t nbd_send_reply(int csock, struct nbd_reply *reply)
 772{
 773    uint8_t buf[NBD_REPLY_SIZE];
 774    ssize_t ret;
 775
 776    /* Reply
 777       [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
 778       [ 4 ..  7]    error   (0 == no error)
 779       [ 7 .. 15]    handle
 780     */
 781    cpu_to_be32w((uint32_t*)buf, NBD_REPLY_MAGIC);
 782    cpu_to_be32w((uint32_t*)(buf + 4), reply->error);
 783    cpu_to_be64w((uint64_t*)(buf + 8), reply->handle);
 784
 785    TRACE("Sending response to client");
 786
 787    ret = write_sync(csock, buf, sizeof(buf));
 788    if (ret < 0) {
 789        return ret;
 790    }
 791
 792    if (ret != sizeof(buf)) {
 793        LOG("writing to socket failed");
 794        return -EINVAL;
 795    }
 796    return 0;
 797}
 798
 799#define MAX_NBD_REQUESTS 16
 800
 801void nbd_client_get(NBDClient *client)
 802{
 803    client->refcount++;
 804}
 805
 806void nbd_client_put(NBDClient *client)
 807{
 808    if (--client->refcount == 0) {
 809        /* The last reference should be dropped by client->close,
 810         * which is called by nbd_client_close.
 811         */
 812        assert(client->closing);
 813
 814        qemu_set_fd_handler2(client->sock, NULL, NULL, NULL, NULL);
 815        close(client->sock);
 816        client->sock = -1;
 817        if (client->exp) {
 818            QTAILQ_REMOVE(&client->exp->clients, client, next);
 819            nbd_export_put(client->exp);
 820        }
 821        g_free(client);
 822    }
 823}
 824
 825void nbd_client_close(NBDClient *client)
 826{
 827    if (client->closing) {
 828        return;
 829    }
 830
 831    client->closing = true;
 832
 833    /* Force requests to finish.  They will drop their own references,
 834     * then we'll close the socket and free the NBDClient.
 835     */
 836    shutdown(client->sock, 2);
 837
 838    /* Also tell the client, so that they release their reference.  */
 839    if (client->close) {
 840        client->close(client);
 841    }
 842}
 843
 844static NBDRequest *nbd_request_get(NBDClient *client)
 845{
 846    NBDRequest *req;
 847
 848    assert(client->nb_requests <= MAX_NBD_REQUESTS - 1);
 849    client->nb_requests++;
 850
 851    req = g_slice_new0(NBDRequest);
 852    nbd_client_get(client);
 853    req->client = client;
 854    return req;
 855}
 856
 857static void nbd_request_put(NBDRequest *req)
 858{
 859    NBDClient *client = req->client;
 860
 861    if (req->data) {
 862        qemu_vfree(req->data);
 863    }
 864    g_slice_free(NBDRequest, req);
 865
 866    if (client->nb_requests-- == MAX_NBD_REQUESTS) {
 867        qemu_notify_event();
 868    }
 869    nbd_client_put(client);
 870}
 871
 872NBDExport *nbd_export_new(BlockDriverState *bs, off_t dev_offset,
 873                          off_t size, uint32_t nbdflags,
 874                          void (*close)(NBDExport *))
 875{
 876    NBDExport *exp = g_malloc0(sizeof(NBDExport));
 877    exp->refcount = 1;
 878    QTAILQ_INIT(&exp->clients);
 879    exp->bs = bs;
 880    exp->dev_offset = dev_offset;
 881    exp->nbdflags = nbdflags;
 882    exp->size = size == -1 ? bdrv_getlength(bs) : size;
 883    exp->close = close;
 884    return exp;
 885}
 886
 887NBDExport *nbd_export_find(const char *name)
 888{
 889    NBDExport *exp;
 890    QTAILQ_FOREACH(exp, &exports, next) {
 891        if (strcmp(name, exp->name) == 0) {
 892            return exp;
 893        }
 894    }
 895
 896    return NULL;
 897}
 898
 899void nbd_export_set_name(NBDExport *exp, const char *name)
 900{
 901    if (exp->name == name) {
 902        return;
 903    }
 904
 905    nbd_export_get(exp);
 906    if (exp->name != NULL) {
 907        g_free(exp->name);
 908        exp->name = NULL;
 909        QTAILQ_REMOVE(&exports, exp, next);
 910        nbd_export_put(exp);
 911    }
 912    if (name != NULL) {
 913        nbd_export_get(exp);
 914        exp->name = g_strdup(name);
 915        QTAILQ_INSERT_TAIL(&exports, exp, next);
 916    }
 917    nbd_export_put(exp);
 918}
 919
 920void nbd_export_close(NBDExport *exp)
 921{
 922    NBDClient *client, *next;
 923
 924    nbd_export_get(exp);
 925    QTAILQ_FOREACH_SAFE(client, &exp->clients, next, next) {
 926        nbd_client_close(client);
 927    }
 928    nbd_export_set_name(exp, NULL);
 929    nbd_export_put(exp);
 930}
 931
 932void nbd_export_get(NBDExport *exp)
 933{
 934    assert(exp->refcount > 0);
 935    exp->refcount++;
 936}
 937
 938void nbd_export_put(NBDExport *exp)
 939{
 940    assert(exp->refcount > 0);
 941    if (exp->refcount == 1) {
 942        nbd_export_close(exp);
 943    }
 944
 945    if (--exp->refcount == 0) {
 946        assert(exp->name == NULL);
 947
 948        if (exp->close) {
 949            exp->close(exp);
 950        }
 951
 952        g_free(exp);
 953    }
 954}
 955
 956BlockDriverState *nbd_export_get_blockdev(NBDExport *exp)
 957{
 958    return exp->bs;
 959}
 960
 961void nbd_export_close_all(void)
 962{
 963    NBDExport *exp, *next;
 964
 965    QTAILQ_FOREACH_SAFE(exp, &exports, next, next) {
 966        nbd_export_close(exp);
 967    }
 968}
 969
 970static int nbd_can_read(void *opaque);
 971static void nbd_read(void *opaque);
 972static void nbd_restart_write(void *opaque);
 973
 974static ssize_t nbd_co_send_reply(NBDRequest *req, struct nbd_reply *reply,
 975                                 int len)
 976{
 977    NBDClient *client = req->client;
 978    int csock = client->sock;
 979    ssize_t rc, ret;
 980
 981    qemu_co_mutex_lock(&client->send_lock);
 982    qemu_set_fd_handler2(csock, nbd_can_read, nbd_read,
 983                         nbd_restart_write, client);
 984    client->send_coroutine = qemu_coroutine_self();
 985
 986    if (!len) {
 987        rc = nbd_send_reply(csock, reply);
 988    } else {
 989        socket_set_cork(csock, 1);
 990        rc = nbd_send_reply(csock, reply);
 991        if (rc >= 0) {
 992            ret = qemu_co_send(csock, req->data, len);
 993            if (ret != len) {
 994                rc = -EIO;
 995            }
 996        }
 997        socket_set_cork(csock, 0);
 998    }
 999
1000    client->send_coroutine = NULL;
1001    qemu_set_fd_handler2(csock, nbd_can_read, nbd_read, NULL, client);
1002    qemu_co_mutex_unlock(&client->send_lock);
1003    return rc;
1004}
1005
1006static ssize_t nbd_co_receive_request(NBDRequest *req, struct nbd_request *request)
1007{
1008    NBDClient *client = req->client;
1009    int csock = client->sock;
1010    uint32_t command;
1011    ssize_t rc;
1012
1013    client->recv_coroutine = qemu_coroutine_self();
1014    rc = nbd_receive_request(csock, request);
1015    if (rc < 0) {
1016        if (rc != -EAGAIN) {
1017            rc = -EIO;
1018        }
1019        goto out;
1020    }
1021
1022    if (request->len > NBD_MAX_BUFFER_SIZE) {
1023        LOG("len (%u) is larger than max len (%u)",
1024            request->len, NBD_MAX_BUFFER_SIZE);
1025        rc = -EINVAL;
1026        goto out;
1027    }
1028
1029    if ((request->from + request->len) < request->from) {
1030        LOG("integer overflow detected! "
1031            "you're probably being attacked");
1032        rc = -EINVAL;
1033        goto out;
1034    }
1035
1036    TRACE("Decoding type");
1037
1038    command = request->type & NBD_CMD_MASK_COMMAND;
1039    if (command == NBD_CMD_READ || command == NBD_CMD_WRITE) {
1040        req->data = qemu_blockalign(client->exp->bs, request->len);
1041    }
1042    if (command == NBD_CMD_WRITE) {
1043        TRACE("Reading %u byte(s)", request->len);
1044
1045        if (qemu_co_recv(csock, req->data, request->len) != request->len) {
1046            LOG("reading from socket failed");
1047            rc = -EIO;
1048            goto out;
1049        }
1050    }
1051    rc = 0;
1052
1053out:
1054    client->recv_coroutine = NULL;
1055    return rc;
1056}
1057
1058static void nbd_trip(void *opaque)
1059{
1060    NBDClient *client = opaque;
1061    NBDExport *exp = client->exp;
1062    NBDRequest *req;
1063    struct nbd_request request;
1064    struct nbd_reply reply;
1065    ssize_t ret;
1066
1067    TRACE("Reading request.");
1068    if (client->closing) {
1069        return;
1070    }
1071
1072    req = nbd_request_get(client);
1073    ret = nbd_co_receive_request(req, &request);
1074    if (ret == -EAGAIN) {
1075        goto done;
1076    }
1077    if (ret == -EIO) {
1078        goto out;
1079    }
1080
1081    reply.handle = request.handle;
1082    reply.error = 0;
1083
1084    if (ret < 0) {
1085        reply.error = -ret;
1086        goto error_reply;
1087    }
1088
1089    if ((request.from + request.len) > exp->size) {
1090            LOG("From: %" PRIu64 ", Len: %u, Size: %" PRIu64
1091            ", Offset: %" PRIu64 "\n",
1092                    request.from, request.len,
1093                    (uint64_t)exp->size, (uint64_t)exp->dev_offset);
1094        LOG("requested operation past EOF--bad client?");
1095        goto invalid_request;
1096    }
1097
1098    switch (request.type & NBD_CMD_MASK_COMMAND) {
1099    case NBD_CMD_READ:
1100        TRACE("Request type is READ");
1101
1102        if (request.type & NBD_CMD_FLAG_FUA) {
1103            ret = bdrv_co_flush(exp->bs);
1104            if (ret < 0) {
1105                LOG("flush failed");
1106                reply.error = -ret;
1107                goto error_reply;
1108            }
1109        }
1110
1111        ret = bdrv_read(exp->bs, (request.from + exp->dev_offset) / 512,
1112                        req->data, request.len / 512);
1113        if (ret < 0) {
1114            LOG("reading from file failed");
1115            reply.error = -ret;
1116            goto error_reply;
1117        }
1118
1119        TRACE("Read %u byte(s)", request.len);
1120        if (nbd_co_send_reply(req, &reply, request.len) < 0)
1121            goto out;
1122        break;
1123    case NBD_CMD_WRITE:
1124        TRACE("Request type is WRITE");
1125
1126        if (exp->nbdflags & NBD_FLAG_READ_ONLY) {
1127            TRACE("Server is read-only, return error");
1128            reply.error = EROFS;
1129            goto error_reply;
1130        }
1131
1132        TRACE("Writing to device");
1133
1134        ret = bdrv_write(exp->bs, (request.from + exp->dev_offset) / 512,
1135                         req->data, request.len / 512);
1136        if (ret < 0) {
1137            LOG("writing to file failed");
1138            reply.error = -ret;
1139            goto error_reply;
1140        }
1141
1142        if (request.type & NBD_CMD_FLAG_FUA) {
1143            ret = bdrv_co_flush(exp->bs);
1144            if (ret < 0) {
1145                LOG("flush failed");
1146                reply.error = -ret;
1147                goto error_reply;
1148            }
1149        }
1150
1151        if (nbd_co_send_reply(req, &reply, 0) < 0) {
1152            goto out;
1153        }
1154        break;
1155    case NBD_CMD_DISC:
1156        TRACE("Request type is DISCONNECT");
1157        errno = 0;
1158        goto out;
1159    case NBD_CMD_FLUSH:
1160        TRACE("Request type is FLUSH");
1161
1162        ret = bdrv_co_flush(exp->bs);
1163        if (ret < 0) {
1164            LOG("flush failed");
1165            reply.error = -ret;
1166        }
1167        if (nbd_co_send_reply(req, &reply, 0) < 0) {
1168            goto out;
1169        }
1170        break;
1171    case NBD_CMD_TRIM:
1172        TRACE("Request type is TRIM");
1173        ret = bdrv_co_discard(exp->bs, (request.from + exp->dev_offset) / 512,
1174                              request.len / 512);
1175        if (ret < 0) {
1176            LOG("discard failed");
1177            reply.error = -ret;
1178        }
1179        if (nbd_co_send_reply(req, &reply, 0) < 0) {
1180            goto out;
1181        }
1182        break;
1183    default:
1184        LOG("invalid request type (%u) received", request.type);
1185    invalid_request:
1186        reply.error = -EINVAL;
1187    error_reply:
1188        if (nbd_co_send_reply(req, &reply, 0) < 0) {
1189            goto out;
1190        }
1191        break;
1192    }
1193
1194    TRACE("Request/Reply complete");
1195
1196done:
1197    nbd_request_put(req);
1198    return;
1199
1200out:
1201    nbd_request_put(req);
1202    nbd_client_close(client);
1203}
1204
1205static int nbd_can_read(void *opaque)
1206{
1207    NBDClient *client = opaque;
1208
1209    return client->recv_coroutine || client->nb_requests < MAX_NBD_REQUESTS;
1210}
1211
1212static void nbd_read(void *opaque)
1213{
1214    NBDClient *client = opaque;
1215
1216    if (client->recv_coroutine) {
1217        qemu_coroutine_enter(client->recv_coroutine, NULL);
1218    } else {
1219        qemu_coroutine_enter(qemu_coroutine_create(nbd_trip), client);
1220    }
1221}
1222
1223static void nbd_restart_write(void *opaque)
1224{
1225    NBDClient *client = opaque;
1226
1227    qemu_coroutine_enter(client->send_coroutine, NULL);
1228}
1229
1230NBDClient *nbd_client_new(NBDExport *exp, int csock,
1231                          void (*close)(NBDClient *))
1232{
1233    NBDClient *client;
1234    client = g_malloc0(sizeof(NBDClient));
1235    client->refcount = 1;
1236    client->exp = exp;
1237    client->sock = csock;
1238    if (nbd_send_negotiate(client) < 0) {
1239        g_free(client);
1240        return NULL;
1241    }
1242    client->close = close;
1243    qemu_co_mutex_init(&client->send_lock);
1244    qemu_set_fd_handler2(csock, nbd_can_read, nbd_read, NULL, client);
1245
1246    if (exp) {
1247        QTAILQ_INSERT_TAIL(&exp->clients, client, next);
1248        nbd_export_get(exp);
1249    }
1250    return client;
1251}
1252