linux/drivers/xen/pvcalls-front.c
<<
>>
Prefs
   1/*
   2 * (c) 2017 Stefano Stabellini <stefano@aporeto.com>
   3 *
   4 * This program is free software; you can redistribute it and/or modify
   5 * it under the terms of the GNU General Public License as published by
   6 * the Free Software Foundation; either version 2 of the License, or
   7 * (at your option) any later version.
   8 *
   9 * This program is distributed in the hope that it will be useful,
  10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  12 * GNU General Public License for more details.
  13 */
  14
  15#include <linux/module.h>
  16#include <linux/net.h>
  17#include <linux/socket.h>
  18
  19#include <net/sock.h>
  20
  21#include <xen/events.h>
  22#include <xen/grant_table.h>
  23#include <xen/xen.h>
  24#include <xen/xenbus.h>
  25#include <xen/interface/io/pvcalls.h>
  26
  27#include "pvcalls-front.h"
  28
  29#define PVCALLS_INVALID_ID UINT_MAX
  30#define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER
  31#define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE)
  32#define PVCALLS_FRONT_MAX_SPIN 5000
  33
  34struct pvcalls_bedata {
  35        struct xen_pvcalls_front_ring ring;
  36        grant_ref_t ref;
  37        int irq;
  38
  39        struct list_head socket_mappings;
  40        spinlock_t socket_lock;
  41
  42        wait_queue_head_t inflight_req;
  43        struct xen_pvcalls_response rsp[PVCALLS_NR_RSP_PER_RING];
  44};
  45/* Only one front/back connection supported. */
  46static struct xenbus_device *pvcalls_front_dev;
  47static atomic_t pvcalls_refcount;
  48
  49/* first increment refcount, then proceed */
  50#define pvcalls_enter() {               \
  51        atomic_inc(&pvcalls_refcount);      \
  52}
  53
  54/* first complete other operations, then decrement refcount */
  55#define pvcalls_exit() {                \
  56        atomic_dec(&pvcalls_refcount);      \
  57}
  58
  59struct sock_mapping {
  60        bool active_socket;
  61        struct list_head list;
  62        struct socket *sock;
  63        atomic_t refcount;
  64        union {
  65                struct {
  66                        int irq;
  67                        grant_ref_t ref;
  68                        struct pvcalls_data_intf *ring;
  69                        struct pvcalls_data data;
  70                        struct mutex in_mutex;
  71                        struct mutex out_mutex;
  72
  73                        wait_queue_head_t inflight_conn_req;
  74                } active;
  75                struct {
  76                /*
  77                 * Socket status, needs to be 64-bit aligned due to the
  78                 * test_and_* functions which have this requirement on arm64.
  79                 */
  80#define PVCALLS_STATUS_UNINITALIZED  0
  81#define PVCALLS_STATUS_BIND          1
  82#define PVCALLS_STATUS_LISTEN        2
  83                        uint8_t status __attribute__((aligned(8)));
  84                /*
  85                 * Internal state-machine flags.
  86                 * Only one accept operation can be inflight for a socket.
  87                 * Only one poll operation can be inflight for a given socket.
  88                 * flags needs to be 64-bit aligned due to the test_and_*
  89                 * functions which have this requirement on arm64.
  90                 */
  91#define PVCALLS_FLAG_ACCEPT_INFLIGHT 0
  92#define PVCALLS_FLAG_POLL_INFLIGHT   1
  93#define PVCALLS_FLAG_POLL_RET        2
  94                        uint8_t flags __attribute__((aligned(8)));
  95                        uint32_t inflight_req_id;
  96                        struct sock_mapping *accept_map;
  97                        wait_queue_head_t inflight_accept_req;
  98                } passive;
  99        };
 100};
 101
 102static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock)
 103{
 104        struct sock_mapping *map;
 105
 106        if (!pvcalls_front_dev ||
 107                dev_get_drvdata(&pvcalls_front_dev->dev) == NULL)
 108                return ERR_PTR(-ENOTCONN);
 109
 110        map = (struct sock_mapping *)sock->sk->sk_send_head;
 111        if (map == NULL)
 112                return ERR_PTR(-ENOTSOCK);
 113
 114        pvcalls_enter();
 115        atomic_inc(&map->refcount);
 116        return map;
 117}
 118
 119static inline void pvcalls_exit_sock(struct socket *sock)
 120{
 121        struct sock_mapping *map;
 122
 123        map = (struct sock_mapping *)sock->sk->sk_send_head;
 124        atomic_dec(&map->refcount);
 125        pvcalls_exit();
 126}
 127
 128static inline int get_request(struct pvcalls_bedata *bedata, int *req_id)
 129{
 130        *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1);
 131        if (RING_FULL(&bedata->ring) ||
 132            bedata->rsp[*req_id].req_id != PVCALLS_INVALID_ID)
 133                return -EAGAIN;
 134        return 0;
 135}
 136
 137static bool pvcalls_front_write_todo(struct sock_mapping *map)
 138{
 139        struct pvcalls_data_intf *intf = map->active.ring;
 140        RING_IDX cons, prod, size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
 141        int32_t error;
 142
 143        error = intf->out_error;
 144        if (error == -ENOTCONN)
 145                return false;
 146        if (error != 0)
 147                return true;
 148
 149        cons = intf->out_cons;
 150        prod = intf->out_prod;
 151        return !!(size - pvcalls_queued(prod, cons, size));
 152}
 153
 154static bool pvcalls_front_read_todo(struct sock_mapping *map)
 155{
 156        struct pvcalls_data_intf *intf = map->active.ring;
 157        RING_IDX cons, prod;
 158        int32_t error;
 159
 160        cons = intf->in_cons;
 161        prod = intf->in_prod;
 162        error = intf->in_error;
 163        return (error != 0 ||
 164                pvcalls_queued(prod, cons,
 165                               XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) != 0);
 166}
 167
 168static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id)
 169{
 170        struct xenbus_device *dev = dev_id;
 171        struct pvcalls_bedata *bedata;
 172        struct xen_pvcalls_response *rsp;
 173        uint8_t *src, *dst;
 174        int req_id = 0, more = 0, done = 0;
 175
 176        if (dev == NULL)
 177                return IRQ_HANDLED;
 178
 179        pvcalls_enter();
 180        bedata = dev_get_drvdata(&dev->dev);
 181        if (bedata == NULL) {
 182                pvcalls_exit();
 183                return IRQ_HANDLED;
 184        }
 185
 186again:
 187        while (RING_HAS_UNCONSUMED_RESPONSES(&bedata->ring)) {
 188                rsp = RING_GET_RESPONSE(&bedata->ring, bedata->ring.rsp_cons);
 189
 190                req_id = rsp->req_id;
 191                if (rsp->cmd == PVCALLS_POLL) {
 192                        struct sock_mapping *map = (struct sock_mapping *)(uintptr_t)
 193                                                   rsp->u.poll.id;
 194
 195                        clear_bit(PVCALLS_FLAG_POLL_INFLIGHT,
 196                                  (void *)&map->passive.flags);
 197                        /*
 198                         * clear INFLIGHT, then set RET. It pairs with
 199                         * the checks at the beginning of
 200                         * pvcalls_front_poll_passive.
 201                         */
 202                        smp_wmb();
 203                        set_bit(PVCALLS_FLAG_POLL_RET,
 204                                (void *)&map->passive.flags);
 205                } else {
 206                        dst = (uint8_t *)&bedata->rsp[req_id] +
 207                              sizeof(rsp->req_id);
 208                        src = (uint8_t *)rsp + sizeof(rsp->req_id);
 209                        memcpy(dst, src, sizeof(*rsp) - sizeof(rsp->req_id));
 210                        /*
 211                         * First copy the rest of the data, then req_id. It is
 212                         * paired with the barrier when accessing bedata->rsp.
 213                         */
 214                        smp_wmb();
 215                        bedata->rsp[req_id].req_id = req_id;
 216                }
 217
 218                done = 1;
 219                bedata->ring.rsp_cons++;
 220        }
 221
 222        RING_FINAL_CHECK_FOR_RESPONSES(&bedata->ring, more);
 223        if (more)
 224                goto again;
 225        if (done)
 226                wake_up(&bedata->inflight_req);
 227        pvcalls_exit();
 228        return IRQ_HANDLED;
 229}
 230
 231static void pvcalls_front_free_map(struct pvcalls_bedata *bedata,
 232                                   struct sock_mapping *map)
 233{
 234        int i;
 235
 236        unbind_from_irqhandler(map->active.irq, map);
 237
 238        spin_lock(&bedata->socket_lock);
 239        if (!list_empty(&map->list))
 240                list_del_init(&map->list);
 241        spin_unlock(&bedata->socket_lock);
 242
 243        for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++)
 244                gnttab_end_foreign_access(map->active.ring->ref[i], 0, 0);
 245        gnttab_end_foreign_access(map->active.ref, 0, 0);
 246        free_page((unsigned long)map->active.ring);
 247
 248        kfree(map);
 249}
 250
 251static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map)
 252{
 253        struct sock_mapping *map = sock_map;
 254
 255        if (map == NULL)
 256                return IRQ_HANDLED;
 257
 258        wake_up_interruptible(&map->active.inflight_conn_req);
 259
 260        return IRQ_HANDLED;
 261}
 262
 263int pvcalls_front_socket(struct socket *sock)
 264{
 265        struct pvcalls_bedata *bedata;
 266        struct sock_mapping *map = NULL;
 267        struct xen_pvcalls_request *req;
 268        int notify, req_id, ret;
 269
 270        /*
 271         * PVCalls only supports domain AF_INET,
 272         * type SOCK_STREAM and protocol 0 sockets for now.
 273         *
 274         * Check socket type here, AF_INET and protocol checks are done
 275         * by the caller.
 276         */
 277        if (sock->type != SOCK_STREAM)
 278                return -EOPNOTSUPP;
 279
 280        pvcalls_enter();
 281        if (!pvcalls_front_dev) {
 282                pvcalls_exit();
 283                return -EACCES;
 284        }
 285        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 286
 287        map = kzalloc(sizeof(*map), GFP_KERNEL);
 288        if (map == NULL) {
 289                pvcalls_exit();
 290                return -ENOMEM;
 291        }
 292
 293        spin_lock(&bedata->socket_lock);
 294
 295        ret = get_request(bedata, &req_id);
 296        if (ret < 0) {
 297                kfree(map);
 298                spin_unlock(&bedata->socket_lock);
 299                pvcalls_exit();
 300                return ret;
 301        }
 302
 303        /*
 304         * sock->sk->sk_send_head is not used for ip sockets: reuse the
 305         * field to store a pointer to the struct sock_mapping
 306         * corresponding to the socket. This way, we can easily get the
 307         * struct sock_mapping from the struct socket.
 308         */
 309        sock->sk->sk_send_head = (void *)map;
 310        list_add_tail(&map->list, &bedata->socket_mappings);
 311
 312        req = RING_GET_REQUEST(&bedata->ring, req_id);
 313        req->req_id = req_id;
 314        req->cmd = PVCALLS_SOCKET;
 315        req->u.socket.id = (uintptr_t) map;
 316        req->u.socket.domain = AF_INET;
 317        req->u.socket.type = SOCK_STREAM;
 318        req->u.socket.protocol = IPPROTO_IP;
 319
 320        bedata->ring.req_prod_pvt++;
 321        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
 322        spin_unlock(&bedata->socket_lock);
 323        if (notify)
 324                notify_remote_via_irq(bedata->irq);
 325
 326        wait_event(bedata->inflight_req,
 327                   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
 328
 329        /* read req_id, then the content */
 330        smp_rmb();
 331        ret = bedata->rsp[req_id].ret;
 332        bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
 333
 334        pvcalls_exit();
 335        return ret;
 336}
 337
 338static int create_active(struct sock_mapping *map, int *evtchn)
 339{
 340        void *bytes;
 341        int ret = -ENOMEM, irq = -1, i;
 342
 343        *evtchn = -1;
 344        init_waitqueue_head(&map->active.inflight_conn_req);
 345
 346        map->active.ring = (struct pvcalls_data_intf *)
 347                __get_free_page(GFP_KERNEL | __GFP_ZERO);
 348        if (map->active.ring == NULL)
 349                goto out_error;
 350        map->active.ring->ring_order = PVCALLS_RING_ORDER;
 351        bytes = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO,
 352                                        PVCALLS_RING_ORDER);
 353        if (bytes == NULL)
 354                goto out_error;
 355        for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++)
 356                map->active.ring->ref[i] = gnttab_grant_foreign_access(
 357                        pvcalls_front_dev->otherend_id,
 358                        pfn_to_gfn(virt_to_pfn(bytes) + i), 0);
 359
 360        map->active.ref = gnttab_grant_foreign_access(
 361                pvcalls_front_dev->otherend_id,
 362                pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0);
 363
 364        map->active.data.in = bytes;
 365        map->active.data.out = bytes +
 366                XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
 367
 368        ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn);
 369        if (ret)
 370                goto out_error;
 371        irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler,
 372                                        0, "pvcalls-frontend", map);
 373        if (irq < 0) {
 374                ret = irq;
 375                goto out_error;
 376        }
 377
 378        map->active.irq = irq;
 379        map->active_socket = true;
 380        mutex_init(&map->active.in_mutex);
 381        mutex_init(&map->active.out_mutex);
 382
 383        return 0;
 384
 385out_error:
 386        if (*evtchn >= 0)
 387                xenbus_free_evtchn(pvcalls_front_dev, *evtchn);
 388        kfree(map->active.data.in);
 389        kfree(map->active.ring);
 390        return ret;
 391}
 392
 393int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr,
 394                                int addr_len, int flags)
 395{
 396        struct pvcalls_bedata *bedata;
 397        struct sock_mapping *map = NULL;
 398        struct xen_pvcalls_request *req;
 399        int notify, req_id, ret, evtchn;
 400
 401        if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
 402                return -EOPNOTSUPP;
 403
 404        map = pvcalls_enter_sock(sock);
 405        if (IS_ERR(map))
 406                return PTR_ERR(map);
 407
 408        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 409
 410        spin_lock(&bedata->socket_lock);
 411        ret = get_request(bedata, &req_id);
 412        if (ret < 0) {
 413                spin_unlock(&bedata->socket_lock);
 414                pvcalls_exit_sock(sock);
 415                return ret;
 416        }
 417        ret = create_active(map, &evtchn);
 418        if (ret < 0) {
 419                spin_unlock(&bedata->socket_lock);
 420                pvcalls_exit_sock(sock);
 421                return ret;
 422        }
 423
 424        req = RING_GET_REQUEST(&bedata->ring, req_id);
 425        req->req_id = req_id;
 426        req->cmd = PVCALLS_CONNECT;
 427        req->u.connect.id = (uintptr_t)map;
 428        req->u.connect.len = addr_len;
 429        req->u.connect.flags = flags;
 430        req->u.connect.ref = map->active.ref;
 431        req->u.connect.evtchn = evtchn;
 432        memcpy(req->u.connect.addr, addr, sizeof(*addr));
 433
 434        map->sock = sock;
 435
 436        bedata->ring.req_prod_pvt++;
 437        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
 438        spin_unlock(&bedata->socket_lock);
 439
 440        if (notify)
 441                notify_remote_via_irq(bedata->irq);
 442
 443        wait_event(bedata->inflight_req,
 444                   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
 445
 446        /* read req_id, then the content */
 447        smp_rmb();
 448        ret = bedata->rsp[req_id].ret;
 449        bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
 450        pvcalls_exit_sock(sock);
 451        return ret;
 452}
 453
 454static int __write_ring(struct pvcalls_data_intf *intf,
 455                        struct pvcalls_data *data,
 456                        struct iov_iter *msg_iter,
 457                        int len)
 458{
 459        RING_IDX cons, prod, size, masked_prod, masked_cons;
 460        RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
 461        int32_t error;
 462
 463        error = intf->out_error;
 464        if (error < 0)
 465                return error;
 466        cons = intf->out_cons;
 467        prod = intf->out_prod;
 468        /* read indexes before continuing */
 469        virt_mb();
 470
 471        size = pvcalls_queued(prod, cons, array_size);
 472        if (size >= array_size)
 473                return -EINVAL;
 474        if (len > array_size - size)
 475                len = array_size - size;
 476
 477        masked_prod = pvcalls_mask(prod, array_size);
 478        masked_cons = pvcalls_mask(cons, array_size);
 479
 480        if (masked_prod < masked_cons) {
 481                len = copy_from_iter(data->out + masked_prod, len, msg_iter);
 482        } else {
 483                if (len > array_size - masked_prod) {
 484                        int ret = copy_from_iter(data->out + masked_prod,
 485                                       array_size - masked_prod, msg_iter);
 486                        if (ret != array_size - masked_prod) {
 487                                len = ret;
 488                                goto out;
 489                        }
 490                        len = ret + copy_from_iter(data->out, len - ret, msg_iter);
 491                } else {
 492                        len = copy_from_iter(data->out + masked_prod, len, msg_iter);
 493                }
 494        }
 495out:
 496        /* write to ring before updating pointer */
 497        virt_wmb();
 498        intf->out_prod += len;
 499
 500        return len;
 501}
 502
 503int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg,
 504                          size_t len)
 505{
 506        struct pvcalls_bedata *bedata;
 507        struct sock_mapping *map;
 508        int sent, tot_sent = 0;
 509        int count = 0, flags;
 510
 511        flags = msg->msg_flags;
 512        if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB))
 513                return -EOPNOTSUPP;
 514
 515        map = pvcalls_enter_sock(sock);
 516        if (IS_ERR(map))
 517                return PTR_ERR(map);
 518        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 519
 520        mutex_lock(&map->active.out_mutex);
 521        if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) {
 522                mutex_unlock(&map->active.out_mutex);
 523                pvcalls_exit_sock(sock);
 524                return -EAGAIN;
 525        }
 526        if (len > INT_MAX)
 527                len = INT_MAX;
 528
 529again:
 530        count++;
 531        sent = __write_ring(map->active.ring,
 532                            &map->active.data, &msg->msg_iter,
 533                            len);
 534        if (sent > 0) {
 535                len -= sent;
 536                tot_sent += sent;
 537                notify_remote_via_irq(map->active.irq);
 538        }
 539        if (sent >= 0 && len > 0 && count < PVCALLS_FRONT_MAX_SPIN)
 540                goto again;
 541        if (sent < 0)
 542                tot_sent = sent;
 543
 544        mutex_unlock(&map->active.out_mutex);
 545        pvcalls_exit_sock(sock);
 546        return tot_sent;
 547}
 548
 549static int __read_ring(struct pvcalls_data_intf *intf,
 550                       struct pvcalls_data *data,
 551                       struct iov_iter *msg_iter,
 552                       size_t len, int flags)
 553{
 554        RING_IDX cons, prod, size, masked_prod, masked_cons;
 555        RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
 556        int32_t error;
 557
 558        cons = intf->in_cons;
 559        prod = intf->in_prod;
 560        error = intf->in_error;
 561        /* get pointers before reading from the ring */
 562        virt_rmb();
 563        if (error < 0)
 564                return error;
 565
 566        size = pvcalls_queued(prod, cons, array_size);
 567        masked_prod = pvcalls_mask(prod, array_size);
 568        masked_cons = pvcalls_mask(cons, array_size);
 569
 570        if (size == 0)
 571                return 0;
 572
 573        if (len > size)
 574                len = size;
 575
 576        if (masked_prod > masked_cons) {
 577                len = copy_to_iter(data->in + masked_cons, len, msg_iter);
 578        } else {
 579                if (len > (array_size - masked_cons)) {
 580                        int ret = copy_to_iter(data->in + masked_cons,
 581                                     array_size - masked_cons, msg_iter);
 582                        if (ret != array_size - masked_cons) {
 583                                len = ret;
 584                                goto out;
 585                        }
 586                        len = ret + copy_to_iter(data->in, len - ret, msg_iter);
 587                } else {
 588                        len = copy_to_iter(data->in + masked_cons, len, msg_iter);
 589                }
 590        }
 591out:
 592        /* read data from the ring before increasing the index */
 593        virt_mb();
 594        if (!(flags & MSG_PEEK))
 595                intf->in_cons += len;
 596
 597        return len;
 598}
 599
 600int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 601                     int flags)
 602{
 603        struct pvcalls_bedata *bedata;
 604        int ret;
 605        struct sock_mapping *map;
 606
 607        if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC))
 608                return -EOPNOTSUPP;
 609
 610        map = pvcalls_enter_sock(sock);
 611        if (IS_ERR(map))
 612                return PTR_ERR(map);
 613        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 614
 615        mutex_lock(&map->active.in_mutex);
 616        if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER))
 617                len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER);
 618
 619        while (!(flags & MSG_DONTWAIT) && !pvcalls_front_read_todo(map)) {
 620                wait_event_interruptible(map->active.inflight_conn_req,
 621                                         pvcalls_front_read_todo(map));
 622        }
 623        ret = __read_ring(map->active.ring, &map->active.data,
 624                          &msg->msg_iter, len, flags);
 625
 626        if (ret > 0)
 627                notify_remote_via_irq(map->active.irq);
 628        if (ret == 0)
 629                ret = (flags & MSG_DONTWAIT) ? -EAGAIN : 0;
 630        if (ret == -ENOTCONN)
 631                ret = 0;
 632
 633        mutex_unlock(&map->active.in_mutex);
 634        pvcalls_exit_sock(sock);
 635        return ret;
 636}
 637
 638int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len)
 639{
 640        struct pvcalls_bedata *bedata;
 641        struct sock_mapping *map = NULL;
 642        struct xen_pvcalls_request *req;
 643        int notify, req_id, ret;
 644
 645        if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
 646                return -EOPNOTSUPP;
 647
 648        map = pvcalls_enter_sock(sock);
 649        if (IS_ERR(map))
 650                return PTR_ERR(map);
 651        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 652
 653        spin_lock(&bedata->socket_lock);
 654        ret = get_request(bedata, &req_id);
 655        if (ret < 0) {
 656                spin_unlock(&bedata->socket_lock);
 657                pvcalls_exit_sock(sock);
 658                return ret;
 659        }
 660        req = RING_GET_REQUEST(&bedata->ring, req_id);
 661        req->req_id = req_id;
 662        map->sock = sock;
 663        req->cmd = PVCALLS_BIND;
 664        req->u.bind.id = (uintptr_t)map;
 665        memcpy(req->u.bind.addr, addr, sizeof(*addr));
 666        req->u.bind.len = addr_len;
 667
 668        init_waitqueue_head(&map->passive.inflight_accept_req);
 669
 670        map->active_socket = false;
 671
 672        bedata->ring.req_prod_pvt++;
 673        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
 674        spin_unlock(&bedata->socket_lock);
 675        if (notify)
 676                notify_remote_via_irq(bedata->irq);
 677
 678        wait_event(bedata->inflight_req,
 679                   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
 680
 681        /* read req_id, then the content */
 682        smp_rmb();
 683        ret = bedata->rsp[req_id].ret;
 684        bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
 685
 686        map->passive.status = PVCALLS_STATUS_BIND;
 687        pvcalls_exit_sock(sock);
 688        return 0;
 689}
 690
 691int pvcalls_front_listen(struct socket *sock, int backlog)
 692{
 693        struct pvcalls_bedata *bedata;
 694        struct sock_mapping *map;
 695        struct xen_pvcalls_request *req;
 696        int notify, req_id, ret;
 697
 698        map = pvcalls_enter_sock(sock);
 699        if (IS_ERR(map))
 700                return PTR_ERR(map);
 701        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 702
 703        if (map->passive.status != PVCALLS_STATUS_BIND) {
 704                pvcalls_exit_sock(sock);
 705                return -EOPNOTSUPP;
 706        }
 707
 708        spin_lock(&bedata->socket_lock);
 709        ret = get_request(bedata, &req_id);
 710        if (ret < 0) {
 711                spin_unlock(&bedata->socket_lock);
 712                pvcalls_exit_sock(sock);
 713                return ret;
 714        }
 715        req = RING_GET_REQUEST(&bedata->ring, req_id);
 716        req->req_id = req_id;
 717        req->cmd = PVCALLS_LISTEN;
 718        req->u.listen.id = (uintptr_t) map;
 719        req->u.listen.backlog = backlog;
 720
 721        bedata->ring.req_prod_pvt++;
 722        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
 723        spin_unlock(&bedata->socket_lock);
 724        if (notify)
 725                notify_remote_via_irq(bedata->irq);
 726
 727        wait_event(bedata->inflight_req,
 728                   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
 729
 730        /* read req_id, then the content */
 731        smp_rmb();
 732        ret = bedata->rsp[req_id].ret;
 733        bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
 734
 735        map->passive.status = PVCALLS_STATUS_LISTEN;
 736        pvcalls_exit_sock(sock);
 737        return ret;
 738}
 739
 740int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags)
 741{
 742        struct pvcalls_bedata *bedata;
 743        struct sock_mapping *map;
 744        struct sock_mapping *map2 = NULL;
 745        struct xen_pvcalls_request *req;
 746        int notify, req_id, ret, evtchn, nonblock;
 747
 748        map = pvcalls_enter_sock(sock);
 749        if (IS_ERR(map))
 750                return PTR_ERR(map);
 751        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 752
 753        if (map->passive.status != PVCALLS_STATUS_LISTEN) {
 754                pvcalls_exit_sock(sock);
 755                return -EINVAL;
 756        }
 757
 758        nonblock = flags & SOCK_NONBLOCK;
 759        /*
 760         * Backend only supports 1 inflight accept request, will return
 761         * errors for the others
 762         */
 763        if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 764                             (void *)&map->passive.flags)) {
 765                req_id = READ_ONCE(map->passive.inflight_req_id);
 766                if (req_id != PVCALLS_INVALID_ID &&
 767                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id) {
 768                        map2 = map->passive.accept_map;
 769                        goto received;
 770                }
 771                if (nonblock) {
 772                        pvcalls_exit_sock(sock);
 773                        return -EAGAIN;
 774                }
 775                if (wait_event_interruptible(map->passive.inflight_accept_req,
 776                        !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 777                                          (void *)&map->passive.flags))) {
 778                        pvcalls_exit_sock(sock);
 779                        return -EINTR;
 780                }
 781        }
 782
 783        spin_lock(&bedata->socket_lock);
 784        ret = get_request(bedata, &req_id);
 785        if (ret < 0) {
 786                clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 787                          (void *)&map->passive.flags);
 788                spin_unlock(&bedata->socket_lock);
 789                pvcalls_exit_sock(sock);
 790                return ret;
 791        }
 792        map2 = kzalloc(sizeof(*map2), GFP_ATOMIC);
 793        if (map2 == NULL) {
 794                clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 795                          (void *)&map->passive.flags);
 796                spin_unlock(&bedata->socket_lock);
 797                pvcalls_exit_sock(sock);
 798                return -ENOMEM;
 799        }
 800        ret = create_active(map2, &evtchn);
 801        if (ret < 0) {
 802                kfree(map2);
 803                clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 804                          (void *)&map->passive.flags);
 805                spin_unlock(&bedata->socket_lock);
 806                pvcalls_exit_sock(sock);
 807                return ret;
 808        }
 809        list_add_tail(&map2->list, &bedata->socket_mappings);
 810
 811        req = RING_GET_REQUEST(&bedata->ring, req_id);
 812        req->req_id = req_id;
 813        req->cmd = PVCALLS_ACCEPT;
 814        req->u.accept.id = (uintptr_t) map;
 815        req->u.accept.ref = map2->active.ref;
 816        req->u.accept.id_new = (uintptr_t) map2;
 817        req->u.accept.evtchn = evtchn;
 818        map->passive.accept_map = map2;
 819
 820        bedata->ring.req_prod_pvt++;
 821        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
 822        spin_unlock(&bedata->socket_lock);
 823        if (notify)
 824                notify_remote_via_irq(bedata->irq);
 825        /* We could check if we have received a response before returning. */
 826        if (nonblock) {
 827                WRITE_ONCE(map->passive.inflight_req_id, req_id);
 828                pvcalls_exit_sock(sock);
 829                return -EAGAIN;
 830        }
 831
 832        if (wait_event_interruptible(bedata->inflight_req,
 833                READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) {
 834                pvcalls_exit_sock(sock);
 835                return -EINTR;
 836        }
 837        /* read req_id, then the content */
 838        smp_rmb();
 839
 840received:
 841        map2->sock = newsock;
 842        newsock->sk = kzalloc(sizeof(*newsock->sk), GFP_KERNEL);
 843        if (!newsock->sk) {
 844                bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
 845                map->passive.inflight_req_id = PVCALLS_INVALID_ID;
 846                clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 847                          (void *)&map->passive.flags);
 848                pvcalls_front_free_map(bedata, map2);
 849                pvcalls_exit_sock(sock);
 850                return -ENOMEM;
 851        }
 852        newsock->sk->sk_send_head = (void *)map2;
 853
 854        ret = bedata->rsp[req_id].ret;
 855        bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
 856        map->passive.inflight_req_id = PVCALLS_INVALID_ID;
 857
 858        clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags);
 859        wake_up(&map->passive.inflight_accept_req);
 860
 861        pvcalls_exit_sock(sock);
 862        return ret;
 863}
 864
 865static __poll_t pvcalls_front_poll_passive(struct file *file,
 866                                               struct pvcalls_bedata *bedata,
 867                                               struct sock_mapping *map,
 868                                               poll_table *wait)
 869{
 870        int notify, req_id, ret;
 871        struct xen_pvcalls_request *req;
 872
 873        if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT,
 874                     (void *)&map->passive.flags)) {
 875                uint32_t req_id = READ_ONCE(map->passive.inflight_req_id);
 876
 877                if (req_id != PVCALLS_INVALID_ID &&
 878                    READ_ONCE(bedata->rsp[req_id].req_id) == req_id)
 879                        return EPOLLIN | EPOLLRDNORM;
 880
 881                poll_wait(file, &map->passive.inflight_accept_req, wait);
 882                return 0;
 883        }
 884
 885        if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET,
 886                               (void *)&map->passive.flags))
 887                return EPOLLIN | EPOLLRDNORM;
 888
 889        /*
 890         * First check RET, then INFLIGHT. No barriers necessary to
 891         * ensure execution ordering because of the conditional
 892         * instructions creating control dependencies.
 893         */
 894
 895        if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT,
 896                             (void *)&map->passive.flags)) {
 897                poll_wait(file, &bedata->inflight_req, wait);
 898                return 0;
 899        }
 900
 901        spin_lock(&bedata->socket_lock);
 902        ret = get_request(bedata, &req_id);
 903        if (ret < 0) {
 904                spin_unlock(&bedata->socket_lock);
 905                return ret;
 906        }
 907        req = RING_GET_REQUEST(&bedata->ring, req_id);
 908        req->req_id = req_id;
 909        req->cmd = PVCALLS_POLL;
 910        req->u.poll.id = (uintptr_t) map;
 911
 912        bedata->ring.req_prod_pvt++;
 913        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
 914        spin_unlock(&bedata->socket_lock);
 915        if (notify)
 916                notify_remote_via_irq(bedata->irq);
 917
 918        poll_wait(file, &bedata->inflight_req, wait);
 919        return 0;
 920}
 921
 922static __poll_t pvcalls_front_poll_active(struct file *file,
 923                                              struct pvcalls_bedata *bedata,
 924                                              struct sock_mapping *map,
 925                                              poll_table *wait)
 926{
 927        __poll_t mask = 0;
 928        int32_t in_error, out_error;
 929        struct pvcalls_data_intf *intf = map->active.ring;
 930
 931        out_error = intf->out_error;
 932        in_error = intf->in_error;
 933
 934        poll_wait(file, &map->active.inflight_conn_req, wait);
 935        if (pvcalls_front_write_todo(map))
 936                mask |= EPOLLOUT | EPOLLWRNORM;
 937        if (pvcalls_front_read_todo(map))
 938                mask |= EPOLLIN | EPOLLRDNORM;
 939        if (in_error != 0 || out_error != 0)
 940                mask |= EPOLLERR;
 941
 942        return mask;
 943}
 944
 945__poll_t pvcalls_front_poll(struct file *file, struct socket *sock,
 946                               poll_table *wait)
 947{
 948        struct pvcalls_bedata *bedata;
 949        struct sock_mapping *map;
 950        __poll_t ret;
 951
 952        map = pvcalls_enter_sock(sock);
 953        if (IS_ERR(map))
 954                return EPOLLNVAL;
 955        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 956
 957        if (map->active_socket)
 958                ret = pvcalls_front_poll_active(file, bedata, map, wait);
 959        else
 960                ret = pvcalls_front_poll_passive(file, bedata, map, wait);
 961        pvcalls_exit_sock(sock);
 962        return ret;
 963}
 964
 965int pvcalls_front_release(struct socket *sock)
 966{
 967        struct pvcalls_bedata *bedata;
 968        struct sock_mapping *map;
 969        int req_id, notify, ret;
 970        struct xen_pvcalls_request *req;
 971
 972        if (sock->sk == NULL)
 973                return 0;
 974
 975        map = pvcalls_enter_sock(sock);
 976        if (IS_ERR(map)) {
 977                if (PTR_ERR(map) == -ENOTCONN)
 978                        return -EIO;
 979                else
 980                        return 0;
 981        }
 982        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 983
 984        spin_lock(&bedata->socket_lock);
 985        ret = get_request(bedata, &req_id);
 986        if (ret < 0) {
 987                spin_unlock(&bedata->socket_lock);
 988                pvcalls_exit_sock(sock);
 989                return ret;
 990        }
 991        sock->sk->sk_send_head = NULL;
 992
 993        req = RING_GET_REQUEST(&bedata->ring, req_id);
 994        req->req_id = req_id;
 995        req->cmd = PVCALLS_RELEASE;
 996        req->u.release.id = (uintptr_t)map;
 997
 998        bedata->ring.req_prod_pvt++;
 999        RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify);
