linux/net/vmw_vsock/hyperv_transport.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Hyper-V transport for vsock
   4 *
   5 * Hyper-V Sockets supplies a byte-stream based communication mechanism
   6 * between the host and the VM. This driver implements the necessary
   7 * support in the VM by introducing the new vsock transport.
   8 *
   9 * Copyright (c) 2017, Microsoft Corporation.
  10 */
  11#include <linux/module.h>
  12#include <linux/vmalloc.h>
  13#include <linux/hyperv.h>
  14#include <net/sock.h>
  15#include <net/af_vsock.h>
  16
  17/* The host side's design of the feature requires 6 exact 4KB pages for
  18 * recv/send rings respectively -- this is suboptimal considering memory
  19 * consumption, however unluckily we have to live with it, before the
  20 * host comes up with a better design in the future.
  21 */
  22#define PAGE_SIZE_4K            4096
  23#define RINGBUFFER_HVS_RCV_SIZE (PAGE_SIZE_4K * 6)
  24#define RINGBUFFER_HVS_SND_SIZE (PAGE_SIZE_4K * 6)
  25
  26/* The MTU is 16KB per the host side's design */
  27#define HVS_MTU_SIZE            (1024 * 16)
  28
  29/* How long to wait for graceful shutdown of a connection */
  30#define HVS_CLOSE_TIMEOUT (8 * HZ)
  31
  32struct vmpipe_proto_header {
  33        u32 pkt_type;
  34        u32 data_size;
  35};
  36
  37/* For recv, we use the VMBus in-place packet iterator APIs to directly copy
  38 * data from the ringbuffer into the userspace buffer.
  39 */
  40struct hvs_recv_buf {
  41        /* The header before the payload data */
  42        struct vmpipe_proto_header hdr;
  43
  44        /* The payload */
  45        u8 data[HVS_MTU_SIZE];
  46};
  47
  48/* We can send up to HVS_MTU_SIZE bytes of payload to the host, but let's use
  49 * a small size, i.e. HVS_SEND_BUF_SIZE, to minimize the dynamically-allocated
  50 * buffer, because tests show there is no significant performance difference.
  51 *
  52 * Note: the buffer can be eliminated in the future when we add new VMBus
  53 * ringbuffer APIs that allow us to directly copy data from userspace buffer
  54 * to VMBus ringbuffer.
  55 */
  56#define HVS_SEND_BUF_SIZE (PAGE_SIZE_4K - sizeof(struct vmpipe_proto_header))
  57
  58struct hvs_send_buf {
  59        /* The header before the payload data */
  60        struct vmpipe_proto_header hdr;
  61
  62        /* The payload */
  63        u8 data[HVS_SEND_BUF_SIZE];
  64};
  65
  66#define HVS_HEADER_LEN  (sizeof(struct vmpacket_descriptor) + \
  67                         sizeof(struct vmpipe_proto_header))
  68
  69/* See 'prev_indices' in hv_ringbuffer_read(), hv_ringbuffer_write(), and
  70 * __hv_pkt_iter_next().
  71 */
  72#define VMBUS_PKT_TRAILER_SIZE  (sizeof(u64))
  73
  74#define HVS_PKT_LEN(payload_len)        (HVS_HEADER_LEN + \
  75                                         ALIGN((payload_len), 8) + \
  76                                         VMBUS_PKT_TRAILER_SIZE)
  77
  78union hvs_service_id {
  79        uuid_le srv_id;
  80
  81        struct {
  82                unsigned int svm_port;
  83                unsigned char b[sizeof(uuid_le) - sizeof(unsigned int)];
  84        };
  85};
  86
  87/* Per-socket state (accessed via vsk->trans) */
  88struct hvsock {
  89        struct vsock_sock *vsk;
  90
  91        uuid_le vm_srv_id;
  92        uuid_le host_srv_id;
  93
  94        struct vmbus_channel *chan;
  95        struct vmpacket_descriptor *recv_desc;
  96
  97        /* The length of the payload not delivered to userland yet */
  98        u32 recv_data_len;
  99        /* The offset of the payload */
 100        u32 recv_data_off;
 101
 102        /* Have we sent the zero-length packet (FIN)? */
 103        bool fin_sent;
 104};
 105
 106/* In the VM, we support Hyper-V Sockets with AF_VSOCK, and the endpoint is
 107 * <cid, port> (see struct sockaddr_vm). Note: cid is not really used here:
 108 * when we write apps to connect to the host, we can only use VMADDR_CID_ANY
 109 * or VMADDR_CID_HOST (both are equivalent) as the remote cid, and when we
 110 * write apps to bind() & listen() in the VM, we can only use VMADDR_CID_ANY
 111 * as the local cid.
 112 *
 113 * On the host, Hyper-V Sockets are supported by Winsock AF_HYPERV:
 114 * https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-
 115 * guide/make-integration-service, and the endpoint is <VmID, ServiceId> with
 116 * the below sockaddr:
 117 *
 118 * struct SOCKADDR_HV
 119 * {
 120 *    ADDRESS_FAMILY Family;
 121 *    USHORT Reserved;
 122 *    GUID VmId;
 123 *    GUID ServiceId;
 124 * };
 125 * Note: VmID is not used by Linux VM and actually it isn't transmitted via
 126 * VMBus, because here it's obvious the host and the VM can easily identify
 127 * each other. Though the VmID is useful on the host, especially in the case
 128 * of Windows container, Linux VM doesn't need it at all.
 129 *
 130 * To make use of the AF_VSOCK infrastructure in Linux VM, we have to limit
 131 * the available GUID space of SOCKADDR_HV so that we can create a mapping
 132 * between AF_VSOCK port and SOCKADDR_HV Service GUID. The rule of writing
 133 * Hyper-V Sockets apps on the host and in Linux VM is:
 134 *
 135 ****************************************************************************
 136 * The only valid Service GUIDs, from the perspectives of both the host and *
 137 * Linux VM, that can be connected by the other end, must conform to this   *
 138 * format: <port>-facb-11e6-bd58-64006a7986d3, and the "port" must be in    *
 139 * this range [0, 0x7FFFFFFF].                                              *
 140 ****************************************************************************
 141 *
 142 * When we write apps on the host to connect(), the GUID ServiceID is used.
 143 * When we write apps in Linux VM to connect(), we only need to specify the
 144 * port and the driver will form the GUID and use that to request the host.
 145 *
 146 * From the perspective of Linux VM:
 147 * 1. the local ephemeral port (i.e. the local auto-bound port when we call
 148 * connect() without explicit bind()) is generated by __vsock_bind_stream(),
 149 * and the range is [1024, 0xFFFFFFFF).
 150 * 2. the remote ephemeral port (i.e. the auto-generated remote port for
 151 * a connect request initiated by the host's connect()) is generated by
 152 * hvs_remote_addr_init() and the range is [0x80000000, 0xFFFFFFFF).
 153 */
 154
 155#define MAX_LISTEN_PORT                 ((u32)0x7FFFFFFF)
 156#define MAX_VM_LISTEN_PORT              MAX_LISTEN_PORT
 157#define MAX_HOST_LISTEN_PORT            MAX_LISTEN_PORT
 158#define MIN_HOST_EPHEMERAL_PORT         (MAX_HOST_LISTEN_PORT + 1)
 159
 160/* 00000000-facb-11e6-bd58-64006a7986d3 */
 161static const uuid_le srv_id_template =
 162        UUID_LE(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58,
 163                0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3);
 164
 165static bool is_valid_srv_id(const uuid_le *id)
 166{
 167        return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(uuid_le) - 4);
 168}
 169
 170static unsigned int get_port_by_srv_id(const uuid_le *svr_id)
 171{
 172        return *((unsigned int *)svr_id);
 173}
 174
 175static void hvs_addr_init(struct sockaddr_vm *addr, const uuid_le *svr_id)
 176{
 177        unsigned int port = get_port_by_srv_id(svr_id);
 178
 179        vsock_addr_init(addr, VMADDR_CID_ANY, port);
 180}
 181
 182static void hvs_remote_addr_init(struct sockaddr_vm *remote,
 183                                 struct sockaddr_vm *local)
 184{
 185        static u32 host_ephemeral_port = MIN_HOST_EPHEMERAL_PORT;
 186        struct sock *sk;
 187
 188        vsock_addr_init(remote, VMADDR_CID_ANY, VMADDR_PORT_ANY);
 189
 190        while (1) {
 191                /* Wrap around ? */
 192                if (host_ephemeral_port < MIN_HOST_EPHEMERAL_PORT ||
 193                    host_ephemeral_port == VMADDR_PORT_ANY)
 194                        host_ephemeral_port = MIN_HOST_EPHEMERAL_PORT;
 195
 196                remote->svm_port = host_ephemeral_port++;
 197
 198                sk = vsock_find_connected_socket(remote, local);
 199                if (!sk) {
 200                        /* Found an available ephemeral port */
 201                        return;
 202                }
 203
 204                /* Release refcnt got in vsock_find_connected_socket */
 205                sock_put(sk);
 206        }
 207}
 208
 209static void hvs_set_channel_pending_send_size(struct vmbus_channel *chan)
 210{
 211        set_channel_pending_send_size(chan,
 212                                      HVS_PKT_LEN(HVS_SEND_BUF_SIZE));
 213
 214        virt_mb();
 215}
 216
 217static bool hvs_channel_readable(struct vmbus_channel *chan)
 218{
 219        u32 readable = hv_get_bytes_to_read(&chan->inbound);
 220
 221        /* 0-size payload means FIN */
 222        return readable >= HVS_PKT_LEN(0);
 223}
 224
 225static int hvs_channel_readable_payload(struct vmbus_channel *chan)
 226{
 227        u32 readable = hv_get_bytes_to_read(&chan->inbound);
 228
 229        if (readable > HVS_PKT_LEN(0)) {
 230                /* At least we have 1 byte to read. We don't need to return
 231                 * the exact readable bytes: see vsock_stream_recvmsg() ->
 232                 * vsock_stream_has_data().
 233                 */
 234                return 1;
 235        }
 236
 237        if (readable == HVS_PKT_LEN(0)) {
 238                /* 0-size payload means FIN */
 239                return 0;
 240        }
 241
 242        /* No payload or FIN */
 243        return -1;
 244}
 245
 246static size_t hvs_channel_writable_bytes(struct vmbus_channel *chan)
 247{
 248        u32 writeable = hv_get_bytes_to_write(&chan->outbound);
 249        size_t ret;
 250
 251        /* The ringbuffer mustn't be 100% full, and we should reserve a
 252         * zero-length-payload packet for the FIN: see hv_ringbuffer_write()
 253         * and hvs_shutdown().
 254         */
 255        if (writeable <= HVS_PKT_LEN(1) + HVS_PKT_LEN(0))
 256                return 0;
 257
 258        ret = writeable - HVS_PKT_LEN(1) - HVS_PKT_LEN(0);
 259
 260        return round_down(ret, 8);
 261}
 262
 263static int hvs_send_data(struct vmbus_channel *chan,
 264                         struct hvs_send_buf *send_buf, size_t to_write)
 265{
 266        send_buf->hdr.pkt_type = 1;
 267        send_buf->hdr.data_size = to_write;
 268        return vmbus_sendpacket(chan, &send_buf->hdr,
 269                                sizeof(send_buf->hdr) + to_write,
 270                                0, VM_PKT_DATA_INBAND, 0);
 271}
 272
 273static void hvs_channel_cb(void *ctx)
 274{
 275        struct sock *sk = (struct sock *)ctx;
 276        struct vsock_sock *vsk = vsock_sk(sk);
 277        struct hvsock *hvs = vsk->trans;
 278        struct vmbus_channel *chan = hvs->chan;
 279
 280        if (hvs_channel_readable(chan))
 281                sk->sk_data_ready(sk);
 282
 283        if (hv_get_bytes_to_write(&chan->outbound) > 0)
 284                sk->sk_write_space(sk);
 285}
 286
 287static void hvs_do_close_lock_held(struct vsock_sock *vsk,
 288                                   bool cancel_timeout)
 289{
 290        struct sock *sk = sk_vsock(vsk);
 291
 292        sock_set_flag(sk, SOCK_DONE);
 293        vsk->peer_shutdown = SHUTDOWN_MASK;
 294        if (vsock_stream_has_data(vsk) <= 0)
 295                sk->sk_state = TCP_CLOSING;
 296        sk->sk_state_change(sk);
 297        if (vsk->close_work_scheduled &&
 298            (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
 299                vsk->close_work_scheduled = false;
 300                vsock_remove_sock(vsk);
 301
 302                /* Release the reference taken while scheduling the timeout */
 303                sock_put(sk);
 304        }
 305}
 306
 307static void hvs_close_connection(struct vmbus_channel *chan)
 308{
 309        struct sock *sk = get_per_channel_state(chan);
 310
 311        lock_sock(sk);
 312        hvs_do_close_lock_held(vsock_sk(sk), true);
 313        release_sock(sk);
 314}
 315
 316static void hvs_open_connection(struct vmbus_channel *chan)
 317{
 318        uuid_le *if_instance, *if_type;
 319        unsigned char conn_from_host;
 320
 321        struct sockaddr_vm addr;
 322        struct sock *sk, *new = NULL;
 323        struct vsock_sock *vnew = NULL;
 324        struct hvsock *hvs, *hvs_new = NULL;
 325        int ret;
 326
 327        if_type = &chan->offermsg.offer.if_type;
 328        if_instance = &chan->offermsg.offer.if_instance;
 329        conn_from_host = chan->offermsg.offer.u.pipe.user_def[0];
 330
 331        /* The host or the VM should only listen on a port in
 332         * [0, MAX_LISTEN_PORT]
 333         */
 334        if (!is_valid_srv_id(if_type) ||
 335            get_port_by_srv_id(if_type) > MAX_LISTEN_PORT)
 336                return;
 337
 338        hvs_addr_init(&addr, conn_from_host ? if_type : if_instance);
 339        sk = vsock_find_bound_socket(&addr);
 340        if (!sk)
 341                return;
 342
 343        lock_sock(sk);
 344        if ((conn_from_host && sk->sk_state != TCP_LISTEN) ||
 345            (!conn_from_host && sk->sk_state != TCP_SYN_SENT))
 346                goto out;
 347
 348        if (conn_from_host) {
 349                if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog)
 350                        goto out;
 351
 352                new = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
 353                                     sk->sk_type, 0);
 354                if (!new)
 355                        goto out;
 356
 357                new->sk_state = TCP_SYN_SENT;
 358                vnew = vsock_sk(new);
 359                hvs_new = vnew->trans;
 360                hvs_new->chan = chan;
 361        } else {
 362                hvs = vsock_sk(sk)->trans;
 363                hvs->chan = chan;
 364        }
 365
 366        set_channel_read_mode(chan, HV_CALL_DIRECT);
 367        ret = vmbus_open(chan, RINGBUFFER_HVS_SND_SIZE,
 368                         RINGBUFFER_HVS_RCV_SIZE, NULL, 0,
 369                         hvs_channel_cb, conn_from_host ? new : sk);
 370        if (ret != 0) {
 371                if (conn_from_host) {
 372                        hvs_new->chan = NULL;
 373                        sock_put(new);
 374                } else {
 375                        hvs->chan = NULL;
 376                }
 377                goto out;
 378        }
 379
 380        set_per_channel_state(chan, conn_from_host ? new : sk);
 381        vmbus_set_chn_rescind_callback(chan, hvs_close_connection);
 382
 383        /* Set the pending send size to max packet size to always get
 384         * notifications from the host when there is enough writable space.
 385         * The host is optimized to send notifications only when the pending
 386         * size boundary is crossed, and not always.
 387         */
 388        hvs_set_channel_pending_send_size(chan);
 389
 390        if (conn_from_host) {
 391                new->sk_state = TCP_ESTABLISHED;
 392                sk->sk_ack_backlog++;
 393
 394                hvs_addr_init(&vnew->local_addr, if_type);
 395                hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr);
 396
 397                hvs_new->vm_srv_id = *if_type;
 398                hvs_new->host_srv_id = *if_instance;
 399
 400                vsock_insert_connected(vnew);
 401
 402                vsock_enqueue_accept(sk, new);
 403        } else {
 404                sk->sk_state = TCP_ESTABLISHED;
 405                sk->sk_socket->state = SS_CONNECTED;
 406
 407                vsock_insert_connected(vsock_sk(sk));
 408        }
 409
 410        sk->sk_state_change(sk);
 411
 412out:
 413        /* Release refcnt obtained when we called vsock_find_bound_socket() */
 414        sock_put(sk);
 415
 416        release_sock(sk);
 417}
 418
 419static u32 hvs_get_local_cid(void)
 420{
 421        return VMADDR_CID_ANY;
 422}
 423
 424static int hvs_sock_init(struct vsock_sock *vsk, struct vsock_sock *psk)
 425{
 426        struct hvsock *hvs;
 427
 428        hvs = kzalloc(sizeof(*hvs), GFP_KERNEL);
 429        if (!hvs)
 430                return -ENOMEM;
 431
 432        vsk->trans = hvs;
 433        hvs->vsk = vsk;
 434
 435        return 0;
 436}
 437
 438static int hvs_connect(struct vsock_sock *vsk)
 439{
 440        union hvs_service_id vm, host;
 441        struct hvsock *h = vsk->trans;
 442
 443        vm.srv_id = srv_id_template;
 444        vm.svm_port = vsk->local_addr.svm_port;
 445        h->vm_srv_id = vm.srv_id;
 446
 447        host.srv_id = srv_id_template;
 448        host.svm_port = vsk->remote_addr.svm_port;
 449        h->host_srv_id = host.srv_id;
 450
 451        return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id);
 452}
 453
 454static void hvs_shutdown_lock_held(struct hvsock *hvs, int mode)
 455{
 456        struct vmpipe_proto_header hdr;
 457
 458        if (hvs->fin_sent || !hvs->chan)
 459                return;
 460
 461        /* It can't fail: see hvs_channel_writable_bytes(). */
 462        (void)hvs_send_data(hvs->chan, (struct hvs_send_buf *)&hdr, 0);
 463        hvs->fin_sent = true;
 464}
 465
 466static int hvs_shutdown(struct vsock_sock *vsk, int mode)
 467{
 468        struct sock *sk = sk_vsock(vsk);
 469
 470        if (!(mode & SEND_SHUTDOWN))
 471                return 0;
 472
 473        lock_sock(sk);
 474        hvs_shutdown_lock_held(vsk->trans, mode);
 475        release_sock(sk);
 476        return 0;
 477}
 478
 479static void hvs_close_timeout(struct work_struct *work)
 480{
 481        struct vsock_sock *vsk =
 482                container_of(work, struct vsock_sock, close_work.work);
 483        struct sock *sk = sk_vsock(vsk);
 484
 485        sock_hold(sk);
 486        lock_sock(sk);
 487        if (!sock_flag(sk, SOCK_DONE))
 488                hvs_do_close_lock_held(vsk, false);
 489
 490        vsk->close_work_scheduled = false;
 491        release_sock(sk);
 492        sock_put(sk);
 493}
 494
 495/* Returns true, if it is safe to remove socket; false otherwise */
 496static bool hvs_close_lock_held(struct vsock_sock *vsk)
 497{
 498        struct sock *sk = sk_vsock(vsk);
 499
 500        if (!(sk->sk_state == TCP_ESTABLISHED ||
 501              sk->sk_state == TCP_CLOSING))
 502                return true;
 503
 504        if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
 505                hvs_shutdown_lock_held(vsk->trans, SHUTDOWN_MASK);
 506
 507        if (sock_flag(sk, SOCK_DONE))
 508                return true;
 509
 510        /* This reference will be dropped by the delayed close routine */
 511        sock_hold(sk);
 512        INIT_DELAYED_WORK(&vsk->close_work, hvs_close_timeout);
 513        vsk->close_work_scheduled = true;
 514        schedule_delayed_work(&vsk->close_work, HVS_CLOSE_TIMEOUT);
 515        return false;
 516}
 517
 518static void hvs_release(struct vsock_sock *vsk)
 519{
 520        struct sock *sk = sk_vsock(vsk);
 521        bool remove_sock;
 522
 523        lock_sock(sk);
 524        remove_sock = hvs_close_lock_held(vsk);
 525        release_sock(sk);
 526        if (remove_sock)
 527                vsock_remove_sock(vsk);
 528}
 529
 530static void hvs_destruct(struct vsock_sock *vsk)
 531{
 532        struct hvsock *hvs = vsk->trans;
 533        struct vmbus_channel *chan = hvs->chan;
 534
 535        if (chan)
 536                vmbus_hvsock_device_unregister(chan);
 537
 538        kfree(hvs);
 539}
 540
 541static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
 542{
 543        return -EOPNOTSUPP;
 544}
 545
 546static int hvs_dgram_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
 547                             size_t len, int flags)
 548{
 549        return -EOPNOTSUPP;
 550}
 551
 552static int hvs_dgram_enqueue(struct vsock_sock *vsk,
 553                             struct sockaddr_vm *remote, struct msghdr *msg,
 554                             size_t dgram_len)
 555{
 556        return -EOPNOTSUPP;
 557}
 558
 559static bool hvs_dgram_allow(u32 cid, u32 port)
 560{
 561        return false;
 562}
 563
 564static int hvs_update_recv_data(struct hvsock *hvs)
 565{
 566        struct hvs_recv_buf *recv_buf;
 567        u32 payload_len;
 568
 569        recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
 570        payload_len = recv_buf->hdr.data_size;
 571
 572        if (payload_len > HVS_MTU_SIZE)
 573                return -EIO;
 574
 575        if (payload_len == 0)
 576                hvs->vsk->peer_shutdown |= SEND_SHUTDOWN;
 577
 578        hvs->recv_data_len = payload_len;
 579        hvs->recv_data_off = 0;
 580
 581        return 0;
 582}
 583
 584static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
 585                                  size_t len, int flags)
 586{
 587        struct hvsock *hvs = vsk->trans;
 588        bool need_refill = !hvs->recv_desc;
 589        struct hvs_recv_buf *recv_buf;
 590        u32 to_read;
 591        int ret;
 592
 593        if (flags & MSG_PEEK)
 594                return -EOPNOTSUPP;
 595
 596        if (need_refill) {
 597                hvs->recv_desc = hv_pkt_iter_first(hvs->chan);
 598                ret = hvs_update_recv_data(hvs);
 599                if (ret)
 600                        return ret;
 601        }
 602
 603        recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
 604        to_read = min_t(u32, len, hvs->recv_data_len);
 605        ret = memcpy_to_msg(msg, recv_buf->data + hvs->recv_data_off, to_read);
 606        if (ret != 0)
 607                return ret;
 608
 609        hvs->recv_data_len -= to_read;
 610        if (hvs->recv_data_len == 0) {
 611                hvs->recv_desc = hv_pkt_iter_next(hvs->chan, hvs->recv_desc);
 612                if (hvs->recv_desc) {
 613                        ret = hvs_update_recv_data(hvs);
 614                        if (ret)
 615                                return ret;
 616                }
 617        } else {
 618                hvs->recv_data_off += to_read;
 619        }
 620
 621        return to_read;
 622}
 623
 624static ssize_t hvs_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg,
 625                                  size_t len)
 626{
 627        struct hvsock *hvs = vsk->trans;
 628        struct vmbus_channel *chan = hvs->chan;
 629        struct hvs_send_buf *send_buf;
 630        ssize_t to_write, max_writable, ret;
 631
 632        BUILD_BUG_ON(sizeof(*send_buf) != PAGE_SIZE_4K);
 633
 634        send_buf = kmalloc(sizeof(*send_buf), GFP_KERNEL);
 635        if (!send_buf)
 636                return -ENOMEM;
 637
 638        max_writable = hvs_channel_writable_bytes(chan);
 639        to_write = min_t(ssize_t, len, max_writable);
 640        to_write = min_t(ssize_t, to_write, HVS_SEND_BUF_SIZE);
 641
 642        ret = memcpy_from_msg(send_buf->data, msg, to_write);
 643        if (ret < 0)
 644                goto out;
 645
 646        ret = hvs_send_data(hvs->chan, send_buf, to_write);
 647        if (ret < 0)
 648                goto out;
 649
 650        ret = to_write;
 651out:
 652        kfree(send_buf);
 653        return ret;
 654}
 655
 656static s64 hvs_stream_has_data(struct vsock_sock *vsk)
 657{
 658        struct hvsock *hvs = vsk->trans;
 659        s64 ret;
 660
 661        if (hvs->recv_data_len > 0)
 662                return 1;
 663
 664        switch (hvs_channel_readable_payload(hvs->chan)) {
 665        case 1:
 666                ret = 1;
 667                break;
 668        case 0:
 669                vsk->peer_shutdown |= SEND_SHUTDOWN;
 670                ret = 0;
 671                break;
 672        default: /* -1 */
 673                ret = 0;
 674                break;
 675        }
 676
 677        return ret;
 678}
 679
 680static s64 hvs_stream_has_space(struct vsock_sock *vsk)
 681{
 682        struct hvsock *hvs = vsk->trans;
 683
 684        return hvs_channel_writable_bytes(hvs->chan);
 685}
 686
 687static u64 hvs_stream_rcvhiwat(struct vsock_sock *vsk)
 688{
 689        return HVS_MTU_SIZE + 1;
 690}
 691
 692static bool hvs_stream_is_active(struct vsock_sock *vsk)
 693{
 694        struct hvsock *hvs = vsk->trans;
 695
 696        return hvs->chan != NULL;
 697}
 698
 699static bool hvs_stream_allow(u32 cid, u32 port)
 700{
 701        /* The host's port range [MIN_HOST_EPHEMERAL_PORT, 0xFFFFFFFF) is
 702         * reserved as ephemeral ports, which are used as the host's ports
 703         * when the host initiates connections.
 704         *
 705         * Perform this check in the guest so an immediate error is produced
 706         * instead of a timeout.
 707         */
 708        if (port > MAX_HOST_LISTEN_PORT)
 709                return false;
 710
 711        if (cid == VMADDR_CID_HOST)
 712                return true;
 713
 714        return false;
 715}
 716
 717static
 718int hvs_notify_poll_in(struct vsock_sock *vsk, size_t target, bool *readable)
 719{
 720        struct hvsock *hvs = vsk->trans;
 721
 722        *readable = hvs_channel_readable(hvs->chan);
 723        return 0;
 724}
 725
 726static
 727int hvs_notify_poll_out(struct vsock_sock *vsk, size_t target, bool *writable)
 728{
 729        *writable = hvs_stream_has_space(vsk) > 0;
 730
 731        return 0;
 732}
 733
 734static
 735int hvs_notify_recv_init(struct vsock_sock *vsk, size_t target,
 736                         struct vsock_transport_recv_notify_data *d)
 737{
 738        return 0;
 739}
 740
 741static
 742int hvs_notify_recv_pre_block(struct vsock_sock *vsk, size_t target,
 743                              struct vsock_transport_recv_notify_data *d)
 744{
 745        return 0;
 746}
 747
 748static
 749int hvs_notify_recv_pre_dequeue(struct vsock_sock *vsk, size_t target,
 750                                struct vsock_transport_recv_notify_data *d)
 751{
 752        return 0;
 753}
 754
 755static
 756int hvs_notify_recv_post_dequeue(struct vsock_sock *vsk, size_t target,
 757                                 ssize_t copied, bool data_read,
 758                                 struct vsock_transport_recv_notify_data *d)
 759{
 760        return 0;
 761}
 762
 763static
 764int hvs_notify_send_init(struct vsock_sock *vsk,
 765                         struct vsock_transport_send_notify_data *d)
 766{
 767        return 0;
 768}
 769
 770static
 771int hvs_notify_send_pre_block(struct vsock_sock *vsk,
 772                              struct vsock_transport_send_notify_data *d)
 773{
 774        return 0;
 775}
 776
 777static
 778int hvs_notify_send_pre_enqueue(struct vsock_sock *vsk,
 779                                struct vsock_transport_send_notify_data *d)
 780{
 781        return 0;
 782}
 783
 784static
 785int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written,
 786                                 struct vsock_transport_send_notify_data *d)
 787{
 788        return 0;
 789}
 790
 791static void hvs_set_buffer_size(struct vsock_sock *vsk, u64 val)
 792{
 793        /* Ignored. */
 794}
 795
 796static void hvs_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
 797{
 798        /* Ignored. */
 799}
 800
 801static void hvs_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
 802{
 803        /* Ignored. */
 804}
 805
 806static u64 hvs_get_buffer_size(struct vsock_sock *vsk)
 807{
 808        return -ENOPROTOOPT;
 809}
 810
 811static u64 hvs_get_min_buffer_size(struct vsock_sock *vsk)
 812{
 813        return -ENOPROTOOPT;
 814}
 815
 816static u64 hvs_get_max_buffer_size(struct vsock_sock *vsk)
 817{
 818        return -ENOPROTOOPT;
 819}
 820
 821static struct vsock_transport hvs_transport = {
 822        .get_local_cid            = hvs_get_local_cid,
 823
 824        .init                     = hvs_sock_init,
 825        .destruct                 = hvs_destruct,
 826        .release                  = hvs_release,
 827        .connect                  = hvs_connect,
 828        .shutdown                 = hvs_shutdown,
 829
 830        .dgram_bind               = hvs_dgram_bind,
 831        .dgram_dequeue            = hvs_dgram_dequeue,
 832        .dgram_enqueue            = hvs_dgram_enqueue,
 833        .dgram_allow              = hvs_dgram_allow,
 834
 835        .stream_dequeue           = hvs_stream_dequeue,
 836        .stream_enqueue           = hvs_stream_enqueue,
 837        .stream_has_data          = hvs_stream_has_data,
 838        .stream_has_space         = hvs_stream_has_space,
 839        .stream_rcvhiwat          = hvs_stream_rcvhiwat,
 840        .stream_is_active         = hvs_stream_is_active,
 841        .stream_allow             = hvs_stream_allow,
 842
 843        .notify_poll_in           = hvs_notify_poll_in,
 844        .notify_poll_out          = hvs_notify_poll_out,
 845        .notify_recv_init         = hvs_notify_recv_init,
 846        .notify_recv_pre_block    = hvs_notify_recv_pre_block,
 847        .notify_recv_pre_dequeue  = hvs_notify_recv_pre_dequeue,
 848        .notify_recv_post_dequeue = hvs_notify_recv_post_dequeue,
 849        .notify_send_init         = hvs_notify_send_init,
 850        .notify_send_pre_block    = hvs_notify_send_pre_block,
 851        .notify_send_pre_enqueue  = hvs_notify_send_pre_enqueue,
 852        .notify_send_post_enqueue = hvs_notify_send_post_enqueue,
 853
 854        .set_buffer_size          = hvs_set_buffer_size,
 855        .set_min_buffer_size      = hvs_set_min_buffer_size,
 856        .set_max_buffer_size      = hvs_set_max_buffer_size,
 857        .get_buffer_size          = hvs_get_buffer_size,
 858        .get_min_buffer_size      = hvs_get_min_buffer_size,
 859        .get_max_buffer_size      = hvs_get_max_buffer_size,
 860};
 861
 862static int hvs_probe(struct hv_device *hdev,
 863                     const struct hv_vmbus_device_id *dev_id)
 864{
 865        struct vmbus_channel *chan = hdev->channel;
 866
 867        hvs_open_connection(chan);
 868
 869        /* Always return success to suppress the unnecessary error message
 870         * in vmbus_probe(): on error the host will rescind the device in
 871         * 30 seconds and we can do cleanup at that time in
 872         * vmbus_onoffer_rescind().
 873         */
 874        return 0;
 875}
 876
 877static int hvs_remove(struct hv_device *hdev)
 878{
 879        struct vmbus_channel *chan = hdev->channel;
 880
 881        vmbus_close(chan);
 882
 883        return 0;
 884}
 885
 886/* This isn't really used. See vmbus_match() and vmbus_probe() */
 887static const struct hv_vmbus_device_id id_table[] = {
 888        {},
 889};
 890
 891static struct hv_driver hvs_drv = {
 892        .name           = "hv_sock",
 893        .hvsock         = true,
 894        .id_table       = id_table,
 895        .probe          = hvs_probe,
 896        .remove         = hvs_remove,
 897};
 898
 899static int __init hvs_init(void)
 900{
 901        int ret;
 902
 903        if (vmbus_proto_version < VERSION_WIN10)
 904                return -ENODEV;
 905
 906        ret = vmbus_driver_register(&hvs_drv);
 907        if (ret != 0)
 908                return ret;
 909
 910        ret = vsock_core_init(&hvs_transport);
 911        if (ret) {
 912                vmbus_driver_unregister(&hvs_drv);
 913                return ret;
 914        }
 915
 916        return 0;
 917}
 918
 919static void __exit hvs_exit(void)
 920{
 921        vsock_core_exit();
 922        vmbus_driver_unregister(&hvs_drv);
 923}
 924
 925module_init(hvs_init);
 926module_exit(hvs_exit);
 927
 928MODULE_DESCRIPTION("Hyper-V Sockets");
 929MODULE_VERSION("1.0.0");
 930MODULE_LICENSE("GPL");
 931MODULE_ALIAS_NETPROTO(PF_VSOCK);
 932