linux/drivers/vhost/net.c
<<
>>
Prefs
   1/* Copyright (C) 2009 Red Hat, Inc.
   2 * Author: Michael S. Tsirkin <mst@redhat.com>
   3 *
   4 * This work is licensed under the terms of the GNU GPL, version 2.
   5 *
   6 * virtio-net server in host kernel.
   7 */
   8
   9#include <linux/compat.h>
  10#include <linux/eventfd.h>
  11#include <linux/vhost.h>
  12#include <linux/virtio_net.h>
  13#include <linux/miscdevice.h>
  14#include <linux/module.h>
  15#include <linux/mutex.h>
  16#include <linux/workqueue.h>
  17#include <linux/rcupdate.h>
  18#include <linux/file.h>
  19#include <linux/slab.h>
  20
  21#include <linux/net.h>
  22#include <linux/if_packet.h>
  23#include <linux/if_arp.h>
  24#include <linux/if_tun.h>
  25#include <linux/if_macvlan.h>
  26
  27#include <net/sock.h>
  28
  29#include "vhost.h"
  30
  31/* Max number of bytes transferred before requeueing the job.
  32 * Using this limit prevents one virtqueue from starving others. */
  33#define VHOST_NET_WEIGHT 0x80000
  34
  35enum {
  36        VHOST_NET_VQ_RX = 0,
  37        VHOST_NET_VQ_TX = 1,
  38        VHOST_NET_VQ_MAX = 2,
  39};
  40
  41enum vhost_net_poll_state {
  42        VHOST_NET_POLL_DISABLED = 0,
  43        VHOST_NET_POLL_STARTED = 1,
  44        VHOST_NET_POLL_STOPPED = 2,
  45};
  46
  47struct vhost_net {
  48        struct vhost_dev dev;
  49        struct vhost_virtqueue vqs[VHOST_NET_VQ_MAX];
  50        struct vhost_poll poll[VHOST_NET_VQ_MAX];
  51        /* Tells us whether we are polling a socket for TX.
  52         * We only do this when socket buffer fills up.
  53         * Protected by tx vq lock. */
  54        enum vhost_net_poll_state tx_poll_state;
  55};
  56
  57/* Pop first len bytes from iovec. Return number of segments used. */
  58static int move_iovec_hdr(struct iovec *from, struct iovec *to,
  59                          size_t len, int iov_count)
  60{
  61        int seg = 0;
  62        size_t size;
  63        while (len && seg < iov_count) {
  64                size = min(from->iov_len, len);
  65                to->iov_base = from->iov_base;
  66                to->iov_len = size;
  67                from->iov_len -= size;
  68                from->iov_base += size;
  69                len -= size;
  70                ++from;
  71                ++to;
  72                ++seg;
  73        }
  74        return seg;
  75}
  76/* Copy iovec entries for len bytes from iovec. */
  77static void copy_iovec_hdr(const struct iovec *from, struct iovec *to,
  78                           size_t len, int iovcount)
  79{
  80        int seg = 0;
  81        size_t size;
  82        while (len && seg < iovcount) {
  83                size = min(from->iov_len, len);
  84                to->iov_base = from->iov_base;
  85                to->iov_len = size;
  86                len -= size;
  87                ++from;
  88                ++to;
  89                ++seg;
  90        }
  91}
  92
  93/* Caller must have TX VQ lock */
  94static void tx_poll_stop(struct vhost_net *net)
  95{
  96        if (likely(net->tx_poll_state != VHOST_NET_POLL_STARTED))
  97                return;
  98        vhost_poll_stop(net->poll + VHOST_NET_VQ_TX);
  99        net->tx_poll_state = VHOST_NET_POLL_STOPPED;
 100}
 101
 102/* Caller must have TX VQ lock */
 103static void tx_poll_start(struct vhost_net *net, struct socket *sock)
 104{
 105        if (unlikely(net->tx_poll_state != VHOST_NET_POLL_STOPPED))
 106                return;
 107        vhost_poll_start(net->poll + VHOST_NET_VQ_TX, sock->file);
 108        net->tx_poll_state = VHOST_NET_POLL_STARTED;
 109}
 110
 111/* Expects to be always run from workqueue - which acts as
 112 * read-size critical section for our kind of RCU. */
 113static void handle_tx(struct vhost_net *net)
 114{
 115        struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX];
 116        unsigned out, in, s;
 117        int head;
 118        struct msghdr msg = {
 119                .msg_name = NULL,
 120                .msg_namelen = 0,
 121                .msg_control = NULL,
 122                .msg_controllen = 0,
 123                .msg_iov = vq->iov,
 124                .msg_flags = MSG_DONTWAIT,
 125        };
 126        size_t len, total_len = 0;
 127        int err, wmem;
 128        size_t hdr_size;
 129        struct socket *sock;
 130
 131        /* TODO: check that we are running from vhost_worker? */
 132        sock = rcu_dereference_check(vq->private_data, 1);
 133        if (!sock)
 134                return;
 135
 136        wmem = atomic_read(&sock->sk->sk_wmem_alloc);
 137        if (wmem >= sock->sk->sk_sndbuf) {
 138                mutex_lock(&vq->mutex);
 139                tx_poll_start(net, sock);
 140                mutex_unlock(&vq->mutex);
 141                return;
 142        }
 143
 144        mutex_lock(&vq->mutex);
 145        vhost_disable_notify(vq);
 146
 147        if (wmem < sock->sk->sk_sndbuf / 2)
 148                tx_poll_stop(net);
 149        hdr_size = vq->vhost_hlen;
 150
 151        for (;;) {
 152                head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
 153                                         ARRAY_SIZE(vq->iov),
 154                                         &out, &in,
 155                                         NULL, NULL);
 156                /* On error, stop handling until the next kick. */
 157                if (unlikely(head < 0))
 158                        break;
 159                /* Nothing new?  Wait for eventfd to tell us they refilled. */
 160                if (head == vq->num) {
 161                        wmem = atomic_read(&sock->sk->sk_wmem_alloc);
 162                        if (wmem >= sock->sk->sk_sndbuf * 3 / 4) {
 163                                tx_poll_start(net, sock);
 164                                set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
 165                                break;
 166                        }
 167                        if (unlikely(vhost_enable_notify(vq))) {
 168                                vhost_disable_notify(vq);
 169                                continue;
 170                        }
 171                        break;
 172                }
 173                if (in) {
 174                        vq_err(vq, "Unexpected descriptor format for TX: "
 175                               "out %d, int %d\n", out, in);
 176                        break;
 177                }
 178                /* Skip header. TODO: support TSO. */
 179                s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, out);
 180                msg.msg_iovlen = out;
 181                len = iov_length(vq->iov, out);
 182                /* Sanity check */
 183                if (!len) {
 184                        vq_err(vq, "Unexpected header len for TX: "
 185                               "%zd expected %zd\n",
 186                               iov_length(vq->hdr, s), hdr_size);
 187                        break;
 188                }
 189                /* TODO: Check specific error and bomb out unless ENOBUFS? */
 190                err = sock->ops->sendmsg(NULL, sock, &msg, len);
 191                if (unlikely(err < 0)) {
 192                        vhost_discard_vq_desc(vq, 1);
 193                        tx_poll_start(net, sock);
 194                        break;
 195                }
 196                if (err != len)
 197                        pr_debug("Truncated TX packet: "
 198                                 " len %d != %zd\n", err, len);
 199                vhost_add_used_and_signal(&net->dev, vq, head, 0);
 200                total_len += len;
 201                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
 202                        vhost_poll_queue(&vq->poll);
 203                        break;
 204                }
 205        }
 206
 207        mutex_unlock(&vq->mutex);
 208}
 209
 210static int peek_head_len(struct sock *sk)
 211{
 212        struct sk_buff *head;
 213        int len = 0;
 214
 215        lock_sock(sk);
 216        head = skb_peek(&sk->sk_receive_queue);
 217        if (head)
 218                len = head->len;
 219        release_sock(sk);
 220        return len;
 221}
 222
 223/* This is a multi-buffer version of vhost_get_desc, that works if
 224 *      vq has read descriptors only.
 225 * @vq          - the relevant virtqueue
 226 * @datalen     - data length we'll be reading
 227 * @iovcount    - returned count of io vectors we fill
 228 * @log         - vhost log
 229 * @log_num     - log offset
 230 *      returns number of buffer heads allocated, negative on error
 231 */
 232static int get_rx_bufs(struct vhost_virtqueue *vq,
 233                       struct vring_used_elem *heads,
 234                       int datalen,
 235                       unsigned *iovcount,
 236                       struct vhost_log *log,
 237                       unsigned *log_num)
 238{
 239        unsigned int out, in;
 240        int seg = 0;
 241        int headcount = 0;
 242        unsigned d;
 243        int r, nlogs = 0;
 244
 245        while (datalen > 0) {
 246                if (unlikely(seg >= UIO_MAXIOV)) {
 247                        r = -ENOBUFS;
 248                        goto err;
 249                }
 250                d = vhost_get_vq_desc(vq->dev, vq, vq->iov + seg,
 251                                      ARRAY_SIZE(vq->iov) - seg, &out,
 252                                      &in, log, log_num);
 253                if (d == vq->num) {
 254                        r = 0;
 255                        goto err;
 256                }
 257                if (unlikely(out || in <= 0)) {
 258                        vq_err(vq, "unexpected descriptor format for RX: "
 259                                "out %d, in %d\n", out, in);
 260                        r = -EINVAL;
 261                        goto err;
 262                }
 263                if (unlikely(log)) {
 264                        nlogs += *log_num;
 265                        log += *log_num;
 266                }
 267                heads[headcount].id = d;
 268                heads[headcount].len = iov_length(vq->iov + seg, in);
 269                datalen -= heads[headcount].len;
 270                ++headcount;
 271                seg += in;
 272        }
 273        heads[headcount - 1].len += datalen;
 274        *iovcount = seg;
 275        if (unlikely(log))
 276                *log_num = nlogs;
 277        return headcount;
 278err:
 279        vhost_discard_vq_desc(vq, headcount);
 280        return r;
 281}
 282
 283/* Expects to be always run from workqueue - which acts as
 284 * read-size critical section for our kind of RCU. */
 285static void handle_rx_big(struct vhost_net *net)
 286{
 287        struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
 288        unsigned out, in, log, s;
 289        int head;
 290        struct vhost_log *vq_log;
 291        struct msghdr msg = {
 292                .msg_name = NULL,
 293                .msg_namelen = 0,
 294                .msg_control = NULL, /* FIXME: get and handle RX aux data. */
 295                .msg_controllen = 0,
 296                .msg_iov = vq->iov,
 297                .msg_flags = MSG_DONTWAIT,
 298        };
 299
 300        struct virtio_net_hdr hdr = {
 301                .flags = 0,
 302                .gso_type = VIRTIO_NET_HDR_GSO_NONE
 303        };
 304
 305        size_t len, total_len = 0;
 306        int err;
 307        size_t hdr_size;
 308        /* TODO: check that we are running from vhost_worker? */
 309        struct socket *sock = rcu_dereference_check(vq->private_data, 1);
 310        if (!sock || skb_queue_empty(&sock->sk->sk_receive_queue))
 311                return;
 312
 313        mutex_lock(&vq->mutex);
 314        vhost_disable_notify(vq);
 315        hdr_size = vq->vhost_hlen;
 316
 317        vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
 318                vq->log : NULL;
 319
 320        for (;;) {
 321                head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
 322                                         ARRAY_SIZE(vq->iov),
 323                                         &out, &in,
 324                                         vq_log, &log);
 325                /* On error, stop handling until the next kick. */
 326                if (unlikely(head < 0))
 327                        break;
 328                /* OK, now we need to know about added descriptors. */
 329                if (head == vq->num) {
 330                        if (unlikely(vhost_enable_notify(vq))) {
 331                                /* They have slipped one in as we were
 332                                 * doing that: check again. */
 333                                vhost_disable_notify(vq);
 334                                continue;
 335                        }
 336                        /* Nothing new?  Wait for eventfd to tell us
 337                         * they refilled. */
 338                        break;
 339                }
 340                /* We don't need to be notified again. */
 341                if (out) {
 342                        vq_err(vq, "Unexpected descriptor format for RX: "
 343                               "out %d, int %d\n",
 344                               out, in);
 345                        break;
 346                }
 347                /* Skip header. TODO: support TSO/mergeable rx buffers. */
 348                s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, in);
 349                msg.msg_iovlen = in;
 350                len = iov_length(vq->iov, in);
 351                /* Sanity check */
 352                if (!len) {
 353                        vq_err(vq, "Unexpected header len for RX: "
 354                               "%zd expected %zd\n",
 355                               iov_length(vq->hdr, s), hdr_size);
 356                        break;
 357                }
 358                err = sock->ops->recvmsg(NULL, sock, &msg,
 359                                         len, MSG_DONTWAIT | MSG_TRUNC);
 360                /* TODO: Check specific error and bomb out unless EAGAIN? */
 361                if (err < 0) {
 362                        vhost_discard_vq_desc(vq, 1);
 363                        break;
 364                }
 365                /* TODO: Should check and handle checksum. */
 366                if (err > len) {
 367                        pr_debug("Discarded truncated rx packet: "
 368                                 " len %d > %zd\n", err, len);
 369                        vhost_discard_vq_desc(vq, 1);
 370                        continue;
 371                }
 372                len = err;
 373                err = memcpy_toiovec(vq->hdr, (unsigned char *)&hdr, hdr_size);
 374                if (err) {
 375                        vq_err(vq, "Unable to write vnet_hdr at addr %p: %d\n",
 376                               vq->iov->iov_base, err);
 377                        break;
 378                }
 379                len += hdr_size;
 380                vhost_add_used_and_signal(&net->dev, vq, head, len);
 381                if (unlikely(vq_log))
 382                        vhost_log_write(vq, vq_log, log, len);
 383                total_len += len;
 384                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
 385                        vhost_poll_queue(&vq->poll);
 386                        break;
 387                }
 388        }
 389
 390        mutex_unlock(&vq->mutex);
 391}
 392
 393/* Expects to be always run from workqueue - which acts as
 394 * read-size critical section for our kind of RCU. */
 395static void handle_rx_mergeable(struct vhost_net *net)
 396{
 397        struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
 398        unsigned uninitialized_var(in), log;
 399        struct vhost_log *vq_log;
 400        struct msghdr msg = {
 401                .msg_name = NULL,
 402                .msg_namelen = 0,
 403                .msg_control = NULL, /* FIXME: get and handle RX aux data. */
 404                .msg_controllen = 0,
 405                .msg_iov = vq->iov,
 406                .msg_flags = MSG_DONTWAIT,
 407        };
 408
 409        struct virtio_net_hdr_mrg_rxbuf hdr = {
 410                .hdr.flags = 0,
 411                .hdr.gso_type = VIRTIO_NET_HDR_GSO_NONE
 412        };
 413
 414        size_t total_len = 0;
 415        int err, headcount;
 416        size_t vhost_hlen, sock_hlen;
 417        size_t vhost_len, sock_len;
 418        /* TODO: check that we are running from vhost_worker? */
 419        struct socket *sock = rcu_dereference_check(vq->private_data, 1);
 420        if (!sock || skb_queue_empty(&sock->sk->sk_receive_queue))
 421                return;
 422
 423        mutex_lock(&vq->mutex);
 424        vhost_disable_notify(vq);
 425        vhost_hlen = vq->vhost_hlen;
 426        sock_hlen = vq->sock_hlen;
 427
 428        vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
 429                vq->log : NULL;
 430
 431        while ((sock_len = peek_head_len(sock->sk))) {
 432                sock_len += sock_hlen;
 433                vhost_len = sock_len + vhost_hlen;
 434                headcount = get_rx_bufs(vq, vq->heads, vhost_len,
 435                                        &in, vq_log, &log);
 436                /* On error, stop handling until the next kick. */
 437                if (unlikely(headcount < 0))
 438                        break;
 439                /* OK, now we need to know about added descriptors. */
 440                if (!headcount) {
 441                        if (unlikely(vhost_enable_notify(vq))) {
 442                                /* They have slipped one in as we were
 443                                 * doing that: check again. */
 444                                vhost_disable_notify(vq);
 445                                continue;
 446                        }
 447                        /* Nothing new?  Wait for eventfd to tell us
 448                         * they refilled. */
 449                        break;
 450                }
 451                /* We don't need to be notified again. */
 452                if (unlikely((vhost_hlen)))
 453                        /* Skip header. TODO: support TSO. */
 454                        move_iovec_hdr(vq->iov, vq->hdr, vhost_hlen, in);
 455                else
 456                        /* Copy the header for use in VIRTIO_NET_F_MRG_RXBUF:
 457                         * needed because recvmsg can modify msg_iov. */
 458                        copy_iovec_hdr(vq->iov, vq->hdr, sock_hlen, in);
 459                msg.msg_iovlen = in;
 460                err = sock->ops->recvmsg(NULL, sock, &msg,
 461                                         sock_len, MSG_DONTWAIT | MSG_TRUNC);
 462                /* Userspace might have consumed the packet meanwhile:
 463                 * it's not supposed to do this usually, but might be hard
 464                 * to prevent. Discard data we got (if any) and keep going. */
 465                if (unlikely(err != sock_len)) {
 466                        pr_debug("Discarded rx packet: "
 467                                 " len %d, expected %zd\n", err, sock_len);
 468                        vhost_discard_vq_desc(vq, headcount);
 469                        continue;
 470                }
 471                if (unlikely(vhost_hlen) &&
 472                    memcpy_toiovecend(vq->hdr, (unsigned char *)&hdr, 0,
 473                                      vhost_hlen)) {
 474                        vq_err(vq, "Unable to write vnet_hdr at addr %p\n",
 475                               vq->iov->iov_base);
 476                        break;
 477                }
 478                /* TODO: Should check and handle checksum. */
 479                if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF) &&
 480                    memcpy_toiovecend(vq->hdr, (unsigned char *)&headcount,
 481                                      offsetof(typeof(hdr), num_buffers),
 482                                      sizeof hdr.num_buffers)) {
 483                        vq_err(vq, "Failed num_buffers write");
 484                        vhost_discard_vq_desc(vq, headcount);
 485                        break;
 486                }
 487                vhost_add_used_and_signal_n(&net->dev, vq, vq->heads,
 488                                            headcount);
 489                if (unlikely(vq_log))
 490                        vhost_log_write(vq, vq_log, log, vhost_len);
 491                total_len += vhost_len;
 492                if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
 493                        vhost_poll_queue(&vq->poll);
 494                        break;
 495                }
 496        }
 497
 498        mutex_unlock(&vq->mutex);
 499}
 500
 501static void handle_rx(struct vhost_net *net)
 502{
 503        if (vhost_has_feature(&net->dev, VIRTIO_NET_F_MRG_RXBUF))
 504                handle_rx_mergeable(net);
 505        else
 506                handle_rx_big(net);
 507}
 508
 509static void handle_tx_kick(struct vhost_work *work)
 510{
 511        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 512                                                  poll.work);
 513        struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
 514
 515        handle_tx(net);
 516}
 517
 518static void handle_rx_kick(struct vhost_work *work)
 519{
 520        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
 521                                                  poll.work);
 522        struct vhost_net *net = container_of(vq->dev, struct vhost_net, dev);
 523
 524        handle_rx(net);
 525}
 526
 527static void handle_tx_net(struct vhost_work *work)
 528{
 529        struct vhost_net *net = container_of(work, struct vhost_net,
 530                                             poll[VHOST_NET_VQ_TX].work);
 531        handle_tx(net);
 532}
 533
 534static void handle_rx_net(struct vhost_work *work)
 535{
 536        struct vhost_net *net = container_of(work, struct vhost_net,
 537                                             poll[VHOST_NET_VQ_RX].work);
 538        handle_rx(net);
 539}
 540
 541static int vhost_net_open(struct inode *inode, struct file *f)
 542{
 543        struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
 544        struct vhost_dev *dev;
 545        int r;
 546
 547        if (!n)
 548                return -ENOMEM;
 549
 550        dev = &n->dev;
 551        n->vqs[VHOST_NET_VQ_TX].handle_kick = handle_tx_kick;
 552        n->vqs[VHOST_NET_VQ_RX].handle_kick = handle_rx_kick;
 553        r = vhost_dev_init(dev, n->vqs, VHOST_NET_VQ_MAX);
 554        if (r < 0) {
 555                kfree(n);
 556                return r;
 557        }
 558
 559        vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT, dev);
 560        vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN, dev);
 561        n->tx_poll_state = VHOST_NET_POLL_DISABLED;
 562
 563        f->private_data = n;
 564
 565        return 0;
 566}
 567
 568static void vhost_net_disable_vq(struct vhost_net *n,
 569                                 struct vhost_virtqueue *vq)
 570{
 571        if (!vq->private_data)
 572                return;
 573        if (vq == n->vqs + VHOST_NET_VQ_TX) {
 574                tx_poll_stop(n);
 575                n->tx_poll_state = VHOST_NET_POLL_DISABLED;
 576        } else
 577                vhost_poll_stop(n->poll + VHOST_NET_VQ_RX);
 578}
 579
 580static void vhost_net_enable_vq(struct vhost_net *n,
 581                                struct vhost_virtqueue *vq)
 582{
 583        struct socket *sock;
 584
 585        sock = rcu_dereference_protected(vq->private_data,
 586                                         lockdep_is_held(&vq->mutex));
 587        if (!sock)
 588                return;
 589        if (vq == n->vqs + VHOST_NET_VQ_TX) {
 590                n->tx_poll_state = VHOST_NET_POLL_STOPPED;
 591                tx_poll_start(n, sock);
 592        } else
 593                vhost_poll_start(n->poll + VHOST_NET_VQ_RX, sock->file);
 594}
 595
 596static struct socket *vhost_net_stop_vq(struct vhost_net *n,
 597                                        struct vhost_virtqueue *vq)
 598{
 599        struct socket *sock;
 600
 601        mutex_lock(&vq->mutex);
 602        sock = rcu_dereference_protected(vq->private_data,
 603                                         lockdep_is_held(&vq->mutex));
 604        vhost_net_disable_vq(n, vq);
 605        rcu_assign_pointer(vq->private_data, NULL);
 606        mutex_unlock(&vq->mutex);
 607        return sock;
 608}
 609
 610static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
 611                           struct socket **rx_sock)
 612{
 613        *tx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_TX);
 614        *rx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_RX);
 615}
 616
 617static void vhost_net_flush_vq(struct vhost_net *n, int index)
 618{
 619        vhost_poll_flush(n->poll + index);
 620        vhost_poll_flush(&n->dev.vqs[index].poll);
 621}
 622
 623static void vhost_net_flush(struct vhost_net *n)
 624{
 625        vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
 626        vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
 627}
 628
 629static int vhost_net_release(struct inode *inode, struct file *f)
 630{
 631        struct vhost_net *n = f->private_data;
 632        struct socket *tx_sock;
 633        struct socket *rx_sock;
 634
 635        vhost_net_stop(n, &tx_sock, &rx_sock);
 636        vhost_net_flush(n);
 637        vhost_dev_cleanup(&n->dev);
 638        if (tx_sock)
 639                fput(tx_sock->file);
 640        if (rx_sock)
 641                fput(rx_sock->file);
 642        /* We do an extra flush before freeing memory,
 643         * since jobs can re-queue themselves. */
 644        vhost_net_flush(n);
 645        kfree(n);
 646        return 0;
 647}
 648
 649static struct socket *get_raw_socket(int fd)
 650{
 651        struct {
 652                struct sockaddr_ll sa;
 653                char  buf[MAX_ADDR_LEN];
 654        } uaddr;
 655        int uaddr_len = sizeof uaddr, r;
 656        struct socket *sock = sockfd_lookup(fd, &r);
 657        if (!sock)
 658                return ERR_PTR(-ENOTSOCK);
 659
 660        /* Parameter checking */
 661        if (sock->sk->sk_type != SOCK_RAW) {
 662                r = -ESOCKTNOSUPPORT;
 663                goto err;
 664        }
 665
 666        r = sock->ops->getname(sock, (struct sockaddr *)&uaddr.sa,
 667                               &uaddr_len, 0);
 668        if (r)
 669                goto err;
 670
 671        if (uaddr.sa.sll_family != AF_PACKET) {
 672                r = -EPFNOSUPPORT;
 673                goto err;
 674        }
 675        return sock;
 676err:
 677        fput(sock->file);
 678        return ERR_PTR(r);
 679}
 680
 681static struct socket *get_tap_socket(int fd)
 682{
 683        struct file *file = fget(fd);
 684        struct socket *sock;
 685        if (!file)
 686                return ERR_PTR(-EBADF);
 687        sock = tun_get_socket(file);
 688        if (!IS_ERR(sock))
 689                return sock;
 690        sock = macvtap_get_socket(file);
 691        if (IS_ERR(sock))
 692                fput(file);
 693        return sock;
 694}
 695
 696static struct socket *get_socket(int fd)
 697{
 698        struct socket *sock;
 699        /* special case to disable backend */
 700        if (fd == -1)
 701                return NULL;
 702        sock = get_raw_socket(fd);
 703        if (!IS_ERR(sock))
 704                return sock;
 705        sock = get_tap_socket(fd);
 706        if (!IS_ERR(sock))
 707                return sock;
 708        return ERR_PTR(-ENOTSOCK);
 709}
 710
 711static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
 712{
 713        struct socket *sock, *oldsock;
 714        struct vhost_virtqueue *vq;
 715        int r;
 716
 717        mutex_lock(&n->dev.mutex);
 718        r = vhost_dev_check_owner(&n->dev);
 719        if (r)
 720                goto err;
 721
 722        if (index >= VHOST_NET_VQ_MAX) {
 723                r = -ENOBUFS;
 724                goto err;
 725        }
 726        vq = n->vqs + index;
 727        mutex_lock(&vq->mutex);
 728
 729        /* Verify that ring has been setup correctly. */
 730        if (!vhost_vq_access_ok(vq)) {
 731                r = -EFAULT;
 732                goto err_vq;
 733        }
 734        sock = get_socket(fd);
 735        if (IS_ERR(sock)) {
 736                r = PTR_ERR(sock);
 737                goto err_vq;
 738        }
 739
 740        /* start polling new socket */
 741        oldsock = rcu_dereference_protected(vq->private_data,
 742                                            lockdep_is_held(&vq->mutex));
 743        if (sock != oldsock) {
 744                vhost_net_disable_vq(n, vq);
 745                rcu_assign_pointer(vq->private_data, sock);
 746                vhost_net_enable_vq(n, vq);
 747        }
 748
 749        mutex_unlock(&vq->mutex);
 750
 751        if (oldsock) {
 752                vhost_net_flush_vq(n, index);
 753                fput(oldsock->file);
 754        }
 755
 756        mutex_unlock(&n->dev.mutex);
 757        return 0;
 758
 759err_vq:
 760        mutex_unlock(&vq->mutex);
 761err:
 762        mutex_unlock(&n->dev.mutex);
 763        return r;
 764}
 765
 766static long vhost_net_reset_owner(struct vhost_net *n)
 767{
 768        struct socket *tx_sock = NULL;
 769        struct socket *rx_sock = NULL;
 770        long err;
 771        mutex_lock(&n->dev.mutex);
 772        err = vhost_dev_check_owner(&n->dev);
 773        if (err)
 774                goto done;
 775        vhost_net_stop(n, &tx_sock, &rx_sock);
 776        vhost_net_flush(n);
 777        err = vhost_dev_reset_owner(&n->dev);
 778done:
 779        mutex_unlock(&n->dev.mutex);
 780        if (tx_sock)
 781                fput(tx_sock->file);
 782        if (rx_sock)
 783                fput(rx_sock->file);
 784        return err;
 785}
 786
 787static int vhost_net_set_features(struct vhost_net *n, u64 features)
 788{
 789        size_t vhost_hlen, sock_hlen, hdr_len;
 790        int i;
 791
 792        hdr_len = (features & (1 << VIRTIO_NET_F_MRG_RXBUF)) ?
 793                        sizeof(struct virtio_net_hdr_mrg_rxbuf) :
 794                        sizeof(struct virtio_net_hdr);
 795        if (features & (1 << VHOST_NET_F_VIRTIO_NET_HDR)) {
 796                /* vhost provides vnet_hdr */
 797                vhost_hlen = hdr_len;
 798                sock_hlen = 0;
 799        } else {
 800                /* socket provides vnet_hdr */
 801                vhost_hlen = 0;
 802                sock_hlen = hdr_len;
 803        }
 804        mutex_lock(&n->dev.mutex);
 805        if ((features & (1 << VHOST_F_LOG_ALL)) &&
 806            !vhost_log_access_ok(&n->dev)) {
 807                mutex_unlock(&n->dev.mutex);
 808                return -EFAULT;
 809        }
 810        n->dev.acked_features = features;
 811        smp_wmb();
 812        for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
 813                mutex_lock(&n->vqs[i].mutex);
 814                n->vqs[i].vhost_hlen = vhost_hlen;
 815                n->vqs[i].sock_hlen = sock_hlen;
 816                mutex_unlock(&n->vqs[i].mutex);
 817        }
 818        vhost_net_flush(n);
 819        mutex_unlock(&n->dev.mutex);
 820        return 0;
 821}
 822
 823static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
 824                            unsigned long arg)
 825{
 826        struct vhost_net *n = f->private_data;
 827        void __user *argp = (void __user *)arg;
 828        u64 __user *featurep = argp;
 829        struct vhost_vring_file backend;
 830        u64 features;
 831        int r;
 832        switch (ioctl) {
 833        case VHOST_NET_SET_BACKEND:
 834                if (copy_from_user(&backend, argp, sizeof backend))
 835                        return -EFAULT;
 836                return vhost_net_set_backend(n, backend.index, backend.fd);
 837        case VHOST_GET_FEATURES:
 838                features = VHOST_FEATURES;
 839                if (copy_to_user(featurep, &features, sizeof features))
 840                        return -EFAULT;
 841                return 0;
 842        case VHOST_SET_FEATURES:
 843                if (copy_from_user(&features, featurep, sizeof features))
 844                        return -EFAULT;
 845                if (features & ~VHOST_FEATURES)
 846                        return -EOPNOTSUPP;
 847                return vhost_net_set_features(n, features);
 848        case VHOST_RESET_OWNER:
 849                return vhost_net_reset_owner(n);
 850        default:
 851                mutex_lock(&n->dev.mutex);
 852                r = vhost_dev_ioctl(&n->dev, ioctl, arg);
 853                vhost_net_flush(n);
 854                mutex_unlock(&n->dev.mutex);
 855                return r;
 856        }
 857}
 858
 859#ifdef CONFIG_COMPAT
 860static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
 861                                   unsigned long arg)
 862{
 863        return vhost_net_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
 864}
 865#endif
 866
 867static const struct file_operations vhost_net_fops = {
 868        .owner          = THIS_MODULE,
 869        .release        = vhost_net_release,
 870        .unlocked_ioctl = vhost_net_ioctl,
 871#ifdef CONFIG_COMPAT
 872        .compat_ioctl   = vhost_net_compat_ioctl,
 873#endif
 874        .open           = vhost_net_open,
 875        .llseek         = noop_llseek,
 876};
 877
 878static struct miscdevice vhost_net_misc = {
 879        MISC_DYNAMIC_MINOR,
 880        "vhost-net",
 881        &vhost_net_fops,
 882};
 883
 884static int vhost_net_init(void)
 885{
 886        return misc_register(&vhost_net_misc);
 887}
 888module_init(vhost_net_init);
 889
 890static void vhost_net_exit(void)
 891{
 892        misc_deregister(&vhost_net_misc);
 893}
 894module_exit(vhost_net_exit);
 895
 896MODULE_VERSION("0.0.1");
 897MODULE_LICENSE("GPL v2");
 898MODULE_AUTHOR("Michael S. Tsirkin");
 899MODULE_DESCRIPTION("Host kernel accelerator for virtio net");
 900