1000        spin_unlock(&bedata->socket_lock);
1001        if (notify)
1002                notify_remote_via_irq(bedata->irq);
1003
1004        wait_event(bedata->inflight_req,
1005                   READ_ONCE(bedata->rsp[req_id].req_id) == req_id);
1006
1007        if (map->active_socket) {
1008                /*
1009                 * Set in_error and wake up inflight_conn_req to force
1010                 * recvmsg waiters to exit.
1011                 */
1012                map->active.ring->in_error = -EBADF;
1013                wake_up_interruptible(&map->active.inflight_conn_req);
1014
1015                /*
1016                 * We need to make sure that sendmsg/recvmsg on this socket have
1017                 * not started before we've cleared sk_send_head here. The
1018                 * easiest way to guarantee this is to see that no pvcalls
1019                 * (other than us) is in progress on this socket.
1020                 */
1021                while (atomic_read(&map->refcount) > 1)
1022                        cpu_relax();
1023
1024                pvcalls_front_free_map(bedata, map);
1025        } else {
1026                wake_up(&bedata->inflight_req);
1027                wake_up(&map->passive.inflight_accept_req);
1028
1029                while (atomic_read(&map->refcount) > 1)
1030                        cpu_relax();
1031
1032                spin_lock(&bedata->socket_lock);
1033                list_del(&map->list);
1034                spin_unlock(&bedata->socket_lock);
1035                if (READ_ONCE(map->passive.inflight_req_id) !=
1036                    PVCALLS_INVALID_ID) {
1037                        pvcalls_front_free_map(bedata,
1038                                               map->passive.accept_map);
1039                }
1040                kfree(map);
1041        }
1042        WRITE_ONCE(bedata->rsp[req_id].req_id, PVCALLS_INVALID_ID);
1043
1044        pvcalls_exit();
1045        return 0;
1046}
1047
1048static const struct xenbus_device_id pvcalls_front_ids[] = {
1049        { "pvcalls" },
1050        { "" }
1051};
1052
1053static int pvcalls_front_remove(struct xenbus_device *dev)
1054{
1055        struct pvcalls_bedata *bedata;
1056        struct sock_mapping *map = NULL, *n;
1057
1058        bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
1059        dev_set_drvdata(&dev->dev, NULL);
1060        pvcalls_front_dev = NULL;
1061        if (bedata->irq >= 0)
1062                unbind_from_irqhandler(bedata->irq, dev);
1063
1064        list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) {
1065                map->sock->sk->sk_send_head = NULL;
1066                if (map->active_socket) {
1067                        map->active.ring->in_error = -EBADF;
1068                        wake_up_interruptible(&map->active.inflight_conn_req);
1069                }
1070        }
1071
1072        smp_mb();
1073        while (atomic_read(&pvcalls_refcount) > 0)
1074                cpu_relax();
1075        list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) {
1076                if (map->active_socket) {
1077                        /* No need to lock, refcount is 0 */
1078                        pvcalls_front_free_map(bedata, map);
1079                } else {
1080                        list_del(&map->list);
1081                        kfree(map);
1082                }
1083        }
1084        if (bedata->ref != -1)
1085                gnttab_end_foreign_access(bedata->ref, 0, 0);
1086        kfree(bedata->ring.sring);
1087        kfree(bedata);
1088        xenbus_switch_state(dev, XenbusStateClosed);
1089        return 0;
1090}
1091
1092static int pvcalls_front_probe(struct xenbus_device *dev,
1093                          const struct xenbus_device_id *id)
1094{
1095        int ret = -ENOMEM, evtchn, i;
1096        unsigned int max_page_order, function_calls, len;
1097        char *versions;
1098        grant_ref_t gref_head = 0;
1099        struct xenbus_transaction xbt;
1100        struct pvcalls_bedata *bedata = NULL;
1101        struct xen_pvcalls_sring *sring;
1102
1103        if (pvcalls_front_dev != NULL) {
1104                dev_err(&dev->dev, "only one PV Calls connection supported\n");
1105                return -EINVAL;
1106        }
1107
1108        versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len);
1109        if (IS_ERR(versions))
1110                return PTR_ERR(versions);
1111        if (!len)
1112                return -EINVAL;
1113        if (strcmp(versions, "1")) {
1114                kfree(versions);
1115                return -EINVAL;
1116        }
1117        kfree(versions);
1118        max_page_order = xenbus_read_unsigned(dev->otherend,
1119                                              "max-page-order", 0);
1120        if (max_page_order < PVCALLS_RING_ORDER)
1121                return -ENODEV;
1122        function_calls = xenbus_read_unsigned(dev->otherend,
1123                                              "function-calls", 0);
1124        /* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */
1125        if (function_calls != 1)
1126                return -ENODEV;
1127        pr_info("%s max-page-order is %u\n", __func__, max_page_order);
1128
1129        bedata = kzalloc(sizeof(struct pvcalls_bedata), GFP_KERNEL);
1130        if (!bedata)
1131                return -ENOMEM;
1132
1133        dev_set_drvdata(&dev->dev, bedata);
1134        pvcalls_front_dev = dev;
1135        init_waitqueue_head(&bedata->inflight_req);
1136        INIT_LIST_HEAD(&bedata->socket_mappings);
1137        spin_lock_init(&bedata->socket_lock);
1138        bedata->irq = -1;
1139        bedata->ref = -1;
1140
1141        for (i = 0; i < PVCALLS_NR_RSP_PER_RING; i++)
1142                bedata->rsp[i].req_id = PVCALLS_INVALID_ID;
1143
1144        sring = (struct xen_pvcalls_sring *) __get_free_page(GFP_KERNEL |
1145                                                             __GFP_ZERO);
1146        if (!sring)
1147                goto error;
1148        SHARED_RING_INIT(sring);
1149        FRONT_RING_INIT(&bedata->ring, sring, XEN_PAGE_SIZE);
1150
1151        ret = xenbus_alloc_evtchn(dev, &evtchn);
1152        if (ret)
1153                goto error;
1154
1155        bedata->irq = bind_evtchn_to_irqhandler(evtchn,
1156                                                pvcalls_front_event_handler,
1157                                                0, "pvcalls-frontend", dev);
1158        if (bedata->irq < 0) {
1159                ret = bedata->irq;
1160                goto error;
1161        }
1162
1163        ret = gnttab_alloc_grant_references(1, &gref_head);
1164        if (ret < 0)
1165                goto error;
1166        ret = gnttab_claim_grant_reference(&gref_head);
1167        if (ret < 0)
1168                goto error;
1169        bedata->ref = ret;
1170        gnttab_grant_foreign_access_ref(bedata->ref, dev->otherend_id,
1171                                        virt_to_gfn((void *)sring), 0);
1172
1173 again:
1174        ret = xenbus_transaction_start(&xbt);
1175        if (ret) {
1176                xenbus_dev_fatal(dev, ret, "starting transaction");
1177                goto error;
1178        }
1179        ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1);
1180        if (ret)
1181                goto error_xenbus;
1182        ret = xenbus_printf(xbt, dev->nodename, "ring-ref", "%d", bedata->ref);
1183        if (ret)
1184                goto error_xenbus;
1185        ret = xenbus_printf(xbt, dev->nodename, "port", "%u",
1186                            evtchn);
1187        if (ret)
1188                goto error_xenbus;
1189        ret = xenbus_transaction_end(xbt, 0);
1190        if (ret) {
1191                if (ret == -EAGAIN)
1192                        goto again;
1193                xenbus_dev_fatal(dev, ret, "completing transaction");
1194                goto error;
1195        }
1196        xenbus_switch_state(dev, XenbusStateInitialised);
1197
1198        return 0;
1199
1200 error_xenbus:
1201        xenbus_transaction_end(xbt, 1);
1202        xenbus_dev_fatal(dev, ret, "writing xenstore");
1203 error:
1204        pvcalls_front_remove(dev);
1205        return ret;
1206}
1207
1208static void pvcalls_front_changed(struct xenbus_device *dev,
1209                            enum xenbus_state backend_state)
1210{
1211        switch (backend_state) {
1212        case XenbusStateReconfiguring:
1213        case XenbusStateReconfigured:
1214        case XenbusStateInitialising:
1215        case XenbusStateInitialised:
1216        case XenbusStateUnknown:
1217                break;
1218
1219        case XenbusStateInitWait:
1220                break;
1221
1222        case XenbusStateConnected:
1223                xenbus_switch_state(dev, XenbusStateConnected);
1224                break;
1225
1226        case XenbusStateClosed:
1227                if (dev->state == XenbusStateClosed)
1228                        break;
1229                /* Missed the backend's CLOSING state */
1230                /* fall through */
1231        case XenbusStateClosing:
1232                xenbus_frontend_closed(dev);
1233                break;
1234        }
1235}
1236
1237static struct xenbus_driver pvcalls_front_driver = {
1238        .ids = pvcalls_front_ids,
1239        .probe = pvcalls_front_probe,
1240        .remove = pvcalls_front_remove,
1241        .otherend_changed = pvcalls_front_changed,
1242};
1243
1244static int __init pvcalls_frontend_init(void)
1245{
1246        if (!xen_domain())
1247                return -ENODEV;
1248
1249        pr_info("Initialising Xen pvcalls frontend driver\n");
1250
1251        return xenbus_register_frontend(&pvcalls_front_driver);
1252}
1253
1254module_init(pvcalls_frontend_init);
1255
1256MODULE_DESCRIPTION("Xen PV Calls frontend driver");
1257MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>");
1258MODULE_LICENSE("GPL");
1259