linux/drivers/vhost/net.c
<<
>>
Prefs
   1/* Copyright (C) 2009 Red Hat, Inc.
   2 * Author: Michael S. Tsirkin <mst@redhat.com>
   3 *
   4 * This work is licensed under the terms of the GNU GPL, version 2.
   5 *
   6 * virtio-net server in host kernel.
   7 */
   8
   9#include <linux/compat.h>
  10#include <linux/eventfd.h>
  11#include <linux/vhost.h>
  12#include <linux/virtio_net.h>
  13#include <linux/miscdevice.h>
  14#include <linux/module.h>
  15#include <linux/mutex.h>
  16#include <linux/workqueue.h>
  17#include <linux/rcupdate.h>
  18#include <linux/file.h>
  19#include <linux/slab.h>
  20
  21#include <linux/net.h>
  22#include <linux/if_packet.h>
  23#include <linux/if_arp.h>
  24#include <linux/if_tun.h>
  25#include <linux/if_macvlan.h>
  26
  27#include <net/sock.h>
  28
  29#include "vhost.h"
  30
  31/* Max number of bytes transferred before requeueing the job.
  32 * Using this limit prevents one virtqueue from starving others. */
  33#define VHOST_NET_WEIGHT 0x80000
  34
  35enum {
  36        VHOST_NET_VQ_RX = 0,
  37        VHOST_NET_VQ_TX = 1,
  38        VHOST_NET_VQ_MAX = 2,
  39};
  40
  41enum vhost_net_poll_state {
  42        VHOST_NET_POLL_DISABLED = 0,
  43        VHOST_NET_POLL_STARTED = 1,
  44        VHOST_NET_POLL_STOPPED = 2,
  45};
  46
  47struct vhost_net {
  48        struct vhost_dev dev;
  49        struct vhost_virtqueue vqs[VHOST_NET_VQ_MAX];
  50        struct vhost_poll poll[VHOST_NET_VQ_MAX];
  51        /* Tells us whether we are polling a socket for TX.
  52         * We only do this when socket buffer fills up.
  53         * Protected by tx vq lock. */
  54        enum vhost_net_poll_state tx_poll_state;
  55};
  56
  57/* Pop first len bytes from iovec. Return number of segments used. */
  58static int move_iovec_hdr(struct iovec *from, struct iovec *to,
  59                          size_t len, int iov_count)
  60{
  61        int seg = 0;
  62        size_t size;
  63
  64        while (len && seg < iov_count) {
  65                size = min(from->iov_len, len);
  66                to->iov_base = from->iov_base;
  67                to->iov_len = size;
  68                from->iov_len -= size;
  69                from->iov_base += size;
  70                len -= size;
  71                ++from;
  72                ++to;
  73                ++seg;
  74        }
  75        return seg;
  76}
  77/* Copy iovec entries for len bytes from iovec. */
  78static void copy_iovec_hdr(const struct iovec *from, struct iovec *to,
  79                           size_t len, int iovcount)
  80{
  81        int seg = 0;
  82        size_t size;
  83
  84        while (len && seg < iovcount) {
  85                size = min(from->iov_len, len);
  86                to->iov_base = from->iov_base;
  87                to->iov_len = size;
  88                len -= size;
  89                ++from;
  90                ++to;
  91                ++seg;
  92        }
  93}
  94
  95/* Caller must have TX VQ lock */
  96static void tx_poll_stop(struct vhost_net *net)
  97{
  98        if (likely(net->tx_poll_state != VHOST_NET_POLL_STARTED))
  99                return;
 100        vhost_poll_stop(net->poll + VHOST_NET_VQ_TX);
 101        net->tx_poll_state = VHOST_NET_POLL_STOPPED;
 102}
 103
 104/* Caller must have TX VQ lock */
 105static void tx_poll_start(struct vhost_net *net, struct socket *sock)
 106{
 107        if (unlikely(net->tx_poll_state != VHOST_NET_POLL_STOPPED))
 108                return;
 109        vhost_poll_start(net->poll + VHOST_NET_VQ_TX, sock->file);
 110        net->tx_poll_state = VHOST_NET_POLL_STARTED;
 111}
 112
 113/* Expects to be always run from workqueue - which acts as
 114 * read-size critical section for our kind of RCU. */
 115static void handle_tx(struct vhost_net *net)
 116{
 117        struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX];
 118        unsigned out, in, s;
 119        int head;
 120        struct msghdr msg = {
 121                .msg_name = NULL,
 122                .msg_namelen = 0,
 123                .msg_control = NULL,
 124                .msg_controllen = 0,
 125                .msg_iov = vq->iov,
 126                .msg_flags = MSG_DONTWAIT,
 127        };
 128        size_t len, total_len = 0;
 129        int err, wmem;
 130        size_t hdr_size;
 131        struct socket *sock;
 132
 133        /* TODO: check that we are running from vhost_worker? */
 134        sock = rcu_dereference_check(vq->private_data, 1);
 135        if (!sock)
 136                return;
 137
 138        wmem = atomic_read(&sock->sk->sk_wmem_alloc);
 139        if (wmem >= sock->sk->sk_sndbuf) {
 140                mutex_lock(&vq->mutex);
 141                tx_poll_start(net, sock);
 142                mutex_unlock(&vq->mutex);
 143                return;
 144        }
 145
 146        mutex_lock(&vq->mutex);
 147        vhost_disable_notify(vq);
 148
 149        if (wmem < sock->sk->sk_sndbuf / 2)
 150                tx_poll_stop(net);
 151        hdr_size = vq->vhost_hlen;
 152
 153        for (;;) {
 154                head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
 155                                         ARRAY_SIZE(vq->iov),
 156                                         &out, &in,
 157                                         NULL, NULL);
 158                /* On error, stop handling until the next kick. */
 159                if (unlikely(head < 0))
 160                        break;
 161                /* Nothing new?  Wait for eventfd to tell us they refilled. */
 162                if (head == vq->num) {
 163                        wmem = atomic_read(&sock->sk->sk_wmem_alloc);
 164                        if (wmem >= sock->sk->sk_sndbuf * 3 / 4) {
 165                                tx_poll_start(net, sock);
 166                                set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
 167                                break;
 168                        }
 169                        if (unlikely(vhost_enable_notify(vq))) {
 170                                vhost_disable_notify(vq);
 171                                continue;
 172                        }
 173                        break;
 174                }
 175                if (in) {
 176                        vq_err(vq, "Unexpected descriptor format for TX: "
 177                               "out %d, int %d\n", out, in);
 178                        break;
 179                }
 180                /* Skip header. TODO: support TSO. */
 181                s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, out);
 182                msg.msg_iovlen = out;
 183                len = iov_length(vq->iov, out);
 184                /* Sanity check */
 185                if (!len) {
 186                        vq_err(vq, "Unexpected header len for TX: "
 187                               "%zd expected %zd\n",
 188                               iov_length(vq->hdr, s), hdr_size);
 189                        break;
 190                }
 191                /* TODO: Check specific error and bomb out unless ENOBUFS? */
 192                err = sock->ops->sendmsg(NULL, sock, &msg, len);
 193                if (unlikely(err < 0)) {
 194                        vhost_discard_vq_desc(vq, 1);
 195                        tx_poll_start(net, sock);
 196                        break;
 197                }
 198                if (err != len)
 199                        pr_debug("Truncated TX packet: "
 200                                 " len %d != %zd\n", err, len);
 201                vhost_add_used_and_signal(&net->dev, vq, head, 0);
 202                total_len += len;
 203                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
 204                        vhost_poll_queue(&vq->poll);
 205                        break;
 206                }
 207        }
 208
 209        mutex_unlock(&vq->mutex);
 210}
 211
 212static int peek_head_len(struct sock *sk)
 213{
 214        struct sk_buff *head;
 215        int len = 0;
 216        unsigned long flags;
 217
 218        spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
 219        head = skb_peek(&sk->sk_receive_queue);
 220        if (likely(head))
 221                len = head->len;
 222        spin_unlock_irqrestore(&sk->sk_receive_queue.lock, flags);
 223        return len;
 224}
 225
 226/* This is a multi-buffer version of vhost_get_desc, that works if
 227 *      vq has read descriptors only.
 228 * @vq          - the relevant virtqueue
 229 * @datalen     - data length we'll be reading
 230 * @iovcount    - returned count of io vectors we fill
 231 * @log         - vhost log
 232 * @log_num     - log offset
 233 * @quota       - headcount quota, 1 for big buffer
 234 *      returns number of buffer heads allocated, negative on error
 235 */
 236static int get_rx_bufs(struct vhost_virtqueue *vq,
 237                       struct vring_used_elem *heads,
 238                       int datalen,
 239                       unsigned *iovcount,
 240                       struct vhost_log *log,
 241                       unsigned *log_num,
 242                       unsigned int quota)
 243{
 244        unsigned int out, in;
 245        int seg = 0;
 246        int headcount = 0;
 247        unsigned d;
 248        int r, nlogs = 0;
 249
 250        while (datalen > 0 && headcount < quota) {
 251                if (unlikely(seg >= UIO_MAXIOV)) {
 252                        r = -ENOBUFS;
 253                        goto err;
 254                }
 255                d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
 256                                      ARRAY_SIZE(vq->iov) - seg, &out,
 257                                      &in, log, log_num);
 258                if (d == vq->num) {
 259                        r = 0;
 260                        goto err;
 261                }
 262                if (unlikely(out || in <= 0)) {
 263                        vq_err(vq, "unexpected descriptor format for RX: "
 264                                "out %d, in %d\n", out, in);
 265                        r = -EINVAL;
 266                        goto err;
 267                }
 268                if (unlikely(log)) {
 269                        nlogs += *log_num;
 270                        log += *log_num;
 271                }
 272                heads[headcount].id = d;
 273                heads[headcount].len = iov_length(vq->iov + seg, in);
 274                datalen -= heads[headcount].len;
 275                ++headcount;
 276                seg += in;
 277        }
 278        heads[headcount - 1].len += datalen;
 279        *iovcount = seg;
 280        if (unlikely(log))
 281                *log_num = nlogs;
 282        return headcount;
 283err:
 284        vhost_discard_vq_desc(vq, headcount);
 285        return r;
 286}
 287
 288/* Expects to be always run from workqueue - which acts as
 289 * read-size critical section for our kind of RCU. */
 290static void handle_rx(struct vhost_net *net)
 291{
 292        struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
 293        unsigned uninitialized_var(in), log;
 294        struct vhost_log *vq_log;
 295        struct msghdr msg = {
 296                .msg_name = NULL,
 297                .msg_namelen = 0,
 298                .msg_control = NULL, /* FIXME: get and handle RX aux data. */
 299                .msg_controllen = 0,
 300                .msg_iov = vq->iov,
 301                .msg_flags = MSG_DONTWAIT,
 302        };
 303        struct virtio_net_hdr_mrg_rxbuf hdr = {
 304                .hdr.flags = 0,
 305                .hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE
 306        };
 307        size_t total_len = 0;
 308        int err, headcount, mergeable;
 309        size_t vhost_hlen, sock_hlen;
 310        size_t vhost_len, sock_len;
 311        /* TODO: check that we are running from vhost_worker? */
 312        struct socket *sock = rcu_dereference_check(vq->private_data, 1);
 313
 314        if (!sock)
 315                return;
 316
 317        mutex_lock(&vq->mutex);
 318        vhost_disable_notify(vq);
 319        vhost_hlen = vq->vhost_hlen;
 320        sock_hlen = vq->sock_hlen;
 321
 322        vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
 323                vq->log : NULL;
 324        mergeable = vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF);
 325
 326        while ((sock_len = peek_head_len(sock->sk))) {
 327                sock_len += sock_hlen;
 328                vhost_len = sock_len + vhost_hlen;
 329                headcount = get_rx_bufs(vq, vq->heads, vhost_len,
 330                                        &in, vq_log, &log,
 331                                        likely(mergeable) ? UIO_MAXIOV : 1);
 332                /* On error, stop handling until the next kick. */
 333                if (unlikely(headcount < 0))
 334                        break;
 335                /* OK, now we need to know about added descriptors. */
 336                if (!headcount) {
 337                        if (unlikely(vhost_enable_notify(vq))) {
 338                                /* They have slipped one in as we were
 339                                 * doing that: check again. */
 340                                vhost_disable_notify(vq);
 341                                continue;
 342                        }
 343                        /* Nothing new?  Wait for eventfd to tell us
 344                         * they refilled. */
 345                        break;
 346                }
 347                /* We don't need to be notified again. */
 348                if (unlikely((vhost_hlen)))
 349                        /* Skip header. TODO: support TSO. */
 350                        move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in);
 351                else
 352                        /* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF:
 353                         * needed because recvmsg can modify msg_iov. */
 354                        copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in);
 355                msg.msg_iovlen = in;
 356                err = sock->ops->recvmsg(NULL, sock, &msg,
 357                                         sock_len, MSG_DONTWAIT | MSG_TRUNC);
 358                /* Userspace might have consumed the packet meanwhile:
 359                 * it's not supposed to do this usually, but might be hard
 360                 * to prevent. Discard data we got (if any) and keep going. */
 361                if (unlikely(err != sock_len)) {
 362                        pr_debug("Discarded rx packet: "
 363                                 " len %d, expected %zd\n", err, sock_len);
 364                        vhost_discard_vq_desc(vq, headcount);
 365                        continue;
 366                }
 367                if (unlikely(vhost_hlen) &&
 368                    memcpy_toiovecend(vq->hdr, (unsigned char *)&hdr, 0,
 369                                      vhost_hlen)) {
 370                        vq_err(vq, "Unable to write vnet_hdr at addr %p\n",
 371                               vq->iov->iov_base);
 372                        break;
 373                }
 374                /* TODO: Should check and handle checksum. */
 375                if (likely(mergeable) &&
 376                    memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount,
 377                                      offsetof(typeof(hdr), num_buffers),
 378                                      sizeof hdr.num_buffers)) {
 379                        vq_err(vq, "Failed num_buffers write");
 380                        vhost_discard_vq_desc(vq, headcount);
 381                        break;
 382                }
 383                vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
 384                                            headcount);
 385                if (unlikely(vq_log))
 386                        vhost_log_write(vq, vq_log, log, vhost_len);
 387                total_len += vhost_len;
 388                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
 389                        vhost_poll_queue(&vq->poll);
 390                        break;
 391                }
 392        }
 393
 394        mutex_unlock(&vq->mutex);
 395}
 396
 397static void handle_tx_kick(struct vhost_work *work)
 398{
 399        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 400                                                  poll.work);
 401        struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
 402
 403        handle_tx(net);
 404}
 405
 406static void handle_rx_kick(struct vhost_work *work)
 407{
 408        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 409                                                  poll.work);
 410        struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
 411
 412        handle_rx(net);
 413}
 414
 415static void handle_tx_net(struct vhost_work *work)
 416{
 417        struct vhost_net *net = container_of(work, struct vhost_net,
 418                                             poll[VHOST_NET_VQ_TX].work);
 419        handle_tx(net);
 420}
 421
 422static void handle_rx_net(struct vhost_work *work)
 423{
 424        struct vhost_net *net = container_of(work, struct vhost_net,
 425                                             poll[VHOST_NET_VQ_RX].work);
 426        handle_rx(net);
 427}
 428
 429static int vhost_net_open(struct inode *inode, struct file *f)
 430{
 431        struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
 432        struct vhost_dev *dev;
 433        int r;
 434
 435        if (!n)
 436                return -ENOMEM;
 437
 438        dev = &n->dev;
 439        n->vqs[VHOST_NET_VQ_TX].handle_kick = handle_tx_kick;
 440        n->vqs[VHOST_NET_VQ_RX].handle_kick = handle_rx_kick;
 441        r = vhost_dev_init(dev, n->vqs, VHOST_NET_VQ_MAX);
 442        if (r < 0) {
 443                kfree(n);
 444                return r;
 445        }
 446
 447        vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
 448        vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
 449        n->tx_poll_state = VHOST_NET_POLL_DISABLED;
 450
 451        f->private_data = n;
 452
 453        return 0;
 454}
 455
 456static void vhost_net_disable_vq(struct vhost_net *n,
 457                                 struct vhost_virtqueue *vq)
 458{
 459        if (!vq->private_data)
 460                return;
 461        if (vq == n->vqs + VHOST_NET_VQ_TX) {
 462                tx_poll_stop(n);
 463                n->tx_poll_state = VHOST_NET_POLL_DISABLED;
 464        } else
 465                vhost_poll_stop(n->poll + VHOST_NET_VQ_RX);
 466}
 467
 468static void vhost_net_enable_vq(struct vhost_net *n,
 469                                struct vhost_virtqueue *vq)
 470{
 471        struct socket *sock;
 472
 473        sock = rcu_dereference_protected(vq->private_data,
 474                                         lockdep_is_held(&vq->mutex));
 475        if (!sock)
 476                return;
 477        if (vq == n->vqs + VHOST_NET_VQ_TX) {
 478                n->tx_poll_state = VHOST_NET_POLL_STOPPED;
 479                tx_poll_start(n, sock);
 480        } else
 481                vhost_poll_start(n->poll + VHOST_NET_VQ_RX, sock->file);
 482}
 483
 484static struct socket *vhost_net_stop_vq(struct vhost_net *n,
 485                                        struct vhost_virtqueue *vq)
 486{
 487        struct socket *sock;
 488
 489        mutex_lock(&vq->mutex);
 490        sock = rcu_dereference_protected(vq->private_data,
 491                                         lockdep_is_held(&vq->mutex));
 492        vhost_net_disable_vq(n, vq);
 493        rcu_assign_pointer(vq->private_data, NULL);
 494        mutex_unlock(&vq->mutex);
 495        return sock;
 496}
 497
 498static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
 499                           struct socket **rx_sock)
 500{
 501        *tx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_TX);
 502        *rx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_RX);
 503}
 504
 505static void vhost_net_flush_vq(struct vhost_net *n, int index)
 506{
 507        vhost_poll_flush(n->poll + index);
 508        vhost_poll_flush(&n->dev.vqs[index].poll);
 509}
 510
 511static void vhost_net_flush(struct vhost_net *n)
 512{
 513        vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
 514        vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
 515}
 516
 517static int vhost_net_release(struct inode *inode, struct file *f)
 518{
 519        struct vhost_net *n = f->private_data;
 520        struct socket *tx_sock;
 521        struct socket *rx_sock;
 522
 523        vhost_net_stop(n, &tx_sock, &rx_sock);
 524        vhost_net_flush(n);
 525        vhost_dev_cleanup(&n->dev);
 526        if (tx_sock)
 527                fput(tx_sock->file);
 528        if (rx_sock)
 529                fput(rx_sock->file);
 530        /* We do an extra flush before freeing memory,
 531         * since jobs can re-queue themselves. */
 532        vhost_net_flush(n);
 533        kfree(n);
 534        return 0;
 535}
 536
 537static struct socket *get_raw_socket(int fd)
 538{
 539        struct {
 540                struct sockaddr_ll sa;
 541                char  buf[MAX_ADDR_LEN];
 542        } uaddr;
 543        int uaddr_len = sizeof uaddr, r;
 544        struct socket *sock = sockfd_lookup(fd, &r);
 545
 546        if (!sock)
 547                return ERR_PTR(-ENOTSOCK);
 548
 549        /* Parameter checking */
 550        if (sock->sk->sk_type != SOCK_RAW) {
 551                r = -ESOCKTNOSUPPORT;
 552                goto err;
 553        }
 554
 555        r = sock->ops->getname(sock, (struct sockaddr *)&uaddr.sa,
 556                               &uaddr_len, 0);
 557        if (r)
 558                goto err;
 559
 560        if (uaddr.sa.sll_family != AF_PACKET) {
 561                r = -EPFNOSUPPORT;
 562                goto err;
 563        }
 564        return sock;
 565err:
 566        fput(sock->file);
 567        return ERR_PTR(r);
 568}
 569
 570static struct socket *get_tap_socket(int fd)
 571{
 572        struct file *file = fget(fd);
 573        struct socket *sock;
 574
 575        if (!file)
 576                return ERR_PTR(-EBADF);
 577        sock = tun_get_socket(file);
 578        if (!IS_ERR(sock))
 579                return sock;
 580        sock = macvtap_get_socket(file);
 581        if (IS_ERR(sock))
 582                fput(file);
 583        return sock;
 584}
 585
 586static struct socket *get_socket(int fd)
 587{
 588        struct socket *sock;
 589
 590        /* special case to disable backend */
 591        if (fd == -1)
 592                return NULL;
 593        sock = get_raw_socket(fd);
 594        if (!IS_ERR(sock))
 595                return sock;
 596        sock = get_tap_socket(fd);
 597        if (!IS_ERR(sock))
 598                return sock;
 599        return ERR_PTR(-ENOTSOCK);
 600}
 601
 602static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
 603{
 604        struct socket *sock, *oldsock;
 605        struct vhost_virtqueue *vq;
 606        int r;
 607
 608        mutex_lock(&n->dev.mutex);
 609        r = vhost_dev_check_owner(&n->dev);
 610        if (r)
 611                goto err;
 612
 613        if (index >= VHOST_NET_VQ_MAX) {
 614                r = -ENOBUFS;
 615                goto err;
 616        }
 617        vq = n->vqs + index;
 618        mutex_lock(&vq->mutex);
 619
 620        /* Verify that ring has been setup correctly. */
 621        if (!vhost_vq_access_ok(vq)) {
 622                r = -EFAULT;
 623                goto err_vq;
 624        }
 625        sock = get_socket(fd);
 626        if (IS_ERR(sock)) {
 627                r = PTR_ERR(sock);
 628                goto err_vq;
 629        }
 630
 631        /* start polling new socket */
 632        oldsock = rcu_dereference_protected(vq->private_data,
 633                                            lockdep_is_held(&vq->mutex));
 634        if (sock != oldsock) {
 635                vhost_net_disable_vq(n, vq);
 636                rcu_assign_pointer(vq->private_data, sock);
 637                vhost_net_enable_vq(n, vq);
 638        }
 639
 640        mutex_unlock(&vq->mutex);
 641
 642        if (oldsock) {
 643                vhost_net_flush_vq(n, index);
 644                fput(oldsock->file);
 645        }
 646
 647        mutex_unlock(&n->dev.mutex);
 648        return 0;
 649
 650err_vq:
 651        mutex_unlock(&vq->mutex);
 652err:
 653        mutex_unlock(&n->dev.mutex);
 654        return r;
 655}
 656
 657static long vhost_net_reset_owner(struct vhost_net *n)
 658{
 659        struct socket *tx_sock = NULL;
 660        struct socket *rx_sock = NULL;
 661        long err;
 662
 663        mutex_lock(&n->dev.mutex);
 664        err = vhost_dev_check_owner(&n->dev);
 665        if (err)
 666                goto done;
 667        vhost_net_stop(n, &tx_sock, &rx_sock);
 668        vhost_net_flush(n);
 669        err = vhost_dev_reset_owner(&n->dev);
 670done:
 671        mutex_unlock(&n->dev.mutex);
 672        if (tx_sock)
 673                fput(tx_sock->file);
 674        if (rx_sock)
 675                fput(rx_sock->file);
 676        return err;
 677}
 678
 679static int vhost_net_set_features(struct vhost_net *n, u64 features)
 680{
 681        size_t vhost_hlen, sock_hlen, hdr_len;
 682        int i;
 683
 684        hdr_len = (features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ?
 685                        sizeof(struct virtio_net_hdr_mrg_rxbuf) :
 686                        sizeof(struct virtio_net_hdr);
 687        if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
 688                /* vhost provides vnet_hdr */
 689                vhost_hlen = hdr_len;
 690                sock_hlen = 0;
 691        } else {
 692                /* socket provides vnet_hdr */
 693                vhost_hlen = 0;
 694                sock_hlen = hdr_len;
 695        }
 696        mutex_lock(&n->dev.mutex);
 697        if ((features & (1 << VHOST_F_LOG_ALL)) &&
 698            !vhost_log_access_ok(&n->dev)) {
 699                mutex_unlock(&n->dev.mutex);
 700                return -EFAULT;
 701        }
 702        n->dev.acked_features = features;
 703        smp_wmb();
 704        for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
 705                mutex_lock(&n->vqs[i].mutex);
 706                n->vqs[i].vhost_hlen = vhost_hlen;
 707                n->vqs[i].sock_hlen = sock_hlen;
 708                mutex_unlock(&n->vqs[i].mutex);
 709        }
 710        vhost_net_flush(n);
 711        mutex_unlock(&n->dev.mutex);
 712        return 0;
 713}
 714
 715static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
 716                            unsigned long arg)
 717{
 718        struct vhost_net *n = f->private_data;
 719        void __user *argp = (void __user *)arg;
 720        u64 __user *featurep = argp;
 721        struct vhost_vring_file backend;
 722        u64 features;
 723        int r;
 724
 725        switch (ioctl) {
 726        case VHOST_NET_SET_BACKEND:
 727                if (copy_from_user(&backend, argp, sizeof backend))
 728                        return -EFAULT;
 729                return vhost_net_set_backend(n, backend.index, backend.fd);
 730        case VHOST_GET_FEATURES:
 731                features = VHOST_FEATURES;
 732                if (copy_to_user(featurep, &features, sizeof features))
 733                        return -EFAULT;
 734                return 0;
 735        case VHOST_SET_FEATURES:
 736                if (copy_from_user(&features, featurep, sizeof features))
 737                        return -EFAULT;
 738                if (features & ~VHOST_FEATURES)
 739                        return -EOPNOTSUPP;
 740                return vhost_net_set_features(n, features);
 741        case VHOST_RESET_OWNER:
 742                return vhost_net_reset_owner(n);
 743        default:
 744                mutex_lock(&n->dev.mutex);
 745                r = vhost_dev_ioctl(&n->dev, ioctl, arg);
 746                vhost_net_flush(n);
 747                mutex_unlock(&n->dev.mutex);
 748                return r;
 749        }
 750}
 751
 752#ifdef CONFIG_COMPAT
 753static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
 754                                   unsigned long arg)
 755{
 756        return vhost_net_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
 757}
 758#endif
 759
 760static const struct file_operations vhost_net_fops = {
 761        .owner          = THIS_MODULE,
 762        .release        = vhost_net_release,
 763        .unlocked_ioctl = vhost_net_ioctl,
 764#ifdef CONFIG_COMPAT
 765        .compat_ioctl   = vhost_net_compat_ioctl,
 766#endif
 767        .open           = vhost_net_open,
 768        .llseek         = noop_llseek,
 769};
 770
 771static struct miscdevice vhost_net_misc = {
 772        MISC_DYNAMIC_MINOR,
 773        "vhost-net",
 774        &vhost_net_fops,
 775};
 776
 777static int vhost_net_init(void)
 778{
 779        return misc_register(&vhost_net_misc);
 780}
 781module_init(vhost_net_init);
 782
 783static void vhost_net_exit(void)
 784{
 785        misc_deregister(&vhost_net_misc);
 786}
 787module_exit(vhost_net_exit);
 788
 789MODULE_VERSION("0.0.1");
 790MODULE_LICENSE("GPL v2");
 791MODULE_AUTHOR("Michael S. Tsirkin");
 792MODULE_DESCRIPTION("Host kernel accelerator for virtio net");
 793