linux/net/vmw_vsock/vmci_transport.c
<<
>>
Prefs
   1/*
   2 * VMware vSockets Driver
   3 *
   4 * Copyright (C) 2007-2013 VMware, Inc. All rights reserved.
   5 *
   6 * This program is free software; you can redistribute it and/or modify it
   7 * under the terms of the GNU General Public License as published by the Free
   8 * Software Foundation version 2 and no later version.
   9 *
  10 * This program is distributed in the hope that it will be useful, but WITHOUT
  11 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
  12 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
  13 * more details.
  14 */
  15
  16#include <linux/types.h>
  17#include <linux/bitops.h>
  18#include <linux/cred.h>
  19#include <linux/init.h>
  20#include <linux/io.h>
  21#include <linux/kernel.h>
  22#include <linux/kmod.h>
  23#include <linux/list.h>
  24#include <linux/miscdevice.h>
  25#include <linux/module.h>
  26#include <linux/mutex.h>
  27#include <linux/net.h>
  28#include <linux/poll.h>
  29#include <linux/skbuff.h>
  30#include <linux/smp.h>
  31#include <linux/socket.h>
  32#include <linux/stddef.h>
  33#include <linux/unistd.h>
  34#include <linux/wait.h>
  35#include <linux/workqueue.h>
  36#include <net/sock.h>
  37#include <net/af_vsock.h>
  38
  39#include "vmci_transport_notify.h"
  40
  41static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg);
  42static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg);
  43static void vmci_transport_peer_detach_cb(u32 sub_id,
  44                                          const struct vmci_event_data *ed,
  45                                          void *client_data);
  46static void vmci_transport_recv_pkt_work(struct work_struct *work);
  47static void vmci_transport_cleanup(struct work_struct *work);
  48static int vmci_transport_recv_listen(struct sock *sk,
  49                                      struct vmci_transport_packet *pkt);
  50static int vmci_transport_recv_connecting_server(
  51                                        struct sock *sk,
  52                                        struct sock *pending,
  53                                        struct vmci_transport_packet *pkt);
  54static int vmci_transport_recv_connecting_client(
  55                                        struct sock *sk,
  56                                        struct vmci_transport_packet *pkt);
  57static int vmci_transport_recv_connecting_client_negotiate(
  58                                        struct sock *sk,
  59                                        struct vmci_transport_packet *pkt);
  60static int vmci_transport_recv_connecting_client_invalid(
  61                                        struct sock *sk,
  62                                        struct vmci_transport_packet *pkt);
  63static int vmci_transport_recv_connected(struct sock *sk,
  64                                         struct vmci_transport_packet *pkt);
  65static bool vmci_transport_old_proto_override(bool *old_pkt_proto);
  66static u16 vmci_transport_new_proto_supported_versions(void);
  67static bool vmci_transport_proto_to_notify_struct(struct sock *sk, u16 *proto,
  68                                                  bool old_pkt_proto);
  69
  70struct vmci_transport_recv_pkt_info {
  71        struct work_struct work;
  72        struct sock *sk;
  73        struct vmci_transport_packet pkt;
  74};
  75
  76static LIST_HEAD(vmci_transport_cleanup_list);
  77static DEFINE_SPINLOCK(vmci_transport_cleanup_lock);
  78static DECLARE_WORK(vmci_transport_cleanup_work, vmci_transport_cleanup);
  79
  80static struct vmci_handle vmci_transport_stream_handle = { VMCI_INVALID_ID,
  81                                                           VMCI_INVALID_ID };
  82static u32 vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
  83
  84static int PROTOCOL_OVERRIDE = -1;
  85
  86#define VMCI_TRANSPORT_DEFAULT_QP_SIZE_MIN   128
  87#define VMCI_TRANSPORT_DEFAULT_QP_SIZE       262144
  88#define VMCI_TRANSPORT_DEFAULT_QP_SIZE_MAX   262144
  89
  90/* The default peer timeout indicates how long we will wait for a peer response
  91 * to a control message.
  92 */
  93#define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ)
  94
  95/* Helper function to convert from a VMCI error code to a VSock error code. */
  96
  97static s32 vmci_transport_error_to_vsock_error(s32 vmci_error)
  98{
  99        int err;
 100
 101        switch (vmci_error) {
 102        case VMCI_ERROR_NO_MEM:
 103                err = ENOMEM;
 104                break;
 105        case VMCI_ERROR_DUPLICATE_ENTRY:
 106        case VMCI_ERROR_ALREADY_EXISTS:
 107                err = EADDRINUSE;
 108                break;
 109        case VMCI_ERROR_NO_ACCESS:
 110                err = EPERM;
 111                break;
 112        case VMCI_ERROR_NO_RESOURCES:
 113                err = ENOBUFS;
 114                break;
 115        case VMCI_ERROR_INVALID_RESOURCE:
 116                err = EHOSTUNREACH;
 117                break;
 118        case VMCI_ERROR_INVALID_ARGS:
 119        default:
 120                err = EINVAL;
 121        }
 122
 123        return err > 0 ? -err : err;
 124}
 125
 126static u32 vmci_transport_peer_rid(u32 peer_cid)
 127{
 128        if (VMADDR_CID_HYPERVISOR == peer_cid)
 129                return VMCI_TRANSPORT_HYPERVISOR_PACKET_RID;
 130
 131        return VMCI_TRANSPORT_PACKET_RID;
 132}
 133
 134static inline void
 135vmci_transport_packet_init(struct vmci_transport_packet *pkt,
 136                           struct sockaddr_vm *src,
 137                           struct sockaddr_vm *dst,
 138                           u8 type,
 139                           u64 size,
 140                           u64 mode,
 141                           struct vmci_transport_waiting_info *wait,
 142                           u16 proto,
 143                           struct vmci_handle handle)
 144{
 145        /* We register the stream control handler as an any cid handle so we
 146         * must always send from a source address of VMADDR_CID_ANY
 147         */
 148        pkt->dg.src = vmci_make_handle(VMADDR_CID_ANY,
 149                                       VMCI_TRANSPORT_PACKET_RID);
 150        pkt->dg.dst = vmci_make_handle(dst->svm_cid,
 151                                       vmci_transport_peer_rid(dst->svm_cid));
 152        pkt->dg.payload_size = sizeof(*pkt) - sizeof(pkt->dg);
 153        pkt->version = VMCI_TRANSPORT_PACKET_VERSION;
 154        pkt->type = type;
 155        pkt->src_port = src->svm_port;
 156        pkt->dst_port = dst->svm_port;
 157        memset(&pkt->proto, 0, sizeof(pkt->proto));
 158        memset(&pkt->_reserved2, 0, sizeof(pkt->_reserved2));
 159
 160        switch (pkt->type) {
 161        case VMCI_TRANSPORT_PACKET_TYPE_INVALID:
 162                pkt->u.size = 0;
 163                break;
 164
 165        case VMCI_TRANSPORT_PACKET_TYPE_REQUEST:
 166        case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE:
 167                pkt->u.size = size;
 168                break;
 169
 170        case VMCI_TRANSPORT_PACKET_TYPE_OFFER:
 171        case VMCI_TRANSPORT_PACKET_TYPE_ATTACH:
 172                pkt->u.handle = handle;
 173                break;
 174
 175        case VMCI_TRANSPORT_PACKET_TYPE_WROTE:
 176        case VMCI_TRANSPORT_PACKET_TYPE_READ:
 177        case VMCI_TRANSPORT_PACKET_TYPE_RST:
 178                pkt->u.size = 0;
 179                break;
 180
 181        case VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN:
 182                pkt->u.mode = mode;
 183                break;
 184
 185        case VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ:
 186        case VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE:
 187                memcpy(&pkt->u.wait, wait, sizeof(pkt->u.wait));
 188                break;
 189
 190        case VMCI_TRANSPORT_PACKET_TYPE_REQUEST2:
 191        case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2:
 192                pkt->u.size = size;
 193                pkt->proto = proto;
 194                break;
 195        }
 196}
 197
 198static inline void
 199vmci_transport_packet_get_addresses(struct vmci_transport_packet *pkt,
 200                                    struct sockaddr_vm *local,
 201                                    struct sockaddr_vm *remote)
 202{
 203        vsock_addr_init(local, pkt->dg.dst.context, pkt->dst_port);
 204        vsock_addr_init(remote, pkt->dg.src.context, pkt->src_port);
 205}
 206
 207static int
 208__vmci_transport_send_control_pkt(struct vmci_transport_packet *pkt,
 209                                  struct sockaddr_vm *src,
 210                                  struct sockaddr_vm *dst,
 211                                  enum vmci_transport_packet_type type,
 212                                  u64 size,
 213                                  u64 mode,
 214                                  struct vmci_transport_waiting_info *wait,
 215                                  u16 proto,
 216                                  struct vmci_handle handle,
 217                                  bool convert_error)
 218{
 219        int err;
 220
 221        vmci_transport_packet_init(pkt, src, dst, type, size, mode, wait,
 222                                   proto, handle);
 223        err = vmci_datagram_send(&pkt->dg);
 224        if (convert_error && (err < 0))
 225                return vmci_transport_error_to_vsock_error(err);
 226
 227        return err;
 228}
 229
 230static int
 231vmci_transport_reply_control_pkt_fast(struct vmci_transport_packet *pkt,
 232                                      enum vmci_transport_packet_type type,
 233                                      u64 size,
 234                                      u64 mode,
 235                                      struct vmci_transport_waiting_info *wait,
 236                                      struct vmci_handle handle)
 237{
 238        struct vmci_transport_packet reply;
 239        struct sockaddr_vm src, dst;
 240
 241        if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST) {
 242                return 0;
 243        } else {
 244                vmci_transport_packet_get_addresses(pkt, &src, &dst);
 245                return __vmci_transport_send_control_pkt(&reply, &src, &dst,
 246                                                         type,
 247                                                         size, mode, wait,
 248                                                         VSOCK_PROTO_INVALID,
 249                                                         handle, true);
 250        }
 251}
 252
 253static int
 254vmci_transport_send_control_pkt_bh(struct sockaddr_vm *src,
 255                                   struct sockaddr_vm *dst,
 256                                   enum vmci_transport_packet_type type,
 257                                   u64 size,
 258                                   u64 mode,
 259                                   struct vmci_transport_waiting_info *wait,
 260                                   struct vmci_handle handle)
 261{
 262        /* Note that it is safe to use a single packet across all CPUs since
 263         * two tasklets of the same type are guaranteed to not ever run
 264         * simultaneously. If that ever changes, or VMCI stops using tasklets,
 265         * we can use per-cpu packets.
 266         */
 267        static struct vmci_transport_packet pkt;
 268
 269        return __vmci_transport_send_control_pkt(&pkt, src, dst, type,
 270                                                 size, mode, wait,
 271                                                 VSOCK_PROTO_INVALID, handle,
 272                                                 false);
 273}
 274
 275static int
 276vmci_transport_send_control_pkt(struct sock *sk,
 277                                enum vmci_transport_packet_type type,
 278                                u64 size,
 279                                u64 mode,
 280                                struct vmci_transport_waiting_info *wait,
 281                                u16 proto,
 282                                struct vmci_handle handle)
 283{
 284        struct vmci_transport_packet *pkt;
 285        struct vsock_sock *vsk;
 286        int err;
 287
 288        vsk = vsock_sk(sk);
 289
 290        if (!vsock_addr_bound(&vsk->local_addr))
 291                return -EINVAL;
 292
 293        if (!vsock_addr_bound(&vsk->remote_addr))
 294                return -EINVAL;
 295
 296        pkt = kmalloc(sizeof(*pkt), GFP_KERNEL);
 297        if (!pkt)
 298                return -ENOMEM;
 299
 300        err = __vmci_transport_send_control_pkt(pkt, &vsk->local_addr,
 301                                                &vsk->remote_addr, type, size,
 302                                                mode, wait, proto, handle,
 303                                                true);
 304        kfree(pkt);
 305
 306        return err;
 307}
 308
 309static int vmci_transport_send_reset_bh(struct sockaddr_vm *dst,
 310                                        struct sockaddr_vm *src,
 311                                        struct vmci_transport_packet *pkt)
 312{
 313        if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST)
 314                return 0;
 315        return vmci_transport_send_control_pkt_bh(
 316                                        dst, src,
 317                                        VMCI_TRANSPORT_PACKET_TYPE_RST, 0,
 318                                        0, NULL, VMCI_INVALID_HANDLE);
 319}
 320
 321static int vmci_transport_send_reset(struct sock *sk,
 322                                     struct vmci_transport_packet *pkt)
 323{
 324        if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST)
 325                return 0;
 326        return vmci_transport_send_control_pkt(sk,
 327                                        VMCI_TRANSPORT_PACKET_TYPE_RST,
 328                                        0, 0, NULL, VSOCK_PROTO_INVALID,
 329                                        VMCI_INVALID_HANDLE);
 330}
 331
 332static int vmci_transport_send_negotiate(struct sock *sk, size_t size)
 333{
 334        return vmci_transport_send_control_pkt(
 335                                        sk,
 336                                        VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE,
 337                                        size, 0, NULL,
 338                                        VSOCK_PROTO_INVALID,
 339                                        VMCI_INVALID_HANDLE);
 340}
 341
 342static int vmci_transport_send_negotiate2(struct sock *sk, size_t size,
 343                                          u16 version)
 344{
 345        return vmci_transport_send_control_pkt(
 346                                        sk,
 347                                        VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2,
 348                                        size, 0, NULL, version,
 349                                        VMCI_INVALID_HANDLE);
 350}
 351
 352static int vmci_transport_send_qp_offer(struct sock *sk,
 353                                        struct vmci_handle handle)
 354{
 355        return vmci_transport_send_control_pkt(
 356                                        sk, VMCI_TRANSPORT_PACKET_TYPE_OFFER, 0,
 357                                        0, NULL,
 358                                        VSOCK_PROTO_INVALID, handle);
 359}
 360
 361static int vmci_transport_send_attach(struct sock *sk,
 362                                      struct vmci_handle handle)
 363{
 364        return vmci_transport_send_control_pkt(
 365                                        sk, VMCI_TRANSPORT_PACKET_TYPE_ATTACH,
 366                                        0, 0, NULL, VSOCK_PROTO_INVALID,
 367                                        handle);
 368}
 369
 370static int vmci_transport_reply_reset(struct vmci_transport_packet *pkt)
 371{
 372        return vmci_transport_reply_control_pkt_fast(
 373                                                pkt,
 374                                                VMCI_TRANSPORT_PACKET_TYPE_RST,
 375                                                0, 0, NULL,
 376                                                VMCI_INVALID_HANDLE);
 377}
 378
 379static int vmci_transport_send_invalid_bh(struct sockaddr_vm *dst,
 380                                          struct sockaddr_vm *src)
 381{
 382        return vmci_transport_send_control_pkt_bh(
 383                                        dst, src,
 384                                        VMCI_TRANSPORT_PACKET_TYPE_INVALID,
 385                                        0, 0, NULL, VMCI_INVALID_HANDLE);
 386}
 387
 388int vmci_transport_send_wrote_bh(struct sockaddr_vm *dst,
 389                                 struct sockaddr_vm *src)
 390{
 391        return vmci_transport_send_control_pkt_bh(
 392                                        dst, src,
 393                                        VMCI_TRANSPORT_PACKET_TYPE_WROTE, 0,
 394                                        0, NULL, VMCI_INVALID_HANDLE);
 395}
 396
 397int vmci_transport_send_read_bh(struct sockaddr_vm *dst,
 398                                struct sockaddr_vm *src)
 399{
 400        return vmci_transport_send_control_pkt_bh(
 401                                        dst, src,
 402                                        VMCI_TRANSPORT_PACKET_TYPE_READ, 0,
 403                                        0, NULL, VMCI_INVALID_HANDLE);
 404}
 405
 406int vmci_transport_send_wrote(struct sock *sk)
 407{
 408        return vmci_transport_send_control_pkt(
 409                                        sk, VMCI_TRANSPORT_PACKET_TYPE_WROTE, 0,
 410                                        0, NULL, VSOCK_PROTO_INVALID,
 411                                        VMCI_INVALID_HANDLE);
 412}
 413
 414int vmci_transport_send_read(struct sock *sk)
 415{
 416        return vmci_transport_send_control_pkt(
 417                                        sk, VMCI_TRANSPORT_PACKET_TYPE_READ, 0,
 418                                        0, NULL, VSOCK_PROTO_INVALID,
 419                                        VMCI_INVALID_HANDLE);
 420}
 421
 422int vmci_transport_send_waiting_write(struct sock *sk,
 423                                      struct vmci_transport_waiting_info *wait)
 424{
 425        return vmci_transport_send_control_pkt(
 426                                sk, VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE,
 427                                0, 0, wait, VSOCK_PROTO_INVALID,
 428                                VMCI_INVALID_HANDLE);
 429}
 430
 431int vmci_transport_send_waiting_read(struct sock *sk,
 432                                     struct vmci_transport_waiting_info *wait)
 433{
 434        return vmci_transport_send_control_pkt(
 435                                sk, VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ,
 436                                0, 0, wait, VSOCK_PROTO_INVALID,
 437                                VMCI_INVALID_HANDLE);
 438}
 439
 440static int vmci_transport_shutdown(struct vsock_sock *vsk, int mode)
 441{
 442        return vmci_transport_send_control_pkt(
 443                                        &vsk->sk,
 444                                        VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN,
 445                                        0, mode, NULL,
 446                                        VSOCK_PROTO_INVALID,
 447                                        VMCI_INVALID_HANDLE);
 448}
 449
 450static int vmci_transport_send_conn_request(struct sock *sk, size_t size)
 451{
 452        return vmci_transport_send_control_pkt(sk,
 453                                        VMCI_TRANSPORT_PACKET_TYPE_REQUEST,
 454                                        size, 0, NULL,
 455                                        VSOCK_PROTO_INVALID,
 456                                        VMCI_INVALID_HANDLE);
 457}
 458
 459static int vmci_transport_send_conn_request2(struct sock *sk, size_t size,
 460                                             u16 version)
 461{
 462        return vmci_transport_send_control_pkt(
 463                                        sk, VMCI_TRANSPORT_PACKET_TYPE_REQUEST2,
 464                                        size, 0, NULL, version,
 465                                        VMCI_INVALID_HANDLE);
 466}
 467
 468static struct sock *vmci_transport_get_pending(
 469                                        struct sock *listener,
 470                                        struct vmci_transport_packet *pkt)
 471{
 472        struct vsock_sock *vlistener;
 473        struct vsock_sock *vpending;
 474        struct sock *pending;
 475        struct sockaddr_vm src;
 476
 477        vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);
 478
 479        vlistener = vsock_sk(listener);
 480
 481        list_for_each_entry(vpending, &vlistener->pending_links,
 482                            pending_links) {
 483                if (vsock_addr_equals_addr(&src, &vpending->remote_addr) &&
 484                    pkt->dst_port == vpending->local_addr.svm_port) {
 485                        pending = sk_vsock(vpending);
 486                        sock_hold(pending);
 487                        goto found;
 488                }
 489        }
 490
 491        pending = NULL;
 492found:
 493        return pending;
 494
 495}
 496
 497static void vmci_transport_release_pending(struct sock *pending)
 498{
 499        sock_put(pending);
 500}
 501
 502/* We allow two kinds of sockets to communicate with a restricted VM: 1)
 503 * trusted sockets 2) sockets from applications running as the same user as the
 504 * VM (this is only true for the host side and only when using hosted products)
 505 */
 506
 507static bool vmci_transport_is_trusted(struct vsock_sock *vsock, u32 peer_cid)
 508{
 509        return vsock->trusted ||
 510               vmci_is_context_owner(peer_cid, vsock->owner->uid);
 511}
 512
 513/* We allow sending datagrams to and receiving datagrams from a restricted VM
 514 * only if it is trusted as described in vmci_transport_is_trusted.
 515 */
 516
 517static bool vmci_transport_allow_dgram(struct vsock_sock *vsock, u32 peer_cid)
 518{
 519        if (VMADDR_CID_HYPERVISOR == peer_cid)
 520                return true;
 521
 522        if (vsock->cached_peer != peer_cid) {
 523                vsock->cached_peer = peer_cid;
 524                if (!vmci_transport_is_trusted(vsock, peer_cid) &&
 525                    (vmci_context_get_priv_flags(peer_cid) &
 526                     VMCI_PRIVILEGE_FLAG_RESTRICTED)) {
 527                        vsock->cached_peer_allow_dgram = false;
 528                } else {
 529                        vsock->cached_peer_allow_dgram = true;
 530                }
 531        }
 532
 533        return vsock->cached_peer_allow_dgram;
 534}
 535
 536static int
 537vmci_transport_queue_pair_alloc(struct vmci_qp **qpair,
 538                                struct vmci_handle *handle,
 539                                u64 produce_size,
 540                                u64 consume_size,
 541                                u32 peer, u32 flags, bool trusted)
 542{
 543        int err = 0;
 544
 545        if (trusted) {
 546                /* Try to allocate our queue pair as trusted. This will only
 547                 * work if vsock is running in the host.
 548                 */
 549
 550                err = vmci_qpair_alloc(qpair, handle, produce_size,
 551                                       consume_size,
 552                                       peer, flags,
 553                                       VMCI_PRIVILEGE_FLAG_TRUSTED);
 554                if (err != VMCI_ERROR_NO_ACCESS)
 555                        goto out;
 556
 557        }
 558
 559        err = vmci_qpair_alloc(qpair, handle, produce_size, consume_size,
 560                               peer, flags, VMCI_NO_PRIVILEGE_FLAGS);
 561out:
 562        if (err < 0) {
 563                pr_err("Could not attach to queue pair with %d\n",
 564                       err);
 565                err = vmci_transport_error_to_vsock_error(err);
 566        }
 567
 568        return err;
 569}
 570
 571static int
 572vmci_transport_datagram_create_hnd(u32 resource_id,
 573                                   u32 flags,
 574                                   vmci_datagram_recv_cb recv_cb,
 575                                   void *client_data,
 576                                   struct vmci_handle *out_handle)
 577{
 578        int err = 0;
 579
 580        /* Try to allocate our datagram handler as trusted. This will only work
 581         * if vsock is running in the host.
 582         */
 583
 584        err = vmci_datagram_create_handle_priv(resource_id, flags,
 585                                               VMCI_PRIVILEGE_FLAG_TRUSTED,
 586                                               recv_cb,
 587                                               client_data, out_handle);
 588
 589        if (err == VMCI_ERROR_NO_ACCESS)
 590                err = vmci_datagram_create_handle(resource_id, flags,
 591                                                  recv_cb, client_data,
 592                                                  out_handle);
 593
 594        return err;
 595}
 596
 597/* This is invoked as part of a tasklet that's scheduled when the VMCI
 598 * interrupt fires.  This is run in bottom-half context and if it ever needs to
 599 * sleep it should defer that work to a work queue.
 600 */
 601
 602static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg)
 603{
 604        struct sock *sk;
 605        size_t size;
 606        struct sk_buff *skb;
 607        struct vsock_sock *vsk;
 608
 609        sk = (struct sock *)data;
 610
 611        /* This handler is privileged when this module is running on the host.
 612         * We will get datagrams from all endpoints (even VMs that are in a
 613         * restricted context). If we get one from a restricted context then
 614         * the destination socket must be trusted.
 615         *
 616         * NOTE: We access the socket struct without holding the lock here.
 617         * This is ok because the field we are interested is never modified
 618         * outside of the create and destruct socket functions.
 619         */
 620        vsk = vsock_sk(sk);
 621        if (!vmci_transport_allow_dgram(vsk, dg->src.context))
 622                return VMCI_ERROR_NO_ACCESS;
 623
 624        size = VMCI_DG_SIZE(dg);
 625
 626        /* Attach the packet to the socket's receive queue as an sk_buff. */
 627        skb = alloc_skb(size, GFP_ATOMIC);
 628        if (!skb)
 629                return VMCI_ERROR_NO_MEM;
 630
 631        /* sk_receive_skb() will do a sock_put(), so hold here. */
 632        sock_hold(sk);
 633        skb_put(skb, size);
 634        memcpy(skb->data, dg, size);
 635        sk_receive_skb(sk, skb, 0);
 636
 637        return VMCI_SUCCESS;
 638}
 639
 640static bool vmci_transport_stream_allow(u32 cid, u32 port)
 641{
 642        static const u32 non_socket_contexts[] = {
 643                VMADDR_CID_RESERVED,
 644        };
 645        int i;
 646
 647        BUILD_BUG_ON(sizeof(cid) != sizeof(*non_socket_contexts));
 648
 649        for (i = 0; i < ARRAY_SIZE(non_socket_contexts); i++) {
 650                if (cid == non_socket_contexts[i])
 651                        return false;
 652        }
 653
 654        return true;
 655}
 656
 657/* This is invoked as part of a tasklet that's scheduled when the VMCI
 658 * interrupt fires.  This is run in bottom-half context but it defers most of
 659 * its work to the packet handling work queue.
 660 */
 661
 662static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg)
 663{
 664        struct sock *sk;
 665        struct sockaddr_vm dst;
 666        struct sockaddr_vm src;
 667        struct vmci_transport_packet *pkt;
 668        struct vsock_sock *vsk;
 669        bool bh_process_pkt;
 670        int err;
 671
 672        sk = NULL;
 673        err = VMCI_SUCCESS;
 674        bh_process_pkt = false;
 675
 676        /* Ignore incoming packets from contexts without sockets, or resources
 677         * that aren't vsock implementations.
 678         */
 679
 680        if (!vmci_transport_stream_allow(dg->src.context, -1)
 681            || vmci_transport_peer_rid(dg->src.context) != dg->src.resource)
 682                return VMCI_ERROR_NO_ACCESS;
 683
 684        if (VMCI_DG_SIZE(dg) < sizeof(*pkt))
 685                /* Drop datagrams that do not contain full VSock packets. */
 686                return VMCI_ERROR_INVALID_ARGS;
 687
 688        pkt = (struct vmci_transport_packet *)dg;
 689
 690        /* Find the socket that should handle this packet.  First we look for a
 691         * connected socket and if there is none we look for a socket bound to
 692         * the destintation address.
 693         */
 694        vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);
 695        vsock_addr_init(&dst, pkt->dg.dst.context, pkt->dst_port);
 696
 697        sk = vsock_find_connected_socket(&src, &dst);
 698        if (!sk) {
 699                sk = vsock_find_bound_socket(&dst);
 700                if (!sk) {
 701                        /* We could not find a socket for this specified
 702                         * address.  If this packet is a RST, we just drop it.
 703                         * If it is another packet, we send a RST.  Note that
 704                         * we do not send a RST reply to RSTs so that we do not
 705                         * continually send RSTs between two endpoints.
 706                         *
 707                         * Note that since this is a reply, dst is src and src
 708                         * is dst.
 709                         */
 710                        if (vmci_transport_send_reset_bh(&dst, &src, pkt) < 0)
 711                                pr_err("unable to send reset\n");
 712
 713                        err = VMCI_ERROR_NOT_FOUND;
 714                        goto out;
 715                }
 716        }
 717
 718        /* If the received packet type is beyond all types known to this
 719         * implementation, reply with an invalid message.  Hopefully this will
 720         * help when implementing backwards compatibility in the future.
 721         */
 722        if (pkt->type >= VMCI_TRANSPORT_PACKET_TYPE_MAX) {
 723                vmci_transport_send_invalid_bh(&dst, &src);
 724                err = VMCI_ERROR_INVALID_ARGS;
 725                goto out;
 726        }
 727
 728        /* This handler is privileged when this module is running on the host.
 729         * We will get datagram connect requests from all endpoints (even VMs
 730         * that are in a restricted context). If we get one from a restricted
 731         * context then the destination socket must be trusted.
 732         *
 733         * NOTE: We access the socket struct without holding the lock here.
 734         * This is ok because the field we are interested is never modified
 735         * outside of the create and destruct socket functions.
 736         */
 737        vsk = vsock_sk(sk);
 738        if (!vmci_transport_allow_dgram(vsk, pkt->dg.src.context)) {
 739                err = VMCI_ERROR_NO_ACCESS;
 740                goto out;
 741        }
 742
 743        /* We do most everything in a work queue, but let's fast path the
 744         * notification of reads and writes to help data transfer performance.
 745         * We can only do this if there is no process context code executing
 746         * for this socket since that may change the state.
 747         */
 748        bh_lock_sock(sk);
 749
 750        if (!sock_owned_by_user(sk)) {
 751                /* The local context ID may be out of date, update it. */
 752                vsk->local_addr.svm_cid = dst.svm_cid;
 753
 754                if (sk->sk_state == SS_CONNECTED)
 755                        vmci_trans(vsk)->notify_ops->handle_notify_pkt(
 756                                        sk, pkt, true, &dst, &src,
 757                                        &bh_process_pkt);
 758        }
 759
 760        bh_unlock_sock(sk);
 761
 762        if (!bh_process_pkt) {
 763                struct vmci_transport_recv_pkt_info *recv_pkt_info;
 764
 765                recv_pkt_info = kmalloc(sizeof(*recv_pkt_info), GFP_ATOMIC);
 766                if (!recv_pkt_info) {
 767                        if (vmci_transport_send_reset_bh(&dst, &src, pkt) < 0)
 768                                pr_err("unable to send reset\n");
 769
 770                        err = VMCI_ERROR_NO_MEM;
 771                        goto out;
 772                }
 773
 774                recv_pkt_info->sk = sk;
 775                memcpy(&recv_pkt_info->pkt, pkt, sizeof(recv_pkt_info->pkt));
 776                INIT_WORK(&recv_pkt_info->work, vmci_transport_recv_pkt_work);
 777
 778                schedule_work(&recv_pkt_info->work);
 779                /* Clear sk so that the reference count incremented by one of
 780                 * the Find functions above is not decremented below.  We need
 781                 * that reference count for the packet handler we've scheduled
 782                 * to run.
 783                 */
 784                sk = NULL;
 785        }
 786
 787out:
 788        if (sk)
 789                sock_put(sk);
 790
 791        return err;
 792}
 793
 794static void vmci_transport_handle_detach(struct sock *sk)
 795{
 796        struct vsock_sock *vsk;
 797
 798        vsk = vsock_sk(sk);
 799        if (!vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)) {
 800                sock_set_flag(sk, SOCK_DONE);
 801
 802                /* On a detach the peer will not be sending or receiving
 803                 * anymore.
 804                 */
 805                vsk->peer_shutdown = SHUTDOWN_MASK;
 806
 807                /* We should not be sending anymore since the peer won't be
 808                 * there to receive, but we can still receive if there is data
 809                 * left in our consume queue.
 810                 */
 811                if (vsock_stream_has_data(vsk) <= 0) {
 812                        if (sk->sk_state == SS_CONNECTING) {
 813                                /* The peer may detach from a queue pair while
 814                                 * we are still in the connecting state, i.e.,
 815                                 * if the peer VM is killed after attaching to
 816                                 * a queue pair, but before we complete the
 817                                 * handshake. In that case, we treat the detach
 818                                 * event like a reset.
 819                                 */
 820
 821                                sk->sk_state = SS_UNCONNECTED;
 822                                sk->sk_err = ECONNRESET;
 823                                sk->sk_error_report(sk);
 824                                return;
 825                        }
 826                        sk->sk_state = SS_UNCONNECTED;
 827                }
 828                sk->sk_state_change(sk);
 829        }
 830}
 831
 832static void vmci_transport_peer_detach_cb(u32 sub_id,
 833                                          const struct vmci_event_data *e_data,
 834                                          void *client_data)
 835{
 836        struct vmci_transport *trans = client_data;
 837        const struct vmci_event_payload_qp *e_payload;
 838
 839        e_payload = vmci_event_data_const_payload(e_data);
 840
 841        /* XXX This is lame, we should provide a way to lookup sockets by
 842         * qp_handle.
 843         */
 844        if (vmci_handle_is_invalid(e_payload->handle) ||
 845            !vmci_handle_is_equal(trans->qp_handle, e_payload->handle))
 846                return;
 847
 848        /* We don't ask for delayed CBs when we subscribe to this event (we
 849         * pass 0 as flags to vmci_event_subscribe()).  VMCI makes no
 850         * guarantees in that case about what context we might be running in,
 851         * so it could be BH or process, blockable or non-blockable.  So we
 852         * need to account for all possible contexts here.
 853         */
 854        spin_lock_bh(&trans->lock);
 855        if (!trans->sk)
 856                goto out;
 857
 858        /* Apart from here, trans->lock is only grabbed as part of sk destruct,
 859         * where trans->sk isn't locked.
 860         */
 861        bh_lock_sock(trans->sk);
 862
 863        vmci_transport_handle_detach(trans->sk);
 864
 865        bh_unlock_sock(trans->sk);
 866 out:
 867        spin_unlock_bh(&trans->lock);
 868}
 869
 870static void vmci_transport_qp_resumed_cb(u32 sub_id,
 871                                         const struct vmci_event_data *e_data,
 872                                         void *client_data)
 873{
 874        vsock_for_each_connected_socket(vmci_transport_handle_detach);
 875}
 876
 877static void vmci_transport_recv_pkt_work(struct work_struct *work)
 878{
 879        struct vmci_transport_recv_pkt_info *recv_pkt_info;
 880        struct vmci_transport_packet *pkt;
 881        struct sock *sk;
 882
 883        recv_pkt_info =
 884                container_of(work, struct vmci_transport_recv_pkt_info, work);
 885        sk = recv_pkt_info->sk;
 886        pkt = &recv_pkt_info->pkt;
 887
 888        lock_sock(sk);
 889
 890        /* The local context ID may be out of date. */
 891        vsock_sk(sk)->local_addr.svm_cid = pkt->dg.dst.context;
 892
 893        switch (sk->sk_state) {
 894        case VSOCK_SS_LISTEN:
 895                vmci_transport_recv_listen(sk, pkt);
 896                break;
 897        case SS_CONNECTING:
 898                /* Processing of pending connections for servers goes through
 899                 * the listening socket, so see vmci_transport_recv_listen()
 900                 * for that path.
 901                 */
 902                vmci_transport_recv_connecting_client(sk, pkt);
 903                break;
 904        case SS_CONNECTED:
 905                vmci_transport_recv_connected(sk, pkt);
 906                break;
 907        default:
 908                /* Because this function does not run in the same context as
 909                 * vmci_transport_recv_stream_cb it is possible that the
 910                 * socket has closed. We need to let the other side know or it
 911                 * could be sitting in a connect and hang forever. Send a
 912                 * reset to prevent that.
 913                 */
 914                vmci_transport_send_reset(sk, pkt);
 915                break;
 916        }
 917
 918        release_sock(sk);
 919        kfree(recv_pkt_info);
 920        /* Release reference obtained in the stream callback when we fetched
 921         * this socket out of the bound or connected list.
 922         */
 923        sock_put(sk);
 924}
 925
 926static int vmci_transport_recv_listen(struct sock *sk,
 927                                      struct vmci_transport_packet *pkt)
 928{
 929        struct sock *pending;
 930        struct vsock_sock *vpending;
 931        int err;
 932        u64 qp_size;
 933        bool old_request = false;
 934        bool old_pkt_proto = false;
 935
 936        err = 0;
 937
 938        /* Because we are in the listen state, we could be receiving a packet
 939         * for ourself or any previous connection requests that we received.
 940         * If it's the latter, we try to find a socket in our list of pending
 941         * connections and, if we do, call the appropriate handler for the
 942         * state that that socket is in.  Otherwise we try to service the
 943         * connection request.
 944         */
 945        pending = vmci_transport_get_pending(sk, pkt);
 946        if (pending) {
 947                lock_sock(pending);
 948
 949                /* The local context ID may be out of date. */
 950                vsock_sk(pending)->local_addr.svm_cid = pkt->dg.dst.context;
 951
 952                switch (pending->sk_state) {
 953                case SS_CONNECTING:
 954                        err = vmci_transport_recv_connecting_server(sk,
 955                                                                    pending,
 956                                                                    pkt);
 957                        break;
 958                default:
 959                        vmci_transport_send_reset(pending, pkt);
 960                        err = -EINVAL;
 961                }
 962
 963                if (err < 0)
 964                        vsock_remove_pending(sk, pending);
 965
 966                release_sock(pending);
 967                vmci_transport_release_pending(pending);
 968
 969                return err;
 970        }
 971
 972        /* The listen state only accepts connection requests.  Reply with a
 973         * reset unless we received a reset.
 974         */
 975
 976        if (!(pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST ||
 977              pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST2)) {
 978                vmci_transport_reply_reset(pkt);
 979                return -EINVAL;
 980        }
 981
 982        if (pkt->u.size == 0) {
 983                vmci_transport_reply_reset(pkt);
 984                return -EINVAL;
 985        }
 986
 987        /* If this socket can't accommodate this connection request, we send a
 988         * reset.  Otherwise we create and initialize a child socket and reply
 989         * with a connection negotiation.
 990         */
 991        if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog) {
 992                vmci_transport_reply_reset(pkt);
 993                return -ECONNREFUSED;
 994        }
 995
 996        pending = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
 997                                 sk->sk_type, 0);
 998        if (!pending) {
 999                vmci_transport_send_reset(sk, pkt);
1000                return -ENOMEM;
1001        }
1002
1003        vpending = vsock_sk(pending);
1004
1005        vsock_addr_init(&vpending->local_addr, pkt->dg.dst.context,
1006                        pkt->dst_port);
1007        vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context,
1008                        pkt->src_port);
1009
1010        /* If the proposed size fits within our min/max, accept it. Otherwise
1011         * propose our own size.
1012         */
1013        if (pkt->u.size >= vmci_trans(vpending)->queue_pair_min_size &&
1014            pkt->u.size <= vmci_trans(vpending)->queue_pair_max_size) {
1015                qp_size = pkt->u.size;
1016        } else {
1017                qp_size = vmci_trans(vpending)->queue_pair_size;
1018        }
1019
1020        /* Figure out if we are using old or new requests based on the
1021         * overrides pkt types sent by our peer.
1022         */
1023        if (vmci_transport_old_proto_override(&old_pkt_proto)) {
1024                old_request = old_pkt_proto;
1025        } else {
1026                if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST)
1027                        old_request = true;
1028                else if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST2)
1029                        old_request = false;
1030
1031        }
1032
1033        if (old_request) {
1034                /* Handle a REQUEST (or override) */
1035                u16 version = VSOCK_PROTO_INVALID;
1036                if (vmci_transport_proto_to_notify_struct(
1037                        pending, &version, true))
1038                        err = vmci_transport_send_negotiate(pending, qp_size);
1039                else
1040                        err = -EINVAL;
1041
1042        } else {
1043                /* Handle a REQUEST2 (or override) */
1044                int proto_int = pkt->proto;
1045                int pos;
1046                u16 active_proto_version = 0;
1047
1048                /* The list of possible protocols is the intersection of all
1049                 * protocols the client supports ... plus all the protocols we
1050                 * support.
1051                 */
1052                proto_int &= vmci_transport_new_proto_supported_versions();
1053
1054                /* We choose the highest possible protocol version and use that
1055                 * one.
1056                 */
1057                pos = fls(proto_int);
1058                if (pos) {
1059                        active_proto_version = (1 << (pos - 1));
1060                        if (vmci_transport_proto_to_notify_struct(
1061                                pending, &active_proto_version, false))
1062                                err = vmci_transport_send_negotiate2(pending,
1063                                                        qp_size,
1064                                                        active_proto_version);
1065                        else
1066                                err = -EINVAL;
1067
1068                } else {
1069                        err = -EINVAL;
1070                }
1071        }
1072
1073        if (err < 0) {
1074                vmci_transport_send_reset(sk, pkt);
1075                sock_put(pending);
1076                err = vmci_transport_error_to_vsock_error(err);
1077                goto out;
1078        }
1079
1080        vsock_add_pending(sk, pending);
1081        sk->sk_ack_backlog++;
1082
1083        pending->sk_state = SS_CONNECTING;
1084        vmci_trans(vpending)->produce_size =
1085                vmci_trans(vpending)->consume_size = qp_size;
1086        vmci_trans(vpending)->queue_pair_size = qp_size;
1087
1088        vmci_trans(vpending)->notify_ops->process_request(pending);
1089
1090        /* We might never receive another message for this socket and it's not
1091         * connected to any process, so we have to ensure it gets cleaned up
1092         * ourself.  Our delayed work function will take care of that.  Note
1093         * that we do not ever cancel this function since we have few
1094         * guarantees about its state when calling cancel_delayed_work().
1095         * Instead we hold a reference on the socket for that function and make
1096         * it capable of handling cases where it needs to do nothing but
1097         * release that reference.
1098         */
1099        vpending->listener = sk;
1100        sock_hold(sk);
1101        sock_hold(pending);
1102        INIT_DELAYED_WORK(&vpending->dwork, vsock_pending_work);
1103        schedule_delayed_work(&vpending->dwork, HZ);
1104
1105out:
1106        return err;
1107}
1108
1109static int
1110vmci_transport_recv_connecting_server(struct sock *listener,
1111                                      struct sock *pending,
1112                                      struct vmci_transport_packet *pkt)
1113{
1114        struct vsock_sock *vpending;
1115        struct vmci_handle handle;
1116        struct vmci_qp *qpair;
1117        bool is_local;
1118        u32 flags;
1119        u32 detach_sub_id;
1120        int err;
1121        int skerr;
1122
1123        vpending = vsock_sk(pending);
1124        detach_sub_id = VMCI_INVALID_ID;
1125
1126        switch (pkt->type) {
1127        case VMCI_TRANSPORT_PACKET_TYPE_OFFER:
1128                if (vmci_handle_is_invalid(pkt->u.handle)) {
1129                        vmci_transport_send_reset(pending, pkt);
1130                        skerr = EPROTO;
1131                        err = -EINVAL;
1132                        goto destroy;
1133                }
1134                break;
1135        default:
1136                /* Close and cleanup the connection. */
1137                vmci_transport_send_reset(pending, pkt);
1138                skerr = EPROTO;
1139                err = pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST ? 0 : -EINVAL;
1140                goto destroy;
1141        }
1142
1143        /* In order to complete the connection we need to attach to the offered
1144         * queue pair and send an attach notification.  We also subscribe to the
1145         * detach event so we know when our peer goes away, and we do that
1146         * before attaching so we don't miss an event.  If all this succeeds,
1147         * we update our state and wakeup anything waiting in accept() for a
1148         * connection.
1149         */
1150
1151        /* We don't care about attach since we ensure the other side has
1152         * attached by specifying the ATTACH_ONLY flag below.
1153         */
1154        err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_DETACH,
1155                                   vmci_transport_peer_detach_cb,
1156                                   vmci_trans(vpending), &detach_sub_id);
1157        if (err < VMCI_SUCCESS) {
1158                vmci_transport_send_reset(pending, pkt);
1159                err = vmci_transport_error_to_vsock_error(err);
1160                skerr = -err;
1161                goto destroy;
1162        }
1163
1164        vmci_trans(vpending)->detach_sub_id = detach_sub_id;
1165
1166        /* Now attach to the queue pair the client created. */
1167        handle = pkt->u.handle;
1168
1169        /* vpending->local_addr always has a context id so we do not need to
1170         * worry about VMADDR_CID_ANY in this case.
1171         */
1172        is_local =
1173            vpending->remote_addr.svm_cid == vpending->local_addr.svm_cid;
1174        flags = VMCI_QPFLAG_ATTACH_ONLY;
1175        flags |= is_local ? VMCI_QPFLAG_LOCAL : 0;
1176
1177        err = vmci_transport_queue_pair_alloc(
1178                                        &qpair,
1179                                        &handle,
1180                                        vmci_trans(vpending)->produce_size,
1181                                        vmci_trans(vpending)->consume_size,
1182                                        pkt->dg.src.context,
1183                                        flags,
1184                                        vmci_transport_is_trusted(
1185                                                vpending,
1186                                                vpending->remote_addr.svm_cid));
1187        if (err < 0) {
1188                vmci_transport_send_reset(pending, pkt);
1189                skerr = -err;
1190                goto destroy;
1191        }
1192
1193        vmci_trans(vpending)->qp_handle = handle;
1194        vmci_trans(vpending)->qpair = qpair;
1195
1196        /* When we send the attach message, we must be ready to handle incoming
1197         * control messages on the newly connected socket. So we move the
1198         * pending socket to the connected state before sending the attach
1199         * message. Otherwise, an incoming packet triggered by the attach being
1200         * received by the peer may be processed concurrently with what happens
1201         * below after sending the attach message, and that incoming packet
1202         * will find the listening socket instead of the (currently) pending
1203         * socket. Note that enqueueing the socket increments the reference
1204         * count, so even if a reset comes before the connection is accepted,
1205         * the socket will be valid until it is removed from the queue.
1206         *
1207         * If we fail sending the attach below, we remove the socket from the
1208         * connected list and move the socket to SS_UNCONNECTED before
1209         * releasing the lock, so a pending slow path processing of an incoming
1210         * packet will not see the socket in the connected state in that case.
1211         */
1212        pending->sk_state = SS_CONNECTED;
1213
1214        vsock_insert_connected(vpending);
1215
1216        /* Notify our peer of our attach. */
1217        err = vmci_transport_send_attach(pending, handle);
1218        if (err < 0) {
1219                vsock_remove_connected(vpending);
1220                pr_err("Could not send attach\n");
1221                vmci_transport_send_reset(pending, pkt);
1222                err = vmci_transport_error_to_vsock_error(err);
1223                skerr = -err;
1224                goto destroy;
1225        }
1226
1227        /* We have a connection. Move the now connected socket from the
1228         * listener's pending list to the accept queue so callers of accept()
1229         * can find it.
1230         */
1231        vsock_remove_pending(listener, pending);
1232        vsock_enqueue_accept(listener, pending);
1233
1234        /* Callers of accept() will be be waiting on the listening socket, not
1235         * the pending socket.
1236         */
1237        listener->sk_data_ready(listener);
1238
1239        return 0;
1240
1241destroy:
1242        pending->sk_err = skerr;
1243        pending->sk_state = SS_UNCONNECTED;
1244        /* As long as we drop our reference, all necessary cleanup will handle
1245         * when the cleanup function drops its reference and our destruct
1246         * implementation is called.  Note that since the listen handler will
1247         * remove pending from the pending list upon our failure, the cleanup
1248         * function won't drop the additional reference, which is why we do it
1249         * here.
1250         */
1251        sock_put(pending);
1252
1253        return err;
1254}
1255
1256static int
1257vmci_transport_recv_connecting_client(struct sock *sk,
1258                                      struct vmci_transport_packet *pkt)
1259{
1260        struct vsock_sock *vsk;
1261        int err;
1262        int skerr;
1263
1264        vsk = vsock_sk(sk);
1265
1266        switch (pkt->type) {
1267        case VMCI_TRANSPORT_PACKET_TYPE_ATTACH:
1268                if (vmci_handle_is_invalid(pkt->u.handle) ||
1269                    !vmci_handle_is_equal(pkt->u.handle,
1270                                          vmci_trans(vsk)->qp_handle)) {
1271                        skerr = EPROTO;
1272                        err = -EINVAL;
1273                        goto destroy;
1274                }
1275
1276                /* Signify the socket is connected and wakeup the waiter in
1277                 * connect(). Also place the socket in the connected table for
1278                 * accounting (it can already be found since it's in the bound
1279                 * table).
1280                 */
1281                sk->sk_state = SS_CONNECTED;
1282                sk->sk_socket->state = SS_CONNECTED;
1283                vsock_insert_connected(vsk);
1284                sk->sk_state_change(sk);
1285
1286                break;
1287        case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE:
1288        case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2:
1289                if (pkt->u.size == 0
1290                    || pkt->dg.src.context != vsk->remote_addr.svm_cid
1291                    || pkt->src_port != vsk->remote_addr.svm_port
1292                    || !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)
1293                    || vmci_trans(vsk)->qpair
1294                    || vmci_trans(vsk)->produce_size != 0
1295                    || vmci_trans(vsk)->consume_size != 0
1296                    || vmci_trans(vsk)->detach_sub_id != VMCI_INVALID_ID) {
1297                        skerr = EPROTO;
1298                        err = -EINVAL;
1299
1300                        goto destroy;
1301                }
1302
1303                err = vmci_transport_recv_connecting_client_negotiate(sk, pkt);
1304                if (err) {
1305                        skerr = -err;
1306                        goto destroy;
1307                }
1308
1309                break;
1310        case VMCI_TRANSPORT_PACKET_TYPE_INVALID:
1311                err = vmci_transport_recv_connecting_client_invalid(sk, pkt);
1312                if (err) {
1313                        skerr = -err;
1314                        goto destroy;
1315                }
1316
1317                break;
1318        case VMCI_TRANSPORT_PACKET_TYPE_RST:
1319                /* Older versions of the linux code (WS 6.5 / ESX 4.0) used to
1320                 * continue processing here after they sent an INVALID packet.
1321                 * This meant that we got a RST after the INVALID. We ignore a
1322                 * RST after an INVALID. The common code doesn't send the RST
1323                 * ... so we can hang if an old version of the common code
1324                 * fails between getting a REQUEST and sending an OFFER back.
1325                 * Not much we can do about it... except hope that it doesn't
1326                 * happen.
1327                 */
1328                if (vsk->ignore_connecting_rst) {
1329                        vsk->ignore_connecting_rst = false;
1330                } else {
1331                        skerr = ECONNRESET;
1332                        err = 0;
1333                        goto destroy;
1334                }
1335
1336                break;
1337        default:
1338                /* Close and cleanup the connection. */
1339                skerr = EPROTO;
1340                err = -EINVAL;
1341                goto destroy;
1342        }
1343
1344        return 0;
1345
1346destroy:
1347        vmci_transport_send_reset(sk, pkt);
1348
1349        sk->sk_state = SS_UNCONNECTED;
1350        sk->sk_err = skerr;
1351        sk->sk_error_report(sk);
1352        return err;
1353}
1354
1355static int vmci_transport_recv_connecting_client_negotiate(
1356                                        struct sock *sk,
1357                                        struct vmci_transport_packet *pkt)
1358{
1359        int err;
1360        struct vsock_sock *vsk;
1361        struct vmci_handle handle;
1362        struct vmci_qp *qpair;
1363        u32 detach_sub_id;
1364        bool is_local;
1365        u32 flags;
1366        bool old_proto = true;
1367        bool old_pkt_proto;
1368        u16 version;
1369
1370        vsk = vsock_sk(sk);
1371        handle = VMCI_INVALID_HANDLE;
1372        detach_sub_id = VMCI_INVALID_ID;
1373
1374        /* If we have gotten here then we should be past the point where old
1375         * linux vsock could have sent the bogus rst.
1376         */
1377        vsk->sent_request = false;
1378        vsk->ignore_connecting_rst = false;
1379
1380        /* Verify that we're OK with the proposed queue pair size */
1381        if (pkt->u.size < vmci_trans(vsk)->queue_pair_min_size ||
1382            pkt->u.size > vmci_trans(vsk)->queue_pair_max_size) {
1383                err = -EINVAL;
1384                goto destroy;
1385        }
1386
1387        /* At this point we know the CID the peer is using to talk to us. */
1388
1389        if (vsk->local_addr.svm_cid == VMADDR_CID_ANY)
1390                vsk->local_addr.svm_cid = pkt->dg.dst.context;
1391
1392        /* Setup the notify ops to be the highest supported version that both
1393         * the server and the client support.
1394         */
1395
1396        if (vmci_transport_old_proto_override(&old_pkt_proto)) {
1397                old_proto = old_pkt_proto;
1398        } else {
1399                if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE)
1400                        old_proto = true;
1401                else if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2)
1402                        old_proto = false;
1403
1404        }
1405
1406        if (old_proto)
1407                version = VSOCK_PROTO_INVALID;
1408        else
1409                version = pkt->proto;
1410
1411        if (!vmci_transport_proto_to_notify_struct(sk, &version, old_proto)) {
1412                err = -EINVAL;
1413                goto destroy;
1414        }
1415
1416        /* Subscribe to detach events first.
1417         *
1418         * XXX We attach once for each queue pair created for now so it is easy
1419         * to find the socket (it's provided), but later we should only
1420         * subscribe once and add a way to lookup sockets by queue pair handle.
1421         */
1422        err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_DETACH,
1423                                   vmci_transport_peer_detach_cb,
1424                                   vmci_trans(vsk), &detach_sub_id);
1425        if (err < VMCI_SUCCESS) {
1426                err = vmci_transport_error_to_vsock_error(err);
1427                goto destroy;
1428        }
1429
1430        /* Make VMCI select the handle for us. */
1431        handle = VMCI_INVALID_HANDLE;
1432        is_local = vsk->remote_addr.svm_cid == vsk->local_addr.svm_cid;
1433        flags = is_local ? VMCI_QPFLAG_LOCAL : 0;
1434
1435        err = vmci_transport_queue_pair_alloc(&qpair,
1436                                              &handle,
1437                                              pkt->u.size,
1438                                              pkt->u.size,
1439                                              vsk->remote_addr.svm_cid,
1440                                              flags,
1441                                              vmci_transport_is_trusted(
1442                                                  vsk,
1443                                                  vsk->
1444                                                  remote_addr.svm_cid));
1445        if (err < 0)
1446                goto destroy;
1447
1448        err = vmci_transport_send_qp_offer(sk, handle);
1449        if (err < 0) {
1450                err = vmci_transport_error_to_vsock_error(err);
1451                goto destroy;
1452        }
1453
1454        vmci_trans(vsk)->qp_handle = handle;
1455        vmci_trans(vsk)->qpair = qpair;
1456
1457        vmci_trans(vsk)->produce_size = vmci_trans(vsk)->consume_size =
1458                pkt->u.size;
1459
1460        vmci_trans(vsk)->detach_sub_id = detach_sub_id;
1461
1462        vmci_trans(vsk)->notify_ops->process_negotiate(sk);
1463
1464        return 0;
1465
1466destroy:
1467        if (detach_sub_id != VMCI_INVALID_ID)
1468                vmci_event_unsubscribe(detach_sub_id);
1469
1470        if (!vmci_handle_is_invalid(handle))
1471                vmci_qpair_detach(&qpair);
1472
1473        return err;
1474}
1475
1476static int
1477vmci_transport_recv_connecting_client_invalid(struct sock *sk,
1478                                              struct vmci_transport_packet *pkt)
1479{
1480        int err = 0;
1481        struct vsock_sock *vsk = vsock_sk(sk);
1482
1483        if (vsk->sent_request) {
1484                vsk->sent_request = false;
1485                vsk->ignore_connecting_rst = true;
1486
1487                err = vmci_transport_send_conn_request(
1488                        sk, vmci_trans(vsk)->queue_pair_size);
1489                if (err < 0)
1490                        err = vmci_transport_error_to_vsock_error(err);
1491                else
1492                        err = 0;
1493
1494        }
1495
1496        return err;
1497}
1498
1499static int vmci_transport_recv_connected(struct sock *sk,
1500                                         struct vmci_transport_packet *pkt)
1501{
1502        struct vsock_sock *vsk;
1503        bool pkt_processed = false;
1504
1505        /* In cases where we are closing the connection, it's sufficient to
1506         * mark the state change (and maybe error) and wake up any waiting
1507         * threads. Since this is a connected socket, it's owned by a user
1508         * process and will be cleaned up when the failure is passed back on
1509         * the current or next system call.  Our system call implementations
1510         * must therefore check for error and state changes on entry and when
1511         * being awoken.
1512         */
1513        switch (pkt->type) {
1514        case VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN:
1515                if (pkt->u.mode) {
1516                        vsk = vsock_sk(sk);
1517
1518                        vsk->peer_shutdown |= pkt->u.mode;
1519                        sk->sk_state_change(sk);
1520                }
1521                break;
1522
1523        case VMCI_TRANSPORT_PACKET_TYPE_RST:
1524                vsk = vsock_sk(sk);
1525                /* It is possible that we sent our peer a message (e.g a
1526                 * WAITING_READ) right before we got notified that the peer had
1527                 * detached. If that happens then we can get a RST pkt back
1528                 * from our peer even though there is data available for us to
1529                 * read. In that case, don't shutdown the socket completely but
1530                 * instead allow the local client to finish reading data off
1531                 * the queuepair. Always treat a RST pkt in connected mode like
1532                 * a clean shutdown.
1533                 */
1534                sock_set_flag(sk, SOCK_DONE);
1535                vsk->peer_shutdown = SHUTDOWN_MASK;
1536                if (vsock_stream_has_data(vsk) <= 0)
1537                        sk->sk_state = SS_DISCONNECTING;
1538
1539                sk->sk_state_change(sk);
1540                break;
1541
1542        default:
1543                vsk = vsock_sk(sk);
1544                vmci_trans(vsk)->notify_ops->handle_notify_pkt(
1545                                sk, pkt, false, NULL, NULL,
1546                                &pkt_processed);
1547                if (!pkt_processed)
1548                        return -EINVAL;
1549
1550                break;
1551        }
1552
1553        return 0;
1554}
1555
1556static int vmci_transport_socket_init(struct vsock_sock *vsk,
1557                                      struct vsock_sock *psk)
1558{
1559        vsk->trans = kmalloc(sizeof(struct vmci_transport), GFP_KERNEL);
1560        if (!vsk->trans)
1561                return -ENOMEM;
1562
1563        vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE;
1564        vmci_trans(vsk)->qp_handle = VMCI_INVALID_HANDLE;
1565        vmci_trans(vsk)->qpair = NULL;
1566        vmci_trans(vsk)->produce_size = vmci_trans(vsk)->consume_size = 0;
1567        vmci_trans(vsk)->detach_sub_id = VMCI_INVALID_ID;
1568        vmci_trans(vsk)->notify_ops = NULL;
1569        INIT_LIST_HEAD(&vmci_trans(vsk)->elem);
1570        vmci_trans(vsk)->sk = &vsk->sk;
1571        spin_lock_init(&vmci_trans(vsk)->lock);
1572        if (psk) {
1573                vmci_trans(vsk)->queue_pair_size =
1574                        vmci_trans(psk)->queue_pair_size;
1575                vmci_trans(vsk)->queue_pair_min_size =
1576                        vmci_trans(psk)->queue_pair_min_size;
1577                vmci_trans(vsk)->queue_pair_max_size =
1578                        vmci_trans(psk)->queue_pair_max_size;
1579        } else {
1580                vmci_trans(vsk)->queue_pair_size =
1581                        VMCI_TRANSPORT_DEFAULT_QP_SIZE;
1582                vmci_trans(vsk)->queue_pair_min_size =
1583                         VMCI_TRANSPORT_DEFAULT_QP_SIZE_MIN;
1584                vmci_trans(vsk)->queue_pair_max_size =
1585                        VMCI_TRANSPORT_DEFAULT_QP_SIZE_MAX;
1586        }
1587
1588        return 0;
1589}
1590
1591static void vmci_transport_free_resources(struct list_head *transport_list)
1592{
1593        while (!list_empty(transport_list)) {
1594                struct vmci_transport *transport =
1595                    list_first_entry(transport_list, struct vmci_transport,
1596                                     elem);
1597                list_del(&transport->elem);
1598
1599                if (transport->detach_sub_id != VMCI_INVALID_ID) {
1600                        vmci_event_unsubscribe(transport->detach_sub_id);
1601                        transport->detach_sub_id = VMCI_INVALID_ID;
1602                }
1603
1604                if (!vmci_handle_is_invalid(transport->qp_handle)) {
1605                        vmci_qpair_detach(&transport->qpair);
1606                        transport->qp_handle = VMCI_INVALID_HANDLE;
1607                        transport->produce_size = 0;
1608                        transport->consume_size = 0;
1609                }
1610
1611                kfree(transport);
1612        }
1613}
1614
1615static void vmci_transport_cleanup(struct work_struct *work)
1616{
1617        LIST_HEAD(pending);
1618
1619        spin_lock_bh(&vmci_transport_cleanup_lock);
1620        list_replace_init(&vmci_transport_cleanup_list, &pending);
1621        spin_unlock_bh(&vmci_transport_cleanup_lock);
1622        vmci_transport_free_resources(&pending);
1623}
1624
1625static void vmci_transport_destruct(struct vsock_sock *vsk)
1626{
1627        /* Ensure that the detach callback doesn't use the sk/vsk
1628         * we are about to destruct.
1629         */
1630        spin_lock_bh(&vmci_trans(vsk)->lock);
1631        vmci_trans(vsk)->sk = NULL;
1632        spin_unlock_bh(&vmci_trans(vsk)->lock);
1633
1634        if (vmci_trans(vsk)->notify_ops)
1635                vmci_trans(vsk)->notify_ops->socket_destruct(vsk);
1636
1637        spin_lock_bh(&vmci_transport_cleanup_lock);
1638        list_add(&vmci_trans(vsk)->elem, &vmci_transport_cleanup_list);
1639        spin_unlock_bh(&vmci_transport_cleanup_lock);
1640        schedule_work(&vmci_transport_cleanup_work);
1641
1642        vsk->trans = NULL;
1643}
1644
1645static void vmci_transport_release(struct vsock_sock *vsk)
1646{
1647        vsock_remove_sock(vsk);
1648
1649        if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) {
1650                vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle);
1651                vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE;
1652        }
1653}
1654
1655static int vmci_transport_dgram_bind(struct vsock_sock *vsk,
1656                                     struct sockaddr_vm *addr)
1657{
1658        u32 port;
1659        u32 flags;
1660        int err;
1661
1662        /* VMCI will select a resource ID for us if we provide
1663         * VMCI_INVALID_ID.
1664         */
1665        port = addr->svm_port == VMADDR_PORT_ANY ?
1666                        VMCI_INVALID_ID : addr->svm_port;
1667
1668        if (port <= LAST_RESERVED_PORT && !capable(CAP_NET_BIND_SERVICE))
1669                return -EACCES;
1670
1671        flags = addr->svm_cid == VMADDR_CID_ANY ?
1672                                VMCI_FLAG_ANYCID_DG_HND : 0;
1673
1674        err = vmci_transport_datagram_create_hnd(port, flags,
1675                                                 vmci_transport_recv_dgram_cb,
1676                                                 &vsk->sk,
1677                                                 &vmci_trans(vsk)->dg_handle);
1678        if (err < VMCI_SUCCESS)
1679                return vmci_transport_error_to_vsock_error(err);
1680        vsock_addr_init(&vsk->local_addr, addr->svm_cid,
1681                        vmci_trans(vsk)->dg_handle.resource);
1682
1683        return 0;
1684}
1685
1686static int vmci_transport_dgram_enqueue(
1687        struct vsock_sock *vsk,
1688        struct sockaddr_vm *remote_addr,
1689        struct msghdr *msg,
1690        size_t len)
1691{
1692        int err;
1693        struct vmci_datagram *dg;
1694
1695        if (len > VMCI_MAX_DG_PAYLOAD_SIZE)
1696                return -EMSGSIZE;
1697
1698        if (!vmci_transport_allow_dgram(vsk, remote_addr->svm_cid))
1699                return -EPERM;
1700
1701        /* Allocate a buffer for the user's message and our packet header. */
1702        dg = kmalloc(len + sizeof(*dg), GFP_KERNEL);
1703        if (!dg)
1704                return -ENOMEM;
1705
1706        memcpy_from_msg(VMCI_DG_PAYLOAD(dg), msg, len);
1707
1708        dg->dst = vmci_make_handle(remote_addr->svm_cid,
1709                                   remote_addr->svm_port);
1710        dg->src = vmci_make_handle(vsk->local_addr.svm_cid,
1711                                   vsk->local_addr.svm_port);
1712        dg->payload_size = len;
1713
1714        err = vmci_datagram_send(dg);
1715        kfree(dg);
1716        if (err < 0)
1717                return vmci_transport_error_to_vsock_error(err);
1718
1719        return err - sizeof(*dg);
1720}
1721
1722static int vmci_transport_dgram_dequeue(struct vsock_sock *vsk,
1723                                        struct msghdr *msg, size_t len,
1724                                        int flags)
1725{
1726        int err;
1727        int noblock;
1728        struct vmci_datagram *dg;
1729        size_t payload_len;
1730        struct sk_buff *skb;
1731
1732        noblock = flags & MSG_DONTWAIT;
1733
1734        if (flags & MSG_OOB || flags & MSG_ERRQUEUE)
1735                return -EOPNOTSUPP;
1736
1737        /* Retrieve the head sk_buff from the socket's receive queue. */
1738        err = 0;
1739        skb = skb_recv_datagram(&vsk->sk, flags, noblock, &err);
1740        if (!skb)
1741                return err;
1742
1743        dg = (struct vmci_datagram *)skb->data;
1744        if (!dg)
1745                /* err is 0, meaning we read zero bytes. */
1746                goto out;
1747
1748        payload_len = dg->payload_size;
1749        /* Ensure the sk_buff matches the payload size claimed in the packet. */
1750        if (payload_len != skb->len - sizeof(*dg)) {
1751                err = -EINVAL;
1752                goto out;
1753        }
1754
1755        if (payload_len > len) {
1756                payload_len = len;
1757                msg->msg_flags |= MSG_TRUNC;
1758        }
1759
1760        /* Place the datagram payload in the user's iovec. */
1761        err = skb_copy_datagram_msg(skb, sizeof(*dg), msg, payload_len);
1762        if (err)
1763                goto out;
1764
1765        if (msg->msg_name) {
1766                /* Provide the address of the sender. */
1767                DECLARE_SOCKADDR(struct sockaddr_vm *, vm_addr, msg->msg_name);
1768                vsock_addr_init(vm_addr, dg->src.context, dg->src.resource);
1769                msg->msg_namelen = sizeof(*vm_addr);
1770        }
1771        err = payload_len;
1772
1773out:
1774        skb_free_datagram(&vsk->sk, skb);
1775        return err;
1776}
1777
1778static bool vmci_transport_dgram_allow(u32 cid, u32 port)
1779{
1780        if (cid == VMADDR_CID_HYPERVISOR) {
1781                /* Registrations of PBRPC Servers do not modify VMX/Hypervisor
1782                 * state and are allowed.
1783                 */
1784                return port == VMCI_UNITY_PBRPC_REGISTER;
1785        }
1786
1787        return true;
1788}
1789
1790static int vmci_transport_connect(struct vsock_sock *vsk)
1791{
1792        int err;
1793        bool old_pkt_proto = false;
1794        struct sock *sk = &vsk->sk;
1795
1796        if (vmci_transport_old_proto_override(&old_pkt_proto) &&
1797                old_pkt_proto) {
1798                err = vmci_transport_send_conn_request(
1799                        sk, vmci_trans(vsk)->queue_pair_size);
1800                if (err < 0) {
1801                        sk->sk_state = SS_UNCONNECTED;
1802                        return err;
1803                }
1804        } else {
1805                int supported_proto_versions =
1806                        vmci_transport_new_proto_supported_versions();
1807                err = vmci_transport_send_conn_request2(
1808                                sk, vmci_trans(vsk)->queue_pair_size,
1809                                supported_proto_versions);
1810                if (err < 0) {
1811                        sk->sk_state = SS_UNCONNECTED;
1812                        return err;
1813                }
1814
1815                vsk->sent_request = true;
1816        }
1817
1818        return err;
1819}
1820
1821static ssize_t vmci_transport_stream_dequeue(
1822        struct vsock_sock *vsk,
1823        struct msghdr *msg,
1824        size_t len,
1825        int flags)
1826{
1827        if (flags & MSG_PEEK)
1828                return vmci_qpair_peekv(vmci_trans(vsk)->qpair, msg, len, 0);
1829        else
1830                return vmci_qpair_dequev(vmci_trans(vsk)->qpair, msg, len, 0);
1831}
1832
1833static ssize_t vmci_transport_stream_enqueue(
1834        struct vsock_sock *vsk,
1835        struct msghdr *msg,
1836        size_t len)
1837{
1838        return vmci_qpair_enquev(vmci_trans(vsk)->qpair, msg, len, 0);
1839}
1840
1841static s64 vmci_transport_stream_has_data(struct vsock_sock *vsk)
1842{
1843        return vmci_qpair_consume_buf_ready(vmci_trans(vsk)->qpair);
1844}
1845
1846static s64 vmci_transport_stream_has_space(struct vsock_sock *vsk)
1847{
1848        return vmci_qpair_produce_free_space(vmci_trans(vsk)->qpair);
1849}
1850
1851static u64 vmci_transport_stream_rcvhiwat(struct vsock_sock *vsk)
1852{
1853        return vmci_trans(vsk)->consume_size;
1854}
1855
1856static bool vmci_transport_stream_is_active(struct vsock_sock *vsk)
1857{
1858        return !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle);
1859}
1860
1861static u64 vmci_transport_get_buffer_size(struct vsock_sock *vsk)
1862{
1863        return vmci_trans(vsk)->queue_pair_size;
1864}
1865
1866static u64 vmci_transport_get_min_buffer_size(struct vsock_sock *vsk)
1867{
1868        return vmci_trans(vsk)->queue_pair_min_size;
1869}
1870
1871static u64 vmci_transport_get_max_buffer_size(struct vsock_sock *vsk)
1872{
1873        return vmci_trans(vsk)->queue_pair_max_size;
1874}
1875
1876static void vmci_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
1877{
1878        if (val < vmci_trans(vsk)->queue_pair_min_size)
1879                vmci_trans(vsk)->queue_pair_min_size = val;
1880        if (val > vmci_trans(vsk)->queue_pair_max_size)
1881                vmci_trans(vsk)->queue_pair_max_size = val;
1882        vmci_trans(vsk)->queue_pair_size = val;
1883}
1884
1885static void vmci_transport_set_min_buffer_size(struct vsock_sock *vsk,
1886                                               u64 val)
1887{
1888        if (val > vmci_trans(vsk)->queue_pair_size)
1889                vmci_trans(vsk)->queue_pair_size = val;
1890        vmci_trans(vsk)->queue_pair_min_size = val;
1891}
1892
1893static void vmci_transport_set_max_buffer_size(struct vsock_sock *vsk,
1894                                               u64 val)
1895{
1896        if (val < vmci_trans(vsk)->queue_pair_size)
1897                vmci_trans(vsk)->queue_pair_size = val;
1898        vmci_trans(vsk)->queue_pair_max_size = val;
1899}
1900
1901static int vmci_transport_notify_poll_in(
1902        struct vsock_sock *vsk,
1903        size_t target,
1904        bool *data_ready_now)
1905{
1906        return vmci_trans(vsk)->notify_ops->poll_in(
1907                        &vsk->sk, target, data_ready_now);
1908}
1909
1910static int vmci_transport_notify_poll_out(
1911        struct vsock_sock *vsk,
1912        size_t target,
1913        bool *space_available_now)
1914{
1915        return vmci_trans(vsk)->notify_ops->poll_out(
1916                        &vsk->sk, target, space_available_now);
1917}
1918
1919static int vmci_transport_notify_recv_init(
1920        struct vsock_sock *vsk,
1921        size_t target,
1922        struct vsock_transport_recv_notify_data *data)
1923{
1924        return vmci_trans(vsk)->notify_ops->recv_init(
1925                        &vsk->sk, target,
1926                        (struct vmci_transport_recv_notify_data *)data);
1927}
1928
1929static int vmci_transport_notify_recv_pre_block(
1930        struct vsock_sock *vsk,
1931        size_t target,
1932        struct vsock_transport_recv_notify_data *data)
1933{
1934        return vmci_trans(vsk)->notify_ops->recv_pre_block(
1935                        &vsk->sk, target,
1936                        (struct vmci_transport_recv_notify_data *)data);
1937}
1938
1939static int vmci_transport_notify_recv_pre_dequeue(
1940        struct vsock_sock *vsk,
1941        size_t target,
1942        struct vsock_transport_recv_notify_data *data)
1943{
1944        return vmci_trans(vsk)->notify_ops->recv_pre_dequeue(
1945                        &vsk->sk, target,
1946                        (struct vmci_transport_recv_notify_data *)data);
1947}
1948
1949static int vmci_transport_notify_recv_post_dequeue(
1950        struct vsock_sock *vsk,
1951        size_t target,
1952        ssize_t copied,
1953        bool data_read,
1954        struct vsock_transport_recv_notify_data *data)
1955{
1956        return vmci_trans(vsk)->notify_ops->recv_post_dequeue(
1957                        &vsk->sk, target, copied, data_read,
1958                        (struct vmci_transport_recv_notify_data *)data);
1959}
1960
1961static int vmci_transport_notify_send_init(
1962        struct vsock_sock *vsk,
1963        struct vsock_transport_send_notify_data *data)
1964{
1965        return vmci_trans(vsk)->notify_ops->send_init(
1966                        &vsk->sk,
1967                        (struct vmci_transport_send_notify_data *)data);
1968}
1969
1970static int vmci_transport_notify_send_pre_block(
1971        struct vsock_sock *vsk,
1972        struct vsock_transport_send_notify_data *data)
1973{
1974        return vmci_trans(vsk)->notify_ops->send_pre_block(
1975                        &vsk->sk,
1976                        (struct vmci_transport_send_notify_data *)data);
1977}
1978
1979static int vmci_transport_notify_send_pre_enqueue(
1980        struct vsock_sock *vsk,
1981        struct vsock_transport_send_notify_data *data)
1982{
1983        return vmci_trans(vsk)->notify_ops->send_pre_enqueue(
1984                        &vsk->sk,
1985                        (struct vmci_transport_send_notify_data *)data);
1986}
1987
1988static int vmci_transport_notify_send_post_enqueue(
1989        struct vsock_sock *vsk,
1990        ssize_t written,
1991        struct vsock_transport_send_notify_data *data)
1992{
1993        return vmci_trans(vsk)->notify_ops->send_post_enqueue(
1994                        &vsk->sk, written,
1995                        (struct vmci_transport_send_notify_data *)data);
1996}
1997
1998static bool vmci_transport_old_proto_override(bool *old_pkt_proto)
1999{
2000        if (PROTOCOL_OVERRIDE != -1) {
2001                if (PROTOCOL_OVERRIDE == 0)
2002                        *old_pkt_proto = true;
2003                else
2004                        *old_pkt_proto = false;
2005
2006                pr_info("Proto override in use\n");
2007                return true;
2008        }
2009
2010        return false;
2011}
2012
2013static bool vmci_transport_proto_to_notify_struct(struct sock *sk,
2014                                                  u16 *proto,
2015                                                  bool old_pkt_proto)
2016{
2017        struct vsock_sock *vsk = vsock_sk(sk);
2018
2019        if (old_pkt_proto) {
2020                if (*proto != VSOCK_PROTO_INVALID) {
2021                        pr_err("Can't set both an old and new protocol\n");
2022                        return false;
2023                }
2024                vmci_trans(vsk)->notify_ops = &vmci_transport_notify_pkt_ops;
2025                goto exit;
2026        }
2027
2028        switch (*proto) {
2029        case VSOCK_PROTO_PKT_ON_NOTIFY:
2030                vmci_trans(vsk)->notify_ops =
2031                        &vmci_transport_notify_pkt_q_state_ops;
2032                break;
2033        default:
2034                pr_err("Unknown notify protocol version\n");
2035                return false;
2036        }
2037
2038exit:
2039        vmci_trans(vsk)->notify_ops->socket_init(sk);
2040        return true;
2041}
2042
2043static u16 vmci_transport_new_proto_supported_versions(void)
2044{
2045        if (PROTOCOL_OVERRIDE != -1)
2046                return PROTOCOL_OVERRIDE;
2047
2048        return VSOCK_PROTO_ALL_SUPPORTED;
2049}
2050
2051static u32 vmci_transport_get_local_cid(void)
2052{
2053        return vmci_get_context_id();
2054}
2055
2056static const struct vsock_transport vmci_transport = {
2057        .init = vmci_transport_socket_init,
2058        .destruct = vmci_transport_destruct,
2059        .release = vmci_transport_release,
2060        .connect = vmci_transport_connect,
2061        .dgram_bind = vmci_transport_dgram_bind,
2062        .dgram_dequeue = vmci_transport_dgram_dequeue,
2063        .dgram_enqueue = vmci_transport_dgram_enqueue,
2064        .dgram_allow = vmci_transport_dgram_allow,
2065        .stream_dequeue = vmci_transport_stream_dequeue,
2066        .stream_enqueue = vmci_transport_stream_enqueue,
2067        .stream_has_data = vmci_transport_stream_has_data,
2068        .stream_has_space = vmci_transport_stream_has_space,
2069        .stream_rcvhiwat = vmci_transport_stream_rcvhiwat,
2070        .stream_is_active = vmci_transport_stream_is_active,
2071        .stream_allow = vmci_transport_stream_allow,
2072        .notify_poll_in = vmci_transport_notify_poll_in,
2073        .notify_poll_out = vmci_transport_notify_poll_out,
2074        .notify_recv_init = vmci_transport_notify_recv_init,
2075        .notify_recv_pre_block = vmci_transport_notify_recv_pre_block,
2076        .notify_recv_pre_dequeue = vmci_transport_notify_recv_pre_dequeue,
2077        .notify_recv_post_dequeue = vmci_transport_notify_recv_post_dequeue,
2078        .notify_send_init = vmci_transport_notify_send_init,
2079        .notify_send_pre_block = vmci_transport_notify_send_pre_block,
2080        .notify_send_pre_enqueue = vmci_transport_notify_send_pre_enqueue,
2081        .notify_send_post_enqueue = vmci_transport_notify_send_post_enqueue,
2082        .shutdown = vmci_transport_shutdown,
2083        .set_buffer_size = vmci_transport_set_buffer_size,
2084        .set_min_buffer_size = vmci_transport_set_min_buffer_size,
2085        .set_max_buffer_size = vmci_transport_set_max_buffer_size,
2086        .get_buffer_size = vmci_transport_get_buffer_size,
2087        .get_min_buffer_size = vmci_transport_get_min_buffer_size,
2088        .get_max_buffer_size = vmci_transport_get_max_buffer_size,
2089        .get_local_cid = vmci_transport_get_local_cid,
2090};
2091
2092static int __init vmci_transport_init(void)
2093{
2094        int err;
2095
2096        /* Create the datagram handle that we will use to send and receive all
2097         * VSocket control messages for this context.
2098         */
2099        err = vmci_transport_datagram_create_hnd(VMCI_TRANSPORT_PACKET_RID,
2100                                                 VMCI_FLAG_ANYCID_DG_HND,
2101                                                 vmci_transport_recv_stream_cb,
2102                                                 NULL,
2103                                                 &vmci_transport_stream_handle);
2104        if (err < VMCI_SUCCESS) {
2105                pr_err("Unable to create datagram handle. (%d)\n", err);
2106                return vmci_transport_error_to_vsock_error(err);
2107        }
2108
2109        err = vmci_event_subscribe(VMCI_EVENT_QP_RESUMED,
2110                                   vmci_transport_qp_resumed_cb,
2111                                   NULL, &vmci_transport_qp_resumed_sub_id);
2112        if (err < VMCI_SUCCESS) {
2113                pr_err("Unable to subscribe to resumed event. (%d)\n", err);
2114                err = vmci_transport_error_to_vsock_error(err);
2115                vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
2116                goto err_destroy_stream_handle;
2117        }
2118
2119        err = vsock_core_init(&vmci_transport);
2120        if (err < 0)
2121                goto err_unsubscribe;
2122
2123        return 0;
2124
2125err_unsubscribe:
2126        vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id);
2127err_destroy_stream_handle:
2128        vmci_datagram_destroy_handle(vmci_transport_stream_handle);
2129        return err;
2130}
2131module_init(vmci_transport_init);
2132
2133static void __exit vmci_transport_exit(void)
2134{
2135        cancel_work_sync(&vmci_transport_cleanup_work);
2136        vmci_transport_free_resources(&vmci_transport_cleanup_list);
2137
2138        if (!vmci_handle_is_invalid(vmci_transport_stream_handle)) {
2139                if (vmci_datagram_destroy_handle(
2140                        vmci_transport_stream_handle) != VMCI_SUCCESS)
2141                        pr_err("Couldn't destroy datagram handle\n");
2142                vmci_transport_stream_handle = VMCI_INVALID_HANDLE;
2143        }
2144
2145        if (vmci_transport_qp_resumed_sub_id != VMCI_INVALID_ID) {
2146                vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id);
2147                vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
2148        }
2149
2150        vsock_core_exit();
2151}
2152module_exit(vmci_transport_exit);
2153
2154MODULE_AUTHOR("VMware, Inc.");
2155MODULE_DESCRIPTION("VMCI transport for Virtual Sockets");
2156MODULE_VERSION("1.0.4.0-k");
2157MODULE_LICENSE("GPL v2");
2158MODULE_ALIAS("vmware_vsock");
2159MODULE_ALIAS_NETPROTO(PF_VSOCK);
2160