linux/drivers/xen/pvcalls-front.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-or-later
   2/*
   3 * (c) 2017 Stefano Stabellini <stefano@aporeto.com>
   4 */
   5
   6#include <linux/module.h>
   7#include <linux/net.h>
   8#include <linux/socket.h>
   9
  10#include <net/sock.h>
  11
  12#include <xen/events.h>
  13#include <xen/grant_table.h>
  14#include <xen/xen.h>
  15#include <xen/xenbus.h>
  16#include <xen/interface/io/pvcalls.h>
  17
  18#include "pvcalls-front.h"
  19
  20#define PVCALLS_INVALID_ID UINT_MAX
  21#define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER
  22#define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE)
  23#define PVCALLS_FRONT_MAX_SPIN 5000
  24
  25static struct proto pvcalls_proto = {
  26        .name   = "PVCalls",
  27        .owner  = THIS_MODULE,
  28        .obj_size = sizeof(struct sock),
  29};
  30
  31struct pvcalls_bedata {
  32        struct xen_pvcalls_front_ring ring;
  33        grant_ref_t ref;
  34        int irq;
  35
  36        struct list_head socket_mappings;
  37        spinlock_t socket_lock;
  38
  39        wait_queue_head_t inflight_req;
  40        struct xen_pvcalls_response rsp[PVCALLS_NR_RSP_PER_RING];
  41};
  42/* Only one front/back connection supported. */
  43static struct xenbus_device *pvcalls_front_dev;
  44static atomic_t pvcalls_refcount;
  45
  46/* first increment refcount, then proceed */
  47#define pvcalls_enter() {               \
  48        atomic_inc(&pvcalls_refcount);      \
  49}
  50
  51/* first complete other operations, then decrement refcount */
  52#define pvcalls_exit() {                \
  53        atomic_dec(&pvcalls_refcount);      \
  54}
  55
  56struct sock_mapping {
  57        bool active_socket;
  58        struct list_head list;
  59        struct socket *sock;
  60        atomic_t refcount;
  61        union {
  62                struct {
  63                        int irq;
  64                        grant_ref_t ref;
  65                        struct pvcalls_data_intf *ring;
  66                        struct pvcalls_data data;
  67                        struct mutex in_mutex;
  68                        struct mutex out_mutex;
  69
  70                        wait_queue_head_t inflight_conn_req;
  71                } active;
  72                struct {
  73                /*
  74                 * Socket status, needs to be 64-bit aligned due to the
  75                 * test_and_* functions which have this requirement on arm64.
  76                 */
  77#define PVCALLS_STATUS_UNINITALIZED  0
  78#define PVCALLS_STATUS_BIND          1
  79#define PVCALLS_STATUS_LISTEN        2
  80                        uint8_t status __attribute__((aligned(8)));
  81                /*
  82                 * Internal state-machine flags.
  83                 * Only one accept operation can be inflight for a socket.
  84                 * Only one poll operation can be inflight for a given socket.
  85                 * flags needs to be 64-bit aligned due to the test_and_*
  86                 * functions which have this requirement on arm64.
  87                 */
  88#define PVCALLS_FLAG_ACCEPT_INFLIGHT 0
  89#define PVCALLS_FLAG_POLL_INFLIGHT   1
  90#define PVCALLS_FLAG_POLL_RET        2
  91                        uint8_t flags __attribute__((aligned(8)));
  92                        uint32_t inflight_req_id;
  93                        struct sock_mapping *accept_map;
  94                        wait_queue_head_t inflight_accept_req;
  95                } passive;
  96        };
  97};
  98
  99static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock)
 100{
 101        struct sock_mapping *map;
 102
 103        if (!pvcalls_front_dev ||
 104                dev_get_drvdata(&pvcalls_front_dev->dev) == NULL)
 105                return ERR_PTR(-ENOTCONN);
 106
 107        map = (struct sock_mapping *)sock->sk->sk_send_head;
 108        if (map == NULL)
 109                return ERR_PTR(-ENOTSOCK);
 110
 111        pvcalls_enter();
 112        atomic_inc(&map->refcount);
 113        return map;
 114}
 115
 116static inline void pvcalls_exit_sock(struct socket *sock)
 117{
 118        struct sock_mapping *map;
 119
 120        map = (struct sock_mapping *)sock->sk->sk_send_head;
 121        atomic_dec(&map->refcount);
 122        pvcalls_exit();
 123}
 124
 125static inline int get_request(struct pvcalls_bedata *bedata, int *req_id)
 126{
 127        *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1);
 128        if (RING_FULL(&bedata->ring) ||
 129            bedata->rsp[*req_id].req_id != PVCALLS_INVALID_ID)
 130                return -EAGAIN;
 131        return 0;
 132}
 133
 134static bool pvcalls_front_write_todo(struct sock_mapping *map)
 135{
 136        struct pvcalls_data_intf *intf = map->active.ring;
 137        RING_IDX cons, prod, size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
 138        int32_t error;
 139
 140        error = intf->out_error;
 141        if (error == -ENOTCONN)
 142                return false;
 143        if (error != 0)
 144                return true;
 145
 146        cons = intf->out_cons;
 147        prod = intf->out_prod;
 148        return !!(size - pvcalls_queued(prod, cons, size));
 149}
 150
 151static bool pvcalls_front_read_todo(struct sock_mapping *map)
 152{
 153        struct pvcalls_data_intf *intf = map->active.ring;
 154        RING_IDX cons, prod;
 155        int32_t error;
 156
 157        cons = intf->in_cons;
 158        prod = intf->in_prod;
 159        error = intf->in_error;
 160        return (error != 0 ||
 161                pvcalls_queued(prod, cons,
 162                               XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) != 0);
 163}
 164
 165static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id)
 166{
 167        struct xenbus_device *dev = dev_id;
 168        struct pvcalls_bedata *bedata;
 169        struct xen_pvcalls_response *rsp;
 170        uint8_t *src, *dst;
 171        int req_id = 0, more = 0, done = 0;
 172
 173        if (dev == NULL)
 174                return IRQ_HANDLED;
 175
 176        pvcalls_enter();
 177        bedata = dev_get_drvdata(&dev->dev);
 178        if (bedata == NULL) {
 179                pvcalls_exit();
 180                return IRQ_HANDLED;
 181        }
 182
 183again:
 184        while (RING_HAS_UNCONSUMED_RESPONSES(&bedata->ring)) {
 185                rsp = RING_GET_RESPONSE(&bedata->ring, bedata->ring.rsp_cons);
 186
 187                req_id = rsp->req_id;
 188                if (rsp->cmd == PVCALLS_POLL) {
 189                        struct sock_mapping *map = (struct sock_mapping *)(uintptr_t)
 190                                                   rsp->u.poll.id;
 191
 192                        clear_bit(PVCALLS_FLAG_POLL_INFLIGHT,
 193                                  (void *)&map->passive.flags);
 194                        /*
 195                         * clear INFLIGHT, then set RET. It pairs with
 196                         * the checks at the beginning of
 197                         * pvcalls_front_poll_passive.
 198                         */
 199                        smp_wmb();
 200                        set_bit(PVCALLS_FLAG_POLL_RET,
 201                                (void *)&map->passive.flags);
 202                } else {
 203                        dst = (uint8_t *)&bedata->rsp[req_id] +
 204                              sizeof(rsp->req_id);
 205                        src = (uint8_t *)rsp + sizeof(rsp->req_id);
 206                        memcpy(dst, src, sizeof(*rsp) - sizeof(rsp->req_id));
 207                        /*
 208                         * First copy the rest of the data, then req_id. It is
 209                         * paired with the barrier when accessing bedata->rsp.
 210                         */
 211                        smp_wmb();
 212                        bedata->rsp[req_id].req_id = req_id;
 213                }
 214
 215                done = 1;
 216                bedata->ring.rsp_cons++;
 217        }
 218
 219        RING_FINAL_CHECK_FOR_RESPONSES(&bedata->ring, more);
 220        if (more)
 221                goto again;
 222        if (done)
 223                wake_up(&bedata->inflight_req);
 224        pvcalls_exit();
 225        return IRQ_HANDLED;
 226}
 227
 228static void pvcalls_front_free_map(struct pvcalls_bedata *bedata,
 229                                   struct sock_mapping *map)
 230{
 231        int i;
 232
 233        unbind_from_irqhandler(map->active.irq, map);
 234
 235        spin_lock(&bedata->socket_lock);
 236        if (!list_empty(&map->list))
 237                list_del_init(&map->list);
 238        spin_unlock(&bedata->socket_lock);
 239
 240        for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++)
 241                gnttab_end_foreign_access(map->active.ring->ref[i], 0, 0);
 242        gnttab_end_foreign_access(map->active.ref, 0, 0);
 243        free_page((unsigned long)map->active.ring);
 244
 245        kfree(map);
 246}
 247
 248static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map)
 249{
 250        struct sock_mapping *map = sock_map;
 251
 252        if (map == NULL)
 253                return IRQ_HANDLED;
 254
 255        wake_up_interruptible(&map->active.inflight_conn_req);
 256
 257        return IRQ_HANDLED;
 258}
 259
 260int pvcalls_front_socket(struct socket *sock)
 261{
 262        struct pvcalls_bedata *bedata;
 263        struct sock_mapping *map = NULL;
 264        struct xen_pvcalls_request *req;
 265        int notify, req_id, ret;
 266
 267        /*
 268         * PVCalls only supports domain AF_INET,
 269         * type SOCK_STREAM and protocol 0 sockets for now.
 270         *
 271         * Check socket type here, AF_INET and protocol checks are done
 272         * by the caller.
 273         */
 274        if (sock->type != SOCK_STREAM)
 275                return -EOPNOTSUPP;
 276
 277        pvcalls_enter();
 278        if (!pvcalls_front_dev) {
 279                pvcalls_exit();
 280                return -EACCES;
 281        }
 282        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 283
 284        map = kzalloc(sizeof(*map), GFP_KERNEL);
 285        if (map == NULL) {
 286                pvcalls_exit();
 287                return -ENOMEM;
 288        }
 289
 290        spin_lock(&bedata->socket_lock);
 291
 292        ret = get_request(bedata, &req_id);
 293        if (ret < 0) {
 294                kfree(map);
 295                spin_unlock(&bedata->socket_lock);
 296                pvcalls_exit();
 297                return ret;
 298        }
 299
 300        /*
 301         * sock->sk->sk_send_head is not used for ip sockets: reuse the
 302         * field to store a pointer to the struct sock_mapping
 303         * corresponding to the socket. This way, we can easily get the
 304         * struct sock_mapping from the struct socket.
 305         */
 306        sock->sk->sk_send_head = (void *)map;
 307        list_add_tail(&map->list, &bedata->socket_mappings);
 308
 309        req = RING_GET_REQUEST(&bedata->ring, req_id);
 310        req->req_id = req_id;
 311        req->cmd = PVCALLS_SOCKET;
 312        req->u.socket.id = (uintptr_t) map;
 313        req->u.socket.domain = AF_INET;
 314        req->u.socket.type = SOCK_STREAM;
 315        req->u.socket.protocol = IPPROTO_IP;
 316
 317        bedata->ring.req_prod_pvt++;
 318        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
 319        spin_unlock(&bedata->socket_lock);
 320        if (notify)
 321                notify_remote_via_irq(bedata->irq);
 322
 323        wait_event(bedata->inflight_req,
 324                   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
 325
 326        /* read req_id, then the content */
 327        smp_rmb();
 328        ret = bedata->rsp[req_id].ret;
 329        bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
 330
 331        pvcalls_exit();
 332        return ret;
 333}
 334
 335static void free_active_ring(struct sock_mapping *map)
 336{
 337        if (!map->active.ring)
 338                return;
 339
 340        free_pages((unsigned long)map->active.data.in,
 341                        map->active.ring->ring_order);
 342        free_page((unsigned long)map->active.ring);
 343}
 344
 345static int alloc_active_ring(struct sock_mapping *map)
 346{
 347        void *bytes;
 348
 349        map->active.ring = (struct pvcalls_data_intf *)
 350                get_zeroed_page(GFP_KERNEL);
 351        if (!map->active.ring)
 352                goto out;
 353
 354        map->active.ring->ring_order = PVCALLS_RING_ORDER;
 355        bytes = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO,
 356                                        PVCALLS_RING_ORDER);
 357        if (!bytes)
 358                goto out;
 359
 360        map->active.data.in = bytes;
 361        map->active.data.out = bytes +
 362                XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
 363
 364        return 0;
 365
 366out:
 367        free_active_ring(map);
 368        return -ENOMEM;
 369}
 370
 371static int create_active(struct sock_mapping *map, int *evtchn)
 372{
 373        void *bytes;
 374        int ret = -ENOMEM, irq = -1, i;
 375
 376        *evtchn = -1;
 377        init_waitqueue_head(&map->active.inflight_conn_req);
 378
 379        bytes = map->active.data.in;
 380        for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++)
 381                map->active.ring->ref[i] = gnttab_grant_foreign_access(
 382                        pvcalls_front_dev->otherend_id,
 383                        pfn_to_gfn(virt_to_pfn(bytes) + i), 0);
 384
 385        map->active.ref = gnttab_grant_foreign_access(
 386                pvcalls_front_dev->otherend_id,
 387                pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0);
 388
 389        ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn);
 390        if (ret)
 391                goto out_error;
 392        irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler,
 393                                        0, "pvcalls-frontend", map);
 394        if (irq < 0) {
 395                ret = irq;
 396                goto out_error;
 397        }
 398
 399        map->active.irq = irq;
 400        map->active_socket = true;
 401        mutex_init(&map->active.in_mutex);
 402        mutex_init(&map->active.out_mutex);
 403
 404        return 0;
 405
 406out_error:
 407        if (*evtchn >= 0)
 408                xenbus_free_evtchn(pvcalls_front_dev, *evtchn);
 409        return ret;
 410}
 411
 412int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr,
 413                                int addr_len, int flags)
 414{
 415        struct pvcalls_bedata *bedata;
 416        struct sock_mapping *map = NULL;
 417        struct xen_pvcalls_request *req;
 418        int notify, req_id, ret, evtchn;
 419
 420        if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
 421                return -EOPNOTSUPP;
 422
 423        map = pvcalls_enter_sock(sock);
 424        if (IS_ERR(map))
 425                return PTR_ERR(map);
 426
 427        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 428        ret = alloc_active_ring(map);
 429        if (ret < 0) {
 430                pvcalls_exit_sock(sock);
 431                return ret;
 432        }
 433
 434        spin_lock(&bedata->socket_lock);
 435        ret = get_request(bedata, &req_id);
 436        if (ret < 0) {
 437                spin_unlock(&bedata->socket_lock);
 438                free_active_ring(map);
 439                pvcalls_exit_sock(sock);
 440                return ret;
 441        }
 442        ret = create_active(map, &evtchn);
 443        if (ret < 0) {
 444                spin_unlock(&bedata->socket_lock);
 445                free_active_ring(map);
 446                pvcalls_exit_sock(sock);
 447                return ret;
 448        }
 449
 450        req = RING_GET_REQUEST(&bedata->ring, req_id);
 451        req->req_id = req_id;
 452        req->cmd = PVCALLS_CONNECT;
 453        req->u.connect.id = (uintptr_t)map;
 454        req->u.connect.len = addr_len;
 455        req->u.connect.flags = flags;
 456        req->u.connect.ref = map->active.ref;
 457        req->u.connect.evtchn = evtchn;
 458        memcpy(req->u.connect.addr, addr, sizeof(*addr));
 459
 460        map->sock = sock;
 461
 462        bedata->ring.req_prod_pvt++;
 463        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
 464        spin_unlock(&bedata->socket_lock);
 465
 466        if (notify)
 467                notify_remote_via_irq(bedata->irq);
 468
 469        wait_event(bedata->inflight_req,
 470                   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
 471
 472        /* read req_id, then the content */
 473        smp_rmb();
 474        ret = bedata->rsp[req_id].ret;
 475        bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
 476        pvcalls_exit_sock(sock);
 477        return ret;
 478}
 479
 480static int __write_ring(struct pvcalls_data_intf *intf,
 481                        struct pvcalls_data *data,
 482                        struct iov_iter *msg_iter,
 483                        int len)
 484{
 485        RING_IDX cons, prod, size, masked_prod, masked_cons;
 486        RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
 487        int32_t error;
 488
 489        error = intf->out_error;
 490        if (error < 0)
 491                return error;
 492        cons = intf->out_cons;
 493        prod = intf->out_prod;
 494        /* read indexes before continuing */
 495        virt_mb();
 496
 497        size = pvcalls_queued(prod, cons, array_size);
 498        if (size > array_size)
 499                return -EINVAL;
 500        if (size == array_size)
 501                return 0;
 502        if (len > array_size - size)
 503                len = array_size - size;
 504
 505        masked_prod = pvcalls_mask(prod, array_size);
 506        masked_cons = pvcalls_mask(cons, array_size);
 507
 508        if (masked_prod < masked_cons) {
 509                len = copy_from_iter(data->out + masked_prod, len, msg_iter);
 510        } else {
 511                if (len > array_size - masked_prod) {
 512                        int ret = copy_from_iter(data->out + masked_prod,
 513                                       array_size - masked_prod, msg_iter);
 514                        if (ret != array_size - masked_prod) {
 515                                len = ret;
 516                                goto out;
 517                        }
 518                        len = ret + copy_from_iter(data->out, len - ret, msg_iter);
 519                } else {
 520                        len = copy_from_iter(data->out + masked_prod, len, msg_iter);
 521                }
 522        }
 523out:
 524        /* write to ring before updating pointer */
 525        virt_wmb();
 526        intf->out_prod += len;
 527
 528        return len;
 529}
 530
 531int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg,
 532                          size_t len)
 533{
 534        struct sock_mapping *map;
 535        int sent, tot_sent = 0;
 536        int count = 0, flags;
 537
 538        flags = msg->msg_flags;
 539        if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB))
 540                return -EOPNOTSUPP;
 541
 542        map = pvcalls_enter_sock(sock);
 543        if (IS_ERR(map))
 544                return PTR_ERR(map);
 545
 546        mutex_lock(&map->active.out_mutex);
 547        if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) {
 548                mutex_unlock(&map->active.out_mutex);
 549                pvcalls_exit_sock(sock);
 550                return -EAGAIN;
 551        }
 552        if (len > INT_MAX)
 553                len = INT_MAX;
 554
 555again:
 556        count++;
 557        sent = __write_ring(map->active.ring,
 558                            &map->active.data, &msg->msg_iter,
 559                            len);
 560        if (sent > 0) {
 561                len -= sent;
 562                tot_sent += sent;
 563                notify_remote_via_irq(map->active.irq);
 564        }
 565        if (sent >= 0 && len > 0 && count < PVCALLS_FRONT_MAX_SPIN)
 566                goto again;
 567        if (sent < 0)
 568                tot_sent = sent;
 569
 570        mutex_unlock(&map->active.out_mutex);
 571        pvcalls_exit_sock(sock);
 572        return tot_sent;
 573}
 574
 575static int __read_ring(struct pvcalls_data_intf *intf,
 576                       struct pvcalls_data *data,
 577                       struct iov_iter *msg_iter,
 578                       size_t len, int flags)
 579{
 580        RING_IDX cons, prod, size, masked_prod, masked_cons;
 581        RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
 582        int32_t error;
 583
 584        cons = intf->in_cons;
 585        prod = intf->in_prod;
 586        error = intf->in_error;
 587        /* get pointers before reading from the ring */
 588        virt_rmb();
 589
 590        size = pvcalls_queued(prod, cons, array_size);
 591        masked_prod = pvcalls_mask(prod, array_size);
 592        masked_cons = pvcalls_mask(cons, array_size);
 593
 594        if (size == 0)
 595                return error ?: size;
 596
 597        if (len > size)
 598                len = size;
 599
 600        if (masked_prod > masked_cons) {
 601                len = copy_to_iter(data->in + masked_cons, len, msg_iter);
 602        } else {
 603                if (len > (array_size - masked_cons)) {
 604                        int ret = copy_to_iter(data->in + masked_cons,
 605                                     array_size - masked_cons, msg_iter);
 606                        if (ret != array_size - masked_cons) {
 607                                len = ret;
 608                                goto out;
 609                        }
 610                        len = ret + copy_to_iter(data->in, len - ret, msg_iter);
 611                } else {
 612                        len = copy_to_iter(data->in + masked_cons, len, msg_iter);
 613                }
 614        }
 615out:
 616        /* read data from the ring before increasing the index */
 617        virt_mb();
 618        if (!(flags & MSG_PEEK))
 619                intf->in_cons += len;
 620
 621        return len;
 622}
 623
 624int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 625                     int flags)
 626{
 627        int ret;
 628        struct sock_mapping *map;
 629
 630        if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC))
 631                return -EOPNOTSUPP;
 632
 633        map = pvcalls_enter_sock(sock);
 634        if (IS_ERR(map))
 635                return PTR_ERR(map);
 636
 637        mutex_lock(&map->active.in_mutex);
 638        if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER))
 639                len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
 640
 641        while (!(flags & MSG_DONTWAIT) && !pvcalls_front_read_todo(map)) {
 642                wait_event_interruptible(map->active.inflight_conn_req,
 643                                         pvcalls_front_read_todo(map));
 644        }
 645        ret = __read_ring(map->active.ring, &map->active.data,
 646                          &msg->msg_iter, len, flags);
 647
 648        if (ret > 0)
 649                notify_remote_via_irq(map->active.irq);
 650        if (ret == 0)
 651                ret = (flags & MSG_DONTWAIT) ? -EAGAIN : 0;
 652        if (ret == -ENOTCONN)
 653                ret = 0;
 654
 655        mutex_unlock(&map->active.in_mutex);
 656        pvcalls_exit_sock(sock);
 657        return ret;
 658}
 659
 660int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
 661{
 662        struct pvcalls_bedata *bedata;
 663        struct sock_mapping *map = NULL;
 664        struct xen_pvcalls_request *req;
 665        int notify, req_id, ret;
 666
 667        if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
 668                return -EOPNOTSUPP;
 669
 670        map = pvcalls_enter_sock(sock);
 671        if (IS_ERR(map))
 672                return PTR_ERR(map);
 673        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 674
 675        spin_lock(&bedata->socket_lock);
 676        ret = get_request(bedata, &req_id);
 677        if (ret < 0) {
 678                spin_unlock(&bedata->socket_lock);
 679                pvcalls_exit_sock(sock);
 680                return ret;
 681        }
 682        req = RING_GET_REQUEST(&bedata->ring, req_id);
 683        req->req_id = req_id;
 684        map->sock = sock;
 685        req->cmd = PVCALLS_BIND;
 686        req->u.bind.id = (uintptr_t)map;
 687        memcpy(req->u.bind.addr, addr, sizeof(*addr));
 688        req->u.bind.len = addr_len;
 689
 690        init_waitqueue_head(&map->passive.inflight_accept_req);
 691
 692        map->active_socket = false;
 693
 694        bedata->ring.req_prod_pvt++;
 695        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
 696        spin_unlock(&bedata->socket_lock);
 697        if (notify)
 698                notify_remote_via_irq(bedata->irq);
 699
 700        wait_event(bedata->inflight_req,
 701                   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
 702
 703        /* read req_id, then the content */
 704        smp_rmb();
 705        ret = bedata->rsp[req_id].ret;
 706        bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
 707
 708        map->passive.status = PVCALLS_STATUS_BIND;
 709        pvcalls_exit_sock(sock);
 710        return 0;
 711}
 712
 713int pvcalls_front_listen(struct socket *sock, int backlog)
 714{
 715        struct pvcalls_bedata *bedata;
 716        struct sock_mapping *map;
 717        struct xen_pvcalls_request *req;
 718        int notify, req_id, ret;
 719
 720        map = pvcalls_enter_sock(sock);
 721        if (IS_ERR(map))
 722                return PTR_ERR(map);
 723        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 724
 725        if (map->passive.status != PVCALLS_STATUS_BIND) {
 726                pvcalls_exit_sock(sock);
 727                return -EOPNOTSUPP;
 728        }
 729
 730        spin_lock(&bedata->socket_lock);
 731        ret = get_request(bedata, &req_id);
 732        if (ret < 0) {
 733                spin_unlock(&bedata->socket_lock);
 734                pvcalls_exit_sock(sock);
 735                return ret;
 736        }
 737        req = RING_GET_REQUEST(&bedata->ring, req_id);
 738        req->req_id = req_id;
 739        req->cmd = PVCALLS_LISTEN;
 740        req->u.listen.id = (uintptr_t) map;
 741        req->u.listen.backlog = backlog;
 742
 743        bedata->ring.req_prod_pvt++;
 744        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
 745        spin_unlock(&bedata->socket_lock);
 746        if (notify)
 747                notify_remote_via_irq(bedata->irq);
 748
 749        wait_event(bedata->inflight_req,
 750                   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
 751
 752        /* read req_id, then the content */
 753        smp_rmb();
 754        ret = bedata->rsp[req_id].ret;
 755        bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
 756
 757        map->passive.status = PVCALLS_STATUS_LISTEN;
 758        pvcalls_exit_sock(sock);
 759        return ret;
 760}
 761
 762int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
 763{
 764        struct pvcalls_bedata *bedata;
 765        struct sock_mapping *map;
 766        struct sock_mapping *map2 = NULL;
 767        struct xen_pvcalls_request *req;
 768        int notify, req_id, ret, evtchn, nonblock;
 769
 770        map = pvcalls_enter_sock(sock);
 771        if (IS_ERR(map))
 772                return PTR_ERR(map);
 773        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 774
 775        if (map->passive.status != PVCALLS_STATUS_LISTEN) {
 776                pvcalls_exit_sock(sock);
 777                return -EINVAL;
 778        }
 779
 780        nonblock = flags & SOCK_NONBLOCK;
 781        /*
 782         * Backend only supports 1 inflight accept request, will return
 783         * errors for the others
 784         */
 785        if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 786                             (void *)&map->passive.flags)) {
 787                req_id = READ_ONCE(map->passive.inflight_req_id);
 788                if (req_id != PVCALLS_INVALID_ID &&
 789                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id) {
 790                        map2 = map->passive.accept_map;
 791                        goto received;
 792                }
 793                if (nonblock) {
 794                        pvcalls_exit_sock(sock);
 795                        return -EAGAIN;
 796                }
 797                if (wait_event_interruptible(map->passive.inflight_accept_req,
 798                        !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 799                                          (void *)&map->passive.flags))) {
 800                        pvcalls_exit_sock(sock);
 801                        return -EINTR;
 802                }
 803        }
 804
 805        map2 = kzalloc(sizeof(*map2), GFP_KERNEL);
 806        if (map2 == NULL) {
 807                clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 808                          (void *)&map->passive.flags);
 809                pvcalls_exit_sock(sock);
 810                return -ENOMEM;
 811        }
 812        ret = alloc_active_ring(map2);
 813        if (ret < 0) {
 814                clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 815                                (void *)&map->passive.flags);
 816                kfree(map2);
 817                pvcalls_exit_sock(sock);
 818                return ret;
 819        }
 820        spin_lock(&bedata->socket_lock);
 821        ret = get_request(bedata, &req_id);
 822        if (ret < 0) {
 823                clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 824                          (void *)&map->passive.flags);
 825                spin_unlock(&bedata->socket_lock);
 826                free_active_ring(map2);
 827                kfree(map2);
 828                pvcalls_exit_sock(sock);
 829                return ret;
 830        }
 831
 832        ret = create_active(map2, &evtchn);
 833        if (ret < 0) {
 834                free_active_ring(map2);
 835                kfree(map2);
 836                clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 837                          (void *)&map->passive.flags);
 838                spin_unlock(&bedata->socket_lock);
 839                pvcalls_exit_sock(sock);
 840                return ret;
 841        }
 842        list_add_tail(&map2->list, &bedata->socket_mappings);
 843
 844        req = RING_GET_REQUEST(&bedata->ring, req_id);
 845        req->req_id = req_id;
 846        req->cmd = PVCALLS_ACCEPT;
 847        req->u.accept.id = (uintptr_t) map;
 848        req->u.accept.ref = map2->active.ref;
 849        req->u.accept.id_new = (uintptr_t) map2;
 850        req->u.accept.evtchn = evtchn;
 851        map->passive.accept_map = map2;
 852
 853        bedata->ring.req_prod_pvt++;
 854        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
 855        spin_unlock(&bedata->socket_lock);
 856        if (notify)
 857                notify_remote_via_irq(bedata->irq);
 858        /* We could check if we have received a response before returning. */
 859        if (nonblock) {
 860                WRITE_ONCE(map->passive.inflight_req_id, req_id);
 861                pvcalls_exit_sock(sock);
 862                return -EAGAIN;
 863        }
 864
 865        if (wait_event_interruptible(bedata->inflight_req,
 866                READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) {
 867                pvcalls_exit_sock(sock);
 868                return -EINTR;
 869        }
 870        /* read req_id, then the content */
 871        smp_rmb();
 872
 873received:
 874        map2->sock = newsock;
 875        newsock->sk = sk_alloc(sock_net(sock->sk), PF_INET, GFP_KERNEL, &pvcalls_proto, false);
 876        if (!newsock->sk) {
 877                bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
 878                map->passive.inflight_req_id = PVCALLS_INVALID_ID;
 879                clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 880                          (void *)&map->passive.flags);
 881                pvcalls_front_free_map(bedata, map2);
 882                pvcalls_exit_sock(sock);
 883                return -ENOMEM;
 884        }
 885        newsock->sk->sk_send_head = (void *)map2;
 886
 887        ret = bedata->rsp[req_id].ret;
 888        bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
 889        map->passive.inflight_req_id = PVCALLS_INVALID_ID;
 890
 891        clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags);
 892        wake_up(&map->passive.inflight_accept_req);
 893
 894        pvcalls_exit_sock(sock);
 895        return ret;
 896}
 897
 898static __poll_t pvcalls_front_poll_passive(struct file *file,
 899                                               struct pvcalls_bedata *bedata,
 900                                               struct sock_mapping *map,
 901                                               poll_table *wait)
 902{
 903        int notify, req_id, ret;
 904        struct xen_pvcalls_request *req;
 905
 906        if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 907                     (void *)&map->passive.flags)) {
 908                uint32_t req_id = READ_ONCE(map->passive.inflight_req_id);
 909
 910                if (req_id != PVCALLS_INVALID_ID &&
 911                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id)
 912                        return EPOLLIN | EPOLLRDNORM;
 913
 914                poll_wait(file, &map->passive.inflight_accept_req, wait);
 915                return 0;
 916        }
 917
 918        if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET,
 919                               (void *)&map->passive.flags))
 920                return EPOLLIN | EPOLLRDNORM;
 921
 922        /*
 923         * First check RET, then INFLIGHT. No barriers necessary to
 924         * ensure execution ordering because of the conditional
 925         * instructions creating control dependencies.
 926         */
 927
 928        if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT,
 929                             (void *)&map->passive.flags)) {
 930                poll_wait(file, &bedata->inflight_req, wait);
 931                return 0;
 932        }
 933
 934        spin_lock(&bedata->socket_lock);
 935        ret = get_request(bedata, &req_id);
 936        if (ret < 0) {
 937                spin_unlock(&bedata->socket_lock);
 938                return ret;
 939        }
 940        req = RING_GET_REQUEST(&bedata->ring, req_id);
 941        req->req_id = req_id;
 942        req->cmd = PVCALLS_POLL;
 943        req->u.poll.id = (uintptr_t) map;
 944
 945        bedata->ring.req_prod_pvt++;
 946        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
 947        spin_unlock(&bedata->socket_lock);
 948        if (notify)
 949                notify_remote_via_irq(bedata->irq);
 950
 951        poll_wait(file, &bedata->inflight_req, wait);
 952        return 0;
 953}
 954
 955static __poll_t pvcalls_front_poll_active(struct file *file,
 956                                              struct pvcalls_bedata *bedata,
 957                                              struct sock_mapping *map,
 958                                              poll_table *wait)
 959{
 960        __poll_t mask = 0;
 961        int32_t in_error, out_error;
 962        struct pvcalls_data_intf *intf = map->active.ring;
 963
 964        out_error = intf->out_error;
 965        in_error = intf->in_error;
 966
 967        poll_wait(file, &map->active.inflight_conn_req, wait);
 968        if (pvcalls_front_write_todo(map))
 969                mask |= EPOLLOUT | EPOLLWRNORM;
 970        if (pvcalls_front_read_todo(map))
 971                mask |= EPOLLIN | EPOLLRDNORM;
 972        if (in_error != 0 || out_error != 0)
 973                mask |= EPOLLERR;
 974
 975        return mask;
 976}
 977
 978__poll_t pvcalls_front_poll(struct file *file, struct socket *sock,
 979                               poll_table *wait)
 980{
 981        struct pvcalls_bedata *bedata;
 982        struct sock_mapping *map;
 983        __poll_t ret;
 984
 985        map = pvcalls_enter_sock(sock);
 986        if (IS_ERR(map))
 987                return EPOLLNVAL;
 988        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 989
 990        if (map->active_socket)
 991                ret = pvcalls_front_poll_active(file, bedata, map, wait);
 992        else
 993                ret = pvcalls_front_poll_passive(file, bedata, map, wait);
 994        pvcalls_exit_sock(sock);
 995        return ret;
 996}
 997
 998int pvcalls_front_release(struct socket *sock)
 999{
1000        struct pvcalls_bedata *bedata;
1001        struct sock_mapping *map;
1002        int req_id, notify, ret;
1003        struct xen_pvcalls_request *req;
1004
1005        if (sock->sk == NULL)
1006                return 0;
1007
1008        map = pvcalls_enter_sock(sock);
1009        if (IS_ERR(map)) {
1010                if (PTR_ERR(map) == -ENOTCONN)
1011                        return -EIO;
1012                else
1013                        return 0;
1014        }
1015        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
1016
1017        spin_lock(&bedata->socket_lock);
1018        ret = get_request(bedata, &req_id);
1019        if (ret < 0) {
1020                spin_unlock(&bedata->socket_lock);
1021                pvcalls_exit_sock(sock);
1022                return ret;
1023        }
1024        sock->sk->sk_send_head = NULL;
1025
1026        req = RING_GET_REQUEST(&bedata->ring, req_id);
1027        req->req_id = req_id;
1028        req->cmd = PVCALLS_RELEASE;
1029        req->u.release.id = (uintptr_t)map;
1030
1031        bedata->ring.req_prod_pvt++;
1032        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
1033        spin_unlock(&bedata->socket_lock);
1034        if (notify)
1035                notify_remote_via_irq(bedata->irq);
1036
1037        wait_event(bedata->inflight_req,
1038                   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
1039
1040        if (map->active_socket) {
1041                /*
1042                 * Set in_error and wake up inflight_conn_req to force
1043                 * recvmsg waiters to exit.
1044                 */
1045                map->active.ring->in_error = -EBADF;
1046                wake_up_interruptible(&map->active.inflight_conn_req);
1047
1048                /*
1049                 * We need to make sure that sendmsg/recvmsg on this socket have
1050                 * not started before we've cleared sk_send_head here. The
1051                 * easiest way to guarantee this is to see that no pvcalls
1052                 * (other than us) is in progress on this socket.
1053                 */
1054                while (atomic_read(&map->refcount) > 1)
1055                        cpu_relax();
1056
1057                pvcalls_front_free_map(bedata, map);
1058        } else {
1059                wake_up(&bedata->inflight_req);
1060                wake_up(&map->passive.inflight_accept_req);
1061
1062                while (atomic_read(&map->refcount) > 1)
1063                        cpu_relax();
1064
1065                spin_lock(&bedata->socket_lock);
1066                list_del(&map->list);
1067                spin_unlock(&bedata->socket_lock);
1068                if (READ_ONCE(map->passive.inflight_req_id) != PVCALLS_INVALID_ID &&
1069                        READ_ONCE(map->passive.inflight_req_id) != 0) {
1070                        pvcalls_front_free_map(bedata,
1071                                               map->passive.accept_map);
1072                }
1073                kfree(map);
1074        }
1075        WRITE_ONCE(bedata->rsp[req_id].req_id, PVCALLS_INVALID_ID);
1076
1077        pvcalls_exit();
1078        return 0;
1079}
1080
1081static const struct xenbus_device_id pvcalls_front_ids[] = {
1082        { "pvcalls" },
1083        { "" }
1084};
1085
1086static int pvcalls_front_remove(struct xenbus_device *dev)
1087{
1088        struct pvcalls_bedata *bedata;
1089        struct sock_mapping *map = NULL, *n;
1090
1091        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
1092        dev_set_drvdata(&dev->dev, NULL);
1093        pvcalls_front_dev = NULL;
1094        if (bedata->irq >= 0)
1095                unbind_from_irqhandler(bedata->irq, dev);
1096
1097        list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) {
1098                map->sock->sk->sk_send_head = NULL;
1099                if (map->active_socket) {
1100                        map->active.ring->in_error = -EBADF;
1101                        wake_up_interruptible(&map->active.inflight_conn_req);
1102                }
1103        }
1104
1105        smp_mb();
1106        while (atomic_read(&pvcalls_refcount) > 0)
1107                cpu_relax();
1108        list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) {
1109                if (map->active_socket) {
1110                        /* No need to lock, refcount is 0 */
1111                        pvcalls_front_free_map(bedata, map);
1112                } else {
1113                        list_del(&map->list);
1114                        kfree(map);
1115                }
1116        }
1117        if (bedata->ref != -1)
1118                gnttab_end_foreign_access(bedata->ref, 0, 0);
1119        kfree(bedata->ring.sring);
1120        kfree(bedata);
1121        xenbus_switch_state(dev, XenbusStateClosed);
1122        return 0;
1123}
1124
1125static int pvcalls_front_probe(struct xenbus_device *dev,
1126                          const struct xenbus_device_id *id)
1127{
1128        int ret = -ENOMEM, evtchn, i;
1129        unsigned int max_page_order, function_calls, len;
1130        char *versions;
1131        grant_ref_t gref_head = 0;
1132        struct xenbus_transaction xbt;
1133        struct pvcalls_bedata *bedata = NULL;
1134        struct xen_pvcalls_sring *sring;
1135
1136        if (pvcalls_front_dev != NULL) {
1137                dev_err(&dev->dev, "only one PV Calls connection supported\n");
1138                return -EINVAL;
1139        }
1140
1141        versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len);
1142        if (IS_ERR(versions))
1143                return PTR_ERR(versions);
1144        if (!len)
1145                return -EINVAL;
1146        if (strcmp(versions, "1")) {
1147                kfree(versions);
1148                return -EINVAL;
1149        }
1150        kfree(versions);
1151        max_page_order = xenbus_read_unsigned(dev->otherend,
1152                                              "max-page-order", 0);
1153        if (max_page_order < PVCALLS_RING_ORDER)
1154                return -ENODEV;
1155        function_calls = xenbus_read_unsigned(dev->otherend,
1156                                              "function-calls", 0);
1157        /* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */
1158        if (function_calls != 1)
1159                return -ENODEV;
1160        pr_info("%s max-page-order is %u\n", __func__, max_page_order);
1161
1162        bedata = kzalloc(sizeof(struct pvcalls_bedata), GFP_KERNEL);
1163        if (!bedata)
1164                return -ENOMEM;
1165
1166        dev_set_drvdata(&dev->dev, bedata);
1167        pvcalls_front_dev = dev;
1168        init_waitqueue_head(&bedata->inflight_req);
1169        INIT_LIST_HEAD(&bedata->socket_mappings);
1170        spin_lock_init(&bedata->socket_lock);
1171        bedata->irq = -1;
1172        bedata->ref = -1;
1173
1174        for (i = 0; i < PVCALLS_NR_RSP_PER_RING; i++)
1175                bedata->rsp[i].req_id = PVCALLS_INVALID_ID;
1176
1177        sring = (struct xen_pvcalls_sring *) __get_free_page(GFP_KERNEL |
1178                                                             __GFP_ZERO);
1179        if (!sring)
1180                goto error;
1181        SHARED_RING_INIT(sring);
1182        FRONT_RING_INIT(&bedata->ring, sring, XEN_PAGE_SIZE);
1183
1184        ret = xenbus_alloc_evtchn(dev, &evtchn);
1185        if (ret)
1186                goto error;
1187
1188        bedata->irq = bind_evtchn_to_irqhandler(evtchn,
1189                                                pvcalls_front_event_handler,
1190                                                0, "pvcalls-frontend", dev);
1191        if (bedata->irq < 0) {
1192                ret = bedata->irq;
1193                goto error;
1194        }
1195
1196        ret = gnttab_alloc_grant_references(1, &gref_head);
1197        if (ret < 0)
1198                goto error;
1199        ret = gnttab_claim_grant_reference(&gref_head);
1200        if (ret < 0)
1201                goto error;
1202        bedata->ref = ret;
1203        gnttab_grant_foreign_access_ref(bedata->ref, dev->otherend_id,
1204                                        virt_to_gfn((void *)sring), 0);
1205
1206 again:
1207        ret = xenbus_transaction_start(&xbt);
1208        if (ret) {
1209                xenbus_dev_fatal(dev, ret, "starting transaction");
1210                goto error;
1211        }
1212        ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1);
1213        if (ret)
1214                goto error_xenbus;
1215        ret = xenbus_printf(xbt, dev->nodename, "ring-ref", "%d", bedata->ref);
1216        if (ret)
1217                goto error_xenbus;
1218        ret = xenbus_printf(xbt, dev->nodename, "port", "%u",
1219                            evtchn);
1220        if (ret)
1221                goto error_xenbus;
1222        ret = xenbus_transaction_end(xbt, 0);
1223        if (ret) {
1224                if (ret == -EAGAIN)
1225                        goto again;
1226                xenbus_dev_fatal(dev, ret, "completing transaction");
1227                goto error;
1228        }
1229        xenbus_switch_state(dev, XenbusStateInitialised);
1230
1231        return 0;
1232
1233 error_xenbus:
1234        xenbus_transaction_end(xbt, 1);
1235        xenbus_dev_fatal(dev, ret, "writing xenstore");
1236 error:
1237        pvcalls_front_remove(dev);
1238        return ret;
1239}
1240
1241static void pvcalls_front_changed(struct xenbus_device *dev,
1242                            enum xenbus_state backend_state)
1243{
1244        switch (backend_state) {
1245        case XenbusStateReconfiguring:
1246        case XenbusStateReconfigured:
1247        case XenbusStateInitialising:
1248        case XenbusStateInitialised:
1249        case XenbusStateUnknown:
1250                break;
1251
1252        case XenbusStateInitWait:
1253                break;
1254
1255        case XenbusStateConnected:
1256                xenbus_switch_state(dev, XenbusStateConnected);
1257                break;
1258
1259        case XenbusStateClosed:
1260                if (dev->state == XenbusStateClosed)
1261                        break;
1262                /* Missed the backend's CLOSING state */
1263                /* fall through */
1264        case XenbusStateClosing:
1265                xenbus_frontend_closed(dev);
1266                break;
1267        }
1268}
1269
1270static struct xenbus_driver pvcalls_front_driver = {
1271        .ids = pvcalls_front_ids,
1272        .probe = pvcalls_front_probe,
1273        .remove = pvcalls_front_remove,
1274        .otherend_changed = pvcalls_front_changed,
1275};
1276
1277static int __init pvcalls_frontend_init(void)
1278{
1279        if (!xen_domain())
1280                return -ENODEV;
1281
1282        pr_info("Initialising Xen pvcalls frontend driver\n");
1283
1284        return xenbus_register_frontend(&pvcalls_front_driver);
1285}
1286
1287module_init(pvcalls_frontend_init);
1288
1289MODULE_DESCRIPTION("Xen PV Calls frontend driver");
1290MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>");
1291MODULE_LICENSE("GPL");
1292