linux/net/vmw_vsock/hyperv_transport.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Hyper-V transport for vsock
   4 *
   5 * Hyper-V Sockets supplies a byte-stream based communication mechanism
   6 * between the host and the VM. This driver implements the necessary
   7 * support in the VM by introducing the new vsock transport.
   8 *
   9 * Copyright (c) 2017, Microsoft Corporation.
  10 */
  11#include <linux/module.h>
  12#include <linux/vmalloc.h>
  13#include <linux/hyperv.h>
  14#include <net/sock.h>
  15#include <net/af_vsock.h>
  16
  17/* Older (VMBUS version 'VERSION_WIN10' or before) Windows hosts have some
  18 * stricter requirements on the hv_sock ring buffer size of six 4K pages. Newer
  19 * hosts don't have this limitation; but, keep the defaults the same for compat.
  20 */
  21#define PAGE_SIZE_4K            4096
  22#define RINGBUFFER_HVS_RCV_SIZE (PAGE_SIZE_4K * 6)
  23#define RINGBUFFER_HVS_SND_SIZE (PAGE_SIZE_4K * 6)
  24#define RINGBUFFER_HVS_MAX_SIZE (PAGE_SIZE_4K * 64)
  25
  26/* The MTU is 16KB per the host side's design */
  27#define HVS_MTU_SIZE            (1024 * 16)
  28
  29/* How long to wait for graceful shutdown of a connection */
  30#define HVS_CLOSE_TIMEOUT (8 * HZ)
  31
  32struct vmpipe_proto_header {
  33        u32 pkt_type;
  34        u32 data_size;
  35};
  36
  37/* For recv, we use the VMBus in-place packet iterator APIs to directly copy
  38 * data from the ringbuffer into the userspace buffer.
  39 */
  40struct hvs_recv_buf {
  41        /* The header before the payload data */
  42        struct vmpipe_proto_header hdr;
  43
  44        /* The payload */
  45        u8 data[HVS_MTU_SIZE];
  46};
  47
  48/* We can send up to HVS_MTU_SIZE bytes of payload to the host, but let's use
  49 * a smaller size, i.e. HVS_SEND_BUF_SIZE, to maximize concurrency between the
  50 * guest and the host processing as one VMBUS packet is the smallest processing
  51 * unit.
  52 *
  53 * Note: the buffer can be eliminated in the future when we add new VMBus
  54 * ringbuffer APIs that allow us to directly copy data from userspace buffer
  55 * to VMBus ringbuffer.
  56 */
  57#define HVS_SEND_BUF_SIZE (PAGE_SIZE_4K - sizeof(struct vmpipe_proto_header))
  58
  59struct hvs_send_buf {
  60        /* The header before the payload data */
  61        struct vmpipe_proto_header hdr;
  62
  63        /* The payload */
  64        u8 data[HVS_SEND_BUF_SIZE];
  65};
  66
  67#define HVS_HEADER_LEN  (sizeof(struct vmpacket_descriptor) + \
  68                         sizeof(struct vmpipe_proto_header))
  69
  70/* See 'prev_indices' in hv_ringbuffer_read(), hv_ringbuffer_write(), and
  71 * __hv_pkt_iter_next().
  72 */
  73#define VMBUS_PKT_TRAILER_SIZE  (sizeof(u64))
  74
  75#define HVS_PKT_LEN(payload_len)        (HVS_HEADER_LEN + \
  76                                         ALIGN((payload_len), 8) + \
  77                                         VMBUS_PKT_TRAILER_SIZE)
  78
  79union hvs_service_id {
  80        guid_t  srv_id;
  81
  82        struct {
  83                unsigned int svm_port;
  84                unsigned char b[sizeof(guid_t) - sizeof(unsigned int)];
  85        };
  86};
  87
  88/* Per-socket state (accessed via vsk->trans) */
  89struct hvsock {
  90        struct vsock_sock *vsk;
  91
  92        guid_t vm_srv_id;
  93        guid_t host_srv_id;
  94
  95        struct vmbus_channel *chan;
  96        struct vmpacket_descriptor *recv_desc;
  97
  98        /* The length of the payload not delivered to userland yet */
  99        u32 recv_data_len;
 100        /* The offset of the payload */
 101        u32 recv_data_off;
 102
 103        /* Have we sent the zero-length packet (FIN)? */
 104        bool fin_sent;
 105};
 106
 107/* In the VM, we support Hyper-V Sockets with AF_VSOCK, and the endpoint is
 108 * <cid, port> (see struct sockaddr_vm). Note: cid is not really used here:
 109 * when we write apps to connect to the host, we can only use VMADDR_CID_ANY
 110 * or VMADDR_CID_HOST (both are equivalent) as the remote cid, and when we
 111 * write apps to bind() & listen() in the VM, we can only use VMADDR_CID_ANY
 112 * as the local cid.
 113 *
 114 * On the host, Hyper-V Sockets are supported by Winsock AF_HYPERV:
 115 * https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-
 116 * guide/make-integration-service, and the endpoint is <VmID, ServiceId> with
 117 * the below sockaddr:
 118 *
 119 * struct SOCKADDR_HV
 120 * {
 121 *    ADDRESS_FAMILY Family;
 122 *    USHORT Reserved;
 123 *    GUID VmId;
 124 *    GUID ServiceId;
 125 * };
 126 * Note: VmID is not used by Linux VM and actually it isn't transmitted via
 127 * VMBus, because here it's obvious the host and the VM can easily identify
 128 * each other. Though the VmID is useful on the host, especially in the case
 129 * of Windows container, Linux VM doesn't need it at all.
 130 *
 131 * To make use of the AF_VSOCK infrastructure in Linux VM, we have to limit
 132 * the available GUID space of SOCKADDR_HV so that we can create a mapping
 133 * between AF_VSOCK port and SOCKADDR_HV Service GUID. The rule of writing
 134 * Hyper-V Sockets apps on the host and in Linux VM is:
 135 *
 136 ****************************************************************************
 137 * The only valid Service GUIDs, from the perspectives of both the host and *
 138 * Linux VM, that can be connected by the other end, must conform to this   *
 139 * format: <port>-facb-11e6-bd58-64006a7986d3, and the "port" must be in    *
 140 * this range [0, 0x7FFFFFFF].                                              *
 141 ****************************************************************************
 142 *
 143 * When we write apps on the host to connect(), the GUID ServiceID is used.
 144 * When we write apps in Linux VM to connect(), we only need to specify the
 145 * port and the driver will form the GUID and use that to request the host.
 146 *
 147 * From the perspective of Linux VM:
 148 * 1. the local ephemeral port (i.e. the local auto-bound port when we call
 149 * connect() without explicit bind()) is generated by __vsock_bind_stream(),
 150 * and the range is [1024, 0xFFFFFFFF).
 151 * 2. the remote ephemeral port (i.e. the auto-generated remote port for
 152 * a connect request initiated by the host's connect()) is generated by
 153 * hvs_remote_addr_init() and the range is [0x80000000, 0xFFFFFFFF).
 154 */
 155
 156#define MAX_LISTEN_PORT                 ((u32)0x7FFFFFFF)
 157#define MAX_VM_LISTEN_PORT              MAX_LISTEN_PORT
 158#define MAX_HOST_LISTEN_PORT            MAX_LISTEN_PORT
 159#define MIN_HOST_EPHEMERAL_PORT         (MAX_HOST_LISTEN_PORT + 1)
 160
 161/* 00000000-facb-11e6-bd58-64006a7986d3 */
 162static const guid_t srv_id_template =
 163        GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58,
 164                  0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3);
 165
 166static bool is_valid_srv_id(const guid_t *id)
 167{
 168        return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4);
 169}
 170
 171static unsigned int get_port_by_srv_id(const guid_t *svr_id)
 172{
 173        return *((unsigned int *)svr_id);
 174}
 175
 176static void hvs_addr_init(struct sockaddr_vm *addr, const guid_t *svr_id)
 177{
 178        unsigned int port = get_port_by_srv_id(svr_id);
 179
 180        vsock_addr_init(addr, VMADDR_CID_ANY, port);
 181}
 182
 183static void hvs_remote_addr_init(struct sockaddr_vm *remote,
 184                                 struct sockaddr_vm *local)
 185{
 186        static u32 host_ephemeral_port = MIN_HOST_EPHEMERAL_PORT;
 187        struct sock *sk;
 188
 189        vsock_addr_init(remote, VMADDR_CID_ANY, VMADDR_PORT_ANY);
 190
 191        while (1) {
 192                /* Wrap around ? */
 193                if (host_ephemeral_port < MIN_HOST_EPHEMERAL_PORT ||
 194                    host_ephemeral_port == VMADDR_PORT_ANY)
 195                        host_ephemeral_port = MIN_HOST_EPHEMERAL_PORT;
 196
 197                remote->svm_port = host_ephemeral_port++;
 198
 199                sk = vsock_find_connected_socket(remote, local);
 200                if (!sk) {
 201                        /* Found an available ephemeral port */
 202                        return;
 203                }
 204
 205                /* Release refcnt got in vsock_find_connected_socket */
 206                sock_put(sk);
 207        }
 208}
 209
 210static void hvs_set_channel_pending_send_size(struct vmbus_channel *chan)
 211{
 212        set_channel_pending_send_size(chan,
 213                                      HVS_PKT_LEN(HVS_SEND_BUF_SIZE));
 214
 215        virt_mb();
 216}
 217
 218static bool hvs_channel_readable(struct vmbus_channel *chan)
 219{
 220        u32 readable = hv_get_bytes_to_read(&chan->inbound);
 221
 222        /* 0-size payload means FIN */
 223        return readable >= HVS_PKT_LEN(0);
 224}
 225
 226static int hvs_channel_readable_payload(struct vmbus_channel *chan)
 227{
 228        u32 readable = hv_get_bytes_to_read(&chan->inbound);
 229
 230        if (readable > HVS_PKT_LEN(0)) {
 231                /* At least we have 1 byte to read. We don't need to return
 232                 * the exact readable bytes: see vsock_stream_recvmsg() ->
 233                 * vsock_stream_has_data().
 234                 */
 235                return 1;
 236        }
 237
 238        if (readable == HVS_PKT_LEN(0)) {
 239                /* 0-size payload means FIN */
 240                return 0;
 241        }
 242
 243        /* No payload or FIN */
 244        return -1;
 245}
 246
 247static size_t hvs_channel_writable_bytes(struct vmbus_channel *chan)
 248{
 249        u32 writeable = hv_get_bytes_to_write(&chan->outbound);
 250        size_t ret;
 251
 252        /* The ringbuffer mustn't be 100% full, and we should reserve a
 253         * zero-length-payload packet for the FIN: see hv_ringbuffer_write()
 254         * and hvs_shutdown().
 255         */
 256        if (writeable <= HVS_PKT_LEN(1) + HVS_PKT_LEN(0))
 257                return 0;
 258
 259        ret = writeable - HVS_PKT_LEN(1) - HVS_PKT_LEN(0);
 260
 261        return round_down(ret, 8);
 262}
 263
 264static int hvs_send_data(struct vmbus_channel *chan,
 265                         struct hvs_send_buf *send_buf, size_t to_write)
 266{
 267        send_buf->hdr.pkt_type = 1;
 268        send_buf->hdr.data_size = to_write;
 269        return vmbus_sendpacket(chan, &send_buf->hdr,
 270                                sizeof(send_buf->hdr) + to_write,
 271                                0, VM_PKT_DATA_INBAND, 0);
 272}
 273
 274static void hvs_channel_cb(void *ctx)
 275{
 276        struct sock *sk = (struct sock *)ctx;
 277        struct vsock_sock *vsk = vsock_sk(sk);
 278        struct hvsock *hvs = vsk->trans;
 279        struct vmbus_channel *chan = hvs->chan;
 280
 281        if (hvs_channel_readable(chan))
 282                sk->sk_data_ready(sk);
 283
 284        if (hv_get_bytes_to_write(&chan->outbound) > 0)
 285                sk->sk_write_space(sk);
 286}
 287
 288static void hvs_do_close_lock_held(struct vsock_sock *vsk,
 289                                   bool cancel_timeout)
 290{
 291        struct sock *sk = sk_vsock(vsk);
 292
 293        sock_set_flag(sk, SOCK_DONE);
 294        vsk->peer_shutdown = SHUTDOWN_MASK;
 295        if (vsock_stream_has_data(vsk) <= 0)
 296                sk->sk_state = TCP_CLOSING;
 297        sk->sk_state_change(sk);
 298        if (vsk->close_work_scheduled &&
 299            (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
 300                vsk->close_work_scheduled = false;
 301                vsock_remove_sock(vsk);
 302
 303                /* Release the reference taken while scheduling the timeout */
 304                sock_put(sk);
 305        }
 306}
 307
 308static void hvs_close_connection(struct vmbus_channel *chan)
 309{
 310        struct sock *sk = get_per_channel_state(chan);
 311
 312        lock_sock(sk);
 313        hvs_do_close_lock_held(vsock_sk(sk), true);
 314        release_sock(sk);
 315
 316        /* Release the refcnt for the channel that's opened in
 317         * hvs_open_connection().
 318         */
 319        sock_put(sk);
 320}
 321
 322static void hvs_open_connection(struct vmbus_channel *chan)
 323{
 324        guid_t *if_instance, *if_type;
 325        unsigned char conn_from_host;
 326
 327        struct sockaddr_vm addr;
 328        struct sock *sk, *new = NULL;
 329        struct vsock_sock *vnew = NULL;
 330        struct hvsock *hvs = NULL;
 331        struct hvsock *hvs_new = NULL;
 332        int rcvbuf;
 333        int ret;
 334        int sndbuf;
 335
 336        if_type = &chan->offermsg.offer.if_type;
 337        if_instance = &chan->offermsg.offer.if_instance;
 338        conn_from_host = chan->offermsg.offer.u.pipe.user_def[0];
 339
 340        /* The host or the VM should only listen on a port in
 341         * [0, MAX_LISTEN_PORT]
 342         */
 343        if (!is_valid_srv_id(if_type) ||
 344            get_port_by_srv_id(if_type) > MAX_LISTEN_PORT)
 345                return;
 346
 347        hvs_addr_init(&addr, conn_from_host ? if_type : if_instance);
 348        sk = vsock_find_bound_socket(&addr);
 349        if (!sk)
 350                return;
 351
 352        lock_sock(sk);
 353        if ((conn_from_host && sk->sk_state != TCP_LISTEN) ||
 354            (!conn_from_host && sk->sk_state != TCP_SYN_SENT))
 355                goto out;
 356
 357        if (conn_from_host) {
 358                if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog)
 359                        goto out;
 360
 361                new = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
 362                                     sk->sk_type, 0);
 363                if (!new)
 364                        goto out;
 365
 366                new->sk_state = TCP_SYN_SENT;
 367                vnew = vsock_sk(new);
 368                hvs_new = vnew->trans;
 369                hvs_new->chan = chan;
 370        } else {
 371                hvs = vsock_sk(sk)->trans;
 372                hvs->chan = chan;
 373        }
 374
 375        set_channel_read_mode(chan, HV_CALL_DIRECT);
 376
 377        /* Use the socket buffer sizes as hints for the VMBUS ring size. For
 378         * server side sockets, 'sk' is the parent socket and thus, this will
 379         * allow the child sockets to inherit the size from the parent. Keep
 380         * the mins to the default value and align to page size as per VMBUS
 381         * requirements.
 382         * For the max, the socket core library will limit the socket buffer
 383         * size that can be set by the user, but, since currently, the hv_sock
 384         * VMBUS ring buffer is physically contiguous allocation, restrict it
 385         * further.
 386         * Older versions of hv_sock host side code cannot handle bigger VMBUS
 387         * ring buffer size. Use the version number to limit the change to newer
 388         * versions.
 389         */
 390        if (vmbus_proto_version < VERSION_WIN10_V5) {
 391                sndbuf = RINGBUFFER_HVS_SND_SIZE;
 392                rcvbuf = RINGBUFFER_HVS_RCV_SIZE;
 393        } else {
 394                sndbuf = max_t(int, sk->sk_sndbuf, RINGBUFFER_HVS_SND_SIZE);
 395                sndbuf = min_t(int, sndbuf, RINGBUFFER_HVS_MAX_SIZE);
 396                sndbuf = ALIGN(sndbuf, PAGE_SIZE);
 397                rcvbuf = max_t(int, sk->sk_rcvbuf, RINGBUFFER_HVS_RCV_SIZE);
 398                rcvbuf = min_t(int, rcvbuf, RINGBUFFER_HVS_MAX_SIZE);
 399                rcvbuf = ALIGN(rcvbuf, PAGE_SIZE);
 400        }
 401
 402        ret = vmbus_open(chan, sndbuf, rcvbuf, NULL, 0, hvs_channel_cb,
 403                         conn_from_host ? new : sk);
 404        if (ret != 0) {
 405                if (conn_from_host) {
 406                        hvs_new->chan = NULL;
 407                        sock_put(new);
 408                } else {
 409                        hvs->chan = NULL;
 410                }
 411                goto out;
 412        }
 413
 414        set_per_channel_state(chan, conn_from_host ? new : sk);
 415
 416        /* This reference will be dropped by hvs_close_connection(). */
 417        sock_hold(conn_from_host ? new : sk);
 418        vmbus_set_chn_rescind_callback(chan, hvs_close_connection);
 419
 420        /* Set the pending send size to max packet size to always get
 421         * notifications from the host when there is enough writable space.
 422         * The host is optimized to send notifications only when the pending
 423         * size boundary is crossed, and not always.
 424         */
 425        hvs_set_channel_pending_send_size(chan);
 426
 427        if (conn_from_host) {
 428                new->sk_state = TCP_ESTABLISHED;
 429                sk->sk_ack_backlog++;
 430
 431                hvs_addr_init(&vnew->local_addr, if_type);
 432                hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr);
 433
 434                hvs_new->vm_srv_id = *if_type;
 435                hvs_new->host_srv_id = *if_instance;
 436
 437                vsock_insert_connected(vnew);
 438
 439                vsock_enqueue_accept(sk, new);
 440        } else {
 441                sk->sk_state = TCP_ESTABLISHED;
 442                sk->sk_socket->state = SS_CONNECTED;
 443
 444                vsock_insert_connected(vsock_sk(sk));
 445        }
 446
 447        sk->sk_state_change(sk);
 448
 449out:
 450        /* Release refcnt obtained when we called vsock_find_bound_socket() */
 451        sock_put(sk);
 452
 453        release_sock(sk);
 454}
 455
 456static u32 hvs_get_local_cid(void)
 457{
 458        return VMADDR_CID_ANY;
 459}
 460
 461static int hvs_sock_init(struct vsock_sock *vsk, struct vsock_sock *psk)
 462{
 463        struct hvsock *hvs;
 464        struct sock *sk = sk_vsock(vsk);
 465
 466        hvs = kzalloc(sizeof(*hvs), GFP_KERNEL);
 467        if (!hvs)
 468                return -ENOMEM;
 469
 470        vsk->trans = hvs;
 471        hvs->vsk = vsk;
 472        sk->sk_sndbuf = RINGBUFFER_HVS_SND_SIZE;
 473        sk->sk_rcvbuf = RINGBUFFER_HVS_RCV_SIZE;
 474        return 0;
 475}
 476
 477static int hvs_connect(struct vsock_sock *vsk)
 478{
 479        union hvs_service_id vm, host;
 480        struct hvsock *h = vsk->trans;
 481
 482        vm.srv_id = srv_id_template;
 483        vm.svm_port = vsk->local_addr.svm_port;
 484        h->vm_srv_id = vm.srv_id;
 485
 486        host.srv_id = srv_id_template;
 487        host.svm_port = vsk->remote_addr.svm_port;
 488        h->host_srv_id = host.srv_id;
 489
 490        return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id);
 491}
 492
 493static void hvs_shutdown_lock_held(struct hvsock *hvs, int mode)
 494{
 495        struct vmpipe_proto_header hdr;
 496
 497        if (hvs->fin_sent || !hvs->chan)
 498                return;
 499
 500        /* It can't fail: see hvs_channel_writable_bytes(). */
 501        (void)hvs_send_data(hvs->chan, (struct hvs_send_buf *)&hdr, 0);
 502        hvs->fin_sent = true;
 503}
 504
 505static int hvs_shutdown(struct vsock_sock *vsk, int mode)
 506{
 507        struct sock *sk = sk_vsock(vsk);
 508
 509        if (!(mode & SEND_SHUTDOWN))
 510                return 0;
 511
 512        lock_sock(sk);
 513        hvs_shutdown_lock_held(vsk->trans, mode);
 514        release_sock(sk);
 515        return 0;
 516}
 517
 518static void hvs_close_timeout(struct work_struct *work)
 519{
 520        struct vsock_sock *vsk =
 521                container_of(work, struct vsock_sock, close_work.work);
 522        struct sock *sk = sk_vsock(vsk);
 523
 524        sock_hold(sk);
 525        lock_sock(sk);
 526        if (!sock_flag(sk, SOCK_DONE))
 527                hvs_do_close_lock_held(vsk, false);
 528
 529        vsk->close_work_scheduled = false;
 530        release_sock(sk);
 531        sock_put(sk);
 532}
 533
 534/* Returns true, if it is safe to remove socket; false otherwise */
 535static bool hvs_close_lock_held(struct vsock_sock *vsk)
 536{
 537        struct sock *sk = sk_vsock(vsk);
 538
 539        if (!(sk->sk_state == TCP_ESTABLISHED ||
 540              sk->sk_state == TCP_CLOSING))
 541                return true;
 542
 543        if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
 544                hvs_shutdown_lock_held(vsk->trans, SHUTDOWN_MASK);
 545
 546        if (sock_flag(sk, SOCK_DONE))
 547                return true;
 548
 549        /* This reference will be dropped by the delayed close routine */
 550        sock_hold(sk);
 551        INIT_DELAYED_WORK(&vsk->close_work, hvs_close_timeout);
 552        vsk->close_work_scheduled = true;
 553        schedule_delayed_work(&vsk->close_work, HVS_CLOSE_TIMEOUT);
 554        return false;
 555}
 556
 557static void hvs_release(struct vsock_sock *vsk)
 558{
 559        struct sock *sk = sk_vsock(vsk);
 560        bool remove_sock;
 561
 562        lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
 563        remove_sock = hvs_close_lock_held(vsk);
 564        release_sock(sk);
 565        if (remove_sock)
 566                vsock_remove_sock(vsk);
 567}
 568
 569static void hvs_destruct(struct vsock_sock *vsk)
 570{
 571        struct hvsock *hvs = vsk->trans;
 572        struct vmbus_channel *chan = hvs->chan;
 573
 574        if (chan)
 575                vmbus_hvsock_device_unregister(chan);
 576
 577        kfree(hvs);
 578}
 579
 580static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
 581{
 582        return -EOPNOTSUPP;
 583}
 584
 585static int hvs_dgram_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
 586                             size_t len, int flags)
 587{
 588        return -EOPNOTSUPP;
 589}
 590
 591static int hvs_dgram_enqueue(struct vsock_sock *vsk,
 592                             struct sockaddr_vm *remote, struct msghdr *msg,
 593                             size_t dgram_len)
 594{
 595        return -EOPNOTSUPP;
 596}
 597
 598static bool hvs_dgram_allow(u32 cid, u32 port)
 599{
 600        return false;
 601}
 602
 603static int hvs_update_recv_data(struct hvsock *hvs)
 604{
 605        struct hvs_recv_buf *recv_buf;
 606        u32 payload_len;
 607
 608        recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
 609        payload_len = recv_buf->hdr.data_size;
 610
 611        if (payload_len > HVS_MTU_SIZE)
 612                return -EIO;
 613
 614        if (payload_len == 0)
 615                hvs->vsk->peer_shutdown |= SEND_SHUTDOWN;
 616
 617        hvs->recv_data_len = payload_len;
 618        hvs->recv_data_off = 0;
 619
 620        return 0;
 621}
 622
 623static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
 624                                  size_t len, int flags)
 625{
 626        struct hvsock *hvs = vsk->trans;
 627        bool need_refill = !hvs->recv_desc;
 628        struct hvs_recv_buf *recv_buf;
 629        u32 to_read;
 630        int ret;
 631
 632        if (flags & MSG_PEEK)
 633                return -EOPNOTSUPP;
 634
 635        if (need_refill) {
 636                hvs->recv_desc = hv_pkt_iter_first(hvs->chan);
 637                ret = hvs_update_recv_data(hvs);
 638                if (ret)
 639                        return ret;
 640        }
 641
 642        recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
 643        to_read = min_t(u32, len, hvs->recv_data_len);
 644        ret = memcpy_to_msg(msg, recv_buf->data + hvs->recv_data_off, to_read);
 645        if (ret != 0)
 646                return ret;
 647
 648        hvs->recv_data_len -= to_read;
 649        if (hvs->recv_data_len == 0) {
 650                hvs->recv_desc = hv_pkt_iter_next(hvs->chan, hvs->recv_desc);
 651                if (hvs->recv_desc) {
 652                        ret = hvs_update_recv_data(hvs);
 653                        if (ret)
 654                                return ret;
 655                }
 656        } else {
 657                hvs->recv_data_off += to_read;
 658        }
 659
 660        return to_read;
 661}
 662
 663static ssize_t hvs_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg,
 664                                  size_t len)
 665{
 666        struct hvsock *hvs = vsk->trans;
 667        struct vmbus_channel *chan = hvs->chan;
 668        struct hvs_send_buf *send_buf;
 669        ssize_t to_write, max_writable;
 670        ssize_t ret = 0;
 671        ssize_t bytes_written = 0;
 672
 673        BUILD_BUG_ON(sizeof(*send_buf) != PAGE_SIZE_4K);
 674
 675        send_buf = kmalloc(sizeof(*send_buf), GFP_KERNEL);
 676        if (!send_buf)
 677                return -ENOMEM;
 678
 679        /* Reader(s) could be draining data from the channel as we write.
 680         * Maximize bandwidth, by iterating until the channel is found to be
 681         * full.
 682         */
 683        while (len) {
 684                max_writable = hvs_channel_writable_bytes(chan);
 685                if (!max_writable)
 686                        break;
 687                to_write = min_t(ssize_t, len, max_writable);
 688                to_write = min_t(ssize_t, to_write, HVS_SEND_BUF_SIZE);
 689                /* memcpy_from_msg is safe for loop as it advances the offsets
 690                 * within the message iterator.
 691                 */
 692                ret = memcpy_from_msg(send_buf->data, msg, to_write);
 693                if (ret < 0)
 694                        goto out;
 695
 696                ret = hvs_send_data(hvs->chan, send_buf, to_write);
 697                if (ret < 0)
 698                        goto out;
 699
 700                bytes_written += to_write;
 701                len -= to_write;
 702        }
 703out:
 704        /* If any data has been sent, return that */
 705        if (bytes_written)
 706                ret = bytes_written;
 707        kfree(send_buf);
 708        return ret;
 709}
 710
 711static s64 hvs_stream_has_data(struct vsock_sock *vsk)
 712{
 713        struct hvsock *hvs = vsk->trans;
 714        s64 ret;
 715
 716        if (hvs->recv_data_len > 0)
 717                return 1;
 718
 719        switch (hvs_channel_readable_payload(hvs->chan)) {
 720        case 1:
 721                ret = 1;
 722                break;
 723        case 0:
 724                vsk->peer_shutdown |= SEND_SHUTDOWN;
 725                ret = 0;
 726                break;
 727        default: /* -1 */
 728                ret = 0;
 729                break;
 730        }
 731
 732        return ret;
 733}
 734
 735static s64 hvs_stream_has_space(struct vsock_sock *vsk)
 736{
 737        struct hvsock *hvs = vsk->trans;
 738
 739        return hvs_channel_writable_bytes(hvs->chan);
 740}
 741
 742static u64 hvs_stream_rcvhiwat(struct vsock_sock *vsk)
 743{
 744        return HVS_MTU_SIZE + 1;
 745}
 746
 747static bool hvs_stream_is_active(struct vsock_sock *vsk)
 748{
 749        struct hvsock *hvs = vsk->trans;
 750
 751        return hvs->chan != NULL;
 752}
 753
 754static bool hvs_stream_allow(u32 cid, u32 port)
 755{
 756        /* The host's port range [MIN_HOST_EPHEMERAL_PORT, 0xFFFFFFFF) is
 757         * reserved as ephemeral ports, which are used as the host's ports
 758         * when the host initiates connections.
 759         *
 760         * Perform this check in the guest so an immediate error is produced
 761         * instead of a timeout.
 762         */
 763        if (port > MAX_HOST_LISTEN_PORT)
 764                return false;
 765
 766        if (cid == VMADDR_CID_HOST)
 767                return true;
 768
 769        return false;
 770}
 771
 772static
 773int hvs_notify_poll_in(struct vsock_sock *vsk, size_t target, bool *readable)
 774{
 775        struct hvsock *hvs = vsk->trans;
 776
 777        *readable = hvs_channel_readable(hvs->chan);
 778        return 0;
 779}
 780
 781static
 782int hvs_notify_poll_out(struct vsock_sock *vsk, size_t target, bool *writable)
 783{
 784        *writable = hvs_stream_has_space(vsk) > 0;
 785
 786        return 0;
 787}
 788
 789static
 790int hvs_notify_recv_init(struct vsock_sock *vsk, size_t target,
 791                         struct vsock_transport_recv_notify_data *d)
 792{
 793        return 0;
 794}
 795
 796static
 797int hvs_notify_recv_pre_block(struct vsock_sock *vsk, size_t target,
 798                              struct vsock_transport_recv_notify_data *d)
 799{
 800        return 0;
 801}
 802
 803static
 804int hvs_notify_recv_pre_dequeue(struct vsock_sock *vsk, size_t target,
 805                                struct vsock_transport_recv_notify_data *d)
 806{
 807        return 0;
 808}
 809
 810static
 811int hvs_notify_recv_post_dequeue(struct vsock_sock *vsk, size_t target,
 812                                 ssize_t copied, bool data_read,
 813                                 struct vsock_transport_recv_notify_data *d)
 814{
 815        return 0;
 816}
 817
 818static
 819int hvs_notify_send_init(struct vsock_sock *vsk,
 820                         struct vsock_transport_send_notify_data *d)
 821{
 822        return 0;
 823}
 824
 825static
 826int hvs_notify_send_pre_block(struct vsock_sock *vsk,
 827                              struct vsock_transport_send_notify_data *d)
 828{
 829        return 0;
 830}
 831
 832static
 833int hvs_notify_send_pre_enqueue(struct vsock_sock *vsk,
 834                                struct vsock_transport_send_notify_data *d)
 835{
 836        return 0;
 837}
 838
 839static
 840int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written,
 841                                 struct vsock_transport_send_notify_data *d)
 842{
 843        return 0;
 844}
 845
 846static void hvs_set_buffer_size(struct vsock_sock *vsk, u64 val)
 847{
 848        /* Ignored. */
 849}
 850
 851static void hvs_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
 852{
 853        /* Ignored. */
 854}
 855
 856static void hvs_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
 857{
 858        /* Ignored. */
 859}
 860
 861static u64 hvs_get_buffer_size(struct vsock_sock *vsk)
 862{
 863        return -ENOPROTOOPT;
 864}
 865
 866static u64 hvs_get_min_buffer_size(struct vsock_sock *vsk)
 867{
 868        return -ENOPROTOOPT;
 869}
 870
 871static u64 hvs_get_max_buffer_size(struct vsock_sock *vsk)
 872{
 873        return -ENOPROTOOPT;
 874}
 875
 876static struct vsock_transport hvs_transport = {
 877        .get_local_cid            = hvs_get_local_cid,
 878
 879        .init                     = hvs_sock_init,
 880        .destruct                 = hvs_destruct,
 881        .release                  = hvs_release,
 882        .connect                  = hvs_connect,
 883        .shutdown                 = hvs_shutdown,
 884
 885        .dgram_bind               = hvs_dgram_bind,
 886        .dgram_dequeue            = hvs_dgram_dequeue,
 887        .dgram_enqueue            = hvs_dgram_enqueue,
 888        .dgram_allow              = hvs_dgram_allow,
 889
 890        .stream_dequeue           = hvs_stream_dequeue,
 891        .stream_enqueue           = hvs_stream_enqueue,
 892        .stream_has_data          = hvs_stream_has_data,
 893        .stream_has_space         = hvs_stream_has_space,
 894        .stream_rcvhiwat          = hvs_stream_rcvhiwat,
 895        .stream_is_active         = hvs_stream_is_active,
 896        .stream_allow             = hvs_stream_allow,
 897
 898        .notify_poll_in           = hvs_notify_poll_in,
 899        .notify_poll_out          = hvs_notify_poll_out,
 900        .notify_recv_init         = hvs_notify_recv_init,
 901        .notify_recv_pre_block    = hvs_notify_recv_pre_block,
 902        .notify_recv_pre_dequeue  = hvs_notify_recv_pre_dequeue,
 903        .notify_recv_post_dequeue = hvs_notify_recv_post_dequeue,
 904        .notify_send_init         = hvs_notify_send_init,
 905        .notify_send_pre_block    = hvs_notify_send_pre_block,
 906        .notify_send_pre_enqueue  = hvs_notify_send_pre_enqueue,
 907        .notify_send_post_enqueue = hvs_notify_send_post_enqueue,
 908
 909        .set_buffer_size          = hvs_set_buffer_size,
 910        .set_min_buffer_size      = hvs_set_min_buffer_size,
 911        .set_max_buffer_size      = hvs_set_max_buffer_size,
 912        .get_buffer_size          = hvs_get_buffer_size,
 913        .get_min_buffer_size      = hvs_get_min_buffer_size,
 914        .get_max_buffer_size      = hvs_get_max_buffer_size,
 915};
 916
 917static int hvs_probe(struct hv_device *hdev,
 918                     const struct hv_vmbus_device_id *dev_id)
 919{
 920        struct vmbus_channel *chan = hdev->channel;
 921
 922        hvs_open_connection(chan);
 923
 924        /* Always return success to suppress the unnecessary error message
 925         * in vmbus_probe(): on error the host will rescind the device in
 926         * 30 seconds and we can do cleanup at that time in
 927         * vmbus_onoffer_rescind().
 928         */
 929        return 0;
 930}
 931
 932static int hvs_remove(struct hv_device *hdev)
 933{
 934        struct vmbus_channel *chan = hdev->channel;
 935
 936        vmbus_close(chan);
 937
 938        return 0;
 939}
 940
 941/* This isn't really used. See vmbus_match() and vmbus_probe() */
 942static const struct hv_vmbus_device_id id_table[] = {
 943        {},
 944};
 945
 946static struct hv_driver hvs_drv = {
 947        .name           = "hv_sock",
 948        .hvsock         = true,
 949        .id_table       = id_table,
 950        .probe          = hvs_probe,
 951        .remove         = hvs_remove,
 952};
 953
 954static int __init hvs_init(void)
 955{
 956        int ret;
 957
 958        if (vmbus_proto_version < VERSION_WIN10)
 959                return -ENODEV;
 960
 961        ret = vmbus_driver_register(&hvs_drv);
 962        if (ret != 0)
 963                return ret;
 964
 965        ret = vsock_core_init(&hvs_transport);
 966        if (ret) {
 967                vmbus_driver_unregister(&hvs_drv);
 968                return ret;
 969        }
 970
 971        return 0;
 972}
 973
 974static void __exit hvs_exit(void)
 975{
 976        vsock_core_exit();
 977        vmbus_driver_unregister(&hvs_drv);
 978}
 979
 980module_init(hvs_init);
 981module_exit(hvs_exit);
 982
 983MODULE_DESCRIPTION("Hyper-V Sockets");
 984MODULE_VERSION("1.0.0");
 985MODULE_LICENSE("GPL");
 986MODULE_ALIAS_NETPROTO(PF_VSOCK);
 987