linux/drivers/vhost/net.c
<<
>>
Prefs
   1/* Copyright (C) 2009 Red Hat, Inc.
   2 * Author: Michael S. Tsirkin <mst@redhat.com>
   3 *
   4 * This work is licensed under the terms of the GNU GPL, version 2.
   5 *
   6 * virtio-net server in host kernel.
   7 */
   8
   9#include <linux/compat.h>
  10#include <linux/eventfd.h>
  11#include <linux/vhost.h>
  12#include <linux/virtio_net.h>
  13#include <linux/miscdevice.h>
  14#include <linux/module.h>
  15#include <linux/moduleparam.h>
  16#include <linux/mutex.h>
  17#include <linux/workqueue.h>
  18#include <linux/file.h>
  19#include <linux/slab.h>
  20#include <linux/vmalloc.h>
  21
  22#include <linux/net.h>
  23#include <linux/if_packet.h>
  24#include <linux/if_arp.h>
  25#include <linux/if_tun.h>
  26#include <linux/if_macvlan.h>
  27#include <linux/if_vlan.h>
  28
  29#include <net/sock.h>
  30
  31#include "vhost.h"
  32
  33static int experimental_zcopytx = 1;
  34module_param(experimental_zcopytx, int, 0444);
  35MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
  36                                       " 1 -Enable; 0 - Disable");
  37
  38/* Max number of bytes transferred before requeueing the job.
  39 * Using this limit prevents one virtqueue from starving others. */
  40#define VHOST_NET_WEIGHT 0x80000
  41
  42/* MAX number of TX used buffers for outstanding zerocopy */
  43#define VHOST_MAX_PEND 128
  44#define VHOST_GOODCOPY_LEN 256
  45
  46/*
  47 * For transmit, used buffer len is unused; we override it to track buffer
  48 * status internally; used for zerocopy tx only.
  49 */
  50/* Lower device DMA failed */
  51#define VHOST_DMA_FAILED_LEN    ((__force __virtio32)3)
  52/* Lower device DMA done */
  53#define VHOST_DMA_DONE_LEN      ((__force __virtio32)2)
  54/* Lower device DMA in progress */
  55#define VHOST_DMA_IN_PROGRESS   ((__force __virtio32)1)
  56/* Buffer unused */
  57#define VHOST_DMA_CLEAR_LEN     ((__force __virtio32)0)
  58
  59#define VHOST_DMA_IS_DONE(len) ((__force u32)(len) >= (__force u32)VHOST_DMA_DONE_LEN)
  60
  61enum {
  62        VHOST_NET_FEATURES = VHOST_FEATURES |
  63                         (1ULL << VHOST_NET_F_VIRTIO_NET_HDR) |
  64                         (1ULL << VIRTIO_NET_F_MRG_RXBUF)
  65};
  66
  67enum {
  68        VHOST_NET_VQ_RX = 0,
  69        VHOST_NET_VQ_TX = 1,
  70        VHOST_NET_VQ_MAX = 2,
  71};
  72
  73struct vhost_net_ubuf_ref {
  74        /* refcount follows semantics similar to kref:
  75         *  0: object is released
  76         *  1: no outstanding ubufs
  77         * >1: outstanding ubufs
  78         */
  79        atomic_t refcount;
  80        wait_queue_head_t wait;
  81        struct vhost_virtqueue *vq;
  82};
  83
  84struct vhost_net_virtqueue {
  85        struct vhost_virtqueue vq;
  86        size_t vhost_hlen;
  87        size_t sock_hlen;
  88        /* vhost zerocopy support fields below: */
  89        /* last used idx for outstanding DMA zerocopy buffers */
  90        int upend_idx;
  91        /* first used idx for DMA done zerocopy buffers */
  92        int done_idx;
  93        /* an array of userspace buffers info */
  94        struct ubuf_info *ubuf_info;
  95        /* Reference counting for outstanding ubufs.
  96         * Protected by vq mutex. Writers must also take device mutex. */
  97        struct vhost_net_ubuf_ref *ubufs;
  98};
  99
 100struct vhost_net {
 101        struct vhost_dev dev;
 102        struct vhost_net_virtqueue vqs[VHOST_NET_VQ_MAX];
 103        struct vhost_poll poll[VHOST_NET_VQ_MAX];
 104        /* Number of TX recently submitted.
 105         * Protected by tx vq lock. */
 106        unsigned tx_packets;
 107        /* Number of times zerocopy TX recently failed.
 108         * Protected by tx vq lock. */
 109        unsigned tx_zcopy_err;
 110        /* Flush in progress. Protected by tx vq lock. */
 111        bool tx_flush;
 112};
 113
 114static unsigned vhost_net_zcopy_mask __read_mostly;
 115
 116static void vhost_net_enable_zcopy(int vq)
 117{
 118        vhost_net_zcopy_mask |= 0x1 << vq;
 119}
 120
 121static struct vhost_net_ubuf_ref *
 122vhost_net_ubuf_alloc(struct vhost_virtqueue *vq, bool zcopy)
 123{
 124        struct vhost_net_ubuf_ref *ubufs;
 125        /* No zero copy backend? Nothing to count. */
 126        if (!zcopy)
 127                return NULL;
 128        ubufs = kmalloc(sizeof(*ubufs), GFP_KERNEL);
 129        if (!ubufs)
 130                return ERR_PTR(-ENOMEM);
 131        atomic_set(&ubufs->refcount, 1);
 132        init_waitqueue_head(&ubufs->wait);
 133        ubufs->vq = vq;
 134        return ubufs;
 135}
 136
 137static int vhost_net_ubuf_put(struct vhost_net_ubuf_ref *ubufs)
 138{
 139        int r = atomic_sub_return(1, &ubufs->refcount);
 140        if (unlikely(!r))
 141                wake_up(&ubufs->wait);
 142        return r;
 143}
 144
 145static void vhost_net_ubuf_put_and_wait(struct vhost_net_ubuf_ref *ubufs)
 146{
 147        vhost_net_ubuf_put(ubufs);
 148        wait_event(ubufs->wait, !atomic_read(&ubufs->refcount));
 149}
 150
 151static void vhost_net_ubuf_put_wait_and_free(struct vhost_net_ubuf_ref *ubufs)
 152{
 153        vhost_net_ubuf_put_and_wait(ubufs);
 154        kfree(ubufs);
 155}
 156
 157static void vhost_net_clear_ubuf_info(struct vhost_net *n)
 158{
 159        int i;
 160
 161        for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
 162                kfree(n->vqs[i].ubuf_info);
 163                n->vqs[i].ubuf_info = NULL;
 164        }
 165}
 166
 167static int vhost_net_set_ubuf_info(struct vhost_net *n)
 168{
 169        bool zcopy;
 170        int i;
 171
 172        for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
 173                zcopy = vhost_net_zcopy_mask & (0x1 << i);
 174                if (!zcopy)
 175                        continue;
 176                n->vqs[i].ubuf_info = kmalloc(sizeof(*n->vqs[i].ubuf_info) *
 177                                              UIO_MAXIOV, GFP_KERNEL);
 178                if  (!n->vqs[i].ubuf_info)
 179                        goto err;
 180        }
 181        return 0;
 182
 183err:
 184        vhost_net_clear_ubuf_info(n);
 185        return -ENOMEM;
 186}
 187
 188static void vhost_net_vq_reset(struct vhost_net *n)
 189{
 190        int i;
 191
 192        vhost_net_clear_ubuf_info(n);
 193
 194        for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
 195                n->vqs[i].done_idx = 0;
 196                n->vqs[i].upend_idx = 0;
 197                n->vqs[i].ubufs = NULL;
 198                n->vqs[i].vhost_hlen = 0;
 199                n->vqs[i].sock_hlen = 0;
 200        }
 201
 202}
 203
 204static void vhost_net_tx_packet(struct vhost_net *net)
 205{
 206        ++net->tx_packets;
 207        if (net->tx_packets < 1024)
 208                return;
 209        net->tx_packets = 0;
 210        net->tx_zcopy_err = 0;
 211}
 212
 213static void vhost_net_tx_err(struct vhost_net *net)
 214{
 215        ++net->tx_zcopy_err;
 216}
 217
 218static bool vhost_net_tx_select_zcopy(struct vhost_net *net)
 219{
 220        /* TX flush waits for outstanding DMAs to be done.
 221         * Don't start new DMAs.
 222         */
 223        return !net->tx_flush &&
 224                net->tx_packets / 64 >= net->tx_zcopy_err;
 225}
 226
 227static bool vhost_sock_zcopy(struct socket *sock)
 228{
 229        return unlikely(experimental_zcopytx) &&
 230                sock_flag(sock->sk, SOCK_ZEROCOPY);
 231}
 232
 233/* In case of DMA done not in order in lower device driver for some reason.
 234 * upend_idx is used to track end of used idx, done_idx is used to track head
 235 * of used idx. Once lower device DMA done contiguously, we will signal KVM
 236 * guest used idx.
 237 */
 238static void vhost_zerocopy_signal_used(struct vhost_net *net,
 239                                       struct vhost_virtqueue *vq)
 240{
 241        struct vhost_net_virtqueue *nvq =
 242                container_of(vq, struct vhost_net_virtqueue, vq);
 243        int i, add;
 244        int j = 0;
 245
 246        for (i = nvq->done_idx; i != nvq->upend_idx; i = (i + 1) % UIO_MAXIOV) {
 247                if (vq->heads[i].len == VHOST_DMA_FAILED_LEN)
 248                        vhost_net_tx_err(net);
 249                if (VHOST_DMA_IS_DONE(vq->heads[i].len)) {
 250                        vq->heads[i].len = VHOST_DMA_CLEAR_LEN;
 251                        ++j;
 252                } else
 253                        break;
 254        }
 255        while (j) {
 256                add = min(UIO_MAXIOV - nvq->done_idx, j);
 257                vhost_add_used_and_signal_n(vq->dev, vq,
 258                                            &vq->heads[nvq->done_idx], add);
 259                nvq->done_idx = (nvq->done_idx + add) % UIO_MAXIOV;
 260                j -= add;
 261        }
 262}
 263
 264static void vhost_zerocopy_callback(struct ubuf_info *ubuf, bool success)
 265{
 266        struct vhost_net_ubuf_ref *ubufs = ubuf->ctx;
 267        struct vhost_virtqueue *vq = ubufs->vq;
 268        int cnt;
 269
 270        rcu_read_lock_bh();
 271
 272        /* set len to mark this desc buffers done DMA */
 273        vq->heads[ubuf->desc].len = success ?
 274                VHOST_DMA_DONE_LEN : VHOST_DMA_FAILED_LEN;
 275        cnt = vhost_net_ubuf_put(ubufs);
 276
 277        /*
 278         * Trigger polling thread if guest stopped submitting new buffers:
 279         * in this case, the refcount after decrement will eventually reach 1.
 280         * We also trigger polling periodically after each 16 packets
 281         * (the value 16 here is more or less arbitrary, it's tuned to trigger
 282         * less than 10% of times).
 283         */
 284        if (cnt <= 1 || !(cnt % 16))
 285                vhost_poll_queue(&vq->poll);
 286
 287        rcu_read_unlock_bh();
 288}
 289
 290/* Expects to be always run from workqueue - which acts as
 291 * read-size critical section for our kind of RCU. */
 292static void handle_tx(struct vhost_net *net)
 293{
 294        struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
 295        struct vhost_virtqueue *vq = &nvq->vq;
 296        unsigned out, in;
 297        int head;
 298        struct msghdr msg = {
 299                .msg_name = NULL,
 300                .msg_namelen = 0,
 301                .msg_control = NULL,
 302                .msg_controllen = 0,
 303                .msg_flags = MSG_DONTWAIT,
 304        };
 305        size_t len, total_len = 0;
 306        int err;
 307        size_t hdr_size;
 308        struct socket *sock;
 309        struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
 310        bool zcopy, zcopy_used;
 311
 312        mutex_lock(&vq->mutex);
 313        sock = vq->private_data;
 314        if (!sock)
 315                goto out;
 316
 317        vhost_disable_notify(&net->dev, vq);
 318
 319        hdr_size = nvq->vhost_hlen;
 320        zcopy = nvq->ubufs;
 321
 322        for (;;) {
 323                /* Release DMAs done buffers first */
 324                if (zcopy)
 325                        vhost_zerocopy_signal_used(net, vq);
 326
 327                /* If more outstanding DMAs, queue the work.
 328                 * Handle upend_idx wrap around
 329                 */
 330                if (unlikely((nvq->upend_idx + vq->num - VHOST_MAX_PEND)
 331                              % UIO_MAXIOV == nvq->done_idx))
 332                        break;
 333
 334                head = vhost_get_vq_desc(vq, vq->iov,
 335                                         ARRAY_SIZE(vq->iov),
 336                                         &out, &in,
 337                                         NULL, NULL);
 338                /* On error, stop handling until the next kick. */
 339                if (unlikely(head < 0))
 340                        break;
 341                /* Nothing new?  Wait for eventfd to tell us they refilled. */
 342                if (head == vq->num) {
 343                        if (unlikely(vhost_enable_notify(&net->dev, vq))) {
 344                                vhost_disable_notify(&net->dev, vq);
 345                                continue;
 346                        }
 347                        break;
 348                }
 349                if (in) {
 350                        vq_err(vq, "Unexpected descriptor format for TX: "
 351                               "out %d, int %d\n", out, in);
 352                        break;
 353                }
 354                /* Skip header. TODO: support TSO. */
 355                len = iov_length(vq->iov, out);
 356                iov_iter_init(&msg.msg_iter, WRITE, vq->iov, out, len);
 357                iov_iter_advance(&msg.msg_iter, hdr_size);
 358                /* Sanity check */
 359                if (!msg_data_left(&msg)) {
 360                        vq_err(vq, "Unexpected header len for TX: "
 361                               "%zd expected %zd\n",
 362                               len, hdr_size);
 363                        break;
 364                }
 365                len = msg_data_left(&msg);
 366
 367                zcopy_used = zcopy && len >= VHOST_GOODCOPY_LEN
 368                                   && (nvq->upend_idx + 1) % UIO_MAXIOV !=
 369                                      nvq->done_idx
 370                                   && vhost_net_tx_select_zcopy(net);
 371
 372                /* use msg_control to pass vhost zerocopy ubuf info to skb */
 373                if (zcopy_used) {
 374                        struct ubuf_info *ubuf;
 375                        ubuf = nvq->ubuf_info + nvq->upend_idx;
 376
 377                        vq->heads[nvq->upend_idx].id = cpu_to_vhost32(vq, head);
 378                        vq->heads[nvq->upend_idx].len = VHOST_DMA_IN_PROGRESS;
 379                        ubuf->callback = vhost_zerocopy_callback;
 380                        ubuf->ctx = nvq->ubufs;
 381                        ubuf->desc = nvq->upend_idx;
 382                        msg.msg_control = ubuf;
 383                        msg.msg_controllen = sizeof(ubuf);
 384                        ubufs = nvq->ubufs;
 385                        atomic_inc(&ubufs->refcount);
 386                        nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
 387                } else {
 388                        msg.msg_control = NULL;
 389                        ubufs = NULL;
 390                }
 391                /* TODO: Check specific error and bomb out unless ENOBUFS? */
 392                err = sock->ops->sendmsg(sock, &msg, len);
 393                if (unlikely(err < 0)) {
 394                        if (zcopy_used) {
 395                                vhost_net_ubuf_put(ubufs);
 396                                nvq->upend_idx = ((unsigned)nvq->upend_idx - 1)
 397                                        % UIO_MAXIOV;
 398                        }
 399                        vhost_discard_vq_desc(vq, 1);
 400                        break;
 401                }
 402                if (err != len)
 403                        pr_debug("Truncated TX packet: "
 404                                 " len %d != %zd\n", err, len);
 405                if (!zcopy_used)
 406                        vhost_add_used_and_signal(&net->dev, vq, head, 0);
 407                else
 408                        vhost_zerocopy_signal_used(net, vq);
 409                total_len += len;
 410                vhost_net_tx_packet(net);
 411                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
 412                        vhost_poll_queue(&vq->poll);
 413                        break;
 414                }
 415        }
 416out:
 417        mutex_unlock(&vq->mutex);
 418}
 419
 420static int peek_head_len(struct sock *sk)
 421{
 422        struct sk_buff *head;
 423        int len = 0;
 424        unsigned long flags;
 425
 426        spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
 427        head = skb_peek(&sk->sk_receive_queue);
 428        if (likely(head)) {
 429                len = head->len;
 430                if (skb_vlan_tag_present(head))
 431                        len += VLAN_HLEN;
 432        }
 433
 434        spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags);
 435        return len;
 436}
 437
 438/* This is a multi-buffer version of vhost_get_desc, that works if
 439 *      vq has read descriptors only.
 440 * @vq          - the relevant virtqueue
 441 * @datalen     - data length we'll be reading
 442 * @iovcount    - returned count of io vectors we fill
 443 * @log         - vhost log
 444 * @log_num     - log offset
 445 * @quota       - headcount quota, 1 for big buffer
 446 *      returns number of buffer heads allocated, negative on error
 447 */
 448static int get_rx_bufs(struct vhost_virtqueue *vq,
 449                       struct vring_used_elem *heads,
 450                       int datalen,
 451                       unsigned *iovcount,
 452                       struct vhost_log *log,
 453                       unsigned *log_num,
 454                       unsigned int quota)
 455{
 456        unsigned int out, in;
 457        int seg = 0;
 458        int headcount = 0;
 459        unsigned d;
 460        int r, nlogs = 0;
 461        /* len is always initialized before use since we are always called with
 462         * datalen > 0.
 463         */
 464        u32 uninitialized_var(len);
 465
 466        while (datalen > 0 && headcount < quota) {
 467                if (unlikely(seg >= UIO_MAXIOV)) {
 468                        r = -ENOBUFS;
 469                        goto err;
 470                }
 471                r = vhost_get_vq_desc(vq, vq->iov + seg,
 472                                      ARRAY_SIZE(vq->iov) - seg, &out,
 473                                      &in, log, log_num);
 474                if (unlikely(r < 0))
 475                        goto err;
 476
 477                d = r;
 478                if (d == vq->num) {
 479                        r = 0;
 480                        goto err;
 481                }
 482                if (unlikely(out || in <= 0)) {
 483                        vq_err(vq, "unexpected descriptor format for RX: "
 484                                "out %d, in %d\n", out, in);
 485                        r = -EINVAL;
 486                        goto err;
 487                }
 488                if (unlikely(log)) {
 489                        nlogs += *log_num;
 490                        log += *log_num;
 491                }
 492                heads[headcount].id = cpu_to_vhost32(vq, d);
 493                len = iov_length(vq->iov + seg, in);
 494                heads[headcount].len = cpu_to_vhost32(vq, len);
 495                datalen -= len;
 496                ++headcount;
 497                seg += in;
 498        }
 499        heads[headcount - 1].len = cpu_to_vhost32(vq, len + datalen);
 500        *iovcount = seg;
 501        if (unlikely(log))
 502                *log_num = nlogs;
 503
 504        /* Detect overrun */
 505        if (unlikely(datalen > 0)) {
 506                r = UIO_MAXIOV + 1;
 507                goto err;
 508        }
 509        return headcount;
 510err:
 511        vhost_discard_vq_desc(vq, headcount);
 512        return r;
 513}
 514
 515/* Expects to be always run from workqueue - which acts as
 516 * read-size critical section for our kind of RCU. */
 517static void handle_rx(struct vhost_net *net)
 518{
 519        struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX];
 520        struct vhost_virtqueue *vq = &nvq->vq;
 521        unsigned uninitialized_var(in), log;
 522        struct vhost_log *vq_log;
 523        struct msghdr msg = {
 524                .msg_name = NULL,
 525                .msg_namelen = 0,
 526                .msg_control = NULL, /* FIXME: get and handle RX aux data. */
 527                .msg_controllen = 0,
 528                .msg_flags = MSG_DONTWAIT,
 529        };
 530        struct virtio_net_hdr hdr = {
 531                .flags = 0,
 532                .gso_type = VIRTIO_NET_HDR_GSO_NONE
 533        };
 534        size_t total_len = 0;
 535        int err, mergeable;
 536        s16 headcount;
 537        size_t vhost_hlen, sock_hlen;
 538        size_t vhost_len, sock_len;
 539        struct socket *sock;
 540        struct iov_iter fixup;
 541        __virtio16 num_buffers;
 542
 543        mutex_lock(&vq->mutex);
 544        sock = vq->private_data;
 545        if (!sock)
 546                goto out;
 547        vhost_disable_notify(&net->dev, vq);
 548
 549        vhost_hlen = nvq->vhost_hlen;
 550        sock_hlen = nvq->sock_hlen;
 551
 552        vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ?
 553                vq->log : NULL;
 554        mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
 555
 556        while ((sock_len = peek_head_len(sock->sk))) {
 557                sock_len += sock_hlen;
 558                vhost_len = sock_len + vhost_hlen;
 559                headcount = get_rx_bufs(vq, vq->heads, vhost_len,
 560                                        &in, vq_log, &log,
 561                                        likely(mergeable) ? UIO_MAXIOV : 1);
 562                /* On error, stop handling until the next kick. */
 563                if (unlikely(headcount < 0))
 564                        break;
 565                /* On overrun, truncate and discard */
 566                if (unlikely(headcount > UIO_MAXIOV)) {
 567                        iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1);
 568                        err = sock->ops->recvmsg(sock, &msg,
 569                                                 1, MSG_DONTWAIT | MSG_TRUNC);
 570                        pr_debug("Discarded rx packet: len %zd\n", sock_len);
 571                        continue;
 572                }
 573                /* OK, now we need to know about added descriptors. */
 574                if (!headcount) {
 575                        if (unlikely(vhost_enable_notify(&net->dev, vq))) {
 576                                /* They have slipped one in as we were
 577                                 * doing that: check again. */
 578                                vhost_disable_notify(&net->dev, vq);
 579                                continue;
 580                        }
 581                        /* Nothing new?  Wait for eventfd to tell us
 582                         * they refilled. */
 583                        break;
 584                }
 585                /* We don't need to be notified again. */
 586                iov_iter_init(&msg.msg_iter, READ, vq->iov, in, vhost_len);
 587                fixup = msg.msg_iter;
 588                if (unlikely((vhost_hlen))) {
 589                        /* We will supply the header ourselves
 590                         * TODO: support TSO.
 591                         */
 592                        iov_iter_advance(&msg.msg_iter, vhost_hlen);
 593                }
 594                err = sock->ops->recvmsg(sock, &msg,
 595                                         sock_len, MSG_DONTWAIT | MSG_TRUNC);
 596                /* Userspace might have consumed the packet meanwhile:
 597                 * it's not supposed to do this usually, but might be hard
 598                 * to prevent. Discard data we got (if any) and keep going. */
 599                if (unlikely(err != sock_len)) {
 600                        pr_debug("Discarded rx packet: "
 601                                 " len %d, expected %zd\n", err, sock_len);
 602                        vhost_discard_vq_desc(vq, headcount);
 603                        continue;
 604                }
 605                /* Supply virtio_net_hdr if VHOST_NET_F_VIRTIO_NET_HDR */
 606                if (unlikely(vhost_hlen)) {
 607                        if (copy_to_iter(&hdr, sizeof(hdr),
 608                                         &fixup) != sizeof(hdr)) {
 609                                vq_err(vq, "Unable to write vnet_hdr "
 610                                       "at addr %p\n", vq->iov->iov_base);
 611                                break;
 612                        }
 613                } else {
 614                        /* Header came from socket; we'll need to patch
 615                         * ->num_buffers over if VIRTIO_NET_F_MRG_RXBUF
 616                         */
 617                        iov_iter_advance(&fixup, sizeof(hdr));
 618                }
 619                /* TODO: Should check and handle checksum. */
 620
 621                num_buffers = cpu_to_vhost16(vq, headcount);
 622                if (likely(mergeable) &&
 623                    copy_to_iter(&num_buffers, sizeof num_buffers,
 624                                 &fixup) != sizeof num_buffers) {
 625                        vq_err(vq, "Failed num_buffers write");
 626                        vhost_discard_vq_desc(vq, headcount);
 627                        break;
 628                }
 629                vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
 630                                            headcount);
 631                if (unlikely(vq_log))
 632                        vhost_log_write(vq, vq_log, log, vhost_len);
 633                total_len += vhost_len;
 634                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
 635                        vhost_poll_queue(&vq->poll);
 636                        break;
 637                }
 638        }
 639out:
 640        mutex_unlock(&vq->mutex);
 641}
 642
 643static void handle_tx_kick(struct vhost_work *work)
 644{
 645        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 646                                                  poll.work);
 647        struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
 648
 649        handle_tx(net);
 650}
 651
 652static void handle_rx_kick(struct vhost_work *work)
 653{
 654        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 655                                                  poll.work);
 656        struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
 657
 658        handle_rx(net);
 659}
 660
 661static void handle_tx_net(struct vhost_work *work)
 662{
 663        struct vhost_net *net = container_of(work, struct vhost_net,
 664                                             poll[VHOST_NET_VQ_TX].work);
 665        handle_tx(net);
 666}
 667
 668static void handle_rx_net(struct vhost_work *work)
 669{
 670        struct vhost_net *net = container_of(work, struct vhost_net,
 671                                             poll[VHOST_NET_VQ_RX].work);
 672        handle_rx(net);
 673}
 674
 675static int vhost_net_open(struct inode *inode, struct file *f)
 676{
 677        struct vhost_net *n;
 678        struct vhost_dev *dev;
 679        struct vhost_virtqueue **vqs;
 680        int i;
 681
 682        n = kmalloc(sizeof *n, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
 683        if (!n) {
 684                n = vmalloc(sizeof *n);
 685                if (!n)
 686                        return -ENOMEM;
 687        }
 688        vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
 689        if (!vqs) {
 690                kvfree(n);
 691                return -ENOMEM;
 692        }
 693
 694        dev = &n->dev;
 695        vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
 696        vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
 697        n->vqs[VHOST_NET_VQ_TX].vq.handle_kick = handle_tx_kick;
 698        n->vqs[VHOST_NET_VQ_RX].vq.handle_kick = handle_rx_kick;
 699        for (i = 0; i < VHOST_NET_VQ_MAX; i++) {
 700                n->vqs[i].ubufs = NULL;
 701                n->vqs[i].ubuf_info = NULL;
 702                n->vqs[i].upend_idx = 0;
 703                n->vqs[i].done_idx = 0;
 704                n->vqs[i].vhost_hlen = 0;
 705                n->vqs[i].sock_hlen = 0;
 706        }
 707        vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
 708
 709        vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
 710        vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
 711
 712        f->private_data = n;
 713
 714        return 0;
 715}
 716
 717static void vhost_net_disable_vq(struct vhost_net *n,
 718                                 struct vhost_virtqueue *vq)
 719{
 720        struct vhost_net_virtqueue *nvq =
 721                container_of(vq, struct vhost_net_virtqueue, vq);
 722        struct vhost_poll *poll = n->poll + (nvq - n->vqs);
 723        if (!vq->private_data)
 724                return;
 725        vhost_poll_stop(poll);
 726}
 727
 728static int vhost_net_enable_vq(struct vhost_net *n,
 729                                struct vhost_virtqueue *vq)
 730{
 731        struct vhost_net_virtqueue *nvq =
 732                container_of(vq, struct vhost_net_virtqueue, vq);
 733        struct vhost_poll *poll = n->poll + (nvq - n->vqs);
 734        struct socket *sock;
 735
 736        sock = vq->private_data;
 737        if (!sock)
 738                return 0;
 739
 740        return vhost_poll_start(poll, sock->file);
 741}
 742
 743static struct socket *vhost_net_stop_vq(struct vhost_net *n,
 744                                        struct vhost_virtqueue *vq)
 745{
 746        struct socket *sock;
 747
 748        mutex_lock(&vq->mutex);
 749        sock = vq->private_data;
 750        vhost_net_disable_vq(n, vq);
 751        vq->private_data = NULL;
 752        mutex_unlock(&vq->mutex);
 753        return sock;
 754}
 755
 756static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
 757                           struct socket **rx_sock)
 758{
 759        *tx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_TX].vq);
 760        *rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq);
 761}
 762
 763static void vhost_net_flush_vq(struct vhost_net *n, int index)
 764{
 765        vhost_poll_flush(n->poll + index);
 766        vhost_poll_flush(&n->vqs[index].vq.poll);
 767}
 768
 769static void vhost_net_flush(struct vhost_net *n)
 770{
 771        vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
 772        vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
 773        if (n->vqs[VHOST_NET_VQ_TX].ubufs) {
 774                mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 775                n->tx_flush = true;
 776                mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 777                /* Wait for all lower device DMAs done. */
 778                vhost_net_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].ubufs);
 779                mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 780                n->tx_flush = false;
 781                atomic_set(&n->vqs[VHOST_NET_VQ_TX].ubufs->refcount, 1);
 782                mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 783        }
 784}
 785
 786static int vhost_net_release(struct inode *inode, struct file *f)
 787{
 788        struct vhost_net *n = f->private_data;
 789        struct socket *tx_sock;
 790        struct socket *rx_sock;
 791
 792        vhost_net_stop(n, &tx_sock, &rx_sock);
 793        vhost_net_flush(n);
 794        vhost_dev_stop(&n->dev);
 795        vhost_dev_cleanup(&n->dev, false);
 796        vhost_net_vq_reset(n);
 797        if (tx_sock)
 798                sockfd_put(tx_sock);
 799        if (rx_sock)
 800                sockfd_put(rx_sock);
 801        /* Make sure no callbacks are outstanding */
 802        synchronize_rcu_bh();
 803        /* We do an extra flush before freeing memory,
 804         * since jobs can re-queue themselves. */
 805        vhost_net_flush(n);
 806        kfree(n->dev.vqs);
 807        kvfree(n);
 808        return 0;
 809}
 810
 811static struct socket *get_raw_socket(int fd)
 812{
 813        struct {
 814                struct sockaddr_ll sa;
 815                char  buf[MAX_ADDR_LEN];
 816        } uaddr;
 817        int uaddr_len = sizeof uaddr, r;
 818        struct socket *sock = sockfd_lookup(fd, &r);
 819
 820        if (!sock)
 821                return ERR_PTR(-ENOTSOCK);
 822
 823        /* Parameter checking */
 824        if (sock->sk->sk_type != SOCK_RAW) {
 825                r = -ESOCKTNOSUPPORT;
 826                goto err;
 827        }
 828
 829        r = sock->ops->getname(sock, (struct sockaddr *)&uaddr.sa,
 830                               &uaddr_len, 0);
 831        if (r)
 832                goto err;
 833
 834        if (uaddr.sa.sll_family != AF_PACKET) {
 835                r = -EPFNOSUPPORT;
 836                goto err;
 837        }
 838        return sock;
 839err:
 840        sockfd_put(sock);
 841        return ERR_PTR(r);
 842}
 843
 844static struct socket *get_tap_socket(int fd)
 845{
 846        struct file *file = fget(fd);
 847        struct socket *sock;
 848
 849        if (!file)
 850                return ERR_PTR(-EBADF);
 851        sock = tun_get_socket(file);
 852        if (!IS_ERR(sock))
 853                return sock;
 854        sock = macvtap_get_socket(file);
 855        if (IS_ERR(sock))
 856                fput(file);
 857        return sock;
 858}
 859
 860static struct socket *get_socket(int fd)
 861{
 862        struct socket *sock;
 863
 864        /* special case to disable backend */
 865        if (fd == -1)
 866                return NULL;
 867        sock = get_raw_socket(fd);
 868        if (!IS_ERR(sock))
 869                return sock;
 870        sock = get_tap_socket(fd);
 871        if (!IS_ERR(sock))
 872                return sock;
 873        return ERR_PTR(-ENOTSOCK);
 874}
 875
 876static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
 877{
 878        struct socket *sock, *oldsock;
 879        struct vhost_virtqueue *vq;
 880        struct vhost_net_virtqueue *nvq;
 881        struct vhost_net_ubuf_ref *ubufs, *oldubufs = NULL;
 882        int r;
 883
 884        mutex_lock(&n->dev.mutex);
 885        r = vhost_dev_check_owner(&n->dev);
 886        if (r)
 887                goto err;
 888
 889        if (index >= VHOST_NET_VQ_MAX) {
 890                r = -ENOBUFS;
 891                goto err;
 892        }
 893        vq = &n->vqs[index].vq;
 894        nvq = &n->vqs[index];
 895        mutex_lock(&vq->mutex);
 896
 897        /* Verify that ring has been setup correctly. */
 898        if (!vhost_vq_access_ok(vq)) {
 899                r = -EFAULT;
 900                goto err_vq;
 901        }
 902        sock = get_socket(fd);
 903        if (IS_ERR(sock)) {
 904                r = PTR_ERR(sock);
 905                goto err_vq;
 906        }
 907
 908        /* start polling new socket */
 909        oldsock = vq->private_data;
 910        if (sock != oldsock) {
 911                ubufs = vhost_net_ubuf_alloc(vq,
 912                                             sock && vhost_sock_zcopy(sock));
 913                if (IS_ERR(ubufs)) {
 914                        r = PTR_ERR(ubufs);
 915                        goto err_ubufs;
 916                }
 917
 918                vhost_net_disable_vq(n, vq);
 919                vq->private_data = sock;
 920                r = vhost_init_used(vq);
 921                if (r)
 922                        goto err_used;
 923                r = vhost_net_enable_vq(n, vq);
 924                if (r)
 925                        goto err_used;
 926
 927                oldubufs = nvq->ubufs;
 928                nvq->ubufs = ubufs;
 929
 930                n->tx_packets = 0;
 931                n->tx_zcopy_err = 0;
 932                n->tx_flush = false;
 933        }
 934
 935        mutex_unlock(&vq->mutex);
 936
 937        if (oldubufs) {
 938                vhost_net_ubuf_put_wait_and_free(oldubufs);
 939                mutex_lock(&vq->mutex);
 940                vhost_zerocopy_signal_used(n, vq);
 941                mutex_unlock(&vq->mutex);
 942        }
 943
 944        if (oldsock) {
 945                vhost_net_flush_vq(n, index);
 946                sockfd_put(oldsock);
 947        }
 948
 949        mutex_unlock(&n->dev.mutex);
 950        return 0;
 951
 952err_used:
 953        vq->private_data = oldsock;
 954        vhost_net_enable_vq(n, vq);
 955        if (ubufs)
 956                vhost_net_ubuf_put_wait_and_free(ubufs);
 957err_ubufs:
 958        sockfd_put(sock);
 959err_vq:
 960        mutex_unlock(&vq->mutex);
 961err:
 962        mutex_unlock(&n->dev.mutex);
 963        return r;
 964}
 965
 966static long vhost_net_reset_owner(struct vhost_net *n)
 967{
 968        struct socket *tx_sock = NULL;
 969        struct socket *rx_sock = NULL;
 970        long err;
 971        struct vhost_memory *memory;
 972
 973        mutex_lock(&n->dev.mutex);
 974        err = vhost_dev_check_owner(&n->dev);
 975        if (err)
 976                goto done;
 977        memory = vhost_dev_reset_owner_prepare();
 978        if (!memory) {
 979                err = -ENOMEM;
 980                goto done;
 981        }
 982        vhost_net_stop(n, &tx_sock, &rx_sock);
 983        vhost_net_flush(n);
 984        vhost_dev_reset_owner(&n->dev, memory);
 985        vhost_net_vq_reset(n);
 986done:
 987        mutex_unlock(&n->dev.mutex);
 988        if (tx_sock)
 989                sockfd_put(tx_sock);
 990        if (rx_sock)
 991                sockfd_put(rx_sock);
 992        return err;
 993}
 994
 995static int vhost_net_set_features(struct vhost_net *n, u64 features)
 996{
 997        size_t vhost_hlen, sock_hlen, hdr_len;
 998        int i;
 999
1000        hdr_len = (features & ((1ULL << VIRTIO_NET_F_MRG_RXBUF) |
1001                               (1ULL << VIRTIO_F_VERSION_1))) ?
1002                        sizeof(struct virtio_net_hdr_mrg_rxbuf) :
1003                        sizeof(struct virtio_net_hdr);
1004        if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
1005                /* vhost provides vnet_hdr */
1006                vhost_hlen = hdr_len;
1007                sock_hlen = 0;
1008        } else {
1009                /* socket provides vnet_hdr */
1010                vhost_hlen = 0;
1011                sock_hlen = hdr_len;
1012        }
1013        mutex_lock(&n->dev.mutex);
1014        if ((features & (1 << VHOST_F_LOG_ALL)) &&
1015            !vhost_log_access_ok(&n->dev)) {
1016                mutex_unlock(&n->dev.mutex);
1017                return -EFAULT;
1018        }
1019        for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
1020                mutex_lock(&n->vqs[i].vq.mutex);
1021                n->vqs[i].vq.acked_features = features;
1022                n->vqs[i].vhost_hlen = vhost_hlen;
1023                n->vqs[i].sock_hlen = sock_hlen;
1024                mutex_unlock(&n->vqs[i].vq.mutex);
1025        }
1026        mutex_unlock(&n->dev.mutex);
1027        return 0;
1028}
1029
1030static long vhost_net_set_owner(struct vhost_net *n)
1031{
1032        int r;
1033
1034        mutex_lock(&n->dev.mutex);
1035        if (vhost_dev_has_owner(&n->dev)) {
1036                r = -EBUSY;
1037                goto out;
1038        }
1039        r = vhost_net_set_ubuf_info(n);
1040        if (r)
1041                goto out;
1042        r = vhost_dev_set_owner(&n->dev);
1043        if (r)
1044                vhost_net_clear_ubuf_info(n);
1045        vhost_net_flush(n);
1046out:
1047        mutex_unlock(&n->dev.mutex);
1048        return r;
1049}
1050
1051static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
1052                            unsigned long arg)
1053{
1054        struct vhost_net *n = f->private_data;
1055        void __user *argp = (void __user *)arg;
1056        u64 __user *featurep = argp;
1057        struct vhost_vring_file backend;
1058        u64 features;
1059        int r;
1060
1061        switch (ioctl) {
1062        case VHOST_NET_SET_BACKEND:
1063                if (copy_from_user(&backend, argp, sizeof backend))
1064                        return -EFAULT;
1065                return vhost_net_set_backend(n, backend.index, backend.fd);
1066        case VHOST_GET_FEATURES:
1067                features = VHOST_NET_FEATURES;
1068                if (copy_to_user(featurep, &features, sizeof features))
1069                        return -EFAULT;
1070                return 0;
1071        case VHOST_SET_FEATURES:
1072                if (copy_from_user(&features, featurep, sizeof features))
1073                        return -EFAULT;
1074                if (features & ~VHOST_NET_FEATURES)
1075                        return -EOPNOTSUPP;
1076                return vhost_net_set_features(n, features);
1077        case VHOST_RESET_OWNER:
1078                return vhost_net_reset_owner(n);
1079        case VHOST_SET_OWNER:
1080                return vhost_net_set_owner(n);
1081        default:
1082                mutex_lock(&n->dev.mutex);
1083                r = vhost_dev_ioctl(&n->dev, ioctl, argp);
1084                if (r == -ENOIOCTLCMD)
1085                        r = vhost_vring_ioctl(&n->dev, ioctl, argp);
1086                else
1087                        vhost_net_flush(n);
1088                mutex_unlock(&n->dev.mutex);
1089                return r;
1090        }
1091}
1092
1093#ifdef CONFIG_COMPAT
1094static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
1095                                   unsigned long arg)
1096{
1097        return vhost_net_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
1098}
1099#endif
1100
1101static const struct file_operations vhost_net_fops = {
1102        .owner          = THIS_MODULE,
1103        .release        = vhost_net_release,
1104        .unlocked_ioctl = vhost_net_ioctl,
1105#ifdef CONFIG_COMPAT
1106        .compat_ioctl   = vhost_net_compat_ioctl,
1107#endif
1108        .open           = vhost_net_open,
1109        .llseek         = noop_llseek,
1110};
1111
1112static struct miscdevice vhost_net_misc = {
1113        .minor = VHOST_NET_MINOR,
1114        .name = "vhost-net",
1115        .fops = &vhost_net_fops,
1116};
1117
1118static int vhost_net_init(void)
1119{
1120        if (experimental_zcopytx)
1121                vhost_net_enable_zcopy(VHOST_NET_VQ_TX);
1122        return misc_register(&vhost_net_misc);
1123}
1124module_init(vhost_net_init);
1125
1126static void vhost_net_exit(void)
1127{
1128        misc_deregister(&vhost_net_misc);
1129}
1130module_exit(vhost_net_exit);
1131
1132MODULE_VERSION("0.0.1");
1133MODULE_LICENSE("GPL v2");
1134MODULE_AUTHOR("Michael S. Tsirkin");
1135MODULE_DESCRIPTION("Host kernel accelerator for virtio net");
1136MODULE_ALIAS_MISCDEV(VHOST_NET_MINOR);
1137MODULE_ALIAS("devname:vhost-net");
1138