linux/drivers/vhost/vdpa.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/*
   3 * Copyright (C) 2018-2020 Intel Corporation.
   4 * Copyright (C) 2020 Red Hat, Inc.
   5 *
   6 * Author: Tiwei Bie <tiwei.bie@intel.com>
   7 *         Jason Wang <jasowang@redhat.com>
   8 *
   9 * Thanks Michael S. Tsirkin for the valuable comments and
  10 * suggestions.  And thanks to Cunming Liang and Zhihong Wang for all
  11 * their supports.
  12 */
  13
  14#include <linux/kernel.h>
  15#include <linux/module.h>
  16#include <linux/cdev.h>
  17#include <linux/device.h>
  18#include <linux/mm.h>
  19#include <linux/iommu.h>
  20#include <linux/uuid.h>
  21#include <linux/vdpa.h>
  22#include <linux/nospec.h>
  23#include <linux/vhost.h>
  24#include <linux/virtio_net.h>
  25
  26#include "vhost.h"
  27
  28enum {
  29        VHOST_VDPA_BACKEND_FEATURES =
  30        (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) |
  31        (1ULL << VHOST_BACKEND_F_IOTLB_BATCH),
  32};
  33
  34#define VHOST_VDPA_DEV_MAX (1U << MINORBITS)
  35
  36struct vhost_vdpa {
  37        struct vhost_dev vdev;
  38        struct iommu_domain *domain;
  39        struct vhost_virtqueue *vqs;
  40        struct completion completion;
  41        struct vdpa_device *vdpa;
  42        struct device dev;
  43        struct cdev cdev;
  44        atomic_t opened;
  45        int nvqs;
  46        int virtio_id;
  47        int minor;
  48        struct eventfd_ctx *config_ctx;
  49        int in_batch;
  50        struct vdpa_iova_range range;
  51};
  52
  53static DEFINE_IDA(vhost_vdpa_ida);
  54
  55static dev_t vhost_vdpa_major;
  56
  57static void handle_vq_kick(struct vhost_work *work)
  58{
  59        struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
  60                                                  poll.work);
  61        struct vhost_vdpa *v = container_of(vq->dev, struct vhost_vdpa, vdev);
  62        const struct vdpa_config_ops *ops = v->vdpa->config;
  63
  64        ops->kick_vq(v->vdpa, vq - v->vqs);
  65}
  66
  67static irqreturn_t vhost_vdpa_virtqueue_cb(void *private)
  68{
  69        struct vhost_virtqueue *vq = private;
  70        struct eventfd_ctx *call_ctx = vq->call_ctx.ctx;
  71
  72        if (call_ctx)
  73                eventfd_signal(call_ctx, 1);
  74
  75        return IRQ_HANDLED;
  76}
  77
  78static irqreturn_t vhost_vdpa_config_cb(void *private)
  79{
  80        struct vhost_vdpa *v = private;
  81        struct eventfd_ctx *config_ctx = v->config_ctx;
  82
  83        if (config_ctx)
  84                eventfd_signal(config_ctx, 1);
  85
  86        return IRQ_HANDLED;
  87}
  88
  89static void vhost_vdpa_setup_vq_irq(struct vhost_vdpa *v, u16 qid)
  90{
  91        struct vhost_virtqueue *vq = &v->vqs[qid];
  92        const struct vdpa_config_ops *ops = v->vdpa->config;
  93        struct vdpa_device *vdpa = v->vdpa;
  94        int ret, irq;
  95
  96        if (!ops->get_vq_irq)
  97                return;
  98
  99        irq = ops->get_vq_irq(vdpa, qid);
 100        irq_bypass_unregister_producer(&vq->call_ctx.producer);
 101        if (!vq->call_ctx.ctx || irq < 0)
 102                return;
 103
 104        vq->call_ctx.producer.token = vq->call_ctx.ctx;
 105        vq->call_ctx.producer.irq = irq;
 106        ret = irq_bypass_register_producer(&vq->call_ctx.producer);
 107        if (unlikely(ret))
 108                dev_info(&v->dev, "vq %u, irq bypass producer (token %p) registration fails, ret =  %d\n",
 109                         qid, vq->call_ctx.producer.token, ret);
 110}
 111
 112static void vhost_vdpa_unsetup_vq_irq(struct vhost_vdpa *v, u16 qid)
 113{
 114        struct vhost_virtqueue *vq = &v->vqs[qid];
 115
 116        irq_bypass_unregister_producer(&vq->call_ctx.producer);
 117}
 118
 119static void vhost_vdpa_reset(struct vhost_vdpa *v)
 120{
 121        struct vdpa_device *vdpa = v->vdpa;
 122
 123        vdpa_reset(vdpa);
 124        v->in_batch = 0;
 125}
 126
 127static long vhost_vdpa_get_device_id(struct vhost_vdpa *v, u8 __user *argp)
 128{
 129        struct vdpa_device *vdpa = v->vdpa;
 130        const struct vdpa_config_ops *ops = vdpa->config;
 131        u32 device_id;
 132
 133        device_id = ops->get_device_id(vdpa);
 134
 135        if (copy_to_user(argp, &device_id, sizeof(device_id)))
 136                return -EFAULT;
 137
 138        return 0;
 139}
 140
 141static long vhost_vdpa_get_status(struct vhost_vdpa *v, u8 __user *statusp)
 142{
 143        struct vdpa_device *vdpa = v->vdpa;
 144        const struct vdpa_config_ops *ops = vdpa->config;
 145        u8 status;
 146
 147        status = ops->get_status(vdpa);
 148
 149        if (copy_to_user(statusp, &status, sizeof(status)))
 150                return -EFAULT;
 151
 152        return 0;
 153}
 154
 155static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp)
 156{
 157        struct vdpa_device *vdpa = v->vdpa;
 158        const struct vdpa_config_ops *ops = vdpa->config;
 159        u8 status, status_old;
 160        int nvqs = v->nvqs;
 161        u16 i;
 162
 163        if (copy_from_user(&status, statusp, sizeof(status)))
 164                return -EFAULT;
 165
 166        status_old = ops->get_status(vdpa);
 167
 168        /*
 169         * Userspace shouldn't remove status bits unless reset the
 170         * status to 0.
 171         */
 172        if (status != 0 && (ops->get_status(vdpa) & ~status) != 0)
 173                return -EINVAL;
 174
 175        ops->set_status(vdpa, status);
 176
 177        if ((status & VIRTIO_CONFIG_S_DRIVER_OK) && !(status_old & VIRTIO_CONFIG_S_DRIVER_OK))
 178                for (i = 0; i < nvqs; i++)
 179                        vhost_vdpa_setup_vq_irq(v, i);
 180
 181        if ((status_old & VIRTIO_CONFIG_S_DRIVER_OK) && !(status & VIRTIO_CONFIG_S_DRIVER_OK))
 182                for (i = 0; i < nvqs; i++)
 183                        vhost_vdpa_unsetup_vq_irq(v, i);
 184
 185        return 0;
 186}
 187
 188static int vhost_vdpa_config_validate(struct vhost_vdpa *v,
 189                                      struct vhost_vdpa_config *c)
 190{
 191        long size = 0;
 192
 193        switch (v->virtio_id) {
 194        case VIRTIO_ID_NET:
 195                size = sizeof(struct virtio_net_config);
 196                break;
 197        }
 198
 199        if (c->len == 0)
 200                return -EINVAL;
 201
 202        if (c->len > size - c->off)
 203                return -E2BIG;
 204
 205        return 0;
 206}
 207
 208static long vhost_vdpa_get_config(struct vhost_vdpa *v,
 209                                  struct vhost_vdpa_config __user *c)
 210{
 211        struct vdpa_device *vdpa = v->vdpa;
 212        struct vhost_vdpa_config config;
 213        unsigned long size = offsetof(struct vhost_vdpa_config, buf);
 214        u8 *buf;
 215
 216        if (copy_from_user(&config, c, size))
 217                return -EFAULT;
 218        if (vhost_vdpa_config_validate(v, &config))
 219                return -EINVAL;
 220        buf = kvzalloc(config.len, GFP_KERNEL);
 221        if (!buf)
 222                return -ENOMEM;
 223
 224        vdpa_get_config(vdpa, config.off, buf, config.len);
 225
 226        if (copy_to_user(c->buf, buf, config.len)) {
 227                kvfree(buf);
 228                return -EFAULT;
 229        }
 230
 231        kvfree(buf);
 232        return 0;
 233}
 234
 235static long vhost_vdpa_set_config(struct vhost_vdpa *v,
 236                                  struct vhost_vdpa_config __user *c)
 237{
 238        struct vdpa_device *vdpa = v->vdpa;
 239        const struct vdpa_config_ops *ops = vdpa->config;
 240        struct vhost_vdpa_config config;
 241        unsigned long size = offsetof(struct vhost_vdpa_config, buf);
 242        u8 *buf;
 243
 244        if (copy_from_user(&config, c, size))
 245                return -EFAULT;
 246        if (vhost_vdpa_config_validate(v, &config))
 247                return -EINVAL;
 248
 249        buf = vmemdup_user(c->buf, config.len);
 250        if (IS_ERR(buf))
 251                return PTR_ERR(buf);
 252
 253        ops->set_config(vdpa, config.off, buf, config.len);
 254
 255        kvfree(buf);
 256        return 0;
 257}
 258
 259static long vhost_vdpa_get_features(struct vhost_vdpa *v, u64 __user *featurep)
 260{
 261        struct vdpa_device *vdpa = v->vdpa;
 262        const struct vdpa_config_ops *ops = vdpa->config;
 263        u64 features;
 264
 265        features = ops->get_features(vdpa);
 266
 267        if (copy_to_user(featurep, &features, sizeof(features)))
 268                return -EFAULT;
 269
 270        return 0;
 271}
 272
 273static long vhost_vdpa_set_features(struct vhost_vdpa *v, u64 __user *featurep)
 274{
 275        struct vdpa_device *vdpa = v->vdpa;
 276        const struct vdpa_config_ops *ops = vdpa->config;
 277        u64 features;
 278
 279        /*
 280         * It's not allowed to change the features after they have
 281         * been negotiated.
 282         */
 283        if (ops->get_status(vdpa) & VIRTIO_CONFIG_S_FEATURES_OK)
 284                return -EBUSY;
 285
 286        if (copy_from_user(&features, featurep, sizeof(features)))
 287                return -EFAULT;
 288
 289        if (vdpa_set_features(vdpa, features))
 290                return -EINVAL;
 291
 292        return 0;
 293}
 294
 295static long vhost_vdpa_get_vring_num(struct vhost_vdpa *v, u16 __user *argp)
 296{
 297        struct vdpa_device *vdpa = v->vdpa;
 298        const struct vdpa_config_ops *ops = vdpa->config;
 299        u16 num;
 300
 301        num = ops->get_vq_num_max(vdpa);
 302
 303        if (copy_to_user(argp, &num, sizeof(num)))
 304                return -EFAULT;
 305
 306        return 0;
 307}
 308
 309static void vhost_vdpa_config_put(struct vhost_vdpa *v)
 310{
 311        if (v->config_ctx)
 312                eventfd_ctx_put(v->config_ctx);
 313}
 314
 315static long vhost_vdpa_set_config_call(struct vhost_vdpa *v, u32 __user *argp)
 316{
 317        struct vdpa_callback cb;
 318        int fd;
 319        struct eventfd_ctx *ctx;
 320
 321        cb.callback = vhost_vdpa_config_cb;
 322        cb.private = v->vdpa;
 323        if (copy_from_user(&fd, argp, sizeof(fd)))
 324                return  -EFAULT;
 325
 326        ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
 327        swap(ctx, v->config_ctx);
 328
 329        if (!IS_ERR_OR_NULL(ctx))
 330                eventfd_ctx_put(ctx);
 331
 332        if (IS_ERR(v->config_ctx))
 333                return PTR_ERR(v->config_ctx);
 334
 335        v->vdpa->config->set_config_cb(v->vdpa, &cb);
 336
 337        return 0;
 338}
 339
 340static long vhost_vdpa_get_iova_range(struct vhost_vdpa *v, u32 __user *argp)
 341{
 342        struct vhost_vdpa_iova_range range = {
 343                .first = v->range.first,
 344                .last = v->range.last,
 345        };
 346
 347        if (copy_to_user(argp, &range, sizeof(range)))
 348                return -EFAULT;
 349        return 0;
 350}
 351
 352static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd,
 353                                   void __user *argp)
 354{
 355        struct vdpa_device *vdpa = v->vdpa;
 356        const struct vdpa_config_ops *ops = vdpa->config;
 357        struct vdpa_vq_state vq_state;
 358        struct vdpa_callback cb;
 359        struct vhost_virtqueue *vq;
 360        struct vhost_vring_state s;
 361        u32 idx;
 362        long r;
 363
 364        r = get_user(idx, (u32 __user *)argp);
 365        if (r < 0)
 366                return r;
 367
 368        if (idx >= v->nvqs)
 369                return -ENOBUFS;
 370
 371        idx = array_index_nospec(idx, v->nvqs);
 372        vq = &v->vqs[idx];
 373
 374        switch (cmd) {
 375        case VHOST_VDPA_SET_VRING_ENABLE:
 376                if (copy_from_user(&s, argp, sizeof(s)))
 377                        return -EFAULT;
 378                ops->set_vq_ready(vdpa, idx, s.num);
 379                return 0;
 380        case VHOST_GET_VRING_BASE:
 381                r = ops->get_vq_state(v->vdpa, idx, &vq_state);
 382                if (r)
 383                        return r;
 384
 385                vq->last_avail_idx = vq_state.avail_index;
 386                break;
 387        }
 388
 389        r = vhost_vring_ioctl(&v->vdev, cmd, argp);
 390        if (r)
 391                return r;
 392
 393        switch (cmd) {
 394        case VHOST_SET_VRING_ADDR:
 395                if (ops->set_vq_address(vdpa, idx,
 396                                        (u64)(uintptr_t)vq->desc,
 397                                        (u64)(uintptr_t)vq->avail,
 398                                        (u64)(uintptr_t)vq->used))
 399                        r = -EINVAL;
 400                break;
 401
 402        case VHOST_SET_VRING_BASE:
 403                vq_state.avail_index = vq->last_avail_idx;
 404                if (ops->set_vq_state(vdpa, idx, &vq_state))
 405                        r = -EINVAL;
 406                break;
 407
 408        case VHOST_SET_VRING_CALL:
 409                if (vq->call_ctx.ctx) {
 410                        cb.callback = vhost_vdpa_virtqueue_cb;
 411                        cb.private = vq;
 412                } else {
 413                        cb.callback = NULL;
 414                        cb.private = NULL;
 415                }
 416                ops->set_vq_cb(vdpa, idx, &cb);
 417                vhost_vdpa_setup_vq_irq(v, idx);
 418                break;
 419
 420        case VHOST_SET_VRING_NUM:
 421                ops->set_vq_num(vdpa, idx, vq->num);
 422                break;
 423        }
 424
 425        return r;
 426}
 427
 428static long vhost_vdpa_unlocked_ioctl(struct file *filep,
 429                                      unsigned int cmd, unsigned long arg)
 430{
 431        struct vhost_vdpa *v = filep->private_data;
 432        struct vhost_dev *d = &v->vdev;
 433        void __user *argp = (void __user *)arg;
 434        u64 __user *featurep = argp;
 435        u64 features;
 436        long r = 0;
 437
 438        if (cmd == VHOST_SET_BACKEND_FEATURES) {
 439                if (copy_from_user(&features, featurep, sizeof(features)))
 440                        return -EFAULT;
 441                if (features & ~VHOST_VDPA_BACKEND_FEATURES)
 442                        return -EOPNOTSUPP;
 443                vhost_set_backend_features(&v->vdev, features);
 444                return 0;
 445        }
 446
 447        mutex_lock(&d->mutex);
 448
 449        switch (cmd) {
 450        case VHOST_VDPA_GET_DEVICE_ID:
 451                r = vhost_vdpa_get_device_id(v, argp);
 452                break;
 453        case VHOST_VDPA_GET_STATUS:
 454                r = vhost_vdpa_get_status(v, argp);
 455                break;
 456        case VHOST_VDPA_SET_STATUS:
 457                r = vhost_vdpa_set_status(v, argp);
 458                break;
 459        case VHOST_VDPA_GET_CONFIG:
 460                r = vhost_vdpa_get_config(v, argp);
 461                break;
 462        case VHOST_VDPA_SET_CONFIG:
 463                r = vhost_vdpa_set_config(v, argp);
 464                break;
 465        case VHOST_GET_FEATURES:
 466                r = vhost_vdpa_get_features(v, argp);
 467                break;
 468        case VHOST_SET_FEATURES:
 469                r = vhost_vdpa_set_features(v, argp);
 470                break;
 471        case VHOST_VDPA_GET_VRING_NUM:
 472                r = vhost_vdpa_get_vring_num(v, argp);
 473                break;
 474        case VHOST_SET_LOG_BASE:
 475        case VHOST_SET_LOG_FD:
 476                r = -ENOIOCTLCMD;
 477                break;
 478        case VHOST_VDPA_SET_CONFIG_CALL:
 479                r = vhost_vdpa_set_config_call(v, argp);
 480                break;
 481        case VHOST_GET_BACKEND_FEATURES:
 482                features = VHOST_VDPA_BACKEND_FEATURES;
 483                if (copy_to_user(featurep, &features, sizeof(features)))
 484                        r = -EFAULT;
 485                break;
 486        case VHOST_VDPA_GET_IOVA_RANGE:
 487                r = vhost_vdpa_get_iova_range(v, argp);
 488                break;
 489        default:
 490                r = vhost_dev_ioctl(&v->vdev, cmd, argp);
 491                if (r == -ENOIOCTLCMD)
 492                        r = vhost_vdpa_vring_ioctl(v, cmd, argp);
 493                break;
 494        }
 495
 496        mutex_unlock(&d->mutex);
 497        return r;
 498}
 499
 500static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v, u64 start, u64 last)
 501{
 502        struct vhost_dev *dev = &v->vdev;
 503        struct vhost_iotlb *iotlb = dev->iotlb;
 504        struct vhost_iotlb_map *map;
 505        struct page *page;
 506        unsigned long pfn, pinned;
 507
 508        while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
 509                pinned = map->size >> PAGE_SHIFT;
 510                for (pfn = map->addr >> PAGE_SHIFT;
 511                     pinned > 0; pfn++, pinned--) {
 512                        page = pfn_to_page(pfn);
 513                        if (map->perm & VHOST_ACCESS_WO)
 514                                set_page_dirty_lock(page);
 515                        unpin_user_page(page);
 516                }
 517                atomic64_sub(map->size >> PAGE_SHIFT, &dev->mm->pinned_vm);
 518                vhost_iotlb_map_free(iotlb, map);
 519        }
 520}
 521
 522static void vhost_vdpa_iotlb_free(struct vhost_vdpa *v)
 523{
 524        struct vhost_dev *dev = &v->vdev;
 525
 526        vhost_vdpa_iotlb_unmap(v, 0ULL, 0ULL - 1);
 527        kfree(dev->iotlb);
 528        dev->iotlb = NULL;
 529}
 530
 531static int perm_to_iommu_flags(u32 perm)
 532{
 533        int flags = 0;
 534
 535        switch (perm) {
 536        case VHOST_ACCESS_WO:
 537                flags |= IOMMU_WRITE;
 538                break;
 539        case VHOST_ACCESS_RO:
 540                flags |= IOMMU_READ;
 541                break;
 542        case VHOST_ACCESS_RW:
 543                flags |= (IOMMU_WRITE | IOMMU_READ);
 544                break;
 545        default:
 546                WARN(1, "invalidate vhost IOTLB permission\n");
 547                break;
 548        }
 549
 550        return flags | IOMMU_CACHE;
 551}
 552
 553static int vhost_vdpa_map(struct vhost_vdpa *v,
 554                          u64 iova, u64 size, u64 pa, u32 perm)
 555{
 556        struct vhost_dev *dev = &v->vdev;
 557        struct vdpa_device *vdpa = v->vdpa;
 558        const struct vdpa_config_ops *ops = vdpa->config;
 559        int r = 0;
 560
 561        r = vhost_iotlb_add_range(dev->iotlb, iova, iova + size - 1,
 562                                  pa, perm);
 563        if (r)
 564                return r;
 565
 566        if (ops->dma_map) {
 567                r = ops->dma_map(vdpa, iova, size, pa, perm);
 568        } else if (ops->set_map) {
 569                if (!v->in_batch)
 570                        r = ops->set_map(vdpa, dev->iotlb);
 571        } else {
 572                r = iommu_map(v->domain, iova, pa, size,
 573                              perm_to_iommu_flags(perm));
 574        }
 575
 576        if (r)
 577                vhost_iotlb_del_range(dev->iotlb, iova, iova + size - 1);
 578        else
 579                atomic64_add(size >> PAGE_SHIFT, &dev->mm->pinned_vm);
 580
 581        return r;
 582}
 583
 584static void vhost_vdpa_unmap(struct vhost_vdpa *v, u64 iova, u64 size)
 585{
 586        struct vhost_dev *dev = &v->vdev;
 587        struct vdpa_device *vdpa = v->vdpa;
 588        const struct vdpa_config_ops *ops = vdpa->config;
 589
 590        vhost_vdpa_iotlb_unmap(v, iova, iova + size - 1);
 591
 592        if (ops->dma_map) {
 593                ops->dma_unmap(vdpa, iova, size);
 594        } else if (ops->set_map) {
 595                if (!v->in_batch)
 596                        ops->set_map(vdpa, dev->iotlb);
 597        } else {
 598                iommu_unmap(v->domain, iova, size);
 599        }
 600}
 601
 602static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
 603                                           struct vhost_iotlb_msg *msg)
 604{
 605        struct vhost_dev *dev = &v->vdev;
 606        struct vhost_iotlb *iotlb = dev->iotlb;
 607        struct page **page_list;
 608        unsigned long list_size = PAGE_SIZE / sizeof(struct page *);
 609        unsigned int gup_flags = FOLL_LONGTERM;
 610        unsigned long npages, cur_base, map_pfn, last_pfn = 0;
 611        unsigned long lock_limit, sz2pin, nchunks, i;
 612        u64 iova = msg->iova;
 613        long pinned;
 614        int ret = 0;
 615
 616        if (msg->iova < v->range.first ||
 617            msg->iova + msg->size - 1 > v->range.last)
 618                return -EINVAL;
 619
 620        if (vhost_iotlb_itree_first(iotlb, msg->iova,
 621                                    msg->iova + msg->size - 1))
 622                return -EEXIST;
 623
 624        /* Limit the use of memory for bookkeeping */
 625        page_list = (struct page **) __get_free_page(GFP_KERNEL);
 626        if (!page_list)
 627                return -ENOMEM;
 628
 629        if (msg->perm & VHOST_ACCESS_WO)
 630                gup_flags |= FOLL_WRITE;
 631
 632        npages = PAGE_ALIGN(msg->size + (iova & ~PAGE_MASK)) >> PAGE_SHIFT;
 633        if (!npages) {
 634                ret = -EINVAL;
 635                goto free;
 636        }
 637
 638        mmap_read_lock(dev->mm);
 639
 640        lock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
 641        if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
 642                ret = -ENOMEM;
 643                goto unlock;
 644        }
 645
 646        cur_base = msg->uaddr & PAGE_MASK;
 647        iova &= PAGE_MASK;
 648        nchunks = 0;
 649
 650        while (npages) {
 651                sz2pin = min_t(unsigned long, npages, list_size);
 652                pinned = pin_user_pages(cur_base, sz2pin,
 653                                        gup_flags, page_list, NULL);
 654                if (sz2pin != pinned) {
 655                        if (pinned < 0) {
 656                                ret = pinned;
 657                        } else {
 658                                unpin_user_pages(page_list, pinned);
 659                                ret = -ENOMEM;
 660                        }
 661                        goto out;
 662                }
 663                nchunks++;
 664
 665                if (!last_pfn)
 666                        map_pfn = page_to_pfn(page_list[0]);
 667
 668                for (i = 0; i < pinned; i++) {
 669                        unsigned long this_pfn = page_to_pfn(page_list[i]);
 670                        u64 csize;
 671
 672                        if (last_pfn && (this_pfn != last_pfn + 1)) {
 673                                /* Pin a contiguous chunk of memory */
 674                                csize = (last_pfn - map_pfn + 1) << PAGE_SHIFT;
 675                                ret = vhost_vdpa_map(v, iova, csize,
 676                                                     map_pfn << PAGE_SHIFT,
 677                                                     msg->perm);
 678                                if (ret) {
 679                                        /*
 680                                         * Unpin the pages that are left unmapped
 681                                         * from this point on in the current
 682                                         * page_list. The remaining outstanding
 683                                         * ones which may stride across several
 684                                         * chunks will be covered in the common
 685                                         * error path subsequently.
 686                                         */
 687                                        unpin_user_pages(&page_list[i],
 688                                                         pinned - i);
 689                                        goto out;
 690                                }
 691
 692                                map_pfn = this_pfn;
 693                                iova += csize;
 694                                nchunks = 0;
 695                        }
 696
 697                        last_pfn = this_pfn;
 698                }
 699
 700                cur_base += pinned << PAGE_SHIFT;
 701                npages -= pinned;
 702        }
 703
 704        /* Pin the rest chunk */
 705        ret = vhost_vdpa_map(v, iova, (last_pfn - map_pfn + 1) << PAGE_SHIFT,
 706                             map_pfn << PAGE_SHIFT, msg->perm);
 707out:
 708        if (ret) {
 709                if (nchunks) {
 710                        unsigned long pfn;
 711
 712                        /*
 713                         * Unpin the outstanding pages which are yet to be
 714                         * mapped but haven't due to vdpa_map() or
 715                         * pin_user_pages() failure.
 716                         *
 717                         * Mapped pages are accounted in vdpa_map(), hence
 718                         * the corresponding unpinning will be handled by
 719                         * vdpa_unmap().
 720                         */
 721                        WARN_ON(!last_pfn);
 722                        for (pfn = map_pfn; pfn <= last_pfn; pfn++)
 723                                unpin_user_page(pfn_to_page(pfn));
 724                }
 725                vhost_vdpa_unmap(v, msg->iova, msg->size);
 726        }
 727unlock:
 728        mmap_read_unlock(dev->mm);
 729free:
 730        free_page((unsigned long)page_list);
 731        return ret;
 732}
 733
 734static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev,
 735                                        struct vhost_iotlb_msg *msg)
 736{
 737        struct vhost_vdpa *v = container_of(dev, struct vhost_vdpa, vdev);
 738        struct vdpa_device *vdpa = v->vdpa;
 739        const struct vdpa_config_ops *ops = vdpa->config;
 740        int r = 0;
 741
 742        r = vhost_dev_check_owner(dev);
 743        if (r)
 744                return r;
 745
 746        switch (msg->type) {
 747        case VHOST_IOTLB_UPDATE:
 748                r = vhost_vdpa_process_iotlb_update(v, msg);
 749                break;
 750        case VHOST_IOTLB_INVALIDATE:
 751                vhost_vdpa_unmap(v, msg->iova, msg->size);
 752                break;
 753        case VHOST_IOTLB_BATCH_BEGIN:
 754                v->in_batch = true;
 755                break;
 756        case VHOST_IOTLB_BATCH_END:
 757                if (v->in_batch && ops->set_map)
 758                        ops->set_map(vdpa, dev->iotlb);
 759                v->in_batch = false;
 760                break;
 761        default:
 762                r = -EINVAL;
 763                break;
 764        }
 765
 766        return r;
 767}
 768
 769static ssize_t vhost_vdpa_chr_write_iter(struct kiocb *iocb,
 770                                         struct iov_iter *from)
 771{
 772        struct file *file = iocb->ki_filp;
 773        struct vhost_vdpa *v = file->private_data;
 774        struct vhost_dev *dev = &v->vdev;
 775
 776        return vhost_chr_write_iter(dev, from);
 777}
 778
 779static int vhost_vdpa_alloc_domain(struct vhost_vdpa *v)
 780{
 781        struct vdpa_device *vdpa = v->vdpa;
 782        const struct vdpa_config_ops *ops = vdpa->config;
 783        struct device *dma_dev = vdpa_get_dma_dev(vdpa);
 784        struct bus_type *bus;
 785        int ret;
 786
 787        /* Device want to do DMA by itself */
 788        if (ops->set_map || ops->dma_map)
 789                return 0;
 790
 791        bus = dma_dev->bus;
 792        if (!bus)
 793                return -EFAULT;
 794
 795        if (!iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
 796                return -ENOTSUPP;
 797
 798        v->domain = iommu_domain_alloc(bus);
 799        if (!v->domain)
 800                return -EIO;
 801
 802        ret = iommu_attach_device(v->domain, dma_dev);
 803        if (ret)
 804                goto err_attach;
 805
 806        return 0;
 807
 808err_attach:
 809        iommu_domain_free(v->domain);
 810        return ret;
 811}
 812
 813static void vhost_vdpa_free_domain(struct vhost_vdpa *v)
 814{
 815        struct vdpa_device *vdpa = v->vdpa;
 816        struct device *dma_dev = vdpa_get_dma_dev(vdpa);
 817
 818        if (v->domain) {
 819                iommu_detach_device(v->domain, dma_dev);
 820                iommu_domain_free(v->domain);
 821        }
 822
 823        v->domain = NULL;
 824}
 825
 826static void vhost_vdpa_set_iova_range(struct vhost_vdpa *v)
 827{
 828        struct vdpa_iova_range *range = &v->range;
 829        struct iommu_domain_geometry geo;
 830        struct vdpa_device *vdpa = v->vdpa;
 831        const struct vdpa_config_ops *ops = vdpa->config;
 832
 833        if (ops->get_iova_range) {
 834                *range = ops->get_iova_range(vdpa);
 835        } else if (v->domain &&
 836                   !iommu_domain_get_attr(v->domain,
 837                   DOMAIN_ATTR_GEOMETRY, &geo) &&
 838                   geo.force_aperture) {
 839                range->first = geo.aperture_start;
 840                range->last = geo.aperture_end;
 841        } else {
 842                range->first = 0;
 843                range->last = ULLONG_MAX;
 844        }
 845}
 846
 847static int vhost_vdpa_open(struct inode *inode, struct file *filep)
 848{
 849        struct vhost_vdpa *v;
 850        struct vhost_dev *dev;
 851        struct vhost_virtqueue **vqs;
 852        int nvqs, i, r, opened;
 853
 854        v = container_of(inode->i_cdev, struct vhost_vdpa, cdev);
 855
 856        opened = atomic_cmpxchg(&v->opened, 0, 1);
 857        if (opened)
 858                return -EBUSY;
 859
 860        nvqs = v->nvqs;
 861        vhost_vdpa_reset(v);
 862
 863        vqs = kmalloc_array(nvqs, sizeof(*vqs), GFP_KERNEL);
 864        if (!vqs) {
 865                r = -ENOMEM;
 866                goto err;
 867        }
 868
 869        dev = &v->vdev;
 870        for (i = 0; i < nvqs; i++) {
 871                vqs[i] = &v->vqs[i];
 872                vqs[i]->handle_kick = handle_vq_kick;
 873        }
 874        vhost_dev_init(dev, vqs, nvqs, 0, 0, 0, false,
 875                       vhost_vdpa_process_iotlb_msg);
 876
 877        dev->iotlb = vhost_iotlb_alloc(0, 0);
 878        if (!dev->iotlb) {
 879                r = -ENOMEM;
 880                goto err_init_iotlb;
 881        }
 882
 883        r = vhost_vdpa_alloc_domain(v);
 884        if (r)
 885                goto err_init_iotlb;
 886
 887        vhost_vdpa_set_iova_range(v);
 888
 889        filep->private_data = v;
 890
 891        return 0;
 892
 893err_init_iotlb:
 894        vhost_dev_cleanup(&v->vdev);
 895        kfree(vqs);
 896err:
 897        atomic_dec(&v->opened);
 898        return r;
 899}
 900
 901static void vhost_vdpa_clean_irq(struct vhost_vdpa *v)
 902{
 903        struct vhost_virtqueue *vq;
 904        int i;
 905
 906        for (i = 0; i < v->nvqs; i++) {
 907                vq = &v->vqs[i];
 908                if (vq->call_ctx.producer.irq)
 909                        irq_bypass_unregister_producer(&vq->call_ctx.producer);
 910        }
 911}
 912
 913static int vhost_vdpa_release(struct inode *inode, struct file *filep)
 914{
 915        struct vhost_vdpa *v = filep->private_data;
 916        struct vhost_dev *d = &v->vdev;
 917
 918        mutex_lock(&d->mutex);
 919        filep->private_data = NULL;
 920        vhost_vdpa_reset(v);
 921        vhost_dev_stop(&v->vdev);
 922        vhost_vdpa_iotlb_free(v);
 923        vhost_vdpa_free_domain(v);
 924        vhost_vdpa_config_put(v);
 925        vhost_vdpa_clean_irq(v);
 926        vhost_dev_cleanup(&v->vdev);
 927        kfree(v->vdev.vqs);
 928        mutex_unlock(&d->mutex);
 929
 930        atomic_dec(&v->opened);
 931        complete(&v->completion);
 932
 933        return 0;
 934}
 935
 936#ifdef CONFIG_MMU
 937static vm_fault_t vhost_vdpa_fault(struct vm_fault *vmf)
 938{
 939        struct vhost_vdpa *v = vmf->vma->vm_file->private_data;
 940        struct vdpa_device *vdpa = v->vdpa;
 941        const struct vdpa_config_ops *ops = vdpa->config;
 942        struct vdpa_notification_area notify;
 943        struct vm_area_struct *vma = vmf->vma;
 944        u16 index = vma->vm_pgoff;
 945
 946        notify = ops->get_vq_notification(vdpa, index);
 947
 948        vma->vm_page_prot = pgprot_noncached(vma->vm_page_prot);
 949        if (remap_pfn_range(vma, vmf->address & PAGE_MASK,
 950                            notify.addr >> PAGE_SHIFT, PAGE_SIZE,
 951                            vma->vm_page_prot))
 952                return VM_FAULT_SIGBUS;
 953
 954        return VM_FAULT_NOPAGE;
 955}
 956
 957static const struct vm_operations_struct vhost_vdpa_vm_ops = {
 958        .fault = vhost_vdpa_fault,
 959};
 960
 961static int vhost_vdpa_mmap(struct file *file, struct vm_area_struct *vma)
 962{
 963        struct vhost_vdpa *v = vma->vm_file->private_data;
 964        struct vdpa_device *vdpa = v->vdpa;
 965        const struct vdpa_config_ops *ops = vdpa->config;
 966        struct vdpa_notification_area notify;
 967        unsigned long index = vma->vm_pgoff;
 968
 969        if (vma->vm_end - vma->vm_start != PAGE_SIZE)
 970                return -EINVAL;
 971        if ((vma->vm_flags & VM_SHARED) == 0)
 972                return -EINVAL;
 973        if (vma->vm_flags & VM_READ)
 974                return -EINVAL;
 975        if (index > 65535)
 976                return -EINVAL;
 977        if (!ops->get_vq_notification)
 978                return -ENOTSUPP;
 979
 980        /* To be safe and easily modelled by userspace, We only
 981         * support the doorbell which sits on the page boundary and
 982         * does not share the page with other registers.
 983         */
 984        notify = ops->get_vq_notification(vdpa, index);
 985        if (notify.addr & (PAGE_SIZE - 1))
 986                return -EINVAL;
 987        if (vma->vm_end - vma->vm_start != notify.size)
 988                return -ENOTSUPP;
 989
 990        vma->vm_ops = &vhost_vdpa_vm_ops;
 991        return 0;
 992}
 993#endif /* CONFIG_MMU */
 994
 995static const struct file_operations vhost_vdpa_fops = {
 996        .owner          = THIS_MODULE,
 997        .open           = vhost_vdpa_open,
 998        .release        = vhost_vdpa_release,
 999        .write_iter     = vhost_vdpa_chr_write_iter,
1000        .unlocked_ioctl = vhost_vdpa_unlocked_ioctl,
1001#ifdef CONFIG_MMU
1002        .mmap           = vhost_vdpa_mmap,
1003#endif /* CONFIG_MMU */
1004        .compat_ioctl   = compat_ptr_ioctl,
1005};
1006
1007static void vhost_vdpa_release_dev(struct device *device)
1008{
1009        struct vhost_vdpa *v =
1010               container_of(device, struct vhost_vdpa, dev);
1011
1012        ida_simple_remove(&vhost_vdpa_ida, v->minor);
1013        kfree(v->vqs);
1014        kfree(v);
1015}
1016
1017static int vhost_vdpa_probe(struct vdpa_device *vdpa)
1018{
1019        const struct vdpa_config_ops *ops = vdpa->config;
1020        struct vhost_vdpa *v;
1021        int minor;
1022        int r;
1023
1024        /* Currently, we only accept the network devices. */
1025        if (ops->get_device_id(vdpa) != VIRTIO_ID_NET)
1026                return -ENOTSUPP;
1027
1028        v = kzalloc(sizeof(*v), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
1029        if (!v)
1030                return -ENOMEM;
1031
1032        minor = ida_simple_get(&vhost_vdpa_ida, 0,
1033                               VHOST_VDPA_DEV_MAX, GFP_KERNEL);
1034        if (minor < 0) {
1035                kfree(v);
1036                return minor;
1037        }
1038
1039        atomic_set(&v->opened, 0);
1040        v->minor = minor;
1041        v->vdpa = vdpa;
1042        v->nvqs = vdpa->nvqs;
1043        v->virtio_id = ops->get_device_id(vdpa);
1044
1045        device_initialize(&v->dev);
1046        v->dev.release = vhost_vdpa_release_dev;
1047        v->dev.parent = &vdpa->dev;
1048        v->dev.devt = MKDEV(MAJOR(vhost_vdpa_major), minor);
1049        v->vqs = kmalloc_array(v->nvqs, sizeof(struct vhost_virtqueue),
1050                               GFP_KERNEL);
1051        if (!v->vqs) {
1052                r = -ENOMEM;
1053                goto err;
1054        }
1055
1056        r = dev_set_name(&v->dev, "vhost-vdpa-%u", minor);
1057        if (r)
1058                goto err;
1059
1060        cdev_init(&v->cdev, &vhost_vdpa_fops);
1061        v->cdev.owner = THIS_MODULE;
1062
1063        r = cdev_device_add(&v->cdev, &v->dev);
1064        if (r)
1065                goto err;
1066
1067        init_completion(&v->completion);
1068        vdpa_set_drvdata(vdpa, v);
1069
1070        return 0;
1071
1072err:
1073        put_device(&v->dev);
1074        return r;
1075}
1076
1077static void vhost_vdpa_remove(struct vdpa_device *vdpa)
1078{
1079        struct vhost_vdpa *v = vdpa_get_drvdata(vdpa);
1080        int opened;
1081
1082        cdev_device_del(&v->cdev, &v->dev);
1083
1084        do {
1085                opened = atomic_cmpxchg(&v->opened, 0, 1);
1086                if (!opened)
1087                        break;
1088                wait_for_completion(&v->completion);
1089        } while (1);
1090
1091        put_device(&v->dev);
1092}
1093
1094static struct vdpa_driver vhost_vdpa_driver = {
1095        .driver = {
1096                .name   = "vhost_vdpa",
1097        },
1098        .probe  = vhost_vdpa_probe,
1099        .remove = vhost_vdpa_remove,
1100};
1101
1102static int __init vhost_vdpa_init(void)
1103{
1104        int r;
1105
1106        r = alloc_chrdev_region(&vhost_vdpa_major, 0, VHOST_VDPA_DEV_MAX,
1107                                "vhost-vdpa");
1108        if (r)
1109                goto err_alloc_chrdev;
1110
1111        r = vdpa_register_driver(&vhost_vdpa_driver);
1112        if (r)
1113                goto err_vdpa_register_driver;
1114
1115        return 0;
1116
1117err_vdpa_register_driver:
1118        unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1119err_alloc_chrdev:
1120        return r;
1121}
1122module_init(vhost_vdpa_init);
1123
1124static void __exit vhost_vdpa_exit(void)
1125{
1126        vdpa_unregister_driver(&vhost_vdpa_driver);
1127        unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1128}
1129module_exit(vhost_vdpa_exit);
1130
1131MODULE_VERSION("0.0.1");
1132MODULE_LICENSE("GPL v2");
1133MODULE_AUTHOR("Intel Corporation");
1134MODULE_DESCRIPTION("vDPA-based vhost backend for virtio");
1135