linux/drivers/vhost/net.c
<<
>>
Prefs
   1/* Copyright (C) 2009 Red Hat, Inc.
   2 * Author: Michael S. Tsirkin <mst@redhat.com>
   3 *
   4 * This work is licensed under the terms of the GNU GPL, version 2.
   5 *
   6 * virtio-net server in host kernel.
   7 */
   8
   9#include <linux/compat.h>
  10#include <linux/eventfd.h>
  11#include <linux/vhost.h>
  12#include <linux/virtio_net.h>
  13#include <linux/miscdevice.h>
  14#include <linux/module.h>
  15#include <linux/moduleparam.h>
  16#include <linux/mutex.h>
  17#include <linux/workqueue.h>
  18#include <linux/file.h>
  19#include <linux/slab.h>
  20#include <linux/sched/clock.h>
  21#include <linux/sched/signal.h>
  22#include <linux/vmalloc.h>
  23
  24#include <linux/net.h>
  25#include <linux/if_packet.h>
  26#include <linux/if_arp.h>
  27#include <linux/if_tun.h>
  28#include <linux/if_macvlan.h>
  29#include <linux/if_tap.h>
  30#include <linux/if_vlan.h>
  31#include <linux/skb_array.h>
  32#include <linux/skbuff.h>
  33
  34#include <net/sock.h>
  35
  36#include "vhost.h"
  37
  38static int experimental_zcopytx = 1;
  39module_param(experimental_zcopytx, int, 0444);
  40MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
  41                                       " 1 -Enable; 0 - Disable");
  42
  43/* Max number of bytes transferred before requeueing the job.
  44 * Using this limit prevents one virtqueue from starving others. */
  45#define VHOST_NET_WEIGHT 0x80000
  46
  47/* MAX number of TX used buffers for outstanding zerocopy */
  48#define VHOST_MAX_PEND 128
  49#define VHOST_GOODCOPY_LEN 256
  50
  51/*
  52 * For transmit, used buffer len is unused; we override it to track buffer
  53 * status internally; used for zerocopy tx only.
  54 */
  55/* Lower device DMA failed */
  56#define VHOST_DMA_FAILED_LEN    ((__force __virtio32)3)
  57/* Lower device DMA done */
  58#define VHOST_DMA_DONE_LEN      ((__force __virtio32)2)
  59/* Lower device DMA in progress */
  60#define VHOST_DMA_IN_PROGRESS   ((__force __virtio32)1)
  61/* Buffer unused */
  62#define VHOST_DMA_CLEAR_LEN     ((__force __virtio32)0)
  63
  64#define VHOST_DMA_IS_DONE(len) ((__force u32)(len) >= (__force u32)VHOST_DMA_DONE_LEN)
  65
  66enum {
  67        VHOST_NET_FEATURES = VHOST_FEATURES |
  68                         (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
  69                         (1ULL << VIRTIO_NET_F_MRG_RXBUF) |
  70                         (1ULL << VIRTIO_F_IOMMU_PLATFORM)
  71};
  72
  73enum {
  74        VHOST_NET_VQ_RX = 0,
  75        VHOST_NET_VQ_TX = 1,
  76        VHOST_NET_VQ_MAX = 2,
  77};
  78
  79struct vhost_net_ubuf_ref {
  80        /* refcount follows semantics similar to kref:
  81         *  0: object is released
  82         *  1: no outstanding ubufs
  83         * >1: outstanding ubufs
  84         */
  85        atomic_t refcount;
  86        wait_queue_head_t wait;
  87        struct vhost_virtqueue *vq;
  88};
  89
  90#define VHOST_RX_BATCH 64
  91struct vhost_net_buf {
  92        struct sk_buff **queue;
  93        int tail;
  94        int head;
  95};
  96
  97struct vhost_net_virtqueue {
  98        struct vhost_virtqueue vq;
  99        size_t vhost_hlen;
 100        size_t sock_hlen;
 101        /* vhost zerocopy support fields below: */
 102        /* last used idx for outstanding DMA zerocopy buffers */
 103        int upend_idx;
 104        /* first used idx for DMA done zerocopy buffers */
 105        int done_idx;
 106        /* an array of userspace buffers info */
 107        struct ubuf_info *ubuf_info;
 108        /* Reference counting for outstanding ubufs.
 109         * Protected by vq mutex. Writers must also take device mutex. */
 110        struct vhost_net_ubuf_ref *ubufs;
 111        struct skb_array *rx_array;
 112        struct vhost_net_buf rxq;
 113};
 114
 115struct vhost_net {
 116        struct vhost_dev dev;
 117        struct vhost_net_virtqueue vqs[VHOST_NET_VQ_MAX];
 118        struct vhost_poll poll[VHOST_NET_VQ_MAX];
 119        /* Number of TX recently submitted.
 120         * Protected by tx vq lock. */
 121        unsigned tx_packets;
 122        /* Number of times zerocopy TX recently failed.
 123         * Protected by tx vq lock. */
 124        unsigned tx_zcopy_err;
 125        /* Flush in progress. Protected by tx vq lock. */
 126        bool tx_flush;
 127};
 128
 129static unsigned vhost_net_zcopy_mask __read_mostly;
 130
 131static void *vhost_net_buf_get_ptr(struct vhost_net_buf *rxq)
 132{
 133        if (rxq->tail != rxq->head)
 134                return rxq->queue[rxq->head];
 135        else
 136                return NULL;
 137}
 138
 139static int vhost_net_buf_get_size(struct vhost_net_buf *rxq)
 140{
 141        return rxq->tail - rxq->head;
 142}
 143
 144static int vhost_net_buf_is_empty(struct vhost_net_buf *rxq)
 145{
 146        return rxq->tail == rxq->head;
 147}
 148
 149static void *vhost_net_buf_consume(struct vhost_net_buf *rxq)
 150{
 151        void *ret = vhost_net_buf_get_ptr(rxq);
 152        ++rxq->head;
 153        return ret;
 154}
 155
 156static int vhost_net_buf_produce(struct vhost_net_virtqueue *nvq)
 157{
 158        struct vhost_net_buf *rxq = &nvq->rxq;
 159
 160        rxq->head = 0;
 161        rxq->tail = skb_array_consume_batched(nvq->rx_array, rxq->queue,
 162                                              VHOST_RX_BATCH);
 163        return rxq->tail;
 164}
 165
 166static void vhost_net_buf_unproduce(struct vhost_net_virtqueue *nvq)
 167{
 168        struct vhost_net_buf *rxq = &nvq->rxq;
 169
 170        if (nvq->rx_array && !vhost_net_buf_is_empty(rxq)) {
 171                skb_array_unconsume(nvq->rx_array, rxq->queue + rxq->head,
 172                                    vhost_net_buf_get_size(rxq));
 173                rxq->head = rxq->tail = 0;
 174        }
 175}
 176
 177static int vhost_net_buf_peek(struct vhost_net_virtqueue *nvq)
 178{
 179        struct vhost_net_buf *rxq = &nvq->rxq;
 180
 181        if (!vhost_net_buf_is_empty(rxq))
 182                goto out;
 183
 184        if (!vhost_net_buf_produce(nvq))
 185                return 0;
 186
 187out:
 188        return __skb_array_len_with_tag(vhost_net_buf_get_ptr(rxq));
 189}
 190
 191static void vhost_net_buf_init(struct vhost_net_buf *rxq)
 192{
 193        rxq->head = rxq->tail = 0;
 194}
 195
 196static void vhost_net_enable_zcopy(int vq)
 197{
 198        vhost_net_zcopy_mask |= 0x1 << vq;
 199}
 200
 201static struct vhost_net_ubuf_ref *
 202vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)
 203{
 204        struct vhost_net_ubuf_ref *ubufs;
 205        /* No zero copy backend? Nothing to count. */
 206        if (!zcopy)
 207                return NULL;
 208        ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL);
 209        if (!ubufs)
 210                return ERR_PTR(-ENOMEM);
 211        atomic_set(&ubufs->refcount, 1);
 212        init_waitqueue_head(&ubufs->wait);
 213        ubufs->vq = vq;
 214        return ubufs;
 215}
 216
 217static int vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)
 218{
 219        int r = atomic_sub_return(1, &ubufs->refcount);
 220        if (unlikely(!r))
 221                wake_up(&ubufs->wait);
 222        return r;
 223}
 224
 225static void vhost_net_ubuf_put_and_wait(struct vhost_net_ubuf_ref *ubufs)
 226{
 227        vhost_net_ubuf_put(ubufs);
 228        wait_event(ubufs->wait, !atomic_read(&ubufs->refcount));
 229}
 230
 231static void vhost_net_ubuf_put_wait_and_free(struct vhost_net_ubuf_ref *ubufs)
 232{
 233        vhost_net_ubuf_put_and_wait(ubufs);
 234        kfree(ubufs);
 235}
 236
 237static void vhost_net_clear_ubuf_info(struct vhost_net *n)
 238{
 239        int i;
 240
 241        for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
 242                kfree(n->vqs[i].ubuf_info);
 243                n->vqs[i].ubuf_info = NULL;
 244        }
 245}
 246
 247static int vhost_net_set_ubuf_info(struct vhost_net *n)
 248{
 249        bool zcopy;
 250        int i;
 251
 252        for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
 253                zcopy = vhost_net_zcopy_mask & (0x1 << i);
 254                if (!zcopy)
 255                        continue;
 256                n->vqs[i].ubuf_info = kmalloc(sizeof(*n->vqs[i].ubuf_info) *
 257                                              UIO_MAXIOV, GFP_KERNEL);
 258                if  (!n->vqs[i].ubuf_info)
 259                        goto err;
 260        }
 261        return 0;
 262
 263err:
 264        vhost_net_clear_ubuf_info(n);
 265        return -ENOMEM;
 266}
 267
 268static void vhost_net_vq_reset(struct vhost_net *n)
 269{
 270        int i;
 271
 272        vhost_net_clear_ubuf_info(n);
 273
 274        for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
 275                n->vqs[i].done_idx = 0;
 276                n->vqs[i].upend_idx = 0;
 277                n->vqs[i].ubufs = NULL;
 278                n->vqs[i].vhost_hlen = 0;
 279                n->vqs[i].sock_hlen = 0;
 280                vhost_net_buf_init(&n->vqs[i].rxq);
 281        }
 282
 283}
 284
 285static void vhost_net_tx_packet(struct vhost_net *net)
 286{
 287        ++net->tx_packets;
 288        if (net->tx_packets < 1024)
 289                return;
 290        net->tx_packets = 0;
 291        net->tx_zcopy_err = 0;
 292}
 293
 294static void vhost_net_tx_err(struct vhost_net *net)
 295{
 296        ++net->tx_zcopy_err;
 297}
 298
 299static bool vhost_net_tx_select_zcopy(struct vhost_net *net)
 300{
 301        /* TX flush waits for outstanding DMAs to be done.
 302         * Don't start new DMAs.
 303         */
 304        return !net->tx_flush &&
 305                net->tx_packets / 64 >= net->tx_zcopy_err;
 306}
 307
 308static bool vhost_sock_zcopy(struct socket *sock)
 309{
 310        return unlikely(experimental_zcopytx) &&
 311                sock_flag(sock->sk, SOCK_ZEROCOPY);
 312}
 313
 314/* In case of DMA done not in order in lower device driver for some reason.
 315 * upend_idx is used to track end of used idx, done_idx is used to track head
 316 * of used idx. Once lower device DMA done contiguously, we will signal KVM
 317 * guest used idx.
 318 */
 319static void vhost_zerocopy_signal_used(struct vhost_net *net,
 320                                       struct vhost_virtqueue *vq)
 321{
 322        struct vhost_net_virtqueue *nvq =
 323                container_of(vq, struct vhost_net_virtqueue, vq);
 324        int i, add;
 325        int j = 0;
 326
 327        for (i = nvq->done_idx; i != nvq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
 328                if (vq->heads[i].len == VHOST_DMA_FAILED_LEN)
 329                        vhost_net_tx_err(net);
 330                if (VHOST_DMA_IS_DONE(vq->heads[i].len)) {
 331                        vq->heads[i].len = VHOST_DMA_CLEAR_LEN;
 332                        ++j;
 333                } else
 334                        break;
 335        }
 336        while (j) {
 337                add = min(UIO_MAXIOV - nvq->done_idx, j);
 338                vhost_add_used_and_signal_n(vq->dev, vq,
 339                                            &vq->heads[nvq->done_idx], add);
 340                nvq->done_idx = (nvq->done_idx + add) % UIO_MAXIOV;
 341                j -= add;
 342        }
 343}
 344
 345static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)
 346{
 347        struct vhost_net_ubuf_ref *ubufs = ubuf->ctx;
 348        struct vhost_virtqueue *vq = ubufs->vq;
 349        int cnt;
 350
 351        rcu_read_lock_bh();
 352
 353        /* set len to mark this desc buffers done DMA */
 354        vq->heads[ubuf->desc].len = success ?
 355                VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN;
 356        cnt = vhost_net_ubuf_put(ubufs);
 357
 358        /*
 359         * Trigger polling thread if guest stopped submitting new buffers:
 360         * in this case, the refcount after decrement will eventually reach 1.
 361         * We also trigger polling periodically after each 16 packets
 362         * (the value 16 here is more or less arbitrary, it's tuned to trigger
 363         * less than 10% of times).
 364         */
 365        if (cnt <= 1 || !(cnt % 16))
 366                vhost_poll_queue(&vq->poll);
 367
 368        rcu_read_unlock_bh();
 369}
 370
 371static inline unsigned long busy_clock(void)
 372{
 373        return local_clock() >> 10;
 374}
 375
 376static bool vhost_can_busy_poll(struct vhost_dev *dev,
 377                                unsigned long endtime)
 378{
 379        return likely(!need_resched()) &&
 380               likely(!time_after(busy_clock(), endtime)) &&
 381               likely(!signal_pending(current)) &&
 382               !vhost_has_work(dev);
 383}
 384
 385static void vhost_net_disable_vq(struct vhost_net *n,
 386                                 struct vhost_virtqueue *vq)
 387{
 388        struct vhost_net_virtqueue *nvq =
 389                container_of(vq, struct vhost_net_virtqueue, vq);
 390        struct vhost_poll *poll = n->poll + (nvq - n->vqs);
 391        if (!vq->private_data)
 392                return;
 393        vhost_poll_stop(poll);
 394}
 395
 396static int vhost_net_enable_vq(struct vhost_net *n,
 397                                struct vhost_virtqueue *vq)
 398{
 399        struct vhost_net_virtqueue *nvq =
 400                container_of(vq, struct vhost_net_virtqueue, vq);
 401        struct vhost_poll *poll = n->poll + (nvq - n->vqs);
 402        struct socket *sock;
 403
 404        sock = vq->private_data;
 405        if (!sock)
 406                return 0;
 407
 408        return vhost_poll_start(poll, sock->file);
 409}
 410
 411static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
 412                                    struct vhost_virtqueue *vq,
 413                                    struct iovec iov[], unsigned int iov_size,
 414                                    unsigned int *out_num, unsigned int *in_num)
 415{
 416        unsigned long uninitialized_var(endtime);
 417        int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
 418                                  out_num, in_num, NULL, NULL);
 419
 420        if (r == vq->num && vq->busyloop_timeout) {
 421                preempt_disable();
 422                endtime = busy_clock() + vq->busyloop_timeout;
 423                while (vhost_can_busy_poll(vq->dev, endtime) &&
 424                       vhost_vq_avail_empty(vq->dev, vq))
 425                        cpu_relax();
 426                preempt_enable();
 427                r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
 428                                      out_num, in_num, NULL, NULL);
 429        }
 430
 431        return r;
 432}
 433
 434static bool vhost_exceeds_maxpend(struct vhost_net *net)
 435{
 436        struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
 437        struct vhost_virtqueue *vq = &nvq->vq;
 438
 439        return (nvq->upend_idx + vq->num - VHOST_MAX_PEND) % UIO_MAXIOV
 440                == nvq->done_idx;
 441}
 442
 443/* Expects to be always run from workqueue - which acts as
 444 * read-size critical section for our kind of RCU. */
 445static void handle_tx(struct vhost_net *net)
 446{
 447        struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
 448        struct vhost_virtqueue *vq = &nvq->vq;
 449        unsigned out, in;
 450        int head;
 451        struct msghdr msg = {
 452                .msg_name = NULL,
 453                .msg_namelen = 0,
 454                .msg_control = NULL,
 455                .msg_controllen = 0,
 456                .msg_flags = MSG_DONTWAIT,
 457        };
 458        size_t len, total_len = 0;
 459        int err;
 460        size_t hdr_size;
 461        struct socket *sock;
 462        struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
 463        bool zcopy, zcopy_used;
 464
 465        mutex_lock(&vq->mutex);
 466        sock = vq->private_data;
 467        if (!sock)
 468                goto out;
 469
 470        if (!vq_iotlb_prefetch(vq))
 471                goto out;
 472
 473        vhost_disable_notify(&net->dev, vq);
 474
 475        hdr_size = nvq->vhost_hlen;
 476        zcopy = nvq->ubufs;
 477
 478        for (;;) {
 479                /* Release DMAs done buffers first */
 480                if (zcopy)
 481                        vhost_zerocopy_signal_used(net, vq);
 482
 483                /* If more outstanding DMAs, queue the work.
 484                 * Handle upend_idx wrap around
 485                 */
 486                if (unlikely(vhost_exceeds_maxpend(net)))
 487                        break;
 488
 489                head = vhost_net_tx_get_vq_desc(net, vq, vq->iov,
 490                                                ARRAY_SIZE(vq->iov),
 491                                                &out, &in);
 492                /* On error, stop handling until the next kick. */
 493                if (unlikely(head < 0))
 494                        break;
 495                /* Nothing new?  Wait for eventfd to tell us they refilled. */
 496                if (head == vq->num) {
 497                        if (unlikely(vhost_enable_notify(&net->dev, vq))) {
 498                                vhost_disable_notify(&net->dev, vq);
 499                                continue;
 500                        }
 501                        break;
 502                }
 503                if (in) {
 504                        vq_err(vq, "Unexpected descriptor format for TX: "
 505                               "out %d, int %d\n", out, in);
 506                        break;
 507                }
 508                /* Skip header. TODO: support TSO. */
 509                len = iov_length(vq->iov, out);
 510                iov_iter_init(&msg.msg_iter, WRITE, vq->iov, out, len);
 511                iov_iter_advance(&msg.msg_iter, hdr_size);
 512                /* Sanity check */
 513                if (!msg_data_left(&msg)) {
 514                        vq_err(vq, "Unexpected header len for TX: "
 515                               "%zd expected %zd\n",
 516                               len, hdr_size);
 517                        break;
 518                }
 519                len = msg_data_left(&msg);
 520
 521                zcopy_used = zcopy && len >= VHOST_GOODCOPY_LEN
 522                                   && (nvq->upend_idx + 1) % UIO_MAXIOV !=
 523                                      nvq->done_idx
 524                                   && vhost_net_tx_select_zcopy(net);
 525
 526                /* use msg_control to pass vhost zerocopy ubuf info to skb */
 527                if (zcopy_used) {
 528                        struct ubuf_info *ubuf;
 529                        ubuf = nvq->ubuf_info + nvq->upend_idx;
 530
 531                        vq->heads[nvq->upend_idx].id = cpu_to_vhost32(vq, head);
 532                        vq->heads[nvq->upend_idx].len = VHOST_DMA_IN_PROGRESS;
 533                        ubuf->callback = vhost_zerocopy_callback;
 534                        ubuf->ctx = nvq->ubufs;
 535                        ubuf->desc = nvq->upend_idx;
 536                        msg.msg_control = ubuf;
 537                        msg.msg_controllen = sizeof(ubuf);
 538                        ubufs = nvq->ubufs;
 539                        atomic_inc(&ubufs->refcount);
 540                        nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
 541                } else {
 542                        msg.msg_control = NULL;
 543                        ubufs = NULL;
 544                }
 545
 546                total_len += len;
 547                if (total_len < VHOST_NET_WEIGHT &&
 548                    !vhost_vq_avail_empty(&net->dev, vq) &&
 549                    likely(!vhost_exceeds_maxpend(net))) {
 550                        msg.msg_flags |= MSG_MORE;
 551                } else {
 552                        msg.msg_flags &= ~MSG_MORE;
 553                }
 554
 555                /* TODO: Check specific error and bomb out unless ENOBUFS? */
 556                err = sock->ops->sendmsg(sock, &msg, len);
 557                if (unlikely(err < 0)) {
 558                        if (zcopy_used) {
 559                                vhost_net_ubuf_put(ubufs);
 560                                nvq->upend_idx = ((unsigned)nvq->upend_idx - 1)
 561                                        % UIO_MAXIOV;
 562                        }
 563                        vhost_discard_vq_desc(vq, 1);
 564                        break;
 565                }
 566                if (err != len)
 567                        pr_debug("Truncated TX packet: "
 568                                 " len %d != %zd\n", err, len);
 569                if (!zcopy_used)
 570                        vhost_add_used_and_signal(&net->dev, vq, head, 0);
 571                else
 572                        vhost_zerocopy_signal_used(net, vq);
 573                vhost_net_tx_packet(net);
 574                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
 575                        vhost_poll_queue(&vq->poll);
 576                        break;
 577                }
 578        }
 579out:
 580        mutex_unlock(&vq->mutex);
 581}
 582
 583static int peek_head_len(struct vhost_net_virtqueue *rvq, struct sock *sk)
 584{
 585        struct sk_buff *head;
 586        int len = 0;
 587        unsigned long flags;
 588
 589        if (rvq->rx_array)
 590                return vhost_net_buf_peek(rvq);
 591
 592        spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
 593        head = skb_peek(&sk->sk_receive_queue);
 594        if (likely(head)) {
 595                len = head->len;
 596                if (skb_vlan_tag_present(head))
 597                        len += VLAN_HLEN;
 598        }
 599
 600        spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags);
 601        return len;
 602}
 603
 604static int sk_has_rx_data(struct sock *sk)
 605{
 606        struct socket *sock = sk->sk_socket;
 607
 608        if (sock->ops->peek_len)
 609                return sock->ops->peek_len(sock);
 610
 611        return skb_queue_empty(&sk->sk_receive_queue);
 612}
 613
 614static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk)
 615{
 616        struct vhost_net_virtqueue *rvq = &net->vqs[VHOST_NET_VQ_RX];
 617        struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
 618        struct vhost_virtqueue *vq = &nvq->vq;
 619        unsigned long uninitialized_var(endtime);
 620        int len = peek_head_len(rvq, sk);
 621
 622        if (!len && vq->busyloop_timeout) {
 623                /* Both tx vq and rx socket were polled here */
 624                mutex_lock(&vq->mutex);
 625                vhost_disable_notify(&net->dev, vq);
 626
 627                preempt_disable();
 628                endtime = busy_clock() + vq->busyloop_timeout;
 629
 630                while (vhost_can_busy_poll(&net->dev, endtime) &&
 631                       !sk_has_rx_data(sk) &&
 632                       vhost_vq_avail_empty(&net->dev, vq))
 633                        cpu_relax();
 634
 635                preempt_enable();
 636
 637                if (vhost_enable_notify(&net->dev, vq))
 638                        vhost_poll_queue(&vq->poll);
 639                mutex_unlock(&vq->mutex);
 640
 641                len = peek_head_len(rvq, sk);
 642        }
 643
 644        return len;
 645}
 646
 647/* This is a multi-buffer version of vhost_get_desc, that works if
 648 *      vq has read descriptors only.
 649 * @vq          - the relevant virtqueue
 650 * @datalen     - data length we'll be reading
 651 * @iovcount    - returned count of io vectors we fill
 652 * @log         - vhost log
 653 * @log_num     - log offset
 654 * @quota       - headcount quota, 1 for big buffer
 655 *      returns number of buffer heads allocated, negative on error
 656 */
 657static int get_rx_bufs(struct vhost_virtqueue *vq,
 658                       struct vring_used_elem *heads,
 659                       int datalen,
 660                       unsigned *iovcount,
 661                       struct vhost_log *log,
 662                       unsigned *log_num,
 663                       unsigned int quota)
 664{
 665        unsigned int out, in;
 666        int seg = 0;
 667        int headcount = 0;
 668        unsigned d;
 669        int r, nlogs = 0;
 670        /* len is always initialized before use since we are always called with
 671         * datalen > 0.
 672         */
 673        u32 uninitialized_var(len);
 674
 675        while (datalen > 0 && headcount < quota) {
 676                if (unlikely(seg >= UIO_MAXIOV)) {
 677                        r = -ENOBUFS;
 678                        goto err;
 679                }
 680                r = vhost_get_vq_desc(vq, vq->iov + seg,
 681                                      ARRAY_SIZE(vq->iov) - seg, &out,
 682                                      &in, log, log_num);
 683                if (unlikely(r < 0))
 684                        goto err;
 685
 686                d = r;
 687                if (d == vq->num) {
 688                        r = 0;
 689                        goto err;
 690                }
 691                if (unlikely(out || in <= 0)) {
 692                        vq_err(vq, "unexpected descriptor format for RX: "
 693                                "out %d, in %d\n", out, in);
 694                        r = -EINVAL;
 695                        goto err;
 696                }
 697                if (unlikely(log)) {
 698                        nlogs += *log_num;
 699                        log += *log_num;
 700                }
 701                heads[headcount].id = cpu_to_vhost32(vq, d);
 702                len = iov_length(vq->iov + seg, in);
 703                heads[headcount].len = cpu_to_vhost32(vq, len);
 704                datalen -= len;
 705                ++headcount;
 706                seg += in;
 707        }
 708        heads[headcount - 1].len = cpu_to_vhost32(vq, len + datalen);
 709        *iovcount = seg;
 710        if (unlikely(log))
 711                *log_num = nlogs;
 712
 713        /* Detect overrun */
 714        if (unlikely(datalen > 0)) {
 715                r = UIO_MAXIOV + 1;
 716                goto err;
 717        }
 718        return headcount;
 719err:
 720        vhost_discard_vq_desc(vq, headcount);
 721        return r;
 722}
 723
 724/* Expects to be always run from workqueue - which acts as
 725 * read-size critical section for our kind of RCU. */
 726static void handle_rx(struct vhost_net *net)
 727{
 728        struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX];
 729        struct vhost_virtqueue *vq = &nvq->vq;
 730        unsigned uninitialized_var(in), log;
 731        struct vhost_log *vq_log;
 732        struct msghdr msg = {
 733                .msg_name = NULL,
 734                .msg_namelen = 0,
 735                .msg_control = NULL, /* FIXME: get and handle RX aux data. */
 736                .msg_controllen = 0,
 737                .msg_flags = MSG_DONTWAIT,
 738        };
 739        struct virtio_net_hdr hdr = {
 740                .flags = 0,
 741                .gso_type = VIRTIO_NET_HDR_GSO_NONE
 742        };
 743        size_t total_len = 0;
 744        int err, mergeable;
 745        s16 headcount;
 746        size_t vhost_hlen, sock_hlen;
 747        size_t vhost_len, sock_len;
 748        struct socket *sock;
 749        struct iov_iter fixup;
 750        __virtio16 num_buffers;
 751
 752        mutex_lock(&vq->mutex);
 753        sock = vq->private_data;
 754        if (!sock)
 755                goto out;
 756
 757        if (!vq_iotlb_prefetch(vq))
 758                goto out;
 759
 760        vhost_disable_notify(&net->dev, vq);
 761        vhost_net_disable_vq(net, vq);
 762
 763        vhost_hlen = nvq->vhost_hlen;
 764        sock_hlen = nvq->sock_hlen;
 765
 766        vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ?
 767                vq->log : NULL;
 768        mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
 769
 770        while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk))) {
 771                sock_len += sock_hlen;
 772                vhost_len = sock_len + vhost_hlen;
 773                headcount = get_rx_bufs(vq, vq->heads, vhost_len,
 774                                        &in, vq_log, &log,
 775                                        likely(mergeable) ? UIO_MAXIOV : 1);
 776                /* On error, stop handling until the next kick. */
 777                if (unlikely(headcount < 0))
 778                        goto out;
 779                if (nvq->rx_array)
 780                        msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
 781                /* On overrun, truncate and discard */
 782                if (unlikely(headcount > UIO_MAXIOV)) {
 783                        iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1);
 784                        err = sock->ops->recvmsg(sock, &msg,
 785                                                 1, MSG_DONTWAIT | MSG_TRUNC);
 786                        pr_debug("Discarded rx packet: len %zd\n", sock_len);
 787                        continue;
 788                }
 789                /* OK, now we need to know about added descriptors. */
 790                if (!headcount) {
 791                        if (unlikely(vhost_enable_notify(&net->dev, vq))) {
 792                                /* They have slipped one in as we were
 793                                 * doing that: check again. */
 794                                vhost_disable_notify(&net->dev, vq);
 795                                continue;
 796                        }
 797                        /* Nothing new?  Wait for eventfd to tell us
 798                         * they refilled. */
 799                        goto out;
 800                }
 801                /* We don't need to be notified again. */
 802                iov_iter_init(&msg.msg_iter, READ, vq->iov, in, vhost_len);
 803                fixup = msg.msg_iter;
 804                if (unlikely((vhost_hlen))) {
 805                        /* We will supply the header ourselves
 806                         * TODO: support TSO.
 807                         */
 808                        iov_iter_advance(&msg.msg_iter, vhost_hlen);
 809                }
 810                err = sock->ops->recvmsg(sock, &msg,
 811                                         sock_len, MSG_DONTWAIT | MSG_TRUNC);
 812                /* Userspace might have consumed the packet meanwhile:
 813                 * it's not supposed to do this usually, but might be hard
 814                 * to prevent. Discard data we got (if any) and keep going. */
 815                if (unlikely(err != sock_len)) {
 816                        pr_debug("Discarded rx packet: "
 817                                 " len %d, expected %zd\n", err, sock_len);
 818                        vhost_discard_vq_desc(vq, headcount);
 819                        continue;
 820                }
 821                /* Supply virtio_net_hdr if VHOST_NET_F_VIRTIO_NET_HDR */
 822                if (unlikely(vhost_hlen)) {
 823                        if (copy_to_iter(&hdr, sizeof(hdr),
 824                                         &fixup) != sizeof(hdr)) {
 825                                vq_err(vq, "Unable to write vnet_hdr "
 826                                       "at addr %p\n", vq->iov->iov_base);
 827                                goto out;
 828                        }
 829                } else {
 830                        /* Header came from socket; we'll need to patch
 831                         * ->num_buffers over if VIRTIO_NET_F_MRG_RXBUF
 832                         */
 833                        iov_iter_advance(&fixup, sizeof(hdr));
 834                }
 835                /* TODO: Should check and handle checksum. */
 836
 837                num_buffers = cpu_to_vhost16(vq, headcount);
 838                if (likely(mergeable) &&
 839                    copy_to_iter(&num_buffers, sizeof num_buffers,
 840                                 &fixup) != sizeof num_buffers) {
 841                        vq_err(vq, "Failed num_buffers write");
 842                        vhost_discard_vq_desc(vq, headcount);
 843                        goto out;
 844                }
 845                vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
 846                                            headcount);
 847                if (unlikely(vq_log))
 848                        vhost_log_write(vq, vq_log, log, vhost_len);
 849                total_len += vhost_len;
 850                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
 851                        vhost_poll_queue(&vq->poll);
 852                        goto out;
 853                }
 854        }
 855        vhost_net_enable_vq(net, vq);
 856out:
 857        mutex_unlock(&vq->mutex);
 858}
 859
 860static void handle_tx_kick(struct vhost_work *work)
 861{
 862        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 863                                                  poll.work);
 864        struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
 865
 866        handle_tx(net);
 867}
 868
 869static void handle_rx_kick(struct vhost_work *work)
 870{
 871        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 872                                                  poll.work);
 873        struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
 874
 875        handle_rx(net);
 876}
 877
 878static void handle_tx_net(struct vhost_work *work)
 879{
 880        struct vhost_net *net = container_of(work, struct vhost_net,
 881                                             poll[VHOST_NET_VQ_TX].work);
 882        handle_tx(net);
 883}
 884
 885static void handle_rx_net(struct vhost_work *work)
 886{
 887        struct vhost_net *net = container_of(work, struct vhost_net,
 888                                             poll[VHOST_NET_VQ_RX].work);
 889        handle_rx(net);
 890}
 891
 892static int vhost_net_open(struct inode *inode, struct file *f)
 893{
 894        struct vhost_net *n;
 895        struct vhost_dev *dev;
 896        struct vhost_virtqueue **vqs;
 897        struct sk_buff **queue;
 898        int i;
 899
 900        n = kvmalloc(sizeof *n, GFP_KERNEL | __GFP_RETRY_MAYFAIL);
 901        if (!n)
 902                return -ENOMEM;
 903        vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
 904        if (!vqs) {
 905                kvfree(n);
 906                return -ENOMEM;
 907        }
 908
 909        queue = kmalloc_array(VHOST_RX_BATCH, sizeof(struct sk_buff *),
 910                              GFP_KERNEL);
 911        if (!queue) {
 912                kfree(vqs);
 913                kvfree(n);
 914                return -ENOMEM;
 915        }
 916        n->vqs[VHOST_NET_VQ_RX].rxq.queue = queue;
 917
 918        dev = &n->dev;
 919        vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
 920        vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
 921        n->vqs[VHOST_NET_VQ_TX].vq.handle_kick = handle_tx_kick;
 922        n->vqs[VHOST_NET_VQ_RX].vq.handle_kick = handle_rx_kick;
 923        for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
 924                n->vqs[i].ubufs = NULL;
 925                n->vqs[i].ubuf_info = NULL;
 926                n->vqs[i].upend_idx = 0;
 927                n->vqs[i].done_idx = 0;
 928                n->vqs[i].vhost_hlen = 0;
 929                n->vqs[i].sock_hlen = 0;
 930                vhost_net_buf_init(&n->vqs[i].rxq);
 931        }
 932        vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
 933
 934        vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
 935        vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
 936
 937        f->private_data = n;
 938
 939        return 0;
 940}
 941
 942static struct socket *vhost_net_stop_vq(struct vhost_net *n,
 943                                        struct vhost_virtqueue *vq)
 944{
 945        struct socket *sock;
 946        struct vhost_net_virtqueue *nvq =
 947                container_of(vq, struct vhost_net_virtqueue, vq);
 948
 949        mutex_lock(&vq->mutex);
 950        sock = vq->private_data;
 951        vhost_net_disable_vq(n, vq);
 952        vq->private_data = NULL;
 953        vhost_net_buf_unproduce(nvq);
 954        mutex_unlock(&vq->mutex);
 955        return sock;
 956}
 957
 958static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
 959                           struct socket **rx_sock)
 960{
 961        *tx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_TX].vq);
 962        *rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq);
 963}
 964
 965static void vhost_net_flush_vq(struct vhost_net *n, int index)
 966{
 967        vhost_poll_flush(n->poll + index);
 968        vhost_poll_flush(&n->vqs[index].vq.poll);
 969}
 970
 971static void vhost_net_flush(struct vhost_net *n)
 972{
 973        vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
 974        vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
 975        if (n->vqs[VHOST_NET_VQ_TX].ubufs) {
 976                mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 977                n->tx_flush = true;
 978                mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 979                /* Wait for all lower device DMAs done. */
 980                vhost_net_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs);
 981                mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 982                n->tx_flush = false;
 983                atomic_set(&n->vqs[VHOST_NET_VQ_TX].ubufs->refcount, 1);
 984                mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 985        }
 986}
 987
 988static int vhost_net_release(struct inode *inode, struct file *f)
 989{
 990        struct vhost_net *n = f->private_data;
 991        struct socket *tx_sock;
 992        struct socket *rx_sock;
 993
 994        vhost_net_stop(n, &tx_sock, &rx_sock);
 995        vhost_net_flush(n);
 996        vhost_dev_stop(&n->dev);
 997        vhost_dev_cleanup(&n->dev, false);
 998        vhost_net_vq_reset(n);
 999        if (tx_sock)
1000                sockfd_put(tx_sock);
1001        if (rx_sock)
1002                sockfd_put(rx_sock);
1003        /* Make sure no callbacks are outstanding */
1004        synchronize_rcu_bh();
1005        /* We do an extra flush before freeing memory,
1006         * since jobs can re-queue themselves. */
1007        vhost_net_flush(n);
1008        kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue);
1009        kfree(n->dev.vqs);
1010        kvfree(n);
1011        return 0;
1012}
1013
1014static struct socket *get_raw_socket(int fd)
1015{
1016        struct {
1017                struct sockaddr_ll sa;
1018                char  buf[MAX_ADDR_LEN];
1019        } uaddr;
1020        int uaddr_len = sizeof uaddr, r;
1021        struct socket *sock = sockfd_lookup(fd, &r);
1022
1023        if (!sock)
1024                return ERR_PTR(-ENOTSOCK);
1025
1026        /* Parameter checking */
1027        if (sock->sk->sk_type != SOCK_RAW) {
1028                r = -ESOCKTNOSUPPORT;
1029                goto err;
1030        }
1031
1032        r = sock->ops->getname(sock, (struct sockaddr *)&uaddr.sa,
1033                               &uaddr_len, 0);
1034        if (r)
1035                goto err;
1036
1037        if (uaddr.sa.sll_family != AF_PACKET) {
1038                r = -EPFNOSUPPORT;
1039                goto err;
1040        }
1041        return sock;
1042err:
1043        sockfd_put(sock);
1044        return ERR_PTR(r);
1045}
1046
1047static struct skb_array *get_tap_skb_array(int fd)
1048{
1049        struct skb_array *array;
1050        struct file *file = fget(fd);
1051
1052        if (!file)
1053                return NULL;
1054        array = tun_get_skb_array(file);
1055        if (!IS_ERR(array))
1056                goto out;
1057        array = tap_get_skb_array(file);
1058        if (!IS_ERR(array))
1059                goto out;
1060        array = NULL;
1061out:
1062        fput(file);
1063        return array;
1064}
1065
1066static struct socket *get_tap_socket(int fd)
1067{
1068        struct file *file = fget(fd);
1069        struct socket *sock;
1070
1071        if (!file)
1072                return ERR_PTR(-EBADF);
1073        sock = tun_get_socket(file);
1074        if (!IS_ERR(sock))
1075                return sock;
1076        sock = tap_get_socket(file);
1077        if (IS_ERR(sock))
1078                fput(file);
1079        return sock;
1080}
1081
1082static struct socket *get_socket(int fd)
1083{
1084        struct socket *sock;
1085
1086        /* special case to disable backend */
1087        if (fd == -1)
1088                return NULL;
1089        sock = get_raw_socket(fd);
1090        if (!IS_ERR(sock))
1091                return sock;
1092        sock = get_tap_socket(fd);
1093        if (!IS_ERR(sock))
1094                return sock;
1095        return ERR_PTR(-ENOTSOCK);
1096}
1097
1098static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
1099{
1100        struct socket *sock, *oldsock;
1101        struct vhost_virtqueue *vq;
1102        struct vhost_net_virtqueue *nvq;
1103        struct vhost_net_ubuf_ref *ubufs, *oldubufs = NULL;
1104        int r;
1105
1106        mutex_lock(&n->dev.mutex);
1107        r = vhost_dev_check_owner(&n->dev);
1108        if (r)
1109                goto err;
1110
1111        if (index >= VHOST_NET_VQ_MAX) {
1112                r = -ENOBUFS;
1113                goto err;
1114        }
1115        vq = &n->vqs[index].vq;
1116        nvq = &n->vqs[index];
1117        mutex_lock(&vq->mutex);
1118
1119        /* Verify that ring has been setup correctly. */
1120        if (!vhost_vq_access_ok(vq)) {
1121                r = -EFAULT;
1122                goto err_vq;
1123        }
1124        sock = get_socket(fd);
1125        if (IS_ERR(sock)) {
1126                r = PTR_ERR(sock);
1127                goto err_vq;
1128        }
1129
1130        /* start polling new socket */
1131        oldsock = vq->private_data;
1132        if (sock != oldsock) {
1133                ubufs = vhost_net_ubuf_alloc(vq,
1134                                             sock && vhost_sock_zcopy(sock));
1135                if (IS_ERR(ubufs)) {
1136                        r = PTR_ERR(ubufs);
1137                        goto err_ubufs;
1138                }
1139
1140                vhost_net_disable_vq(n, vq);
1141                vq->private_data = sock;
1142                vhost_net_buf_unproduce(nvq);
1143                if (index == VHOST_NET_VQ_RX)
1144                        nvq->rx_array = get_tap_skb_array(fd);
1145                r = vhost_vq_init_access(vq);
1146                if (r)
1147                        goto err_used;
1148                r = vhost_net_enable_vq(n, vq);
1149                if (r)
1150                        goto err_used;
1151
1152                oldubufs = nvq->ubufs;
1153                nvq->ubufs = ubufs;
1154
1155                n->tx_packets = 0;
1156                n->tx_zcopy_err = 0;
1157                n->tx_flush = false;
1158        }
1159
1160        mutex_unlock(&vq->mutex);
1161
1162        if (oldubufs) {
1163                vhost_net_ubuf_put_wait_and_free(oldubufs);
1164                mutex_lock(&vq->mutex);
1165                vhost_zerocopy_signal_used(n, vq);
1166                mutex_unlock(&vq->mutex);
1167        }
1168
1169        if (oldsock) {
1170                vhost_net_flush_vq(n, index);
1171                sockfd_put(oldsock);
1172        }
1173
1174        mutex_unlock(&n->dev.mutex);
1175        return 0;
1176
1177err_used:
1178        vq->private_data = oldsock;
1179        vhost_net_enable_vq(n, vq);
1180        if (ubufs)
1181                vhost_net_ubuf_put_wait_and_free(ubufs);
1182err_ubufs:
1183        sockfd_put(sock);
1184err_vq:
1185        mutex_unlock(&vq->mutex);
1186err:
1187        mutex_unlock(&n->dev.mutex);
1188        return r;
1189}
1190
1191static long vhost_net_reset_owner(struct vhost_net *n)
1192{
1193        struct socket *tx_sock = NULL;
1194        struct socket *rx_sock = NULL;
1195        long err;
1196        struct vhost_umem *umem;
1197
1198        mutex_lock(&n->dev.mutex);
1199        err = vhost_dev_check_owner(&n->dev);
1200        if (err)
1201                goto done;
1202        umem = vhost_dev_reset_owner_prepare();
1203        if (!umem) {
1204                err = -ENOMEM;
1205                goto done;
1206        }
1207        vhost_net_stop(n, &tx_sock, &rx_sock);
1208        vhost_net_flush(n);
1209        vhost_dev_reset_owner(&n->dev, umem);
1210        vhost_net_vq_reset(n);
1211done:
1212        mutex_unlock(&n->dev.mutex);
1213        if (tx_sock)
1214                sockfd_put(tx_sock);
1215        if (rx_sock)
1216                sockfd_put(rx_sock);
1217        return err;
1218}
1219
1220static int vhost_net_set_features(struct vhost_net *n, u64 features)
1221{
1222        size_t vhost_hlen, sock_hlen, hdr_len;
1223        int i;
1224
1225        hdr_len = (features & ((1ULL << VIRTIO_NET_F_MRG_RXBUF) |
1226                               (1ULL << VIRTIO_F_VERSION_1))) ?
1227                        sizeof(struct virtio_net_hdr_mrg_rxbuf) :
1228                        sizeof(struct virtio_net_hdr);
1229        if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
1230                /* vhost provides vnet_hdr */
1231                vhost_hlen = hdr_len;
1232                sock_hlen = 0;
1233        } else {
1234                /* socket provides vnet_hdr */
1235                vhost_hlen = 0;
1236                sock_hlen = hdr_len;
1237        }
1238        mutex_lock(&n->dev.mutex);
1239        if ((features & (1 << VHOST_F_LOG_ALL)) &&
1240            !vhost_log_access_ok(&n->dev))
1241                goto out_unlock;
1242
1243        if ((features & (1ULL << VIRTIO_F_IOMMU_PLATFORM))) {
1244                if (vhost_init_device_iotlb(&n->dev, true))
1245                        goto out_unlock;
1246        }
1247
1248        for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
1249                mutex_lock(&n->vqs[i].vq.mutex);
1250                n->vqs[i].vq.acked_features = features;
1251                n->vqs[i].vhost_hlen = vhost_hlen;
1252                n->vqs[i].sock_hlen = sock_hlen;
1253                mutex_unlock(&n->vqs[i].vq.mutex);
1254        }
1255        mutex_unlock(&n->dev.mutex);
1256        return 0;
1257
1258out_unlock:
1259        mutex_unlock(&n->dev.mutex);
1260        return -EFAULT;
1261}
1262
1263static long vhost_net_set_owner(struct vhost_net *n)
1264{
1265        int r;
1266
1267        mutex_lock(&n->dev.mutex);
1268        if (vhost_dev_has_owner(&n->dev)) {
1269                r = -EBUSY;
1270                goto out;
1271        }
1272        r = vhost_net_set_ubuf_info(n);
1273        if (r)
1274                goto out;
1275        r = vhost_dev_set_owner(&n->dev);
1276        if (r)
1277                vhost_net_clear_ubuf_info(n);
1278        vhost_net_flush(n);
1279out:
1280        mutex_unlock(&n->dev.mutex);
1281        return r;
1282}
1283
1284static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
1285                            unsigned long arg)
1286{
1287        struct vhost_net *n = f->private_data;
1288        void __user *argp = (void __user *)arg;
1289        u64 __user *featurep = argp;
1290        struct vhost_vring_file backend;
1291        u64 features;
1292        int r;
1293
1294        switch (ioctl) {
1295        case VHOST_NET_SET_BACKEND:
1296                if (copy_from_user(&backend, argp, sizeof backend))
1297                        return -EFAULT;
1298                return vhost_net_set_backend(n, backend.index, backend.fd);
1299        case VHOST_GET_FEATURES:
1300                features = VHOST_NET_FEATURES;
1301                if (copy_to_user(featurep, &features, sizeof features))
1302                        return -EFAULT;
1303                return 0;
1304        case VHOST_SET_FEATURES:
1305                if (copy_from_user(&features, featurep, sizeof features))
1306                        return -EFAULT;
1307                if (features & ~VHOST_NET_FEATURES)
1308                        return -EOPNOTSUPP;
1309                return vhost_net_set_features(n, features);
1310        case VHOST_RESET_OWNER:
1311                return vhost_net_reset_owner(n);
1312        case VHOST_SET_OWNER:
1313                return vhost_net_set_owner(n);
1314        default:
1315                mutex_lock(&n->dev.mutex);
1316                r = vhost_dev_ioctl(&n->dev, ioctl, argp);
1317                if (r == -ENOIOCTLCMD)
1318                        r = vhost_vring_ioctl(&n->dev, ioctl, argp);
1319                else
1320                        vhost_net_flush(n);
1321                mutex_unlock(&n->dev.mutex);
1322                return r;
1323        }
1324}
1325
1326#ifdef CONFIG_COMPAT
1327static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
1328                                   unsigned long arg)
1329{
1330        return vhost_net_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
1331}
1332#endif
1333
1334static ssize_t vhost_net_chr_read_iter(struct kiocb *iocb, struct iov_iter *to)
1335{
1336        struct file *file = iocb->ki_filp;
1337        struct vhost_net *n = file->private_data;
1338        struct vhost_dev *dev = &n->dev;
1339        int noblock = file->f_flags & O_NONBLOCK;
1340
1341        return vhost_chr_read_iter(dev, to, noblock);
1342}
1343
1344static ssize_t vhost_net_chr_write_iter(struct kiocb *iocb,
1345                                        struct iov_iter *from)
1346{
1347        struct file *file = iocb->ki_filp;
1348        struct vhost_net *n = file->private_data;
1349        struct vhost_dev *dev = &n->dev;
1350
1351        return vhost_chr_write_iter(dev, from);
1352}
1353
1354static unsigned int vhost_net_chr_poll(struct file *file, poll_table *wait)
1355{
1356        struct vhost_net *n = file->private_data;
1357        struct vhost_dev *dev = &n->dev;
1358
1359        return vhost_chr_poll(file, dev, wait);
1360}
1361
1362static const struct file_operations vhost_net_fops = {
1363        .owner          = THIS_MODULE,
1364        .release        = vhost_net_release,
1365        .read_iter      = vhost_net_chr_read_iter,
1366        .write_iter     = vhost_net_chr_write_iter,
1367        .poll           = vhost_net_chr_poll,
1368        .unlocked_ioctl = vhost_net_ioctl,
1369#ifdef CONFIG_COMPAT
1370        .compat_ioctl   = vhost_net_compat_ioctl,
1371#endif
1372        .open           = vhost_net_open,
1373        .llseek         = noop_llseek,
1374};
1375
1376static struct miscdevice vhost_net_misc = {
1377        .minor = VHOST_NET_MINOR,
1378        .name = "vhost-net",
1379        .fops = &vhost_net_fops,
1380};
1381
1382static int vhost_net_init(void)
1383{
1384        if (experimental_zcopytx)
1385                vhost_net_enable_zcopy(VHOST_NET_VQ_TX);
1386        return misc_register(&vhost_net_misc);
1387}
1388module_init(vhost_net_init);
1389
1390static void vhost_net_exit(void)
1391{
1392        misc_deregister(&vhost_net_misc);
1393}
1394module_exit(vhost_net_exit);
1395
1396MODULE_VERSION("0.0.1");
1397MODULE_LICENSE("GPL v2");
1398MODULE_AUTHOR("Michael S. Tsirkin");
1399MODULE_DESCRIPTION("Host kernel accelerator for virtio net");
1400MODULE_ALIAS_MISCDEV(VHOST_NET_MINOR);
1401MODULE_ALIAS("devname:vhost-net");
1402