linux/drivers/xen/pvcalls-back.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-or-later
   2/*
   3 * (c) 2017 Stefano Stabellini <stefano@aporeto.com>
   4 */
   5
   6#include <linux/inet.h>
   7#include <linux/kthread.h>
   8#include <linux/list.h>
   9#include <linux/radix-tree.h>
  10#include <linux/module.h>
  11#include <linux/semaphore.h>
  12#include <linux/wait.h>
  13#include <net/sock.h>
  14#include <net/inet_common.h>
  15#include <net/inet_connection_sock.h>
  16#include <net/request_sock.h>
  17
  18#include <xen/events.h>
  19#include <xen/grant_table.h>
  20#include <xen/xen.h>
  21#include <xen/xenbus.h>
  22#include <xen/interface/io/pvcalls.h>
  23
  24#define PVCALLS_VERSIONS "1"
  25#define MAX_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER
  26
  27static struct pvcalls_back_global {
  28        struct list_head frontends;
  29        struct semaphore frontends_lock;
  30} pvcalls_back_global;
  31
  32/*
  33 * Per-frontend data structure. It contains pointers to the command
  34 * ring, its event channel, a list of active sockets and a tree of
  35 * passive sockets.
  36 */
  37struct pvcalls_fedata {
  38        struct list_head list;
  39        struct xenbus_device *dev;
  40        struct xen_pvcalls_sring *sring;
  41        struct xen_pvcalls_back_ring ring;
  42        int irq;
  43        struct list_head socket_mappings;
  44        struct radix_tree_root socketpass_mappings;
  45        struct semaphore socket_lock;
  46};
  47
  48struct pvcalls_ioworker {
  49        struct work_struct register_work;
  50        struct workqueue_struct *wq;
  51};
  52
  53struct sock_mapping {
  54        struct list_head list;
  55        struct pvcalls_fedata *fedata;
  56        struct sockpass_mapping *sockpass;
  57        struct socket *sock;
  58        uint64_t id;
  59        grant_ref_t ref;
  60        struct pvcalls_data_intf *ring;
  61        void *bytes;
  62        struct pvcalls_data data;
  63        uint32_t ring_order;
  64        int irq;
  65        atomic_t read;
  66        atomic_t write;
  67        atomic_t io;
  68        atomic_t release;
  69        void (*saved_data_ready)(struct sock *sk);
  70        struct pvcalls_ioworker ioworker;
  71};
  72
  73struct sockpass_mapping {
  74        struct list_head list;
  75        struct pvcalls_fedata *fedata;
  76        struct socket *sock;
  77        uint64_t id;
  78        struct xen_pvcalls_request reqcopy;
  79        spinlock_t copy_lock;
  80        struct workqueue_struct *wq;
  81        struct work_struct register_work;
  82        void (*saved_data_ready)(struct sock *sk);
  83};
  84
  85static irqreturn_t pvcalls_back_conn_event(int irq, void *sock_map);
  86static int pvcalls_back_release_active(struct xenbus_device *dev,
  87                                       struct pvcalls_fedata *fedata,
  88                                       struct sock_mapping *map);
  89
  90static void pvcalls_conn_back_read(void *opaque)
  91{
  92        struct sock_mapping *map = (struct sock_mapping *)opaque;
  93        struct msghdr msg;
  94        struct kvec vec[2];
  95        RING_IDX cons, prod, size, wanted, array_size, masked_prod, masked_cons;
  96        int32_t error;
  97        struct pvcalls_data_intf *intf = map->ring;
  98        struct pvcalls_data *data = &map->data;
  99        unsigned long flags;
 100        int ret;
 101
 102        array_size = XEN_FLEX_RING_SIZE(map->ring_order);
 103        cons = intf->in_cons;
 104        prod = intf->in_prod;
 105        error = intf->in_error;
 106        /* read the indexes first, then deal with the data */
 107        virt_mb();
 108
 109        if (error)
 110                return;
 111
 112        size = pvcalls_queued(prod, cons, array_size);
 113        if (size >= array_size)
 114                return;
 115        spin_lock_irqsave(&map->sock->sk->sk_receive_queue.lock, flags);
 116        if (skb_queue_empty(&map->sock->sk->sk_receive_queue)) {
 117                atomic_set(&map->read, 0);
 118                spin_unlock_irqrestore(&map->sock->sk->sk_receive_queue.lock,
 119                                flags);
 120                return;
 121        }
 122        spin_unlock_irqrestore(&map->sock->sk->sk_receive_queue.lock, flags);
 123        wanted = array_size - size;
 124        masked_prod = pvcalls_mask(prod, array_size);
 125        masked_cons = pvcalls_mask(cons, array_size);
 126
 127        memset(&msg, 0, sizeof(msg));
 128        if (masked_prod < masked_cons) {
 129                vec[0].iov_base = data->in + masked_prod;
 130                vec[0].iov_len = wanted;
 131                iov_iter_kvec(&msg.msg_iter, WRITE, vec, 1, wanted);
 132        } else {
 133                vec[0].iov_base = data->in + masked_prod;
 134                vec[0].iov_len = array_size - masked_prod;
 135                vec[1].iov_base = data->in;
 136                vec[1].iov_len = wanted - vec[0].iov_len;
 137                iov_iter_kvec(&msg.msg_iter, WRITE, vec, 2, wanted);
 138        }
 139
 140        atomic_set(&map->read, 0);
 141        ret = inet_recvmsg(map->sock, &msg, wanted, MSG_DONTWAIT);
 142        WARN_ON(ret > wanted);
 143        if (ret == -EAGAIN) /* shouldn't happen */
 144                return;
 145        if (!ret)
 146                ret = -ENOTCONN;
 147        spin_lock_irqsave(&map->sock->sk->sk_receive_queue.lock, flags);
 148        if (ret > 0 && !skb_queue_empty(&map->sock->sk->sk_receive_queue))
 149                atomic_inc(&map->read);
 150        spin_unlock_irqrestore(&map->sock->sk->sk_receive_queue.lock, flags);
 151
 152        /* write the data, then modify the indexes */
 153        virt_wmb();
 154        if (ret < 0) {
 155                atomic_set(&map->read, 0);
 156                intf->in_error = ret;
 157        } else
 158                intf->in_prod = prod + ret;
 159        /* update the indexes, then notify the other end */
 160        virt_wmb();
 161        notify_remote_via_irq(map->irq);
 162
 163        return;
 164}
 165
 166static void pvcalls_conn_back_write(struct sock_mapping *map)
 167{
 168        struct pvcalls_data_intf *intf = map->ring;
 169        struct pvcalls_data *data = &map->data;
 170        struct msghdr msg;
 171        struct kvec vec[2];
 172        RING_IDX cons, prod, size, array_size;
 173        int ret;
 174
 175        cons = intf->out_cons;
 176        prod = intf->out_prod;
 177        /* read the indexes before dealing with the data */
 178        virt_mb();
 179
 180        array_size = XEN_FLEX_RING_SIZE(map->ring_order);
 181        size = pvcalls_queued(prod, cons, array_size);
 182        if (size == 0)
 183                return;
 184
 185        memset(&msg, 0, sizeof(msg));
 186        msg.msg_flags |= MSG_DONTWAIT;
 187        if (pvcalls_mask(prod, array_size) > pvcalls_mask(cons, array_size)) {
 188                vec[0].iov_base = data->out + pvcalls_mask(cons, array_size);
 189                vec[0].iov_len = size;
 190                iov_iter_kvec(&msg.msg_iter, READ, vec, 1, size);
 191        } else {
 192                vec[0].iov_base = data->out + pvcalls_mask(cons, array_size);
 193                vec[0].iov_len = array_size - pvcalls_mask(cons, array_size);
 194                vec[1].iov_base = data->out;
 195                vec[1].iov_len = size - vec[0].iov_len;
 196                iov_iter_kvec(&msg.msg_iter, READ, vec, 2, size);
 197        }
 198
 199        atomic_set(&map->write, 0);
 200        ret = inet_sendmsg(map->sock, &msg, size);
 201        if (ret == -EAGAIN || (ret >= 0 && ret < size)) {
 202                atomic_inc(&map->write);
 203                atomic_inc(&map->io);
 204        }
 205        if (ret == -EAGAIN)
 206                return;
 207
 208        /* write the data, then update the indexes */
 209        virt_wmb();
 210        if (ret < 0) {
 211                intf->out_error = ret;
 212        } else {
 213                intf->out_error = 0;
 214                intf->out_cons = cons + ret;
 215                prod = intf->out_prod;
 216        }
 217        /* update the indexes, then notify the other end */
 218        virt_wmb();
 219        if (prod != cons + ret)
 220                atomic_inc(&map->write);
 221        notify_remote_via_irq(map->irq);
 222}
 223
 224static void pvcalls_back_ioworker(struct work_struct *work)
 225{
 226        struct pvcalls_ioworker *ioworker = container_of(work,
 227                struct pvcalls_ioworker, register_work);
 228        struct sock_mapping *map = container_of(ioworker, struct sock_mapping,
 229                ioworker);
 230
 231        while (atomic_read(&map->io) > 0) {
 232                if (atomic_read(&map->release) > 0) {
 233                        atomic_set(&map->release, 0);
 234                        return;
 235                }
 236
 237                if (atomic_read(&map->read) > 0)
 238                        pvcalls_conn_back_read(map);
 239                if (atomic_read(&map->write) > 0)
 240                        pvcalls_conn_back_write(map);
 241
 242                atomic_dec(&map->io);
 243        }
 244}
 245
 246static int pvcalls_back_socket(struct xenbus_device *dev,
 247                struct xen_pvcalls_request *req)
 248{
 249        struct pvcalls_fedata *fedata;
 250        int ret;
 251        struct xen_pvcalls_response *rsp;
 252
 253        fedata = dev_get_drvdata(&dev->dev);
 254
 255        if (req->u.socket.domain != AF_INET ||
 256            req->u.socket.type != SOCK_STREAM ||
 257            (req->u.socket.protocol != IPPROTO_IP &&
 258             req->u.socket.protocol != AF_INET))
 259                ret = -EAFNOSUPPORT;
 260        else
 261                ret = 0;
 262
 263        /* leave the actual socket allocation for later */
 264
 265        rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
 266        rsp->req_id = req->req_id;
 267        rsp->cmd = req->cmd;
 268        rsp->u.socket.id = req->u.socket.id;
 269        rsp->ret = ret;
 270
 271        return 0;
 272}
 273
 274static void pvcalls_sk_state_change(struct sock *sock)
 275{
 276        struct sock_mapping *map = sock->sk_user_data;
 277
 278        if (map == NULL)
 279                return;
 280
 281        atomic_inc(&map->read);
 282        notify_remote_via_irq(map->irq);
 283}
 284
 285static void pvcalls_sk_data_ready(struct sock *sock)
 286{
 287        struct sock_mapping *map = sock->sk_user_data;
 288        struct pvcalls_ioworker *iow;
 289
 290        if (map == NULL)
 291                return;
 292
 293        iow = &map->ioworker;
 294        atomic_inc(&map->read);
 295        atomic_inc(&map->io);
 296        queue_work(iow->wq, &iow->register_work);
 297}
 298
 299static struct sock_mapping *pvcalls_new_active_socket(
 300                struct pvcalls_fedata *fedata,
 301                uint64_t id,
 302                grant_ref_t ref,
 303                evtchn_port_t evtchn,
 304                struct socket *sock)
 305{
 306        int ret;
 307        struct sock_mapping *map;
 308        void *page;
 309
 310        map = kzalloc(sizeof(*map), GFP_KERNEL);
 311        if (map == NULL)
 312                return NULL;
 313
 314        map->fedata = fedata;
 315        map->sock = sock;
 316        map->id = id;
 317        map->ref = ref;
 318
 319        ret = xenbus_map_ring_valloc(fedata->dev, &ref, 1, &page);
 320        if (ret < 0)
 321                goto out;
 322        map->ring = page;
 323        map->ring_order = map->ring->ring_order;
 324        /* first read the order, then map the data ring */
 325        virt_rmb();
 326        if (map->ring_order > MAX_RING_ORDER) {
 327                pr_warn("%s frontend requested ring_order %u, which is > MAX (%u)\n",
 328                                __func__, map->ring_order, MAX_RING_ORDER);
 329                goto out;
 330        }
 331        ret = xenbus_map_ring_valloc(fedata->dev, map->ring->ref,
 332                                     (1 << map->ring_order), &page);
 333        if (ret < 0)
 334                goto out;
 335        map->bytes = page;
 336
 337        ret = bind_interdomain_evtchn_to_irqhandler(fedata->dev->otherend_id,
 338                                                    evtchn,
 339                                                    pvcalls_back_conn_event,
 340                                                    0,
 341                                                    "pvcalls-backend",
 342                                                    map);
 343        if (ret < 0)
 344                goto out;
 345        map->irq = ret;
 346
 347        map->data.in = map->bytes;
 348        map->data.out = map->bytes + XEN_FLEX_RING_SIZE(map->ring_order);
 349
 350        map->ioworker.wq = alloc_workqueue("pvcalls_io", WQ_UNBOUND, 1);
 351        if (!map->ioworker.wq)
 352                goto out;
 353        atomic_set(&map->io, 1);
 354        INIT_WORK(&map->ioworker.register_work, pvcalls_back_ioworker);
 355
 356        down(&fedata->socket_lock);
 357        list_add_tail(&map->list, &fedata->socket_mappings);
 358        up(&fedata->socket_lock);
 359
 360        write_lock_bh(&map->sock->sk->sk_callback_lock);
 361        map->saved_data_ready = map->sock->sk->sk_data_ready;
 362        map->sock->sk->sk_user_data = map;
 363        map->sock->sk->sk_data_ready = pvcalls_sk_data_ready;
 364        map->sock->sk->sk_state_change = pvcalls_sk_state_change;
 365        write_unlock_bh(&map->sock->sk->sk_callback_lock);
 366
 367        return map;
 368out:
 369        down(&fedata->socket_lock);
 370        list_del(&map->list);
 371        pvcalls_back_release_active(fedata->dev, fedata, map);
 372        up(&fedata->socket_lock);
 373        return NULL;
 374}
 375
 376static int pvcalls_back_connect(struct xenbus_device *dev,
 377                                struct xen_pvcalls_request *req)
 378{
 379        struct pvcalls_fedata *fedata;
 380        int ret = -EINVAL;
 381        struct socket *sock;
 382        struct sock_mapping *map;
 383        struct xen_pvcalls_response *rsp;
 384        struct sockaddr *sa = (struct sockaddr *)&req->u.connect.addr;
 385
 386        fedata = dev_get_drvdata(&dev->dev);
 387
 388        if (req->u.connect.len < sizeof(sa->sa_family) ||
 389            req->u.connect.len > sizeof(req->u.connect.addr) ||
 390            sa->sa_family != AF_INET)
 391                goto out;
 392
 393        ret = sock_create(AF_INET, SOCK_STREAM, 0, &sock);
 394        if (ret < 0)
 395                goto out;
 396        ret = inet_stream_connect(sock, sa, req->u.connect.len, 0);
 397        if (ret < 0) {
 398                sock_release(sock);
 399                goto out;
 400        }
 401
 402        map = pvcalls_new_active_socket(fedata,
 403                                        req->u.connect.id,
 404                                        req->u.connect.ref,
 405                                        req->u.connect.evtchn,
 406                                        sock);
 407        if (!map) {
 408                ret = -EFAULT;
 409                sock_release(sock);
 410        }
 411
 412out:
 413        rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
 414        rsp->req_id = req->req_id;
 415        rsp->cmd = req->cmd;
 416        rsp->u.connect.id = req->u.connect.id;
 417        rsp->ret = ret;
 418
 419        return 0;
 420}
 421
 422static int pvcalls_back_release_active(struct xenbus_device *dev,
 423                                       struct pvcalls_fedata *fedata,
 424                                       struct sock_mapping *map)
 425{
 426        disable_irq(map->irq);
 427        if (map->sock->sk != NULL) {
 428                write_lock_bh(&map->sock->sk->sk_callback_lock);
 429                map->sock->sk->sk_user_data = NULL;
 430                map->sock->sk->sk_data_ready = map->saved_data_ready;
 431                write_unlock_bh(&map->sock->sk->sk_callback_lock);
 432        }
 433
 434        atomic_set(&map->release, 1);
 435        flush_work(&map->ioworker.register_work);
 436
 437        xenbus_unmap_ring_vfree(dev, map->bytes);
 438        xenbus_unmap_ring_vfree(dev, (void *)map->ring);
 439        unbind_from_irqhandler(map->irq, map);
 440
 441        sock_release(map->sock);
 442        kfree(map);
 443
 444        return 0;
 445}
 446
 447static int pvcalls_back_release_passive(struct xenbus_device *dev,
 448                                        struct pvcalls_fedata *fedata,
 449                                        struct sockpass_mapping *mappass)
 450{
 451        if (mappass->sock->sk != NULL) {
 452                write_lock_bh(&mappass->sock->sk->sk_callback_lock);
 453                mappass->sock->sk->sk_user_data = NULL;
 454                mappass->sock->sk->sk_data_ready = mappass->saved_data_ready;
 455                write_unlock_bh(&mappass->sock->sk->sk_callback_lock);
 456        }
 457        sock_release(mappass->sock);
 458        flush_workqueue(mappass->wq);
 459        destroy_workqueue(mappass->wq);
 460        kfree(mappass);
 461
 462        return 0;
 463}
 464
 465static int pvcalls_back_release(struct xenbus_device *dev,
 466                                struct xen_pvcalls_request *req)
 467{
 468        struct pvcalls_fedata *fedata;
 469        struct sock_mapping *map, *n;
 470        struct sockpass_mapping *mappass;
 471        int ret = 0;
 472        struct xen_pvcalls_response *rsp;
 473
 474        fedata = dev_get_drvdata(&dev->dev);
 475
 476        down(&fedata->socket_lock);
 477        list_for_each_entry_safe(map, n, &fedata->socket_mappings, list) {
 478                if (map->id == req->u.release.id) {
 479                        list_del(&map->list);
 480                        up(&fedata->socket_lock);
 481                        ret = pvcalls_back_release_active(dev, fedata, map);
 482                        goto out;
 483                }
 484        }
 485        mappass = radix_tree_lookup(&fedata->socketpass_mappings,
 486                                    req->u.release.id);
 487        if (mappass != NULL) {
 488                radix_tree_delete(&fedata->socketpass_mappings, mappass->id);
 489                up(&fedata->socket_lock);
 490                ret = pvcalls_back_release_passive(dev, fedata, mappass);
 491        } else
 492                up(&fedata->socket_lock);
 493
 494out:
 495        rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
 496        rsp->req_id = req->req_id;
 497        rsp->u.release.id = req->u.release.id;
 498        rsp->cmd = req->cmd;
 499        rsp->ret = ret;
 500        return 0;
 501}
 502
 503static void __pvcalls_back_accept(struct work_struct *work)
 504{
 505        struct sockpass_mapping *mappass = container_of(
 506                work, struct sockpass_mapping, register_work);
 507        struct sock_mapping *map;
 508        struct pvcalls_ioworker *iow;
 509        struct pvcalls_fedata *fedata;
 510        struct socket *sock;
 511        struct xen_pvcalls_response *rsp;
 512        struct xen_pvcalls_request *req;
 513        int notify;
 514        int ret = -EINVAL;
 515        unsigned long flags;
 516
 517        fedata = mappass->fedata;
 518        /*
 519         * __pvcalls_back_accept can race against pvcalls_back_accept.
 520         * We only need to check the value of "cmd" on read. It could be
 521         * done atomically, but to simplify the code on the write side, we
 522         * use a spinlock.
 523         */
 524        spin_lock_irqsave(&mappass->copy_lock, flags);
 525        req = &mappass->reqcopy;
 526        if (req->cmd != PVCALLS_ACCEPT) {
 527                spin_unlock_irqrestore(&mappass->copy_lock, flags);
 528                return;
 529        }
 530        spin_unlock_irqrestore(&mappass->copy_lock, flags);
 531
 532        sock = sock_alloc();
 533        if (sock == NULL)
 534                goto out_error;
 535        sock->type = mappass->sock->type;
 536        sock->ops = mappass->sock->ops;
 537
 538        ret = inet_accept(mappass->sock, sock, O_NONBLOCK, true);
 539        if (ret == -EAGAIN) {
 540                sock_release(sock);
 541                return;
 542        }
 543
 544        map = pvcalls_new_active_socket(fedata,
 545                                        req->u.accept.id_new,
 546                                        req->u.accept.ref,
 547                                        req->u.accept.evtchn,
 548                                        sock);
 549        if (!map) {
 550                ret = -EFAULT;
 551                sock_release(sock);
 552                goto out_error;
 553        }
 554
 555        map->sockpass = mappass;
 556        iow = &map->ioworker;
 557        atomic_inc(&map->read);
 558        atomic_inc(&map->io);
 559        queue_work(iow->wq, &iow->register_work);
 560
 561out_error:
 562        rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
 563        rsp->req_id = req->req_id;
 564        rsp->cmd = req->cmd;
 565        rsp->u.accept.id = req->u.accept.id;
 566        rsp->ret = ret;
 567        RING_PUSH_RESPONSES_AND_CHECK_NOTIFY(&fedata->ring, notify);
 568        if (notify)
 569                notify_remote_via_irq(fedata->irq);
 570
 571        mappass->reqcopy.cmd = 0;
 572}
 573
 574static void pvcalls_pass_sk_data_ready(struct sock *sock)
 575{
 576        struct sockpass_mapping *mappass = sock->sk_user_data;
 577        struct pvcalls_fedata *fedata;
 578        struct xen_pvcalls_response *rsp;
 579        unsigned long flags;
 580        int notify;
 581
 582        if (mappass == NULL)
 583                return;
 584
 585        fedata = mappass->fedata;
 586        spin_lock_irqsave(&mappass->copy_lock, flags);
 587        if (mappass->reqcopy.cmd == PVCALLS_POLL) {
 588                rsp = RING_GET_RESPONSE(&fedata->ring,
 589                                        fedata->ring.rsp_prod_pvt++);
 590                rsp->req_id = mappass->reqcopy.req_id;
 591                rsp->u.poll.id = mappass->reqcopy.u.poll.id;
 592                rsp->cmd = mappass->reqcopy.cmd;
 593                rsp->ret = 0;
 594
 595                mappass->reqcopy.cmd = 0;
 596                spin_unlock_irqrestore(&mappass->copy_lock, flags);
 597
 598                RING_PUSH_RESPONSES_AND_CHECK_NOTIFY(&fedata->ring, notify);
 599                if (notify)
 600                        notify_remote_via_irq(mappass->fedata->irq);
 601        } else {
 602                spin_unlock_irqrestore(&mappass->copy_lock, flags);
 603                queue_work(mappass->wq, &mappass->register_work);
 604        }
 605}
 606
 607static int pvcalls_back_bind(struct xenbus_device *dev,
 608                             struct xen_pvcalls_request *req)
 609{
 610        struct pvcalls_fedata *fedata;
 611        int ret;
 612        struct sockpass_mapping *map;
 613        struct xen_pvcalls_response *rsp;
 614
 615        fedata = dev_get_drvdata(&dev->dev);
 616
 617        map = kzalloc(sizeof(*map), GFP_KERNEL);
 618        if (map == NULL) {
 619                ret = -ENOMEM;
 620                goto out;
 621        }
 622
 623        INIT_WORK(&map->register_work, __pvcalls_back_accept);
 624        spin_lock_init(&map->copy_lock);
 625        map->wq = alloc_workqueue("pvcalls_wq", WQ_UNBOUND, 1);
 626        if (!map->wq) {
 627                ret = -ENOMEM;
 628                goto out;
 629        }
 630
 631        ret = sock_create(AF_INET, SOCK_STREAM, 0, &map->sock);
 632        if (ret < 0)
 633                goto out;
 634
 635        ret = inet_bind(map->sock, (struct sockaddr *)&req->u.bind.addr,
 636                        req->u.bind.len);
 637        if (ret < 0)
 638                goto out;
 639
 640        map->fedata = fedata;
 641        map->id = req->u.bind.id;
 642
 643        down(&fedata->socket_lock);
 644        ret = radix_tree_insert(&fedata->socketpass_mappings, map->id,
 645                                map);
 646        up(&fedata->socket_lock);
 647        if (ret)
 648                goto out;
 649
 650        write_lock_bh(&map->sock->sk->sk_callback_lock);
 651        map->saved_data_ready = map->sock->sk->sk_data_ready;
 652        map->sock->sk->sk_user_data = map;
 653        map->sock->sk->sk_data_ready = pvcalls_pass_sk_data_ready;
 654        write_unlock_bh(&map->sock->sk->sk_callback_lock);
 655
 656out:
 657        if (ret) {
 658                if (map && map->sock)
 659                        sock_release(map->sock);
 660                if (map && map->wq)
 661                        destroy_workqueue(map->wq);
 662                kfree(map);
 663        }
 664        rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
 665        rsp->req_id = req->req_id;
 666        rsp->cmd = req->cmd;
 667        rsp->u.bind.id = req->u.bind.id;
 668        rsp->ret = ret;
 669        return 0;
 670}
 671
 672static int pvcalls_back_listen(struct xenbus_device *dev,
 673                               struct xen_pvcalls_request *req)
 674{
 675        struct pvcalls_fedata *fedata;
 676        int ret = -EINVAL;
 677        struct sockpass_mapping *map;
 678        struct xen_pvcalls_response *rsp;
 679
 680        fedata = dev_get_drvdata(&dev->dev);
 681
 682        down(&fedata->socket_lock);
 683        map = radix_tree_lookup(&fedata->socketpass_mappings, req->u.listen.id);
 684        up(&fedata->socket_lock);
 685        if (map == NULL)
 686                goto out;
 687
 688        ret = inet_listen(map->sock, req->u.listen.backlog);
 689
 690out:
 691        rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
 692        rsp->req_id = req->req_id;
 693        rsp->cmd = req->cmd;
 694        rsp->u.listen.id = req->u.listen.id;
 695        rsp->ret = ret;
 696        return 0;
 697}
 698
 699static int pvcalls_back_accept(struct xenbus_device *dev,
 700                               struct xen_pvcalls_request *req)
 701{
 702        struct pvcalls_fedata *fedata;
 703        struct sockpass_mapping *mappass;
 704        int ret = -EINVAL;
 705        struct xen_pvcalls_response *rsp;
 706        unsigned long flags;
 707
 708        fedata = dev_get_drvdata(&dev->dev);
 709
 710        down(&fedata->socket_lock);
 711        mappass = radix_tree_lookup(&fedata->socketpass_mappings,
 712                req->u.accept.id);
 713        up(&fedata->socket_lock);
 714        if (mappass == NULL)
 715                goto out_error;
 716
 717        /*
 718         * Limitation of the current implementation: only support one
 719         * concurrent accept or poll call on one socket.
 720         */
 721        spin_lock_irqsave(&mappass->copy_lock, flags);
 722        if (mappass->reqcopy.cmd != 0) {
 723                spin_unlock_irqrestore(&mappass->copy_lock, flags);
 724                ret = -EINTR;
 725                goto out_error;
 726        }
 727
 728        mappass->reqcopy = *req;
 729        spin_unlock_irqrestore(&mappass->copy_lock, flags);
 730        queue_work(mappass->wq, &mappass->register_work);
 731
 732        /* Tell the caller we don't need to send back a notification yet */
 733        return -1;
 734
 735out_error:
 736        rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
 737        rsp->req_id = req->req_id;
 738        rsp->cmd = req->cmd;
 739        rsp->u.accept.id = req->u.accept.id;
 740        rsp->ret = ret;
 741        return 0;
 742}
 743
 744static int pvcalls_back_poll(struct xenbus_device *dev,
 745                             struct xen_pvcalls_request *req)
 746{
 747        struct pvcalls_fedata *fedata;
 748        struct sockpass_mapping *mappass;
 749        struct xen_pvcalls_response *rsp;
 750        struct inet_connection_sock *icsk;
 751        struct request_sock_queue *queue;
 752        unsigned long flags;
 753        int ret;
 754        bool data;
 755
 756        fedata = dev_get_drvdata(&dev->dev);
 757
 758        down(&fedata->socket_lock);
 759        mappass = radix_tree_lookup(&fedata->socketpass_mappings,
 760                                    req->u.poll.id);
 761        up(&fedata->socket_lock);
 762        if (mappass == NULL)
 763                return -EINVAL;
 764
 765        /*
 766         * Limitation of the current implementation: only support one
 767         * concurrent accept or poll call on one socket.
 768         */
 769        spin_lock_irqsave(&mappass->copy_lock, flags);
 770        if (mappass->reqcopy.cmd != 0) {
 771                ret = -EINTR;
 772                goto out;
 773        }
 774
 775        mappass->reqcopy = *req;
 776        icsk = inet_csk(mappass->sock->sk);
 777        queue = &icsk->icsk_accept_queue;
 778        data = READ_ONCE(queue->rskq_accept_head) != NULL;
 779        if (data) {
 780                mappass->reqcopy.cmd = 0;
 781                ret = 0;
 782                goto out;
 783        }
 784        spin_unlock_irqrestore(&mappass->copy_lock, flags);
 785
 786        /* Tell the caller we don't need to send back a notification yet */
 787        return -1;
 788
 789out:
 790        spin_unlock_irqrestore(&mappass->copy_lock, flags);
 791
 792        rsp = RING_GET_RESPONSE(&fedata->ring, fedata->ring.rsp_prod_pvt++);
 793        rsp->req_id = req->req_id;
 794        rsp->cmd = req->cmd;
 795        rsp->u.poll.id = req->u.poll.id;
 796        rsp->ret = ret;
 797        return 0;
 798}
 799
 800static int pvcalls_back_handle_cmd(struct xenbus_device *dev,
 801                                   struct xen_pvcalls_request *req)
 802{
 803        int ret = 0;
 804
 805        switch (req->cmd) {
 806        case PVCALLS_SOCKET:
 807                ret = pvcalls_back_socket(dev, req);
 808                break;
 809        case PVCALLS_CONNECT:
 810                ret = pvcalls_back_connect(dev, req);
 811                break;
 812        case PVCALLS_RELEASE:
 813                ret = pvcalls_back_release(dev, req);
 814                break;
 815        case PVCALLS_BIND:
 816                ret = pvcalls_back_bind(dev, req);
 817                break;
 818        case PVCALLS_LISTEN:
 819                ret = pvcalls_back_listen(dev, req);
 820                break;
 821        case PVCALLS_ACCEPT:
 822                ret = pvcalls_back_accept(dev, req);
 823                break;
 824        case PVCALLS_POLL:
 825                ret = pvcalls_back_poll(dev, req);
 826                break;
 827        default:
 828        {
 829                struct pvcalls_fedata *fedata;
 830                struct xen_pvcalls_response *rsp;
 831
 832                fedata = dev_get_drvdata(&dev->dev);
 833                rsp = RING_GET_RESPONSE(
 834                                &fedata->ring, fedata->ring.rsp_prod_pvt++);
 835                rsp->req_id = req->req_id;
 836                rsp->cmd = req->cmd;
 837                rsp->ret = -ENOTSUPP;
 838                break;
 839        }
 840        }
 841        return ret;
 842}
 843
 844static void pvcalls_back_work(struct pvcalls_fedata *fedata)
 845{
 846        int notify, notify_all = 0, more = 1;
 847        struct xen_pvcalls_request req;
 848        struct xenbus_device *dev = fedata->dev;
 849
 850        while (more) {
 851                while (RING_HAS_UNCONSUMED_REQUESTS(&fedata->ring)) {
 852                        RING_COPY_REQUEST(&fedata->ring,
 853                                          fedata->ring.req_cons++,
 854                                          &req);
 855
 856                        if (!pvcalls_back_handle_cmd(dev, &req)) {
 857                                RING_PUSH_RESPONSES_AND_CHECK_NOTIFY(
 858                                        &fedata->ring, notify);
 859                                notify_all += notify;
 860                        }
 861                }
 862
 863                if (notify_all) {
 864                        notify_remote_via_irq(fedata->irq);
 865                        notify_all = 0;
 866                }
 867
 868                RING_FINAL_CHECK_FOR_REQUESTS(&fedata->ring, more);
 869        }
 870}
 871
 872static irqreturn_t pvcalls_back_event(int irq, void *dev_id)
 873{
 874        struct xenbus_device *dev = dev_id;
 875        struct pvcalls_fedata *fedata = NULL;
 876
 877        if (dev == NULL)
 878                return IRQ_HANDLED;
 879
 880        fedata = dev_get_drvdata(&dev->dev);
 881        if (fedata == NULL)
 882                return IRQ_HANDLED;
 883
 884        pvcalls_back_work(fedata);
 885        return IRQ_HANDLED;
 886}
 887
 888static irqreturn_t pvcalls_back_conn_event(int irq, void *sock_map)
 889{
 890        struct sock_mapping *map = sock_map;
 891        struct pvcalls_ioworker *iow;
 892
 893        if (map == NULL || map->sock == NULL || map->sock->sk == NULL ||
 894                map->sock->sk->sk_user_data != map)
 895                return IRQ_HANDLED;
 896
 897        iow = &map->ioworker;
 898
 899        atomic_inc(&map->write);
 900        atomic_inc(&map->io);
 901        queue_work(iow->wq, &iow->register_work);
 902
 903        return IRQ_HANDLED;
 904}
 905
 906static int backend_connect(struct xenbus_device *dev)
 907{
 908        int err;
 909        evtchn_port_t evtchn;
 910        grant_ref_t ring_ref;
 911        struct pvcalls_fedata *fedata = NULL;
 912
 913        fedata = kzalloc(sizeof(struct pvcalls_fedata), GFP_KERNEL);
 914        if (!fedata)
 915                return -ENOMEM;
 916
 917        fedata->irq = -1;
 918        err = xenbus_scanf(XBT_NIL, dev->otherend, "port", "%u",
 919                           &evtchn);
 920        if (err != 1) {
 921                err = -EINVAL;
 922                xenbus_dev_fatal(dev, err, "reading %s/event-channel",
 923                                 dev->otherend);
 924                goto error;
 925        }
 926
 927        err = xenbus_scanf(XBT_NIL, dev->otherend, "ring-ref", "%u", &ring_ref);
 928        if (err != 1) {
 929                err = -EINVAL;
 930                xenbus_dev_fatal(dev, err, "reading %s/ring-ref",
 931                                 dev->otherend);
 932                goto error;
 933        }
 934
 935        err = bind_interdomain_evtchn_to_irq(dev->otherend_id, evtchn);
 936        if (err < 0)
 937                goto error;
 938        fedata->irq = err;
 939
 940        err = request_threaded_irq(fedata->irq, NULL, pvcalls_back_event,
 941                                   IRQF_ONESHOT, "pvcalls-back", dev);
 942        if (err < 0)
 943                goto error;
 944
 945        err = xenbus_map_ring_valloc(dev, &ring_ref, 1,
 946                                     (void **)&fedata->sring);
 947        if (err < 0)
 948                goto error;
 949
 950        BACK_RING_INIT(&fedata->ring, fedata->sring, XEN_PAGE_SIZE * 1);
 951        fedata->dev = dev;
 952
 953        INIT_LIST_HEAD(&fedata->socket_mappings);
 954        INIT_RADIX_TREE(&fedata->socketpass_mappings, GFP_KERNEL);
 955        sema_init(&fedata->socket_lock, 1);
 956        dev_set_drvdata(&dev->dev, fedata);
 957
 958        down(&pvcalls_back_global.frontends_lock);
 959        list_add_tail(&fedata->list, &pvcalls_back_global.frontends);
 960        up(&pvcalls_back_global.frontends_lock);
 961
 962        return 0;
 963
 964 error:
 965        if (fedata->irq >= 0)
 966                unbind_from_irqhandler(fedata->irq, dev);
 967        if (fedata->sring != NULL)
 968                xenbus_unmap_ring_vfree(dev, fedata->sring);
 969        kfree(fedata);
 970        return err;
 971}
 972
 973static int backend_disconnect(struct xenbus_device *dev)
 974{
 975        struct pvcalls_fedata *fedata;
 976        struct sock_mapping *map, *n;
 977        struct sockpass_mapping *mappass;
 978        struct radix_tree_iter iter;
 979        void **slot;
 980
 981
 982        fedata = dev_get_drvdata(&dev->dev);
 983
 984        down(&fedata->socket_lock);
 985        list_for_each_entry_safe(map, n, &fedata->socket_mappings, list) {
 986                list_del(&map->list);
 987                pvcalls_back_release_active(dev, fedata, map);
 988        }
 989
 990        radix_tree_for_each_slot(slot, &fedata->socketpass_mappings, &iter, 0) {
 991                mappass = radix_tree_deref_slot(slot);
 992                if (!mappass)
 993                        continue;
 994                if (radix_tree_exception(mappass)) {
 995                        if (radix_tree_deref_retry(mappass))
 996                                slot = radix_tree_iter_retry(&iter);
 997                } else {
 998                        radix_tree_delete(&fedata->socketpass_mappings,
 999                                          mappass->id);
1000                        pvcalls_back_release_passive(dev, fedata, mappass);
1001                }
1002        }
1003        up(&fedata->socket_lock);
1004
1005        unbind_from_irqhandler(fedata->irq, dev);
1006        xenbus_unmap_ring_vfree(dev, fedata->sring);
1007
1008        list_del(&fedata->list);
1009        kfree(fedata);
1010        dev_set_drvdata(&dev->dev, NULL);
1011
1012        return 0;
1013}
1014
1015static int pvcalls_back_probe(struct xenbus_device *dev,
1016                              const struct xenbus_device_id *id)
1017{
1018        int err, abort;
1019        struct xenbus_transaction xbt;
1020
1021again:
1022        abort = 1;
1023
1024        err = xenbus_transaction_start(&xbt);
1025        if (err) {
1026                pr_warn("%s cannot create xenstore transaction\n", __func__);
1027                return err;
1028        }
1029
1030        err = xenbus_printf(xbt, dev->nodename, "versions", "%s",
1031                            PVCALLS_VERSIONS);
1032        if (err) {
1033                pr_warn("%s write out 'versions' failed\n", __func__);
1034                goto abort;
1035        }
1036
1037        err = xenbus_printf(xbt, dev->nodename, "max-page-order", "%u",
1038                            MAX_RING_ORDER);
1039        if (err) {
1040                pr_warn("%s write out 'max-page-order' failed\n", __func__);
1041                goto abort;
1042        }
1043
1044        err = xenbus_printf(xbt, dev->nodename, "function-calls",
1045                            XENBUS_FUNCTIONS_CALLS);
1046        if (err) {
1047                pr_warn("%s write out 'function-calls' failed\n", __func__);
1048                goto abort;
1049        }
1050
1051        abort = 0;
1052abort:
1053        err = xenbus_transaction_end(xbt, abort);
1054        if (err) {
1055                if (err == -EAGAIN && !abort)
1056                        goto again;
1057                pr_warn("%s cannot complete xenstore transaction\n", __func__);
1058                return err;
1059        }
1060
1061        if (abort)
1062                return -EFAULT;
1063
1064        xenbus_switch_state(dev, XenbusStateInitWait);
1065
1066        return 0;
1067}
1068
1069static void set_backend_state(struct xenbus_device *dev,
1070                              enum xenbus_state state)
1071{
1072        while (dev->state != state) {
1073                switch (dev->state) {
1074                case XenbusStateClosed:
1075                        switch (state) {
1076                        case XenbusStateInitWait:
1077                        case XenbusStateConnected:
1078                                xenbus_switch_state(dev, XenbusStateInitWait);
1079                                break;
1080                        case XenbusStateClosing:
1081                                xenbus_switch_state(dev, XenbusStateClosing);
1082                                break;
1083                        default:
1084                                WARN_ON(1);
1085                        }
1086                        break;
1087                case XenbusStateInitWait:
1088                case XenbusStateInitialised:
1089                        switch (state) {
1090                        case XenbusStateConnected:
1091                                if (backend_connect(dev))
1092                                        return;
1093                                xenbus_switch_state(dev, XenbusStateConnected);
1094                                break;
1095                        case XenbusStateClosing:
1096                        case XenbusStateClosed:
1097                                xenbus_switch_state(dev, XenbusStateClosing);
1098                                break;
1099                        default:
1100                                WARN_ON(1);
1101                        }
1102                        break;
1103                case XenbusStateConnected:
1104                        switch (state) {
1105                        case XenbusStateInitWait:
1106                        case XenbusStateClosing:
1107                        case XenbusStateClosed:
1108                                down(&pvcalls_back_global.frontends_lock);
1109                                backend_disconnect(dev);
1110                                up(&pvcalls_back_global.frontends_lock);
1111                                xenbus_switch_state(dev, XenbusStateClosing);
1112                                break;
1113                        default:
1114                                WARN_ON(1);
1115                        }
1116                        break;
1117                case XenbusStateClosing:
1118                        switch (state) {
1119                        case XenbusStateInitWait:
1120                        case XenbusStateConnected:
1121                        case XenbusStateClosed:
1122                                xenbus_switch_state(dev, XenbusStateClosed);
1123                                break;
1124                        default:
1125                                WARN_ON(1);
1126                        }
1127                        break;
1128                default:
1129                        WARN_ON(1);
1130                }
1131        }
1132}
1133
1134static void pvcalls_back_changed(struct xenbus_device *dev,
1135                                 enum xenbus_state frontend_state)
1136{
1137        switch (frontend_state) {
1138        case XenbusStateInitialising:
1139                set_backend_state(dev, XenbusStateInitWait);
1140                break;
1141
1142        case XenbusStateInitialised:
1143        case XenbusStateConnected:
1144                set_backend_state(dev, XenbusStateConnected);
1145                break;
1146
1147        case XenbusStateClosing:
1148                set_backend_state(dev, XenbusStateClosing);
1149                break;
1150
1151        case XenbusStateClosed:
1152                set_backend_state(dev, XenbusStateClosed);
1153                if (xenbus_dev_is_online(dev))
1154                        break;
1155                device_unregister(&dev->dev);
1156                break;
1157        case XenbusStateUnknown:
1158                set_backend_state(dev, XenbusStateClosed);
1159                device_unregister(&dev->dev);
1160                break;
1161
1162        default:
1163                xenbus_dev_fatal(dev, -EINVAL, "saw state %d at frontend",
1164                                 frontend_state);
1165                break;
1166        }
1167}
1168
1169static int pvcalls_back_remove(struct xenbus_device *dev)
1170{
1171        return 0;
1172}
1173
1174static int pvcalls_back_uevent(struct xenbus_device *xdev,
1175                               struct kobj_uevent_env *env)
1176{
1177        return 0;
1178}
1179
1180static const struct xenbus_device_id pvcalls_back_ids[] = {
1181        { "pvcalls" },
1182        { "" }
1183};
1184
1185static struct xenbus_driver pvcalls_back_driver = {
1186        .ids = pvcalls_back_ids,
1187        .probe = pvcalls_back_probe,
1188        .remove = pvcalls_back_remove,
1189        .uevent = pvcalls_back_uevent,
1190        .otherend_changed = pvcalls_back_changed,
1191};
1192
1193static int __init pvcalls_back_init(void)
1194{
1195        int ret;
1196
1197        if (!xen_domain())
1198                return -ENODEV;
1199
1200        ret = xenbus_register_backend(&pvcalls_back_driver);
1201        if (ret < 0)
1202                return ret;
1203
1204        sema_init(&pvcalls_back_global.frontends_lock, 1);
1205        INIT_LIST_HEAD(&pvcalls_back_global.frontends);
1206        return 0;
1207}
1208module_init(pvcalls_back_init);
1209
1210static void __exit pvcalls_back_fin(void)
1211{
1212        struct pvcalls_fedata *fedata, *nfedata;
1213
1214        down(&pvcalls_back_global.frontends_lock);
1215        list_for_each_entry_safe(fedata, nfedata,
1216                                 &pvcalls_back_global.frontends, list) {
1217                backend_disconnect(fedata->dev);
1218        }
1219        up(&pvcalls_back_global.frontends_lock);
1220
1221        xenbus_unregister_driver(&pvcalls_back_driver);
1222}
1223
1224module_exit(pvcalls_back_fin);
1225
1226MODULE_DESCRIPTION("Xen PV Calls backend driver");
1227MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>");
1228MODULE_LICENSE("GPL");
1229