linux/net/vmw_vsock/vmci_transport.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * VMware vSockets Driver
   4 *
   5 * Copyright (C) 2007-2013 VMware, Inc. All rights reserved.
   6 */
   7
   8#include <linux/types.h>
   9#include <linux/bitops.h>
  10#include <linux/cred.h>
  11#include <linux/init.h>
  12#include <linux/io.h>
  13#include <linux/kernel.h>
  14#include <linux/kmod.h>
  15#include <linux/list.h>
  16#include <linux/module.h>
  17#include <linux/mutex.h>
  18#include <linux/net.h>
  19#include <linux/poll.h>
  20#include <linux/skbuff.h>
  21#include <linux/smp.h>
  22#include <linux/socket.h>
  23#include <linux/stddef.h>
  24#include <linux/unistd.h>
  25#include <linux/wait.h>
  26#include <linux/workqueue.h>
  27#include <net/sock.h>
  28#include <net/af_vsock.h>
  29
  30#include "vmci_transport_notify.h"
  31
  32static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg);
  33static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg);
  34static void vmci_transport_peer_detach_cb(u32 sub_id,
  35                                          const struct vmci_event_data *ed,
  36                                          void *client_data);
  37static void vmci_transport_recv_pkt_work(struct work_struct *work);
  38static void vmci_transport_cleanup(struct work_struct *work);
  39static int vmci_transport_recv_listen(struct sock *sk,
  40                                      struct vmci_transport_packet *pkt);
  41static int vmci_transport_recv_connecting_server(
  42                                        struct sock *sk,
  43                                        struct sock *pending,
  44                                        struct vmci_transport_packet *pkt);
  45static int vmci_transport_recv_connecting_client(
  46                                        struct sock *sk,
  47                                        struct vmci_transport_packet *pkt);
  48static int vmci_transport_recv_connecting_client_negotiate(
  49                                        struct sock *sk,
  50                                        struct vmci_transport_packet *pkt);
  51static int vmci_transport_recv_connecting_client_invalid(
  52                                        struct sock *sk,
  53                                        struct vmci_transport_packet *pkt);
  54static int vmci_transport_recv_connected(struct sock *sk,
  55                                         struct vmci_transport_packet *pkt);
  56static bool vmci_transport_old_proto_override(bool *old_pkt_proto);
  57static u16 vmci_transport_new_proto_supported_versions(void);
  58static bool vmci_transport_proto_to_notify_struct(struct sock *sk, u16 *proto,
  59                                                  bool old_pkt_proto);
  60static bool vmci_check_transport(struct vsock_sock *vsk);
  61
  62struct vmci_transport_recv_pkt_info {
  63        struct work_struct work;
  64        struct sock *sk;
  65        struct vmci_transport_packet pkt;
  66};
  67
  68static LIST_HEAD(vmci_transport_cleanup_list);
  69static DEFINE_SPINLOCK(vmci_transport_cleanup_lock);
  70static DECLARE_WORK(vmci_transport_cleanup_work, vmci_transport_cleanup);
  71
  72static struct vmci_handle vmci_transport_stream_handle = { VMCI_INVALID_ID,
  73                                                           VMCI_INVALID_ID };
  74static u32 vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
  75
  76static int PROTOCOL_OVERRIDE = -1;
  77
  78/* Helper function to convert from a VMCI error code to a VSock error code. */
  79
  80static s32 vmci_transport_error_to_vsock_error(s32 vmci_error)
  81{
  82        switch (vmci_error) {
  83        case VMCI_ERROR_NO_MEM:
  84                return -ENOMEM;
  85        case VMCI_ERROR_DUPLICATE_ENTRY:
  86        case VMCI_ERROR_ALREADY_EXISTS:
  87                return -EADDRINUSE;
  88        case VMCI_ERROR_NO_ACCESS:
  89                return -EPERM;
  90        case VMCI_ERROR_NO_RESOURCES:
  91                return -ENOBUFS;
  92        case VMCI_ERROR_INVALID_RESOURCE:
  93                return -EHOSTUNREACH;
  94        case VMCI_ERROR_INVALID_ARGS:
  95        default:
  96                break;
  97        }
  98        return -EINVAL;
  99}
 100
 101static u32 vmci_transport_peer_rid(u32 peer_cid)
 102{
 103        if (VMADDR_CID_HYPERVISOR == peer_cid)
 104                return VMCI_TRANSPORT_HYPERVISOR_PACKET_RID;
 105
 106        return VMCI_TRANSPORT_PACKET_RID;
 107}
 108
 109static inline void
 110vmci_transport_packet_init(struct vmci_transport_packet *pkt,
 111                           struct sockaddr_vm *src,
 112                           struct sockaddr_vm *dst,
 113                           u8 type,
 114                           u64 size,
 115                           u64 mode,
 116                           struct vmci_transport_waiting_info *wait,
 117                           u16 proto,
 118                           struct vmci_handle handle)
 119{
 120        /* We register the stream control handler as an any cid handle so we
 121         * must always send from a source address of VMADDR_CID_ANY
 122         */
 123        pkt->dg.src = vmci_make_handle(VMADDR_CID_ANY,
 124                                       VMCI_TRANSPORT_PACKET_RID);
 125        pkt->dg.dst = vmci_make_handle(dst->svm_cid,
 126                                       vmci_transport_peer_rid(dst->svm_cid));
 127        pkt->dg.payload_size = sizeof(*pkt) - sizeof(pkt->dg);
 128        pkt->version = VMCI_TRANSPORT_PACKET_VERSION;
 129        pkt->type = type;
 130        pkt->src_port = src->svm_port;
 131        pkt->dst_port = dst->svm_port;
 132        memset(&pkt->proto, 0, sizeof(pkt->proto));
 133        memset(&pkt->_reserved2, 0, sizeof(pkt->_reserved2));
 134
 135        switch (pkt->type) {
 136        case VMCI_TRANSPORT_PACKET_TYPE_INVALID:
 137                pkt->u.size = 0;
 138                break;
 139
 140        case VMCI_TRANSPORT_PACKET_TYPE_REQUEST:
 141        case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE:
 142                pkt->u.size = size;
 143                break;
 144
 145        case VMCI_TRANSPORT_PACKET_TYPE_OFFER:
 146        case VMCI_TRANSPORT_PACKET_TYPE_ATTACH:
 147                pkt->u.handle = handle;
 148                break;
 149
 150        case VMCI_TRANSPORT_PACKET_TYPE_WROTE:
 151        case VMCI_TRANSPORT_PACKET_TYPE_READ:
 152        case VMCI_TRANSPORT_PACKET_TYPE_RST:
 153                pkt->u.size = 0;
 154                break;
 155
 156        case VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN:
 157                pkt->u.mode = mode;
 158                break;
 159
 160        case VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ:
 161        case VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE:
 162                memcpy(&pkt->u.wait, wait, sizeof(pkt->u.wait));
 163                break;
 164
 165        case VMCI_TRANSPORT_PACKET_TYPE_REQUEST2:
 166        case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2:
 167                pkt->u.size = size;
 168                pkt->proto = proto;
 169                break;
 170        }
 171}
 172
 173static inline void
 174vmci_transport_packet_get_addresses(struct vmci_transport_packet *pkt,
 175                                    struct sockaddr_vm *local,
 176                                    struct sockaddr_vm *remote)
 177{
 178        vsock_addr_init(local, pkt->dg.dst.context, pkt->dst_port);
 179        vsock_addr_init(remote, pkt->dg.src.context, pkt->src_port);
 180}
 181
 182static int
 183__vmci_transport_send_control_pkt(struct vmci_transport_packet *pkt,
 184                                  struct sockaddr_vm *src,
 185                                  struct sockaddr_vm *dst,
 186                                  enum vmci_transport_packet_type type,
 187                                  u64 size,
 188                                  u64 mode,
 189                                  struct vmci_transport_waiting_info *wait,
 190                                  u16 proto,
 191                                  struct vmci_handle handle,
 192                                  bool convert_error)
 193{
 194        int err;
 195
 196        vmci_transport_packet_init(pkt, src, dst, type, size, mode, wait,
 197                                   proto, handle);
 198        err = vmci_datagram_send(&pkt->dg);
 199        if (convert_error && (err < 0))
 200                return vmci_transport_error_to_vsock_error(err);
 201
 202        return err;
 203}
 204
 205static int
 206vmci_transport_reply_control_pkt_fast(struct vmci_transport_packet *pkt,
 207                                      enum vmci_transport_packet_type type,
 208                                      u64 size,
 209                                      u64 mode,
 210                                      struct vmci_transport_waiting_info *wait,
 211                                      struct vmci_handle handle)
 212{
 213        struct vmci_transport_packet reply;
 214        struct sockaddr_vm src, dst;
 215
 216        if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST) {
 217                return 0;
 218        } else {
 219                vmci_transport_packet_get_addresses(pkt, &src, &dst);
 220                return __vmci_transport_send_control_pkt(&reply, &src, &dst,
 221                                                         type,
 222                                                         size, mode, wait,
 223                                                         VSOCK_PROTO_INVALID,
 224                                                         handle, true);
 225        }
 226}
 227
 228static int
 229vmci_transport_send_control_pkt_bh(struct sockaddr_vm *src,
 230                                   struct sockaddr_vm *dst,
 231                                   enum vmci_transport_packet_type type,
 232                                   u64 size,
 233                                   u64 mode,
 234                                   struct vmci_transport_waiting_info *wait,
 235                                   struct vmci_handle handle)
 236{
 237        /* Note that it is safe to use a single packet across all CPUs since
 238         * two tasklets of the same type are guaranteed to not ever run
 239         * simultaneously. If that ever changes, or VMCI stops using tasklets,
 240         * we can use per-cpu packets.
 241         */
 242        static struct vmci_transport_packet pkt;
 243
 244        return __vmci_transport_send_control_pkt(&pkt, src, dst, type,
 245                                                 size, mode, wait,
 246                                                 VSOCK_PROTO_INVALID, handle,
 247                                                 false);
 248}
 249
 250static int
 251vmci_transport_alloc_send_control_pkt(struct sockaddr_vm *src,
 252                                      struct sockaddr_vm *dst,
 253                                      enum vmci_transport_packet_type type,
 254                                      u64 size,
 255                                      u64 mode,
 256                                      struct vmci_transport_waiting_info *wait,
 257                                      u16 proto,
 258                                      struct vmci_handle handle)
 259{
 260        struct vmci_transport_packet *pkt;
 261        int err;
 262
 263        pkt = kmalloc(sizeof(*pkt), GFP_KERNEL);
 264        if (!pkt)
 265                return -ENOMEM;
 266
 267        err = __vmci_transport_send_control_pkt(pkt, src, dst, type, size,
 268                                                mode, wait, proto, handle,
 269                                                true);
 270        kfree(pkt);
 271
 272        return err;
 273}
 274
 275static int
 276vmci_transport_send_control_pkt(struct sock *sk,
 277                                enum vmci_transport_packet_type type,
 278                                u64 size,
 279                                u64 mode,
 280                                struct vmci_transport_waiting_info *wait,
 281                                u16 proto,
 282                                struct vmci_handle handle)
 283{
 284        struct vsock_sock *vsk;
 285
 286        vsk = vsock_sk(sk);
 287
 288        if (!vsock_addr_bound(&vsk->local_addr))
 289                return -EINVAL;
 290
 291        if (!vsock_addr_bound(&vsk->remote_addr))
 292                return -EINVAL;
 293
 294        return vmci_transport_alloc_send_control_pkt(&vsk->local_addr,
 295                                                     &vsk->remote_addr,
 296                                                     type, size, mode,
 297                                                     wait, proto, handle);
 298}
 299
 300static int vmci_transport_send_reset_bh(struct sockaddr_vm *dst,
 301                                        struct sockaddr_vm *src,
 302                                        struct vmci_transport_packet *pkt)
 303{
 304        if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST)
 305                return 0;
 306        return vmci_transport_send_control_pkt_bh(
 307                                        dst, src,
 308                                        VMCI_TRANSPORT_PACKET_TYPE_RST, 0,
 309                                        0, NULL, VMCI_INVALID_HANDLE);
 310}
 311
 312static int vmci_transport_send_reset(struct sock *sk,
 313                                     struct vmci_transport_packet *pkt)
 314{
 315        struct sockaddr_vm *dst_ptr;
 316        struct sockaddr_vm dst;
 317        struct vsock_sock *vsk;
 318
 319        if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST)
 320                return 0;
 321
 322        vsk = vsock_sk(sk);
 323
 324        if (!vsock_addr_bound(&vsk->local_addr))
 325                return -EINVAL;
 326
 327        if (vsock_addr_bound(&vsk->remote_addr)) {
 328                dst_ptr = &vsk->remote_addr;
 329        } else {
 330                vsock_addr_init(&dst, pkt->dg.src.context,
 331                                pkt->src_port);
 332                dst_ptr = &dst;
 333        }
 334        return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, dst_ptr,
 335                                             VMCI_TRANSPORT_PACKET_TYPE_RST,
 336                                             0, 0, NULL, VSOCK_PROTO_INVALID,
 337                                             VMCI_INVALID_HANDLE);
 338}
 339
 340static int vmci_transport_send_negotiate(struct sock *sk, size_t size)
 341{
 342        return vmci_transport_send_control_pkt(
 343                                        sk,
 344                                        VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE,
 345                                        size, 0, NULL,
 346                                        VSOCK_PROTO_INVALID,
 347                                        VMCI_INVALID_HANDLE);
 348}
 349
 350static int vmci_transport_send_negotiate2(struct sock *sk, size_t size,
 351                                          u16 version)
 352{
 353        return vmci_transport_send_control_pkt(
 354                                        sk,
 355                                        VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2,
 356                                        size, 0, NULL, version,
 357                                        VMCI_INVALID_HANDLE);
 358}
 359
 360static int vmci_transport_send_qp_offer(struct sock *sk,
 361                                        struct vmci_handle handle)
 362{
 363        return vmci_transport_send_control_pkt(
 364                                        sk, VMCI_TRANSPORT_PACKET_TYPE_OFFER, 0,
 365                                        0, NULL,
 366                                        VSOCK_PROTO_INVALID, handle);
 367}
 368
 369static int vmci_transport_send_attach(struct sock *sk,
 370                                      struct vmci_handle handle)
 371{
 372        return vmci_transport_send_control_pkt(
 373                                        sk, VMCI_TRANSPORT_PACKET_TYPE_ATTACH,
 374                                        0, 0, NULL, VSOCK_PROTO_INVALID,
 375                                        handle);
 376}
 377
 378static int vmci_transport_reply_reset(struct vmci_transport_packet *pkt)
 379{
 380        return vmci_transport_reply_control_pkt_fast(
 381                                                pkt,
 382                                                VMCI_TRANSPORT_PACKET_TYPE_RST,
 383                                                0, 0, NULL,
 384                                                VMCI_INVALID_HANDLE);
 385}
 386
 387static int vmci_transport_send_invalid_bh(struct sockaddr_vm *dst,
 388                                          struct sockaddr_vm *src)
 389{
 390        return vmci_transport_send_control_pkt_bh(
 391                                        dst, src,
 392                                        VMCI_TRANSPORT_PACKET_TYPE_INVALID,
 393                                        0, 0, NULL, VMCI_INVALID_HANDLE);
 394}
 395
 396int vmci_transport_send_wrote_bh(struct sockaddr_vm *dst,
 397                                 struct sockaddr_vm *src)
 398{
 399        return vmci_transport_send_control_pkt_bh(
 400                                        dst, src,
 401                                        VMCI_TRANSPORT_PACKET_TYPE_WROTE, 0,
 402                                        0, NULL, VMCI_INVALID_HANDLE);
 403}
 404
 405int vmci_transport_send_read_bh(struct sockaddr_vm *dst,
 406                                struct sockaddr_vm *src)
 407{
 408        return vmci_transport_send_control_pkt_bh(
 409                                        dst, src,
 410                                        VMCI_TRANSPORT_PACKET_TYPE_READ, 0,
 411                                        0, NULL, VMCI_INVALID_HANDLE);
 412}
 413
 414int vmci_transport_send_wrote(struct sock *sk)
 415{
 416        return vmci_transport_send_control_pkt(
 417                                        sk, VMCI_TRANSPORT_PACKET_TYPE_WROTE, 0,
 418                                        0, NULL, VSOCK_PROTO_INVALID,
 419                                        VMCI_INVALID_HANDLE);
 420}
 421
 422int vmci_transport_send_read(struct sock *sk)
 423{
 424        return vmci_transport_send_control_pkt(
 425                                        sk, VMCI_TRANSPORT_PACKET_TYPE_READ, 0,
 426                                        0, NULL, VSOCK_PROTO_INVALID,
 427                                        VMCI_INVALID_HANDLE);
 428}
 429
 430int vmci_transport_send_waiting_write(struct sock *sk,
 431                                      struct vmci_transport_waiting_info *wait)
 432{
 433        return vmci_transport_send_control_pkt(
 434                                sk, VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE,
 435                                0, 0, wait, VSOCK_PROTO_INVALID,
 436                                VMCI_INVALID_HANDLE);
 437}
 438
 439int vmci_transport_send_waiting_read(struct sock *sk,
 440                                     struct vmci_transport_waiting_info *wait)
 441{
 442        return vmci_transport_send_control_pkt(
 443                                sk, VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ,
 444                                0, 0, wait, VSOCK_PROTO_INVALID,
 445                                VMCI_INVALID_HANDLE);
 446}
 447
 448static int vmci_transport_shutdown(struct vsock_sock *vsk, int mode)
 449{
 450        return vmci_transport_send_control_pkt(
 451                                        &vsk->sk,
 452                                        VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN,
 453                                        0, mode, NULL,
 454                                        VSOCK_PROTO_INVALID,
 455                                        VMCI_INVALID_HANDLE);
 456}
 457
 458static int vmci_transport_send_conn_request(struct sock *sk, size_t size)
 459{
 460        return vmci_transport_send_control_pkt(sk,
 461                                        VMCI_TRANSPORT_PACKET_TYPE_REQUEST,
 462                                        size, 0, NULL,
 463                                        VSOCK_PROTO_INVALID,
 464                                        VMCI_INVALID_HANDLE);
 465}
 466
 467static int vmci_transport_send_conn_request2(struct sock *sk, size_t size,
 468                                             u16 version)
 469{
 470        return vmci_transport_send_control_pkt(
 471                                        sk, VMCI_TRANSPORT_PACKET_TYPE_REQUEST2,
 472                                        size, 0, NULL, version,
 473                                        VMCI_INVALID_HANDLE);
 474}
 475
 476static struct sock *vmci_transport_get_pending(
 477                                        struct sock *listener,
 478                                        struct vmci_transport_packet *pkt)
 479{
 480        struct vsock_sock *vlistener;
 481        struct vsock_sock *vpending;
 482        struct sock *pending;
 483        struct sockaddr_vm src;
 484
 485        vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);
 486
 487        vlistener = vsock_sk(listener);
 488
 489        list_for_each_entry(vpending, &vlistener->pending_links,
 490                            pending_links) {
 491                if (vsock_addr_equals_addr(&src, &vpending->remote_addr) &&
 492                    pkt->dst_port == vpending->local_addr.svm_port) {
 493                        pending = sk_vsock(vpending);
 494                        sock_hold(pending);
 495                        goto found;
 496                }
 497        }
 498
 499        pending = NULL;
 500found:
 501        return pending;
 502
 503}
 504
 505static void vmci_transport_release_pending(struct sock *pending)
 506{
 507        sock_put(pending);
 508}
 509
 510/* We allow two kinds of sockets to communicate with a restricted VM: 1)
 511 * trusted sockets 2) sockets from applications running as the same user as the
 512 * VM (this is only true for the host side and only when using hosted products)
 513 */
 514
 515static bool vmci_transport_is_trusted(struct vsock_sock *vsock, u32 peer_cid)
 516{
 517        return vsock->trusted ||
 518               vmci_is_context_owner(peer_cid, vsock->owner->uid);
 519}
 520
 521/* We allow sending datagrams to and receiving datagrams from a restricted VM
 522 * only if it is trusted as described in vmci_transport_is_trusted.
 523 */
 524
 525static bool vmci_transport_allow_dgram(struct vsock_sock *vsock, u32 peer_cid)
 526{
 527        if (VMADDR_CID_HYPERVISOR == peer_cid)
 528                return true;
 529
 530        if (vsock->cached_peer != peer_cid) {
 531                vsock->cached_peer = peer_cid;
 532                if (!vmci_transport_is_trusted(vsock, peer_cid) &&
 533                    (vmci_context_get_priv_flags(peer_cid) &
 534                     VMCI_PRIVILEGE_FLAG_RESTRICTED)) {
 535                        vsock->cached_peer_allow_dgram = false;
 536                } else {
 537                        vsock->cached_peer_allow_dgram = true;
 538                }
 539        }
 540
 541        return vsock->cached_peer_allow_dgram;
 542}
 543
 544static int
 545vmci_transport_queue_pair_alloc(struct vmci_qp **qpair,
 546                                struct vmci_handle *handle,
 547                                u64 produce_size,
 548                                u64 consume_size,
 549                                u32 peer, u32 flags, bool trusted)
 550{
 551        int err = 0;
 552
 553        if (trusted) {
 554                /* Try to allocate our queue pair as trusted. This will only
 555                 * work if vsock is running in the host.
 556                 */
 557
 558                err = vmci_qpair_alloc(qpair, handle, produce_size,
 559                                       consume_size,
 560                                       peer, flags,
 561                                       VMCI_PRIVILEGE_FLAG_TRUSTED);
 562                if (err != VMCI_ERROR_NO_ACCESS)
 563                        goto out;
 564
 565        }
 566
 567        err = vmci_qpair_alloc(qpair, handle, produce_size, consume_size,
 568                               peer, flags, VMCI_NO_PRIVILEGE_FLAGS);
 569out:
 570        if (err < 0) {
 571                pr_err("Could not attach to queue pair with %d\n",
 572                       err);
 573                err = vmci_transport_error_to_vsock_error(err);
 574        }
 575
 576        return err;
 577}
 578
 579static int
 580vmci_transport_datagram_create_hnd(u32 resource_id,
 581                                   u32 flags,
 582                                   vmci_datagram_recv_cb recv_cb,
 583                                   void *client_data,
 584                                   struct vmci_handle *out_handle)
 585{
 586        int err = 0;
 587
 588        /* Try to allocate our datagram handler as trusted. This will only work
 589         * if vsock is running in the host.
 590         */
 591
 592        err = vmci_datagram_create_handle_priv(resource_id, flags,
 593                                               VMCI_PRIVILEGE_FLAG_TRUSTED,
 594                                               recv_cb,
 595                                               client_data, out_handle);
 596
 597        if (err == VMCI_ERROR_NO_ACCESS)
 598                err = vmci_datagram_create_handle(resource_id, flags,
 599                                                  recv_cb, client_data,
 600                                                  out_handle);
 601
 602        return err;
 603}
 604
 605/* This is invoked as part of a tasklet that's scheduled when the VMCI
 606 * interrupt fires.  This is run in bottom-half context and if it ever needs to
 607 * sleep it should defer that work to a work queue.
 608 */
 609
 610static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg)
 611{
 612        struct sock *sk;
 613        size_t size;
 614        struct sk_buff *skb;
 615        struct vsock_sock *vsk;
 616
 617        sk = (struct sock *)data;
 618
 619        /* This handler is privileged when this module is running on the host.
 620         * We will get datagrams from all endpoints (even VMs that are in a
 621         * restricted context). If we get one from a restricted context then
 622         * the destination socket must be trusted.
 623         *
 624         * NOTE: We access the socket struct without holding the lock here.
 625         * This is ok because the field we are interested is never modified
 626         * outside of the create and destruct socket functions.
 627         */
 628        vsk = vsock_sk(sk);
 629        if (!vmci_transport_allow_dgram(vsk, dg->src.context))
 630                return VMCI_ERROR_NO_ACCESS;
 631
 632        size = VMCI_DG_SIZE(dg);
 633
 634        /* Attach the packet to the socket's receive queue as an sk_buff. */
 635        skb = alloc_skb(size, GFP_ATOMIC);
 636        if (!skb)
 637                return VMCI_ERROR_NO_MEM;
 638
 639        /* sk_receive_skb() will do a sock_put(), so hold here. */
 640        sock_hold(sk);
 641        skb_put(skb, size);
 642        memcpy(skb->data, dg, size);
 643        sk_receive_skb(sk, skb, 0);
 644
 645        return VMCI_SUCCESS;
 646}
 647
 648static bool vmci_transport_stream_allow(u32 cid, u32 port)
 649{
 650        static const u32 non_socket_contexts[] = {
 651                VMADDR_CID_LOCAL,
 652        };
 653        int i;
 654
 655        BUILD_BUG_ON(sizeof(cid) != sizeof(*non_socket_contexts));
 656
 657        for (i = 0; i < ARRAY_SIZE(non_socket_contexts); i++) {
 658                if (cid == non_socket_contexts[i])
 659                        return false;
 660        }
 661
 662        return true;
 663}
 664
 665/* This is invoked as part of a tasklet that's scheduled when the VMCI
 666 * interrupt fires.  This is run in bottom-half context but it defers most of
 667 * its work to the packet handling work queue.
 668 */
 669
 670static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg)
 671{
 672        struct sock *sk;
 673        struct sockaddr_vm dst;
 674        struct sockaddr_vm src;
 675        struct vmci_transport_packet *pkt;
 676        struct vsock_sock *vsk;
 677        bool bh_process_pkt;
 678        int err;
 679
 680        sk = NULL;
 681        err = VMCI_SUCCESS;
 682        bh_process_pkt = false;
 683
 684        /* Ignore incoming packets from contexts without sockets, or resources
 685         * that aren't vsock implementations.
 686         */
 687
 688        if (!vmci_transport_stream_allow(dg->src.context, -1)
 689            || vmci_transport_peer_rid(dg->src.context) != dg->src.resource)
 690                return VMCI_ERROR_NO_ACCESS;
 691
 692        if (VMCI_DG_SIZE(dg) < sizeof(*pkt))
 693                /* Drop datagrams that do not contain full VSock packets. */
 694                return VMCI_ERROR_INVALID_ARGS;
 695
 696        pkt = (struct vmci_transport_packet *)dg;
 697
 698        /* Find the socket that should handle this packet.  First we look for a
 699         * connected socket and if there is none we look for a socket bound to
 700         * the destintation address.
 701         */
 702        vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);
 703        vsock_addr_init(&dst, pkt->dg.dst.context, pkt->dst_port);
 704
 705        sk = vsock_find_connected_socket(&src, &dst);
 706        if (!sk) {
 707                sk = vsock_find_bound_socket(&dst);
 708                if (!sk) {
 709                        /* We could not find a socket for this specified
 710                         * address.  If this packet is a RST, we just drop it.
 711                         * If it is another packet, we send a RST.  Note that
 712                         * we do not send a RST reply to RSTs so that we do not
 713                         * continually send RSTs between two endpoints.
 714                         *
 715                         * Note that since this is a reply, dst is src and src
 716                         * is dst.
 717                         */
 718                        if (vmci_transport_send_reset_bh(&dst, &src, pkt) < 0)
 719                                pr_err("unable to send reset\n");
 720
 721                        err = VMCI_ERROR_NOT_FOUND;
 722                        goto out;
 723                }
 724        }
 725
 726        /* If the received packet type is beyond all types known to this
 727         * implementation, reply with an invalid message.  Hopefully this will
 728         * help when implementing backwards compatibility in the future.
 729         */
 730        if (pkt->type >= VMCI_TRANSPORT_PACKET_TYPE_MAX) {
 731                vmci_transport_send_invalid_bh(&dst, &src);
 732                err = VMCI_ERROR_INVALID_ARGS;
 733                goto out;
 734        }
 735
 736        /* This handler is privileged when this module is running on the host.
 737         * We will get datagram connect requests from all endpoints (even VMs
 738         * that are in a restricted context). If we get one from a restricted
 739         * context then the destination socket must be trusted.
 740         *
 741         * NOTE: We access the socket struct without holding the lock here.
 742         * This is ok because the field we are interested is never modified
 743         * outside of the create and destruct socket functions.
 744         */
 745        vsk = vsock_sk(sk);
 746        if (!vmci_transport_allow_dgram(vsk, pkt->dg.src.context)) {
 747                err = VMCI_ERROR_NO_ACCESS;
 748                goto out;
 749        }
 750
 751        /* We do most everything in a work queue, but let's fast path the
 752         * notification of reads and writes to help data transfer performance.
 753         * We can only do this if there is no process context code executing
 754         * for this socket since that may change the state.
 755         */
 756        bh_lock_sock(sk);
 757
 758        if (!sock_owned_by_user(sk)) {
 759                /* The local context ID may be out of date, update it. */
 760                vsk->local_addr.svm_cid = dst.svm_cid;
 761
 762                if (sk->sk_state == TCP_ESTABLISHED)
 763                        vmci_trans(vsk)->notify_ops->handle_notify_pkt(
 764                                        sk, pkt, true, &dst, &src,
 765                                        &bh_process_pkt);
 766        }
 767
 768        bh_unlock_sock(sk);
 769
 770        if (!bh_process_pkt) {
 771                struct vmci_transport_recv_pkt_info *recv_pkt_info;
 772
 773                recv_pkt_info = kmalloc(sizeof(*recv_pkt_info), GFP_ATOMIC);
 774                if (!recv_pkt_info) {
 775                        if (vmci_transport_send_reset_bh(&dst, &src, pkt) < 0)
 776                                pr_err("unable to send reset\n");
 777
 778                        err = VMCI_ERROR_NO_MEM;
 779                        goto out;
 780                }
 781
 782                recv_pkt_info->sk = sk;
 783                memcpy(&recv_pkt_info->pkt, pkt, sizeof(recv_pkt_info->pkt));
 784                INIT_WORK(&recv_pkt_info->work, vmci_transport_recv_pkt_work);
 785
 786                schedule_work(&recv_pkt_info->work);
 787                /* Clear sk so that the reference count incremented by one of
 788                 * the Find functions above is not decremented below.  We need
 789                 * that reference count for the packet handler we've scheduled
 790                 * to run.
 791                 */
 792                sk = NULL;
 793        }
 794
 795out:
 796        if (sk)
 797                sock_put(sk);
 798
 799        return err;
 800}
 801
 802static void vmci_transport_handle_detach(struct sock *sk)
 803{
 804        struct vsock_sock *vsk;
 805
 806        vsk = vsock_sk(sk);
 807        if (!vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)) {
 808                sock_set_flag(sk, SOCK_DONE);
 809
 810                /* On a detach the peer will not be sending or receiving
 811                 * anymore.
 812                 */
 813                vsk->peer_shutdown = SHUTDOWN_MASK;
 814
 815                /* We should not be sending anymore since the peer won't be
 816                 * there to receive, but we can still receive if there is data
 817                 * left in our consume queue. If the local endpoint is a host,
 818                 * we can't call vsock_stream_has_data, since that may block,
 819                 * but a host endpoint can't read data once the VM has
 820                 * detached, so there is no available data in that case.
 821                 */
 822                if (vsk->local_addr.svm_cid == VMADDR_CID_HOST ||
 823                    vsock_stream_has_data(vsk) <= 0) {
 824                        if (sk->sk_state == TCP_SYN_SENT) {
 825                                /* The peer may detach from a queue pair while
 826                                 * we are still in the connecting state, i.e.,
 827                                 * if the peer VM is killed after attaching to
 828                                 * a queue pair, but before we complete the
 829                                 * handshake. In that case, we treat the detach
 830                                 * event like a reset.
 831                                 */
 832
 833                                sk->sk_state = TCP_CLOSE;
 834                                sk->sk_err = ECONNRESET;
 835                                sk->sk_error_report(sk);
 836                                return;
 837                        }
 838                        sk->sk_state = TCP_CLOSE;
 839                }
 840                sk->sk_state_change(sk);
 841        }
 842}
 843
 844static void vmci_transport_peer_detach_cb(u32 sub_id,
 845                                          const struct vmci_event_data *e_data,
 846                                          void *client_data)
 847{
 848        struct vmci_transport *trans = client_data;
 849        const struct vmci_event_payload_qp *e_payload;
 850
 851        e_payload = vmci_event_data_const_payload(e_data);
 852
 853        /* XXX This is lame, we should provide a way to lookup sockets by
 854         * qp_handle.
 855         */
 856        if (vmci_handle_is_invalid(e_payload->handle) ||
 857            !vmci_handle_is_equal(trans->qp_handle, e_payload->handle))
 858                return;
 859
 860        /* We don't ask for delayed CBs when we subscribe to this event (we
 861         * pass 0 as flags to vmci_event_subscribe()).  VMCI makes no
 862         * guarantees in that case about what context we might be running in,
 863         * so it could be BH or process, blockable or non-blockable.  So we
 864         * need to account for all possible contexts here.
 865         */
 866        spin_lock_bh(&trans->lock);
 867        if (!trans->sk)
 868                goto out;
 869
 870        /* Apart from here, trans->lock is only grabbed as part of sk destruct,
 871         * where trans->sk isn't locked.
 872         */
 873        bh_lock_sock(trans->sk);
 874
 875        vmci_transport_handle_detach(trans->sk);
 876
 877        bh_unlock_sock(trans->sk);
 878 out:
 879        spin_unlock_bh(&trans->lock);
 880}
 881
 882static void vmci_transport_qp_resumed_cb(u32 sub_id,
 883                                         const struct vmci_event_data *e_data,
 884                                         void *client_data)
 885{
 886        vsock_for_each_connected_socket(vmci_transport_handle_detach);
 887}
 888
 889static void vmci_transport_recv_pkt_work(struct work_struct *work)
 890{
 891        struct vmci_transport_recv_pkt_info *recv_pkt_info;
 892        struct vmci_transport_packet *pkt;
 893        struct sock *sk;
 894
 895        recv_pkt_info =
 896                container_of(work, struct vmci_transport_recv_pkt_info, work);
 897        sk = recv_pkt_info->sk;
 898        pkt = &recv_pkt_info->pkt;
 899
 900        lock_sock(sk);
 901
 902        /* The local context ID may be out of date. */
 903        vsock_sk(sk)->local_addr.svm_cid = pkt->dg.dst.context;
 904
 905        switch (sk->sk_state) {
 906        case TCP_LISTEN:
 907                vmci_transport_recv_listen(sk, pkt);
 908                break;
 909        case TCP_SYN_SENT:
 910                /* Processing of pending connections for servers goes through
 911                 * the listening socket, so see vmci_transport_recv_listen()
 912                 * for that path.
 913                 */
 914                vmci_transport_recv_connecting_client(sk, pkt);
 915                break;
 916        case TCP_ESTABLISHED:
 917                vmci_transport_recv_connected(sk, pkt);
 918                break;
 919        default:
 920                /* Because this function does not run in the same context as
 921                 * vmci_transport_recv_stream_cb it is possible that the
 922                 * socket has closed. We need to let the other side know or it
 923                 * could be sitting in a connect and hang forever. Send a
 924                 * reset to prevent that.
 925                 */
 926                vmci_transport_send_reset(sk, pkt);
 927                break;
 928        }
 929
 930        release_sock(sk);
 931        kfree(recv_pkt_info);
 932        /* Release reference obtained in the stream callback when we fetched
 933         * this socket out of the bound or connected list.
 934         */
 935        sock_put(sk);
 936}
 937
 938static int vmci_transport_recv_listen(struct sock *sk,
 939                                      struct vmci_transport_packet *pkt)
 940{
 941        struct sock *pending;
 942        struct vsock_sock *vpending;
 943        int err;
 944        u64 qp_size;
 945        bool old_request = false;
 946        bool old_pkt_proto = false;
 947
 948        err = 0;
 949
 950        /* Because we are in the listen state, we could be receiving a packet
 951         * for ourself or any previous connection requests that we received.
 952         * If it's the latter, we try to find a socket in our list of pending
 953         * connections and, if we do, call the appropriate handler for the
 954         * state that that socket is in.  Otherwise we try to service the
 955         * connection request.
 956         */
 957        pending = vmci_transport_get_pending(sk, pkt);
 958        if (pending) {
 959                lock_sock(pending);
 960
 961                /* The local context ID may be out of date. */
 962                vsock_sk(pending)->local_addr.svm_cid = pkt->dg.dst.context;
 963
 964                switch (pending->sk_state) {
 965                case TCP_SYN_SENT:
 966                        err = vmci_transport_recv_connecting_server(sk,
 967                                                                    pending,
 968                                                                    pkt);
 969                        break;
 970                default:
 971                        vmci_transport_send_reset(pending, pkt);
 972                        err = -EINVAL;
 973                }
 974
 975                if (err < 0)
 976                        vsock_remove_pending(sk, pending);
 977
 978                release_sock(pending);
 979                vmci_transport_release_pending(pending);
 980
 981                return err;
 982        }
 983
 984        /* The listen state only accepts connection requests.  Reply with a
 985         * reset unless we received a reset.
 986         */
 987
 988        if (!(pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST ||
 989              pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST2)) {
 990                vmci_transport_reply_reset(pkt);
 991                return -EINVAL;
 992        }
 993
 994        if (pkt->u.size == 0) {
 995                vmci_transport_reply_reset(pkt);
 996                return -EINVAL;
 997        }
 998
 999        /* If this socket can't accommodate this connection request, we send a
1000         * reset.  Otherwise we create and initialize a child socket and reply
1001         * with a connection negotiation.
1002         */
1003        if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog) {
1004                vmci_transport_reply_reset(pkt);
1005                return -ECONNREFUSED;
1006        }
1007
1008        pending = vsock_create_connected(sk);
1009        if (!pending) {
1010                vmci_transport_send_reset(sk, pkt);
1011                return -ENOMEM;
1012        }
1013
1014        vpending = vsock_sk(pending);
1015
1016        vsock_addr_init(&vpending->local_addr, pkt->dg.dst.context,
1017                        pkt->dst_port);
1018        vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context,
1019                        pkt->src_port);
1020
1021        err = vsock_assign_transport(vpending, vsock_sk(sk));
1022        /* Transport assigned (looking at remote_addr) must be the same
1023         * where we received the request.
1024         */
1025        if (err || !vmci_check_transport(vpending)) {
1026                vmci_transport_send_reset(sk, pkt);
1027                sock_put(pending);
1028                return err;
1029        }
1030
1031        /* If the proposed size fits within our min/max, accept it. Otherwise
1032         * propose our own size.
1033         */
1034        if (pkt->u.size >= vpending->buffer_min_size &&
1035            pkt->u.size <= vpending->buffer_max_size) {
1036                qp_size = pkt->u.size;
1037        } else {
1038                qp_size = vpending->buffer_size;
1039        }
1040
1041        /* Figure out if we are using old or new requests based on the
1042         * overrides pkt types sent by our peer.
1043         */
1044        if (vmci_transport_old_proto_override(&old_pkt_proto)) {
1045                old_request = old_pkt_proto;
1046        } else {
1047                if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST)
1048                        old_request = true;
1049                else if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST2)
1050                        old_request = false;
1051
1052        }
1053
1054        if (old_request) {
1055                /* Handle a REQUEST (or override) */
1056                u16 version = VSOCK_PROTO_INVALID;
1057                if (vmci_transport_proto_to_notify_struct(
1058                        pending, &version, true))
1059                        err = vmci_transport_send_negotiate(pending, qp_size);
1060                else
1061                        err = -EINVAL;
1062
1063        } else {
1064                /* Handle a REQUEST2 (or override) */
1065                int proto_int = pkt->proto;
1066                int pos;
1067                u16 active_proto_version = 0;
1068
1069                /* The list of possible protocols is the intersection of all
1070                 * protocols the client supports ... plus all the protocols we
1071                 * support.
1072                 */
1073                proto_int &= vmci_transport_new_proto_supported_versions();
1074
1075                /* We choose the highest possible protocol version and use that
1076                 * one.
1077                 */
1078                pos = fls(proto_int);
1079                if (pos) {
1080                        active_proto_version = (1 << (pos - 1));
1081                        if (vmci_transport_proto_to_notify_struct(
1082                                pending, &active_proto_version, false))
1083                                err = vmci_transport_send_negotiate2(pending,
1084                                                        qp_size,
1085                                                        active_proto_version);
1086                        else
1087                                err = -EINVAL;
1088
1089                } else {
1090                        err = -EINVAL;
1091                }
1092        }
1093
1094        if (err < 0) {
1095                vmci_transport_send_reset(sk, pkt);
1096                sock_put(pending);
1097                err = vmci_transport_error_to_vsock_error(err);
1098                goto out;
1099        }
1100
1101        vsock_add_pending(sk, pending);
1102        sk_acceptq_added(sk);
1103
1104        pending->sk_state = TCP_SYN_SENT;
1105        vmci_trans(vpending)->produce_size =
1106                vmci_trans(vpending)->consume_size = qp_size;
1107        vpending->buffer_size = qp_size;
1108
1109        vmci_trans(vpending)->notify_ops->process_request(pending);
1110
1111        /* We might never receive another message for this socket and it's not
1112         * connected to any process, so we have to ensure it gets cleaned up
1113         * ourself.  Our delayed work function will take care of that.  Note
1114         * that we do not ever cancel this function since we have few
1115         * guarantees about its state when calling cancel_delayed_work().
1116         * Instead we hold a reference on the socket for that function and make
1117         * it capable of handling cases where it needs to do nothing but
1118         * release that reference.
1119         */
1120        vpending->listener = sk;
1121        sock_hold(sk);
1122        sock_hold(pending);
1123        schedule_delayed_work(&vpending->pending_work, HZ);
1124
1125out:
1126        return err;
1127}
1128
1129static int
1130vmci_transport_recv_connecting_server(struct sock *listener,
1131                                      struct sock *pending,
1132                                      struct vmci_transport_packet *pkt)
1133{
1134        struct vsock_sock *vpending;
1135        struct vmci_handle handle;
1136        struct vmci_qp *qpair;
1137        bool is_local;
1138        u32 flags;
1139        u32 detach_sub_id;
1140        int err;
1141        int skerr;
1142
1143        vpending = vsock_sk(pending);
1144        detach_sub_id = VMCI_INVALID_ID;
1145
1146        switch (pkt->type) {
1147        case VMCI_TRANSPORT_PACKET_TYPE_OFFER:
1148                if (vmci_handle_is_invalid(pkt->u.handle)) {
1149                        vmci_transport_send_reset(pending, pkt);
1150                        skerr = EPROTO;
1151                        err = -EINVAL;
1152                        goto destroy;
1153                }
1154                break;
1155        default:
1156                /* Close and cleanup the connection. */
1157                vmci_transport_send_reset(pending, pkt);
1158                skerr = EPROTO;
1159                err = pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST ? 0 : -EINVAL;
1160                goto destroy;
1161        }
1162
1163        /* In order to complete the connection we need to attach to the offered
1164         * queue pair and send an attach notification.  We also subscribe to the
1165         * detach event so we know when our peer goes away, and we do that
1166         * before attaching so we don't miss an event.  If all this succeeds,
1167         * we update our state and wakeup anything waiting in accept() for a
1168         * connection.
1169         */
1170
1171        /* We don't care about attach since we ensure the other side has
1172         * attached by specifying the ATTACH_ONLY flag below.
1173         */
1174        err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_DETACH,
1175                                   vmci_transport_peer_detach_cb,
1176                                   vmci_trans(vpending), &detach_sub_id);
1177        if (err < VMCI_SUCCESS) {
1178                vmci_transport_send_reset(pending, pkt);
1179                err = vmci_transport_error_to_vsock_error(err);
1180                skerr = -err;
1181                goto destroy;
1182        }
1183
1184        vmci_trans(vpending)->detach_sub_id = detach_sub_id;
1185
1186        /* Now attach to the queue pair the client created. */
1187        handle = pkt->u.handle;
1188
1189        /* vpending->local_addr always has a context id so we do not need to
1190         * worry about VMADDR_CID_ANY in this case.
1191         */
1192        is_local =
1193            vpending->remote_addr.svm_cid == vpending->local_addr.svm_cid;
1194        flags = VMCI_QPFLAG_ATTACH_ONLY;
1195        flags |= is_local ? VMCI_QPFLAG_LOCAL : 0;
1196
1197        err = vmci_transport_queue_pair_alloc(
1198                                        &qpair,
1199                                        &handle,
1200                                        vmci_trans(vpending)->produce_size,
1201                                        vmci_trans(vpending)->consume_size,
1202                                        pkt->dg.src.context,
1203                                        flags,
1204                                        vmci_transport_is_trusted(
1205                                                vpending,
1206                                                vpending->remote_addr.svm_cid));
1207        if (err < 0) {
1208                vmci_transport_send_reset(pending, pkt);
1209                skerr = -err;
1210                goto destroy;
1211        }
1212
1213        vmci_trans(vpending)->qp_handle = handle;
1214        vmci_trans(vpending)->qpair = qpair;
1215
1216        /* When we send the attach message, we must be ready to handle incoming
1217         * control messages on the newly connected socket. So we move the
1218         * pending socket to the connected state before sending the attach
1219         * message. Otherwise, an incoming packet triggered by the attach being
1220         * received by the peer may be processed concurrently with what happens
1221         * below after sending the attach message, and that incoming packet
1222         * will find the listening socket instead of the (currently) pending
1223         * socket. Note that enqueueing the socket increments the reference
1224         * count, so even if a reset comes before the connection is accepted,
1225         * the socket will be valid until it is removed from the queue.
1226         *
1227         * If we fail sending the attach below, we remove the socket from the
1228         * connected list and move the socket to TCP_CLOSE before
1229         * releasing the lock, so a pending slow path processing of an incoming
1230         * packet will not see the socket in the connected state in that case.
1231         */
1232        pending->sk_state = TCP_ESTABLISHED;
1233
1234        vsock_insert_connected(vpending);
1235
1236        /* Notify our peer of our attach. */
1237        err = vmci_transport_send_attach(pending, handle);
1238        if (err < 0) {
1239                vsock_remove_connected(vpending);
1240                pr_err("Could not send attach\n");
1241                vmci_transport_send_reset(pending, pkt);
1242                err = vmci_transport_error_to_vsock_error(err);
1243                skerr = -err;
1244                goto destroy;
1245        }
1246
1247        /* We have a connection. Move the now connected socket from the
1248         * listener's pending list to the accept queue so callers of accept()
1249         * can find it.
1250         */
1251        vsock_remove_pending(listener, pending);
1252        vsock_enqueue_accept(listener, pending);
1253
1254        /* Callers of accept() will be be waiting on the listening socket, not
1255         * the pending socket.
1256         */
1257        listener->sk_data_ready(listener);
1258
1259        return 0;
1260
1261destroy:
1262        pending->sk_err = skerr;
1263        pending->sk_state = TCP_CLOSE;
1264        /* As long as we drop our reference, all necessary cleanup will handle
1265         * when the cleanup function drops its reference and our destruct
1266         * implementation is called.  Note that since the listen handler will
1267         * remove pending from the pending list upon our failure, the cleanup
1268         * function won't drop the additional reference, which is why we do it
1269         * here.
1270         */
1271        sock_put(pending);
1272
1273        return err;
1274}
1275
1276static int
1277vmci_transport_recv_connecting_client(struct sock *sk,
1278                                      struct vmci_transport_packet *pkt)
1279{
1280        struct vsock_sock *vsk;
1281        int err;
1282        int skerr;
1283
1284        vsk = vsock_sk(sk);
1285
1286        switch (pkt->type) {
1287        case VMCI_TRANSPORT_PACKET_TYPE_ATTACH:
1288                if (vmci_handle_is_invalid(pkt->u.handle) ||
1289                    !vmci_handle_is_equal(pkt->u.handle,
1290                                          vmci_trans(vsk)->qp_handle)) {
1291                        skerr = EPROTO;
1292                        err = -EINVAL;
1293                        goto destroy;
1294                }
1295
1296                /* Signify the socket is connected and wakeup the waiter in
1297                 * connect(). Also place the socket in the connected table for
1298                 * accounting (it can already be found since it's in the bound
1299                 * table).
1300                 */
1301                sk->sk_state = TCP_ESTABLISHED;
1302                sk->sk_socket->state = SS_CONNECTED;
1303                vsock_insert_connected(vsk);
1304                sk->sk_state_change(sk);
1305
1306                break;
1307        case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE:
1308        case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2:
1309                if (pkt->u.size == 0
1310                    || pkt->dg.src.context != vsk->remote_addr.svm_cid
1311                    || pkt->src_port != vsk->remote_addr.svm_port
1312                    || !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)
1313                    || vmci_trans(vsk)->qpair
1314                    || vmci_trans(vsk)->produce_size != 0
1315                    || vmci_trans(vsk)->consume_size != 0
1316                    || vmci_trans(vsk)->detach_sub_id != VMCI_INVALID_ID) {
1317                        skerr = EPROTO;
1318                        err = -EINVAL;
1319
1320                        goto destroy;
1321                }
1322
1323                err = vmci_transport_recv_connecting_client_negotiate(sk, pkt);
1324                if (err) {
1325                        skerr = -err;
1326                        goto destroy;
1327                }
1328
1329                break;
1330        case VMCI_TRANSPORT_PACKET_TYPE_INVALID:
1331                err = vmci_transport_recv_connecting_client_invalid(sk, pkt);
1332                if (err) {
1333                        skerr = -err;
1334                        goto destroy;
1335                }
1336
1337                break;
1338        case VMCI_TRANSPORT_PACKET_TYPE_RST:
1339                /* Older versions of the linux code (WS 6.5 / ESX 4.0) used to
1340                 * continue processing here after they sent an INVALID packet.
1341                 * This meant that we got a RST after the INVALID. We ignore a
1342                 * RST after an INVALID. The common code doesn't send the RST
1343                 * ... so we can hang if an old version of the common code
1344                 * fails between getting a REQUEST and sending an OFFER back.
1345                 * Not much we can do about it... except hope that it doesn't
1346                 * happen.
1347                 */
1348                if (vsk->ignore_connecting_rst) {
1349                        vsk->ignore_connecting_rst = false;
1350                } else {
1351                        skerr = ECONNRESET;
1352                        err = 0;
1353                        goto destroy;
1354                }
1355
1356                break;
1357        default:
1358                /* Close and cleanup the connection. */
1359                skerr = EPROTO;
1360                err = -EINVAL;
1361                goto destroy;
1362        }
1363
1364        return 0;
1365
1366destroy:
1367        vmci_transport_send_reset(sk, pkt);
1368
1369        sk->sk_state = TCP_CLOSE;
1370        sk->sk_err = skerr;
1371        sk->sk_error_report(sk);
1372        return err;
1373}
1374
1375static int vmci_transport_recv_connecting_client_negotiate(
1376                                        struct sock *sk,
1377                                        struct vmci_transport_packet *pkt)
1378{
1379        int err;
1380        struct vsock_sock *vsk;
1381        struct vmci_handle handle;
1382        struct vmci_qp *qpair;
1383        u32 detach_sub_id;
1384        bool is_local;
1385        u32 flags;
1386        bool old_proto = true;
1387        bool old_pkt_proto;
1388        u16 version;
1389
1390        vsk = vsock_sk(sk);
1391        handle = VMCI_INVALID_HANDLE;
1392        detach_sub_id = VMCI_INVALID_ID;
1393
1394        /* If we have gotten here then we should be past the point where old
1395         * linux vsock could have sent the bogus rst.
1396         */
1397        vsk->sent_request = false;
1398        vsk->ignore_connecting_rst = false;
1399
1400        /* Verify that we're OK with the proposed queue pair size */
1401        if (pkt->u.size < vsk->buffer_min_size ||
1402            pkt->u.size > vsk->buffer_max_size) {
1403                err = -EINVAL;
1404                goto destroy;
1405        }
1406
1407        /* At this point we know the CID the peer is using to talk to us. */
1408
1409        if (vsk->local_addr.svm_cid == VMADDR_CID_ANY)
1410                vsk->local_addr.svm_cid = pkt->dg.dst.context;
1411
1412        /* Setup the notify ops to be the highest supported version that both
1413         * the server and the client support.
1414         */
1415
1416        if (vmci_transport_old_proto_override(&old_pkt_proto)) {
1417                old_proto = old_pkt_proto;
1418        } else {
1419                if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE)
1420                        old_proto = true;
1421                else if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2)
1422                        old_proto = false;
1423
1424        }
1425
1426        if (old_proto)
1427                version = VSOCK_PROTO_INVALID;
1428        else
1429                version = pkt->proto;
1430
1431        if (!vmci_transport_proto_to_notify_struct(sk, &version, old_proto)) {
1432                err = -EINVAL;
1433                goto destroy;
1434        }
1435
1436        /* Subscribe to detach events first.
1437         *
1438         * XXX We attach once for each queue pair created for now so it is easy
1439         * to find the socket (it's provided), but later we should only
1440         * subscribe once and add a way to lookup sockets by queue pair handle.
1441         */
1442        err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_DETACH,
1443                                   vmci_transport_peer_detach_cb,
1444                                   vmci_trans(vsk), &detach_sub_id);
1445        if (err < VMCI_SUCCESS) {
1446                err = vmci_transport_error_to_vsock_error(err);
1447                goto destroy;
1448        }
1449
1450        /* Make VMCI select the handle for us. */
1451        handle = VMCI_INVALID_HANDLE;
1452        is_local = vsk->remote_addr.svm_cid == vsk->local_addr.svm_cid;
1453        flags = is_local ? VMCI_QPFLAG_LOCAL : 0;
1454
1455        err = vmci_transport_queue_pair_alloc(&qpair,
1456                                              &handle,
1457                                              pkt->u.size,
1458                                              pkt->u.size,
1459                                              vsk->remote_addr.svm_cid,
1460                                              flags,
1461                                              vmci_transport_is_trusted(
1462                                                  vsk,
1463                                                  vsk->
1464                                                  remote_addr.svm_cid));
1465        if (err < 0)
1466                goto destroy;
1467
1468        err = vmci_transport_send_qp_offer(sk, handle);
1469        if (err < 0) {
1470                err = vmci_transport_error_to_vsock_error(err);
1471                goto destroy;
1472        }
1473
1474        vmci_trans(vsk)->qp_handle = handle;
1475        vmci_trans(vsk)->qpair = qpair;
1476
1477        vmci_trans(vsk)->produce_size = vmci_trans(vsk)->consume_size =
1478                pkt->u.size;
1479
1480        vmci_trans(vsk)->detach_sub_id = detach_sub_id;
1481
1482        vmci_trans(vsk)->notify_ops->process_negotiate(sk);
1483
1484        return 0;
1485
1486destroy:
1487        if (detach_sub_id != VMCI_INVALID_ID)
1488                vmci_event_unsubscribe(detach_sub_id);
1489
1490        if (!vmci_handle_is_invalid(handle))
1491                vmci_qpair_detach(&qpair);
1492
1493        return err;
1494}
1495
1496static int
1497vmci_transport_recv_connecting_client_invalid(struct sock *sk,
1498                                              struct vmci_transport_packet *pkt)
1499{
1500        int err = 0;
1501        struct vsock_sock *vsk = vsock_sk(sk);
1502
1503        if (vsk->sent_request) {
1504                vsk->sent_request = false;
1505                vsk->ignore_connecting_rst = true;
1506
1507                err = vmci_transport_send_conn_request(sk, vsk->buffer_size);
1508                if (err < 0)
1509                        err = vmci_transport_error_to_vsock_error(err);
1510                else
1511                        err = 0;
1512
1513        }
1514
1515        return err;
1516}
1517
1518static int vmci_transport_recv_connected(struct sock *sk,
1519                                         struct vmci_transport_packet *pkt)
1520{
1521        struct vsock_sock *vsk;
1522        bool pkt_processed = false;
1523
1524        /* In cases where we are closing the connection, it's sufficient to
1525         * mark the state change (and maybe error) and wake up any waiting
1526         * threads. Since this is a connected socket, it's owned by a user
1527         * process and will be cleaned up when the failure is passed back on
1528         * the current or next system call.  Our system call implementations
1529         * must therefore check for error and state changes on entry and when
1530         * being awoken.
1531         */
1532        switch (pkt->type) {
1533        case VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN:
1534                if (pkt->u.mode) {
1535                        vsk = vsock_sk(sk);
1536
1537                        vsk->peer_shutdown |= pkt->u.mode;
1538                        sk->sk_state_change(sk);
1539                }
1540                break;
1541
1542        case VMCI_TRANSPORT_PACKET_TYPE_RST:
1543                vsk = vsock_sk(sk);
1544                /* It is possible that we sent our peer a message (e.g a
1545                 * WAITING_READ) right before we got notified that the peer had
1546                 * detached. If that happens then we can get a RST pkt back
1547                 * from our peer even though there is data available for us to
1548                 * read. In that case, don't shutdown the socket completely but
1549                 * instead allow the local client to finish reading data off
1550                 * the queuepair. Always treat a RST pkt in connected mode like
1551                 * a clean shutdown.
1552                 */
1553                sock_set_flag(sk, SOCK_DONE);
1554                vsk->peer_shutdown = SHUTDOWN_MASK;
1555                if (vsock_stream_has_data(vsk) <= 0)
1556                        sk->sk_state = TCP_CLOSING;
1557
1558                sk->sk_state_change(sk);
1559                break;
1560
1561        default:
1562                vsk = vsock_sk(sk);
1563                vmci_trans(vsk)->notify_ops->handle_notify_pkt(
1564                                sk, pkt, false, NULL, NULL,
1565                                &pkt_processed);
1566                if (!pkt_processed)
1567                        return -EINVAL;
1568
1569                break;
1570        }
1571
1572        return 0;
1573}
1574
1575static int vmci_transport_socket_init(struct vsock_sock *vsk,
1576                                      struct vsock_sock *psk)
1577{
1578        vsk->trans = kmalloc(sizeof(struct vmci_transport), GFP_KERNEL);
1579        if (!vsk->trans)
1580                return -ENOMEM;
1581
1582        vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE;
1583        vmci_trans(vsk)->qp_handle = VMCI_INVALID_HANDLE;
1584        vmci_trans(vsk)->qpair = NULL;
1585        vmci_trans(vsk)->produce_size = vmci_trans(vsk)->consume_size = 0;
1586        vmci_trans(vsk)->detach_sub_id = VMCI_INVALID_ID;
1587        vmci_trans(vsk)->notify_ops = NULL;
1588        INIT_LIST_HEAD(&vmci_trans(vsk)->elem);
1589        vmci_trans(vsk)->sk = &vsk->sk;
1590        spin_lock_init(&vmci_trans(vsk)->lock);
1591
1592        return 0;
1593}
1594
1595static void vmci_transport_free_resources(struct list_head *transport_list)
1596{
1597        while (!list_empty(transport_list)) {
1598                struct vmci_transport *transport =
1599                    list_first_entry(transport_list, struct vmci_transport,
1600                                     elem);
1601                list_del(&transport->elem);
1602
1603                if (transport->detach_sub_id != VMCI_INVALID_ID) {
1604                        vmci_event_unsubscribe(transport->detach_sub_id);
1605                        transport->detach_sub_id = VMCI_INVALID_ID;
1606                }
1607
1608                if (!vmci_handle_is_invalid(transport->qp_handle)) {
1609                        vmci_qpair_detach(&transport->qpair);
1610                        transport->qp_handle = VMCI_INVALID_HANDLE;
1611                        transport->produce_size = 0;
1612                        transport->consume_size = 0;
1613                }
1614
1615                kfree(transport);
1616        }
1617}
1618
1619static void vmci_transport_cleanup(struct work_struct *work)
1620{
1621        LIST_HEAD(pending);
1622
1623        spin_lock_bh(&vmci_transport_cleanup_lock);
1624        list_replace_init(&vmci_transport_cleanup_list, &pending);
1625        spin_unlock_bh(&vmci_transport_cleanup_lock);
1626        vmci_transport_free_resources(&pending);
1627}
1628
1629static void vmci_transport_destruct(struct vsock_sock *vsk)
1630{
1631        /* transport can be NULL if we hit a failure at init() time */
1632        if (!vmci_trans(vsk))
1633                return;
1634
1635        /* Ensure that the detach callback doesn't use the sk/vsk
1636         * we are about to destruct.
1637         */
1638        spin_lock_bh(&vmci_trans(vsk)->lock);
1639        vmci_trans(vsk)->sk = NULL;
1640        spin_unlock_bh(&vmci_trans(vsk)->lock);
1641
1642        if (vmci_trans(vsk)->notify_ops)
1643                vmci_trans(vsk)->notify_ops->socket_destruct(vsk);
1644
1645        spin_lock_bh(&vmci_transport_cleanup_lock);
1646        list_add(&vmci_trans(vsk)->elem, &vmci_transport_cleanup_list);
1647        spin_unlock_bh(&vmci_transport_cleanup_lock);
1648        schedule_work(&vmci_transport_cleanup_work);
1649
1650        vsk->trans = NULL;
1651}
1652
1653static void vmci_transport_release(struct vsock_sock *vsk)
1654{
1655        vsock_remove_sock(vsk);
1656
1657        if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) {
1658                vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle);
1659                vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE;
1660        }
1661}
1662
1663static int vmci_transport_dgram_bind(struct vsock_sock *vsk,
1664                                     struct sockaddr_vm *addr)
1665{
1666        u32 port;
1667        u32 flags;
1668        int err;
1669
1670        /* VMCI will select a resource ID for us if we provide
1671         * VMCI_INVALID_ID.
1672         */
1673        port = addr->svm_port == VMADDR_PORT_ANY ?
1674                        VMCI_INVALID_ID : addr->svm_port;
1675
1676        if (port <= LAST_RESERVED_PORT && !capable(CAP_NET_BIND_SERVICE))
1677                return -EACCES;
1678
1679        flags = addr->svm_cid == VMADDR_CID_ANY ?
1680                                VMCI_FLAG_ANYCID_DG_HND : 0;
1681
1682        err = vmci_transport_datagram_create_hnd(port, flags,
1683                                                 vmci_transport_recv_dgram_cb,
1684                                                 &vsk->sk,
1685                                                 &vmci_trans(vsk)->dg_handle);
1686        if (err < VMCI_SUCCESS)
1687                return vmci_transport_error_to_vsock_error(err);
1688        vsock_addr_init(&vsk->local_addr, addr->svm_cid,
1689                        vmci_trans(vsk)->dg_handle.resource);
1690
1691        return 0;
1692}
1693
1694static int vmci_transport_dgram_enqueue(
1695        struct vsock_sock *vsk,
1696        struct sockaddr_vm *remote_addr,
1697        struct msghdr *msg,
1698        size_t len)
1699{
1700        int err;
1701        struct vmci_datagram *dg;
1702
1703        if (len > VMCI_MAX_DG_PAYLOAD_SIZE)
1704                return -EMSGSIZE;
1705
1706        if (!vmci_transport_allow_dgram(vsk, remote_addr->svm_cid))
1707                return -EPERM;
1708
1709        /* Allocate a buffer for the user's message and our packet header. */
1710        dg = kmalloc(len + sizeof(*dg), GFP_KERNEL);
1711        if (!dg)
1712                return -ENOMEM;
1713
1714        memcpy_from_msg(VMCI_DG_PAYLOAD(dg), msg, len);
1715
1716        dg->dst = vmci_make_handle(remote_addr->svm_cid,
1717                                   remote_addr->svm_port);
1718        dg->src = vmci_make_handle(vsk->local_addr.svm_cid,
1719                                   vsk->local_addr.svm_port);
1720        dg->payload_size = len;
1721
1722        err = vmci_datagram_send(dg);
1723        kfree(dg);
1724        if (err < 0)
1725                return vmci_transport_error_to_vsock_error(err);
1726
1727        return err - sizeof(*dg);
1728}
1729
1730static int vmci_transport_dgram_dequeue(struct vsock_sock *vsk,
1731                                        struct msghdr *msg, size_t len,
1732                                        int flags)
1733{
1734        int err;
1735        int noblock;
1736        struct vmci_datagram *dg;
1737        size_t payload_len;
1738        struct sk_buff *skb;
1739
1740        noblock = flags & MSG_DONTWAIT;
1741
1742        if (flags & MSG_OOB || flags & MSG_ERRQUEUE)
1743                return -EOPNOTSUPP;
1744
1745        /* Retrieve the head sk_buff from the socket's receive queue. */
1746        err = 0;
1747        skb = skb_recv_datagram(&vsk->sk, flags, noblock, &err);
1748        if (!skb)
1749                return err;
1750
1751        dg = (struct vmci_datagram *)skb->data;
1752        if (!dg)
1753                /* err is 0, meaning we read zero bytes. */
1754                goto out;
1755
1756        payload_len = dg->payload_size;
1757        /* Ensure the sk_buff matches the payload size claimed in the packet. */
1758        if (payload_len != skb->len - sizeof(*dg)) {
1759                err = -EINVAL;
1760                goto out;
1761        }
1762
1763        if (payload_len > len) {
1764                payload_len = len;
1765                msg->msg_flags |= MSG_TRUNC;
1766        }
1767
1768        /* Place the datagram payload in the user's iovec. */
1769        err = skb_copy_datagram_msg(skb, sizeof(*dg), msg, payload_len);
1770        if (err)
1771                goto out;
1772
1773        if (msg->msg_name) {
1774                /* Provide the address of the sender. */
1775                DECLARE_SOCKADDR(struct sockaddr_vm *, vm_addr, msg->msg_name);
1776                vsock_addr_init(vm_addr, dg->src.context, dg->src.resource);
1777                msg->msg_namelen = sizeof(*vm_addr);
1778        }
1779        err = payload_len;
1780
1781out:
1782        skb_free_datagram(&vsk->sk, skb);
1783        return err;
1784}
1785
1786static bool vmci_transport_dgram_allow(u32 cid, u32 port)
1787{
1788        if (cid == VMADDR_CID_HYPERVISOR) {
1789                /* Registrations of PBRPC Servers do not modify VMX/Hypervisor
1790                 * state and are allowed.
1791                 */
1792                return port == VMCI_UNITY_PBRPC_REGISTER;
1793        }
1794
1795        return true;
1796}
1797
1798static int vmci_transport_connect(struct vsock_sock *vsk)
1799{
1800        int err;
1801        bool old_pkt_proto = false;
1802        struct sock *sk = &vsk->sk;
1803
1804        if (vmci_transport_old_proto_override(&old_pkt_proto) &&
1805                old_pkt_proto) {
1806                err = vmci_transport_send_conn_request(sk, vsk->buffer_size);
1807                if (err < 0) {
1808                        sk->sk_state = TCP_CLOSE;
1809                        return err;
1810                }
1811        } else {
1812                int supported_proto_versions =
1813                        vmci_transport_new_proto_supported_versions();
1814                err = vmci_transport_send_conn_request2(sk, vsk->buffer_size,
1815                                supported_proto_versions);
1816                if (err < 0) {
1817                        sk->sk_state = TCP_CLOSE;
1818                        return err;
1819                }
1820
1821                vsk->sent_request = true;
1822        }
1823
1824        return err;
1825}
1826
1827static ssize_t vmci_transport_stream_dequeue(
1828        struct vsock_sock *vsk,
1829        struct msghdr *msg,
1830        size_t len,
1831        int flags)
1832{
1833        if (flags & MSG_PEEK)
1834                return vmci_qpair_peekv(vmci_trans(vsk)->qpair, msg, len, 0);
1835        else
1836                return vmci_qpair_dequev(vmci_trans(vsk)->qpair, msg, len, 0);
1837}
1838
1839static ssize_t vmci_transport_stream_enqueue(
1840        struct vsock_sock *vsk,
1841        struct msghdr *msg,
1842        size_t len)
1843{
1844        return vmci_qpair_enquev(vmci_trans(vsk)->qpair, msg, len, 0);
1845}
1846
1847static s64 vmci_transport_stream_has_data(struct vsock_sock *vsk)
1848{
1849        return vmci_qpair_consume_buf_ready(vmci_trans(vsk)->qpair);
1850}
1851
1852static s64 vmci_transport_stream_has_space(struct vsock_sock *vsk)
1853{
1854        return vmci_qpair_produce_free_space(vmci_trans(vsk)->qpair);
1855}
1856
1857static u64 vmci_transport_stream_rcvhiwat(struct vsock_sock *vsk)
1858{
1859        return vmci_trans(vsk)->consume_size;
1860}
1861
1862static bool vmci_transport_stream_is_active(struct vsock_sock *vsk)
1863{
1864        return !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle);
1865}
1866
1867static int vmci_transport_notify_poll_in(
1868        struct vsock_sock *vsk,
1869        size_t target,
1870        bool *data_ready_now)
1871{
1872        return vmci_trans(vsk)->notify_ops->poll_in(
1873                        &vsk->sk, target, data_ready_now);
1874}
1875
1876static int vmci_transport_notify_poll_out(
1877        struct vsock_sock *vsk,
1878        size_t target,
1879        bool *space_available_now)
1880{
1881        return vmci_trans(vsk)->notify_ops->poll_out(
1882                        &vsk->sk, target, space_available_now);
1883}
1884
1885static int vmci_transport_notify_recv_init(
1886        struct vsock_sock *vsk,
1887        size_t target,
1888        struct vsock_transport_recv_notify_data *data)
1889{
1890        return vmci_trans(vsk)->notify_ops->recv_init(
1891                        &vsk->sk, target,
1892                        (struct vmci_transport_recv_notify_data *)data);
1893}
1894
1895static int vmci_transport_notify_recv_pre_block(
1896        struct vsock_sock *vsk,
1897        size_t target,
1898        struct vsock_transport_recv_notify_data *data)
1899{
1900        return vmci_trans(vsk)->notify_ops->recv_pre_block(
1901                        &vsk->sk, target,
1902                        (struct vmci_transport_recv_notify_data *)data);
1903}
1904
1905static int vmci_transport_notify_recv_pre_dequeue(
1906        struct vsock_sock *vsk,
1907        size_t target,
1908        struct vsock_transport_recv_notify_data *data)
1909{
1910        return vmci_trans(vsk)->notify_ops->recv_pre_dequeue(
1911                        &vsk->sk, target,
1912                        (struct vmci_transport_recv_notify_data *)data);
1913}
1914
1915static int vmci_transport_notify_recv_post_dequeue(
1916        struct vsock_sock *vsk,
1917        size_t target,
1918        ssize_t copied,
1919        bool data_read,
1920        struct vsock_transport_recv_notify_data *data)
1921{
1922        return vmci_trans(vsk)->notify_ops->recv_post_dequeue(
1923                        &vsk->sk, target, copied, data_read,
1924                        (struct vmci_transport_recv_notify_data *)data);
1925}
1926
1927static int vmci_transport_notify_send_init(
1928        struct vsock_sock *vsk,
1929        struct vsock_transport_send_notify_data *data)
1930{
1931        return vmci_trans(vsk)->notify_ops->send_init(
1932                        &vsk->sk,
1933                        (struct vmci_transport_send_notify_data *)data);
1934}
1935
1936static int vmci_transport_notify_send_pre_block(
1937        struct vsock_sock *vsk,
1938        struct vsock_transport_send_notify_data *data)
1939{
1940        return vmci_trans(vsk)->notify_ops->send_pre_block(
1941                        &vsk->sk,
1942                        (struct vmci_transport_send_notify_data *)data);
1943}
1944
1945static int vmci_transport_notify_send_pre_enqueue(
1946        struct vsock_sock *vsk,
1947        struct vsock_transport_send_notify_data *data)
1948{
1949        return vmci_trans(vsk)->notify_ops->send_pre_enqueue(
1950                        &vsk->sk,
1951                        (struct vmci_transport_send_notify_data *)data);
1952}
1953
1954static int vmci_transport_notify_send_post_enqueue(
1955        struct vsock_sock *vsk,
1956        ssize_t written,
1957        struct vsock_transport_send_notify_data *data)
1958{
1959        return vmci_trans(vsk)->notify_ops->send_post_enqueue(
1960                        &vsk->sk, written,
1961                        (struct vmci_transport_send_notify_data *)data);
1962}
1963
1964static bool vmci_transport_old_proto_override(bool *old_pkt_proto)
1965{
1966        if (PROTOCOL_OVERRIDE != -1) {
1967                if (PROTOCOL_OVERRIDE == 0)
1968                        *old_pkt_proto = true;
1969                else
1970                        *old_pkt_proto = false;
1971
1972                pr_info("Proto override in use\n");
1973                return true;
1974        }
1975
1976        return false;
1977}
1978
1979static bool vmci_transport_proto_to_notify_struct(struct sock *sk,
1980                                                  u16 *proto,
1981                                                  bool old_pkt_proto)
1982{
1983        struct vsock_sock *vsk = vsock_sk(sk);
1984
1985        if (old_pkt_proto) {
1986                if (*proto != VSOCK_PROTO_INVALID) {
1987                        pr_err("Can't set both an old and new protocol\n");
1988                        return false;
1989                }
1990                vmci_trans(vsk)->notify_ops = &vmci_transport_notify_pkt_ops;
1991                goto exit;
1992        }
1993
1994        switch (*proto) {
1995        case VSOCK_PROTO_PKT_ON_NOTIFY:
1996                vmci_trans(vsk)->notify_ops =
1997                        &vmci_transport_notify_pkt_q_state_ops;
1998                break;
1999        default:
2000                pr_err("Unknown notify protocol version\n");
2001                return false;
2002        }
2003
2004exit:
2005        vmci_trans(vsk)->notify_ops->socket_init(sk);
2006        return true;
2007}
2008
2009static u16 vmci_transport_new_proto_supported_versions(void)
2010{
2011        if (PROTOCOL_OVERRIDE != -1)
2012                return PROTOCOL_OVERRIDE;
2013
2014        return VSOCK_PROTO_ALL_SUPPORTED;
2015}
2016
2017static u32 vmci_transport_get_local_cid(void)
2018{
2019        return vmci_get_context_id();
2020}
2021
2022static struct vsock_transport vmci_transport = {
2023        .module = THIS_MODULE,
2024        .init = vmci_transport_socket_init,
2025        .destruct = vmci_transport_destruct,
2026        .release = vmci_transport_release,
2027        .connect = vmci_transport_connect,
2028        .dgram_bind = vmci_transport_dgram_bind,
2029        .dgram_dequeue = vmci_transport_dgram_dequeue,
2030        .dgram_enqueue = vmci_transport_dgram_enqueue,
2031        .dgram_allow = vmci_transport_dgram_allow,
2032        .stream_dequeue = vmci_transport_stream_dequeue,
2033        .stream_enqueue = vmci_transport_stream_enqueue,
2034        .stream_has_data = vmci_transport_stream_has_data,
2035        .stream_has_space = vmci_transport_stream_has_space,
2036        .stream_rcvhiwat = vmci_transport_stream_rcvhiwat,
2037        .stream_is_active = vmci_transport_stream_is_active,
2038        .stream_allow = vmci_transport_stream_allow,
2039        .notify_poll_in = vmci_transport_notify_poll_in,
2040        .notify_poll_out = vmci_transport_notify_poll_out,
2041        .notify_recv_init = vmci_transport_notify_recv_init,
2042        .notify_recv_pre_block = vmci_transport_notify_recv_pre_block,
2043        .notify_recv_pre_dequeue = vmci_transport_notify_recv_pre_dequeue,
2044        .notify_recv_post_dequeue = vmci_transport_notify_recv_post_dequeue,
2045        .notify_send_init = vmci_transport_notify_send_init,
2046        .notify_send_pre_block = vmci_transport_notify_send_pre_block,
2047        .notify_send_pre_enqueue = vmci_transport_notify_send_pre_enqueue,
2048        .notify_send_post_enqueue = vmci_transport_notify_send_post_enqueue,
2049        .shutdown = vmci_transport_shutdown,
2050        .get_local_cid = vmci_transport_get_local_cid,
2051};
2052
2053static bool vmci_check_transport(struct vsock_sock *vsk)
2054{
2055        return vsk->transport == &vmci_transport;
2056}
2057
2058static void vmci_vsock_transport_cb(bool is_host)
2059{
2060        int features;
2061
2062        if (is_host)
2063                features = VSOCK_TRANSPORT_F_H2G;
2064        else
2065                features = VSOCK_TRANSPORT_F_G2H;
2066
2067        vsock_core_register(&vmci_transport, features);
2068}
2069
2070static int __init vmci_transport_init(void)
2071{
2072        int err;
2073
2074        /* Create the datagram handle that we will use to send and receive all
2075         * VSocket control messages for this context.
2076         */
2077        err = vmci_transport_datagram_create_hnd(VMCI_TRANSPORT_PACKET_RID,
2078                                                 VMCI_FLAG_ANYCID_DG_HND,
2079                                                 vmci_transport_recv_stream_cb,
2080                                                 NULL,
2081                                                 &vmci_transport_stream_handle);
2082        if (err < VMCI_SUCCESS) {
2083                pr_err("Unable to create datagram handle. (%d)\n", err);
2084                return vmci_transport_error_to_vsock_error(err);
2085        }
2086        err = vmci_event_subscribe(VMCI_EVENT_QP_RESUMED,
2087                                   vmci_transport_qp_resumed_cb,
2088                                   NULL, &vmci_transport_qp_resumed_sub_id);
2089        if (err < VMCI_SUCCESS) {
2090                pr_err("Unable to subscribe to resumed event. (%d)\n", err);
2091                err = vmci_transport_error_to_vsock_error(err);
2092                vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
2093                goto err_destroy_stream_handle;
2094        }
2095
2096        /* Register only with dgram feature, other features (H2G, G2H) will be
2097         * registered when the first host or guest becomes active.
2098         */
2099        err = vsock_core_register(&vmci_transport, VSOCK_TRANSPORT_F_DGRAM);
2100        if (err < 0)
2101                goto err_unsubscribe;
2102
2103        err = vmci_register_vsock_callback(vmci_vsock_transport_cb);
2104        if (err < 0)
2105                goto err_unregister;
2106
2107        return 0;
2108
2109err_unregister:
2110        vsock_core_unregister(&vmci_transport);
2111err_unsubscribe:
2112        vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id);
2113err_destroy_stream_handle:
2114        vmci_datagram_destroy_handle(vmci_transport_stream_handle);
2115        return err;
2116}
2117module_init(vmci_transport_init);
2118
2119static void __exit vmci_transport_exit(void)
2120{
2121        cancel_work_sync(&vmci_transport_cleanup_work);
2122        vmci_transport_free_resources(&vmci_transport_cleanup_list);
2123
2124        if (!vmci_handle_is_invalid(vmci_transport_stream_handle)) {
2125                if (vmci_datagram_destroy_handle(
2126                        vmci_transport_stream_handle) != VMCI_SUCCESS)
2127                        pr_err("Couldn't destroy datagram handle\n");
2128                vmci_transport_stream_handle = VMCI_INVALID_HANDLE;
2129        }
2130
2131        if (vmci_transport_qp_resumed_sub_id != VMCI_INVALID_ID) {
2132                vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id);
2133                vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
2134        }
2135
2136        vmci_register_vsock_callback(NULL);
2137        vsock_core_unregister(&vmci_transport);
2138}
2139module_exit(vmci_transport_exit);
2140
2141MODULE_AUTHOR("VMware, Inc.");
2142MODULE_DESCRIPTION("VMCI transport for Virtual Sockets");
2143MODULE_VERSION("1.0.5.0-k");
2144MODULE_LICENSE("GPL v2");
2145MODULE_ALIAS("vmware_vsock");
2146MODULE_ALIAS_NETPROTO(PF_VSOCK);
2147