linux/drivers/vfio/vfio_iommu_type1.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * VFIO: IOMMU DMA mapping support for Type1 IOMMU
   4 *
   5 * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
   6 *     Author: Alex Williamson <alex.williamson@redhat.com>
   7 *
   8 * Derived from original vfio:
   9 * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
  10 * Author: Tom Lyon, pugs@cisco.com
  11 *
  12 * We arbitrarily define a Type1 IOMMU as one matching the below code.
  13 * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
  14 * VT-d, but that makes it harder to re-use as theoretically anyone
  15 * implementing a similar IOMMU could make use of this.  We expect the
  16 * IOMMU to support the IOMMU API and have few to no restrictions around
  17 * the IOVA range that can be mapped.  The Type1 IOMMU is currently
  18 * optimized for relatively static mappings of a userspace process with
  19 * userspace pages pinned into memory.  We also assume devices and IOMMU
  20 * domains are PCI based as the IOMMU API is still centered around a
  21 * device/bus interface rather than a group interface.
  22 */
  23
  24#include <linux/compat.h>
  25#include <linux/device.h>
  26#include <linux/fs.h>
  27#include <linux/highmem.h>
  28#include <linux/iommu.h>
  29#include <linux/module.h>
  30#include <linux/mm.h>
  31#include <linux/kthread.h>
  32#include <linux/rbtree.h>
  33#include <linux/sched/signal.h>
  34#include <linux/sched/mm.h>
  35#include <linux/slab.h>
  36#include <linux/uaccess.h>
  37#include <linux/vfio.h>
  38#include <linux/workqueue.h>
  39#include <linux/notifier.h>
  40#include <linux/dma-iommu.h>
  41#include <linux/irqdomain.h>
  42#include "vfio.h"
  43
  44#define DRIVER_VERSION  "0.2"
  45#define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
  46#define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
  47
  48static bool allow_unsafe_interrupts;
  49module_param_named(allow_unsafe_interrupts,
  50                   allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
  51MODULE_PARM_DESC(allow_unsafe_interrupts,
  52                 "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
  53
  54static bool disable_hugepages;
  55module_param_named(disable_hugepages,
  56                   disable_hugepages, bool, S_IRUGO | S_IWUSR);
  57MODULE_PARM_DESC(disable_hugepages,
  58                 "Disable VFIO IOMMU support for IOMMU hugepages.");
  59
  60static unsigned int dma_entry_limit __read_mostly = U16_MAX;
  61module_param_named(dma_entry_limit, dma_entry_limit, uint, 0644);
  62MODULE_PARM_DESC(dma_entry_limit,
  63                 "Maximum number of user DMA mappings per container (65535).");
  64
  65struct vfio_iommu {
  66        struct list_head        domain_list;
  67        struct list_head        iova_list;
  68        struct mutex            lock;
  69        struct rb_root          dma_list;
  70        struct blocking_notifier_head notifier;
  71        unsigned int            dma_avail;
  72        unsigned int            vaddr_invalid_count;
  73        uint64_t                pgsize_bitmap;
  74        uint64_t                num_non_pinned_groups;
  75        wait_queue_head_t       vaddr_wait;
  76        bool                    v2;
  77        bool                    nesting;
  78        bool                    dirty_page_tracking;
  79        bool                    container_open;
  80        struct list_head        emulated_iommu_groups;
  81};
  82
  83struct vfio_domain {
  84        struct iommu_domain     *domain;
  85        struct list_head        next;
  86        struct list_head        group_list;
  87        int                     prot;           /* IOMMU_CACHE */
  88        bool                    fgsp;           /* Fine-grained super pages */
  89};
  90
  91struct vfio_dma {
  92        struct rb_node          node;
  93        dma_addr_t              iova;           /* Device address */
  94        unsigned long           vaddr;          /* Process virtual addr */
  95        size_t                  size;           /* Map size (bytes) */
  96        int                     prot;           /* IOMMU_READ/WRITE */
  97        bool                    iommu_mapped;
  98        bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
  99        bool                    vaddr_invalid;
 100        struct task_struct      *task;
 101        struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
 102        unsigned long           *bitmap;
 103};
 104
 105struct vfio_batch {
 106        struct page             **pages;        /* for pin_user_pages_remote */
 107        struct page             *fallback_page; /* if pages alloc fails */
 108        int                     capacity;       /* length of pages array */
 109        int                     size;           /* of batch currently */
 110        int                     offset;         /* of next entry in pages */
 111};
 112
 113struct vfio_iommu_group {
 114        struct iommu_group      *iommu_group;
 115        struct list_head        next;
 116        bool                    pinned_page_dirty_scope;
 117};
 118
 119struct vfio_iova {
 120        struct list_head        list;
 121        dma_addr_t              start;
 122        dma_addr_t              end;
 123};
 124
 125/*
 126 * Guest RAM pinning working set or DMA target
 127 */
 128struct vfio_pfn {
 129        struct rb_node          node;
 130        dma_addr_t              iova;           /* Device address */
 131        unsigned long           pfn;            /* Host pfn */
 132        unsigned int            ref_count;
 133};
 134
 135struct vfio_regions {
 136        struct list_head list;
 137        dma_addr_t iova;
 138        phys_addr_t phys;
 139        size_t len;
 140};
 141
 142#define DIRTY_BITMAP_BYTES(n)   (ALIGN(n, BITS_PER_TYPE(u64)) / BITS_PER_BYTE)
 143
 144/*
 145 * Input argument of number of bits to bitmap_set() is unsigned integer, which
 146 * further casts to signed integer for unaligned multi-bit operation,
 147 * __bitmap_set().
 148 * Then maximum bitmap size supported is 2^31 bits divided by 2^3 bits/byte,
 149 * that is 2^28 (256 MB) which maps to 2^31 * 2^12 = 2^43 (8TB) on 4K page
 150 * system.
 151 */
 152#define DIRTY_BITMAP_PAGES_MAX   ((u64)INT_MAX)
 153#define DIRTY_BITMAP_SIZE_MAX    DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
 154
 155#define WAITED 1
 156
 157static int put_pfn(unsigned long pfn, int prot);
 158
 159static struct vfio_iommu_group*
 160vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
 161                            struct iommu_group *iommu_group);
 162
 163/*
 164 * This code handles mapping and unmapping of user data buffers
 165 * into DMA'ble space using the IOMMU
 166 */
 167
 168static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
 169                                      dma_addr_t start, size_t size)
 170{
 171        struct rb_node *node = iommu->dma_list.rb_node;
 172
 173        while (node) {
 174                struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
 175
 176                if (start + size <= dma->iova)
 177                        node = node->rb_left;
 178                else if (start >= dma->iova + dma->size)
 179                        node = node->rb_right;
 180                else
 181                        return dma;
 182        }
 183
 184        return NULL;
 185}
 186
 187static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
 188                                                dma_addr_t start, u64 size)
 189{
 190        struct rb_node *res = NULL;
 191        struct rb_node *node = iommu->dma_list.rb_node;
 192        struct vfio_dma *dma_res = NULL;
 193
 194        while (node) {
 195                struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
 196
 197                if (start < dma->iova + dma->size) {
 198                        res = node;
 199                        dma_res = dma;
 200                        if (start >= dma->iova)
 201                                break;
 202                        node = node->rb_left;
 203                } else {
 204                        node = node->rb_right;
 205                }
 206        }
 207        if (res && size && dma_res->iova >= start + size)
 208                res = NULL;
 209        return res;
 210}
 211
 212static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
 213{
 214        struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
 215        struct vfio_dma *dma;
 216
 217        while (*link) {
 218                parent = *link;
 219                dma = rb_entry(parent, struct vfio_dma, node);
 220
 221                if (new->iova + new->size <= dma->iova)
 222                        link = &(*link)->rb_left;
 223                else
 224                        link = &(*link)->rb_right;
 225        }
 226
 227        rb_link_node(&new->node, parent, link);
 228        rb_insert_color(&new->node, &iommu->dma_list);
 229}
 230
 231static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
 232{
 233        rb_erase(&old->node, &iommu->dma_list);
 234}
 235
 236
 237static int vfio_dma_bitmap_alloc(struct vfio_dma *dma, size_t pgsize)
 238{
 239        uint64_t npages = dma->size / pgsize;
 240
 241        if (npages > DIRTY_BITMAP_PAGES_MAX)
 242                return -EINVAL;
 243
 244        /*
 245         * Allocate extra 64 bits that are used to calculate shift required for
 246         * bitmap_shift_left() to manipulate and club unaligned number of pages
 247         * in adjacent vfio_dma ranges.
 248         */
 249        dma->bitmap = kvzalloc(DIRTY_BITMAP_BYTES(npages) + sizeof(u64),
 250                               GFP_KERNEL);
 251        if (!dma->bitmap)
 252                return -ENOMEM;
 253
 254        return 0;
 255}
 256
 257static void vfio_dma_bitmap_free(struct vfio_dma *dma)
 258{
 259        kfree(dma->bitmap);
 260        dma->bitmap = NULL;
 261}
 262
 263static void vfio_dma_populate_bitmap(struct vfio_dma *dma, size_t pgsize)
 264{
 265        struct rb_node *p;
 266        unsigned long pgshift = __ffs(pgsize);
 267
 268        for (p = rb_first(&dma->pfn_list); p; p = rb_next(p)) {
 269                struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn, node);
 270
 271                bitmap_set(dma->bitmap, (vpfn->iova - dma->iova) >> pgshift, 1);
 272        }
 273}
 274
 275static void vfio_iommu_populate_bitmap_full(struct vfio_iommu *iommu)
 276{
 277        struct rb_node *n;
 278        unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
 279
 280        for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
 281                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
 282
 283                bitmap_set(dma->bitmap, 0, dma->size >> pgshift);
 284        }
 285}
 286
 287static int vfio_dma_bitmap_alloc_all(struct vfio_iommu *iommu, size_t pgsize)
 288{
 289        struct rb_node *n;
 290
 291        for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
 292                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
 293                int ret;
 294
 295                ret = vfio_dma_bitmap_alloc(dma, pgsize);
 296                if (ret) {
 297                        struct rb_node *p;
 298
 299                        for (p = rb_prev(n); p; p = rb_prev(p)) {
 300                                struct vfio_dma *dma = rb_entry(n,
 301                                                        struct vfio_dma, node);
 302
 303                                vfio_dma_bitmap_free(dma);
 304                        }
 305                        return ret;
 306                }
 307                vfio_dma_populate_bitmap(dma, pgsize);
 308        }
 309        return 0;
 310}
 311
 312static void vfio_dma_bitmap_free_all(struct vfio_iommu *iommu)
 313{
 314        struct rb_node *n;
 315
 316        for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
 317                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
 318
 319                vfio_dma_bitmap_free(dma);
 320        }
 321}
 322
 323/*
 324 * Helper Functions for host iova-pfn list
 325 */
 326static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
 327{
 328        struct vfio_pfn *vpfn;
 329        struct rb_node *node = dma->pfn_list.rb_node;
 330
 331        while (node) {
 332                vpfn = rb_entry(node, struct vfio_pfn, node);
 333
 334                if (iova < vpfn->iova)
 335                        node = node->rb_left;
 336                else if (iova > vpfn->iova)
 337                        node = node->rb_right;
 338                else
 339                        return vpfn;
 340        }
 341        return NULL;
 342}
 343
 344static void vfio_link_pfn(struct vfio_dma *dma,
 345                          struct vfio_pfn *new)
 346{
 347        struct rb_node **link, *parent = NULL;
 348        struct vfio_pfn *vpfn;
 349
 350        link = &dma->pfn_list.rb_node;
 351        while (*link) {
 352                parent = *link;
 353                vpfn = rb_entry(parent, struct vfio_pfn, node);
 354
 355                if (new->iova < vpfn->iova)
 356                        link = &(*link)->rb_left;
 357                else
 358                        link = &(*link)->rb_right;
 359        }
 360
 361        rb_link_node(&new->node, parent, link);
 362        rb_insert_color(&new->node, &dma->pfn_list);
 363}
 364
 365static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
 366{
 367        rb_erase(&old->node, &dma->pfn_list);
 368}
 369
 370static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
 371                                unsigned long pfn)
 372{
 373        struct vfio_pfn *vpfn;
 374
 375        vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
 376        if (!vpfn)
 377                return -ENOMEM;
 378
 379        vpfn->iova = iova;
 380        vpfn->pfn = pfn;
 381        vpfn->ref_count = 1;
 382        vfio_link_pfn(dma, vpfn);
 383        return 0;
 384}
 385
 386static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
 387                                      struct vfio_pfn *vpfn)
 388{
 389        vfio_unlink_pfn(dma, vpfn);
 390        kfree(vpfn);
 391}
 392
 393static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
 394                                               unsigned long iova)
 395{
 396        struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
 397
 398        if (vpfn)
 399                vpfn->ref_count++;
 400        return vpfn;
 401}
 402
 403static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
 404{
 405        int ret = 0;
 406
 407        vpfn->ref_count--;
 408        if (!vpfn->ref_count) {
 409                ret = put_pfn(vpfn->pfn, dma->prot);
 410                vfio_remove_from_pfn_list(dma, vpfn);
 411        }
 412        return ret;
 413}
 414
 415static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
 416{
 417        struct mm_struct *mm;
 418        int ret;
 419
 420        if (!npage)
 421                return 0;
 422
 423        mm = async ? get_task_mm(dma->task) : dma->task->mm;
 424        if (!mm)
 425                return -ESRCH; /* process exited */
 426
 427        ret = mmap_write_lock_killable(mm);
 428        if (!ret) {
 429                ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task,
 430                                          dma->lock_cap);
 431                mmap_write_unlock(mm);
 432        }
 433
 434        if (async)
 435                mmput(mm);
 436
 437        return ret;
 438}
 439
 440/*
 441 * Some mappings aren't backed by a struct page, for example an mmap'd
 442 * MMIO range for our own or another device.  These use a different
 443 * pfn conversion and shouldn't be tracked as locked pages.
 444 * For compound pages, any driver that sets the reserved bit in head
 445 * page needs to set the reserved bit in all subpages to be safe.
 446 */
 447static bool is_invalid_reserved_pfn(unsigned long pfn)
 448{
 449        if (pfn_valid(pfn))
 450                return PageReserved(pfn_to_page(pfn));
 451
 452        return true;
 453}
 454
 455static int put_pfn(unsigned long pfn, int prot)
 456{
 457        if (!is_invalid_reserved_pfn(pfn)) {
 458                struct page *page = pfn_to_page(pfn);
 459
 460                unpin_user_pages_dirty_lock(&page, 1, prot & IOMMU_WRITE);
 461                return 1;
 462        }
 463        return 0;
 464}
 465
 466#define VFIO_BATCH_MAX_CAPACITY (PAGE_SIZE / sizeof(struct page *))
 467
 468static void vfio_batch_init(struct vfio_batch *batch)
 469{
 470        batch->size = 0;
 471        batch->offset = 0;
 472
 473        if (unlikely(disable_hugepages))
 474                goto fallback;
 475
 476        batch->pages = (struct page **) __get_free_page(GFP_KERNEL);
 477        if (!batch->pages)
 478                goto fallback;
 479
 480        batch->capacity = VFIO_BATCH_MAX_CAPACITY;
 481        return;
 482
 483fallback:
 484        batch->pages = &batch->fallback_page;
 485        batch->capacity = 1;
 486}
 487
 488static void vfio_batch_unpin(struct vfio_batch *batch, struct vfio_dma *dma)
 489{
 490        while (batch->size) {
 491                unsigned long pfn = page_to_pfn(batch->pages[batch->offset]);
 492
 493                put_pfn(pfn, dma->prot);
 494                batch->offset++;
 495                batch->size--;
 496        }
 497}
 498
 499static void vfio_batch_fini(struct vfio_batch *batch)
 500{
 501        if (batch->capacity == VFIO_BATCH_MAX_CAPACITY)
 502                free_page((unsigned long)batch->pages);
 503}
 504
 505static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
 506                            unsigned long vaddr, unsigned long *pfn,
 507                            bool write_fault)
 508{
 509        pte_t *ptep;
 510        spinlock_t *ptl;
 511        int ret;
 512
 513        ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
 514        if (ret) {
 515                bool unlocked = false;
 516
 517                ret = fixup_user_fault(mm, vaddr,
 518                                       FAULT_FLAG_REMOTE |
 519                                       (write_fault ?  FAULT_FLAG_WRITE : 0),
 520                                       &unlocked);
 521                if (unlocked)
 522                        return -EAGAIN;
 523
 524                if (ret)
 525                        return ret;
 526
 527                ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
 528                if (ret)
 529                        return ret;
 530        }
 531
 532        if (write_fault && !pte_write(*ptep))
 533                ret = -EFAULT;
 534        else
 535                *pfn = pte_pfn(*ptep);
 536
 537        pte_unmap_unlock(ptep, ptl);
 538        return ret;
 539}
 540
 541/*
 542 * Returns the positive number of pfns successfully obtained or a negative
 543 * error code.
 544 */
 545static int vaddr_get_pfns(struct mm_struct *mm, unsigned long vaddr,
 546                          long npages, int prot, unsigned long *pfn,
 547                          struct page **pages)
 548{
 549        struct vm_area_struct *vma;
 550        unsigned int flags = 0;
 551        int ret;
 552
 553        if (prot & IOMMU_WRITE)
 554                flags |= FOLL_WRITE;
 555
 556        mmap_read_lock(mm);
 557        ret = pin_user_pages_remote(mm, vaddr, npages, flags | FOLL_LONGTERM,
 558                                    pages, NULL, NULL);
 559        if (ret > 0) {
 560                *pfn = page_to_pfn(pages[0]);
 561                goto done;
 562        }
 563
 564        vaddr = untagged_addr(vaddr);
 565
 566retry:
 567        vma = vma_lookup(mm, vaddr);
 568
 569        if (vma && vma->vm_flags & VM_PFNMAP) {
 570                ret = follow_fault_pfn(vma, mm, vaddr, pfn, prot & IOMMU_WRITE);
 571                if (ret == -EAGAIN)
 572                        goto retry;
 573
 574                if (!ret) {
 575                        if (is_invalid_reserved_pfn(*pfn))
 576                                ret = 1;
 577                        else
 578                                ret = -EFAULT;
 579                }
 580        }
 581done:
 582        mmap_read_unlock(mm);
 583        return ret;
 584}
 585
 586static int vfio_wait(struct vfio_iommu *iommu)
 587{
 588        DEFINE_WAIT(wait);
 589
 590        prepare_to_wait(&iommu->vaddr_wait, &wait, TASK_KILLABLE);
 591        mutex_unlock(&iommu->lock);
 592        schedule();
 593        mutex_lock(&iommu->lock);
 594        finish_wait(&iommu->vaddr_wait, &wait);
 595        if (kthread_should_stop() || !iommu->container_open ||
 596            fatal_signal_pending(current)) {
 597                return -EFAULT;
 598        }
 599        return WAITED;
 600}
 601
 602/*
 603 * Find dma struct and wait for its vaddr to be valid.  iommu lock is dropped
 604 * if the task waits, but is re-locked on return.  Return result in *dma_p.
 605 * Return 0 on success with no waiting, WAITED on success if waited, and -errno
 606 * on error.
 607 */
 608static int vfio_find_dma_valid(struct vfio_iommu *iommu, dma_addr_t start,
 609                               size_t size, struct vfio_dma **dma_p)
 610{
 611        int ret = 0;
 612
 613        do {
 614                *dma_p = vfio_find_dma(iommu, start, size);
 615                if (!*dma_p)
 616                        return -EINVAL;
 617                else if (!(*dma_p)->vaddr_invalid)
 618                        return ret;
 619                else
 620                        ret = vfio_wait(iommu);
 621        } while (ret == WAITED);
 622
 623        return ret;
 624}
 625
 626/*
 627 * Wait for all vaddr in the dma_list to become valid.  iommu lock is dropped
 628 * if the task waits, but is re-locked on return.  Return 0 on success with no
 629 * waiting, WAITED on success if waited, and -errno on error.
 630 */
 631static int vfio_wait_all_valid(struct vfio_iommu *iommu)
 632{
 633        int ret = 0;
 634
 635        while (iommu->vaddr_invalid_count && ret >= 0)
 636                ret = vfio_wait(iommu);
 637
 638        return ret;
 639}
 640
 641/*
 642 * Attempt to pin pages.  We really don't want to track all the pfns and
 643 * the iommu can only map chunks of consecutive pfns anyway, so get the
 644 * first page and all consecutive pages with the same locking.
 645 */
 646static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
 647                                  long npage, unsigned long *pfn_base,
 648                                  unsigned long limit, struct vfio_batch *batch)
 649{
 650        unsigned long pfn;
 651        struct mm_struct *mm = current->mm;
 652        long ret, pinned = 0, lock_acct = 0;
 653        bool rsvd;
 654        dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
 655
 656        /* This code path is only user initiated */
 657        if (!mm)
 658                return -ENODEV;
 659
 660        if (batch->size) {
 661                /* Leftover pages in batch from an earlier call. */
 662                *pfn_base = page_to_pfn(batch->pages[batch->offset]);
 663                pfn = *pfn_base;
 664                rsvd = is_invalid_reserved_pfn(*pfn_base);
 665        } else {
 666                *pfn_base = 0;
 667        }
 668
 669        while (npage) {
 670                if (!batch->size) {
 671                        /* Empty batch, so refill it. */
 672                        long req_pages = min_t(long, npage, batch->capacity);
 673
 674                        ret = vaddr_get_pfns(mm, vaddr, req_pages, dma->prot,
 675                                             &pfn, batch->pages);
 676                        if (ret < 0)
 677                                goto unpin_out;
 678
 679                        batch->size = ret;
 680                        batch->offset = 0;
 681
 682                        if (!*pfn_base) {
 683                                *pfn_base = pfn;
 684                                rsvd = is_invalid_reserved_pfn(*pfn_base);
 685                        }
 686                }
 687
 688                /*
 689                 * pfn is preset for the first iteration of this inner loop and
 690                 * updated at the end to handle a VM_PFNMAP pfn.  In that case,
 691                 * batch->pages isn't valid (there's no struct page), so allow
 692                 * batch->pages to be touched only when there's more than one
 693                 * pfn to check, which guarantees the pfns are from a
 694                 * !VM_PFNMAP vma.
 695                 */
 696                while (true) {
 697                        if (pfn != *pfn_base + pinned ||
 698                            rsvd != is_invalid_reserved_pfn(pfn))
 699                                goto out;
 700
 701                        /*
 702                         * Reserved pages aren't counted against the user,
 703                         * externally pinned pages are already counted against
 704                         * the user.
 705                         */
 706                        if (!rsvd && !vfio_find_vpfn(dma, iova)) {
 707                                if (!dma->lock_cap &&
 708                                    mm->locked_vm + lock_acct + 1 > limit) {
 709                                        pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
 710                                                __func__, limit << PAGE_SHIFT);
 711                                        ret = -ENOMEM;
 712                                        goto unpin_out;
 713                                }
 714                                lock_acct++;
 715                        }
 716
 717                        pinned++;
 718                        npage--;
 719                        vaddr += PAGE_SIZE;
 720                        iova += PAGE_SIZE;
 721                        batch->offset++;
 722                        batch->size--;
 723
 724                        if (!batch->size)
 725                                break;
 726
 727                        pfn = page_to_pfn(batch->pages[batch->offset]);
 728                }
 729
 730                if (unlikely(disable_hugepages))
 731                        break;
 732        }
 733
 734out:
 735        ret = vfio_lock_acct(dma, lock_acct, false);
 736
 737unpin_out:
 738        if (batch->size == 1 && !batch->offset) {
 739                /* May be a VM_PFNMAP pfn, which the batch can't remember. */
 740                put_pfn(pfn, dma->prot);
 741                batch->size = 0;
 742        }
 743
 744        if (ret < 0) {
 745                if (pinned && !rsvd) {
 746                        for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
 747                                put_pfn(pfn, dma->prot);
 748                }
 749                vfio_batch_unpin(batch, dma);
 750
 751                return ret;
 752        }
 753
 754        return pinned;
 755}
 756
 757static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
 758                                    unsigned long pfn, long npage,
 759                                    bool do_accounting)
 760{
 761        long unlocked = 0, locked = 0;
 762        long i;
 763
 764        for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
 765                if (put_pfn(pfn++, dma->prot)) {
 766                        unlocked++;
 767                        if (vfio_find_vpfn(dma, iova))
 768                                locked++;
 769                }
 770        }
 771
 772        if (do_accounting)
 773                vfio_lock_acct(dma, locked - unlocked, true);
 774
 775        return unlocked;
 776}
 777
 778static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
 779                                  unsigned long *pfn_base, bool do_accounting)
 780{
 781        struct page *pages[1];
 782        struct mm_struct *mm;
 783        int ret;
 784
 785        mm = get_task_mm(dma->task);
 786        if (!mm)
 787                return -ENODEV;
 788
 789        ret = vaddr_get_pfns(mm, vaddr, 1, dma->prot, pfn_base, pages);
 790        if (ret != 1)
 791                goto out;
 792
 793        ret = 0;
 794
 795        if (do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
 796                ret = vfio_lock_acct(dma, 1, true);
 797                if (ret) {
 798                        put_pfn(*pfn_base, dma->prot);
 799                        if (ret == -ENOMEM)
 800                                pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
 801                                        "(%ld) exceeded\n", __func__,
 802                                        dma->task->comm, task_pid_nr(dma->task),
 803                                        task_rlimit(dma->task, RLIMIT_MEMLOCK));
 804                }
 805        }
 806
 807out:
 808        mmput(mm);
 809        return ret;
 810}
 811
 812static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
 813                                    bool do_accounting)
 814{
 815        int unlocked;
 816        struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
 817
 818        if (!vpfn)
 819                return 0;
 820
 821        unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
 822
 823        if (do_accounting)
 824                vfio_lock_acct(dma, -unlocked, true);
 825
 826        return unlocked;
 827}
 828
 829static int vfio_iommu_type1_pin_pages(void *iommu_data,
 830                                      struct iommu_group *iommu_group,
 831                                      unsigned long *user_pfn,
 832                                      int npage, int prot,
 833                                      unsigned long *phys_pfn)
 834{
 835        struct vfio_iommu *iommu = iommu_data;
 836        struct vfio_iommu_group *group;
 837        int i, j, ret;
 838        unsigned long remote_vaddr;
 839        struct vfio_dma *dma;
 840        bool do_accounting;
 841        dma_addr_t iova;
 842
 843        if (!iommu || !user_pfn || !phys_pfn)
 844                return -EINVAL;
 845
 846        /* Supported for v2 version only */
 847        if (!iommu->v2)
 848                return -EACCES;
 849
 850        mutex_lock(&iommu->lock);
 851
 852        /*
 853         * Wait for all necessary vaddr's to be valid so they can be used in
 854         * the main loop without dropping the lock, to avoid racing vs unmap.
 855         */
 856again:
 857        if (iommu->vaddr_invalid_count) {
 858                for (i = 0; i < npage; i++) {
 859                        iova = user_pfn[i] << PAGE_SHIFT;
 860                        ret = vfio_find_dma_valid(iommu, iova, PAGE_SIZE, &dma);
 861                        if (ret < 0)
 862                                goto pin_done;
 863                        if (ret == WAITED)
 864                                goto again;
 865                }
 866        }
 867
 868        /* Fail if notifier list is empty */
 869        if (!iommu->notifier.head) {
 870                ret = -EINVAL;
 871                goto pin_done;
 872        }
 873
 874        /*
 875         * If iommu capable domain exist in the container then all pages are
 876         * already pinned and accounted. Accounting should be done if there is no
 877         * iommu capable domain in the container.
 878         */
 879        do_accounting = list_empty(&iommu->domain_list);
 880
 881        for (i = 0; i < npage; i++) {
 882                struct vfio_pfn *vpfn;
 883
 884                iova = user_pfn[i] << PAGE_SHIFT;
 885                dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
 886                if (!dma) {
 887                        ret = -EINVAL;
 888                        goto pin_unwind;
 889                }
 890
 891                if ((dma->prot & prot) != prot) {
 892                        ret = -EPERM;
 893                        goto pin_unwind;
 894                }
 895
 896                vpfn = vfio_iova_get_vfio_pfn(dma, iova);
 897                if (vpfn) {
 898                        phys_pfn[i] = vpfn->pfn;
 899                        continue;
 900                }
 901
 902                remote_vaddr = dma->vaddr + (iova - dma->iova);
 903                ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
 904                                             do_accounting);
 905                if (ret)
 906                        goto pin_unwind;
 907
 908                ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
 909                if (ret) {
 910                        if (put_pfn(phys_pfn[i], dma->prot) && do_accounting)
 911                                vfio_lock_acct(dma, -1, true);
 912                        goto pin_unwind;
 913                }
 914
 915                if (iommu->dirty_page_tracking) {
 916                        unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
 917
 918                        /*
 919                         * Bitmap populated with the smallest supported page
 920                         * size
 921                         */
 922                        bitmap_set(dma->bitmap,
 923                                   (iova - dma->iova) >> pgshift, 1);
 924                }
 925        }
 926        ret = i;
 927
 928        group = vfio_iommu_find_iommu_group(iommu, iommu_group);
 929        if (!group->pinned_page_dirty_scope) {
 930                group->pinned_page_dirty_scope = true;
 931                iommu->num_non_pinned_groups--;
 932        }
 933
 934        goto pin_done;
 935
 936pin_unwind:
 937        phys_pfn[i] = 0;
 938        for (j = 0; j < i; j++) {
 939                dma_addr_t iova;
 940
 941                iova = user_pfn[j] << PAGE_SHIFT;
 942                dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
 943                vfio_unpin_page_external(dma, iova, do_accounting);
 944                phys_pfn[j] = 0;
 945        }
 946pin_done:
 947        mutex_unlock(&iommu->lock);
 948        return ret;
 949}
 950
 951static int vfio_iommu_type1_unpin_pages(void *iommu_data,
 952                                        unsigned long *user_pfn,
 953                                        int npage)
 954{
 955        struct vfio_iommu *iommu = iommu_data;
 956        bool do_accounting;
 957        int i;
 958
 959        if (!iommu || !user_pfn || npage <= 0)
 960                return -EINVAL;
 961
 962        /* Supported for v2 version only */
 963        if (!iommu->v2)
 964                return -EACCES;
 965
 966        mutex_lock(&iommu->lock);
 967
 968        do_accounting = list_empty(&iommu->domain_list);
 969        for (i = 0; i < npage; i++) {
 970                struct vfio_dma *dma;
 971                dma_addr_t iova;
 972
 973                iova = user_pfn[i] << PAGE_SHIFT;
 974                dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
 975                if (!dma)
 976                        break;
 977
 978                vfio_unpin_page_external(dma, iova, do_accounting);
 979        }
 980
 981        mutex_unlock(&iommu->lock);
 982        return i > 0 ? i : -EINVAL;
 983}
 984
 985static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
 986                            struct list_head *regions,
 987                            struct iommu_iotlb_gather *iotlb_gather)
 988{
 989        long unlocked = 0;
 990        struct vfio_regions *entry, *next;
 991
 992        iommu_iotlb_sync(domain->domain, iotlb_gather);
 993
 994        list_for_each_entry_safe(entry, next, regions, list) {
 995                unlocked += vfio_unpin_pages_remote(dma,
 996                                                    entry->iova,
 997                                                    entry->phys >> PAGE_SHIFT,
 998                                                    entry->len >> PAGE_SHIFT,
 999                                                    false);
1000                list_del(&entry->list);
1001                kfree(entry);
1002        }
1003
1004        cond_resched();
1005
1006        return unlocked;
1007}
1008
1009/*
1010 * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
1011 * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
1012 * of these regions (currently using a list).
1013 *
1014 * This value specifies maximum number of regions for each IOTLB flush sync.
1015 */
1016#define VFIO_IOMMU_TLB_SYNC_MAX         512
1017
1018static size_t unmap_unpin_fast(struct vfio_domain *domain,
1019                               struct vfio_dma *dma, dma_addr_t *iova,
1020                               size_t len, phys_addr_t phys, long *unlocked,
1021                               struct list_head *unmapped_list,
1022                               int *unmapped_cnt,
1023                               struct iommu_iotlb_gather *iotlb_gather)
1024{
1025        size_t unmapped = 0;
1026        struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
1027
1028        if (entry) {
1029                unmapped = iommu_unmap_fast(domain->domain, *iova, len,
1030                                            iotlb_gather);
1031
1032                if (!unmapped) {
1033                        kfree(entry);
1034                } else {
1035                        entry->iova = *iova;
1036                        entry->phys = phys;
1037                        entry->len  = unmapped;
1038                        list_add_tail(&entry->list, unmapped_list);
1039
1040                        *iova += unmapped;
1041                        (*unmapped_cnt)++;
1042                }
1043        }
1044
1045        /*
1046         * Sync if the number of fast-unmap regions hits the limit
1047         * or in case of errors.
1048         */
1049        if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
1050                *unlocked += vfio_sync_unpin(dma, domain, unmapped_list,
1051                                             iotlb_gather);
1052                *unmapped_cnt = 0;
1053        }
1054
1055        return unmapped;
1056}
1057
1058static size_t unmap_unpin_slow(struct vfio_domain *domain,
1059                               struct vfio_dma *dma, dma_addr_t *iova,
1060                               size_t len, phys_addr_t phys,
1061                               long *unlocked)
1062{
1063        size_t unmapped = iommu_unmap(domain->domain, *iova, len);
1064
1065        if (unmapped) {
1066                *unlocked += vfio_unpin_pages_remote(dma, *iova,
1067                                                     phys >> PAGE_SHIFT,
1068                                                     unmapped >> PAGE_SHIFT,
1069                                                     false);
1070                *iova += unmapped;
1071                cond_resched();
1072        }
1073        return unmapped;
1074}
1075
1076static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
1077                             bool do_accounting)
1078{
1079        dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
1080        struct vfio_domain *domain, *d;
1081        LIST_HEAD(unmapped_region_list);
1082        struct iommu_iotlb_gather iotlb_gather;
1083        int unmapped_region_cnt = 0;
1084        long unlocked = 0;
1085
1086        if (!dma->size)
1087                return 0;
1088
1089        if (list_empty(&iommu->domain_list))
1090                return 0;
1091
1092        /*
1093         * We use the IOMMU to track the physical addresses, otherwise we'd
1094         * need a much more complicated tracking system.  Unfortunately that
1095         * means we need to use one of the iommu domains to figure out the
1096         * pfns to unpin.  The rest need to be unmapped in advance so we have
1097         * no iommu translations remaining when the pages are unpinned.
1098         */
1099        domain = d = list_first_entry(&iommu->domain_list,
1100                                      struct vfio_domain, next);
1101
1102        list_for_each_entry_continue(d, &iommu->domain_list, next) {
1103                iommu_unmap(d->domain, dma->iova, dma->size);
1104                cond_resched();
1105        }
1106
1107        iommu_iotlb_gather_init(&iotlb_gather);
1108        while (iova < end) {
1109                size_t unmapped, len;
1110                phys_addr_t phys, next;
1111
1112                phys = iommu_iova_to_phys(domain->domain, iova);
1113                if (WARN_ON(!phys)) {
1114                        iova += PAGE_SIZE;
1115                        continue;
1116                }
1117
1118                /*
1119                 * To optimize for fewer iommu_unmap() calls, each of which
1120                 * may require hardware cache flushing, try to find the
1121                 * largest contiguous physical memory chunk to unmap.
1122                 */
1123                for (len = PAGE_SIZE;
1124                     !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
1125                        next = iommu_iova_to_phys(domain->domain, iova + len);
1126                        if (next != phys + len)
1127                                break;
1128                }
1129
1130                /*
1131                 * First, try to use fast unmap/unpin. In case of failure,
1132                 * switch to slow unmap/unpin path.
1133                 */
1134                unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
1135                                            &unlocked, &unmapped_region_list,
1136                                            &unmapped_region_cnt,
1137                                            &iotlb_gather);
1138                if (!unmapped) {
1139                        unmapped = unmap_unpin_slow(domain, dma, &iova, len,
1140                                                    phys, &unlocked);
1141                        if (WARN_ON(!unmapped))
1142                                break;
1143                }
1144        }
1145
1146        dma->iommu_mapped = false;
1147
1148        if (unmapped_region_cnt) {
1149                unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list,
1150                                            &iotlb_gather);
1151        }
1152
1153        if (do_accounting) {
1154                vfio_lock_acct(dma, -unlocked, true);
1155                return 0;
1156        }
1157        return unlocked;
1158}
1159
1160static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
1161{
1162        WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list));
1163        vfio_unmap_unpin(iommu, dma, true);
1164        vfio_unlink_dma(iommu, dma);
1165        put_task_struct(dma->task);
1166        vfio_dma_bitmap_free(dma);
1167        if (dma->vaddr_invalid) {
1168                iommu->vaddr_invalid_count--;
1169                wake_up_all(&iommu->vaddr_wait);
1170        }
1171        kfree(dma);
1172        iommu->dma_avail++;
1173}
1174
1175static void vfio_update_pgsize_bitmap(struct vfio_iommu *iommu)
1176{
1177        struct vfio_domain *domain;
1178
1179        iommu->pgsize_bitmap = ULONG_MAX;
1180
1181        list_for_each_entry(domain, &iommu->domain_list, next)
1182                iommu->pgsize_bitmap &= domain->domain->pgsize_bitmap;
1183
1184        /*
1185         * In case the IOMMU supports page sizes smaller than PAGE_SIZE
1186         * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
1187         * That way the user will be able to map/unmap buffers whose size/
1188         * start address is aligned with PAGE_SIZE. Pinning code uses that
1189         * granularity while iommu driver can use the sub-PAGE_SIZE size
1190         * to map the buffer.
1191         */
1192        if (iommu->pgsize_bitmap & ~PAGE_MASK) {
1193                iommu->pgsize_bitmap &= PAGE_MASK;
1194                iommu->pgsize_bitmap |= PAGE_SIZE;
1195        }
1196}
1197
1198static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1199                              struct vfio_dma *dma, dma_addr_t base_iova,
1200                              size_t pgsize)
1201{
1202        unsigned long pgshift = __ffs(pgsize);
1203        unsigned long nbits = dma->size >> pgshift;
1204        unsigned long bit_offset = (dma->iova - base_iova) >> pgshift;
1205        unsigned long copy_offset = bit_offset / BITS_PER_LONG;
1206        unsigned long shift = bit_offset % BITS_PER_LONG;
1207        unsigned long leftover;
1208
1209        /*
1210         * mark all pages dirty if any IOMMU capable device is not able
1211         * to report dirty pages and all pages are pinned and mapped.
1212         */
1213        if (iommu->num_non_pinned_groups && dma->iommu_mapped)
1214                bitmap_set(dma->bitmap, 0, nbits);
1215
1216        if (shift) {
1217                bitmap_shift_left(dma->bitmap, dma->bitmap, shift,
1218                                  nbits + shift);
1219
1220                if (copy_from_user(&leftover,
1221                                   (void __user *)(bitmap + copy_offset),
1222                                   sizeof(leftover)))
1223                        return -EFAULT;
1224
1225                bitmap_or(dma->bitmap, dma->bitmap, &leftover, shift);
1226        }
1227
1228        if (copy_to_user((void __user *)(bitmap + copy_offset), dma->bitmap,
1229                         DIRTY_BITMAP_BYTES(nbits + shift)))
1230                return -EFAULT;
1231
1232        return 0;
1233}
1234
1235static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1236                                  dma_addr_t iova, size_t size, size_t pgsize)
1237{
1238        struct vfio_dma *dma;
1239        struct rb_node *n;
1240        unsigned long pgshift = __ffs(pgsize);
1241        int ret;
1242
1243        /*
1244         * GET_BITMAP request must fully cover vfio_dma mappings.  Multiple
1245         * vfio_dma mappings may be clubbed by specifying large ranges, but
1246         * there must not be any previous mappings bisected by the range.
1247         * An error will be returned if these conditions are not met.
1248         */
1249        dma = vfio_find_dma(iommu, iova, 1);
1250        if (dma && dma->iova != iova)
1251                return -EINVAL;
1252
1253        dma = vfio_find_dma(iommu, iova + size - 1, 0);
1254        if (dma && dma->iova + dma->size != iova + size)
1255                return -EINVAL;
1256
1257        for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1258                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1259
1260                if (dma->iova < iova)
1261                        continue;
1262
1263                if (dma->iova > iova + size - 1)
1264                        break;
1265
1266                ret = update_user_bitmap(bitmap, iommu, dma, iova, pgsize);
1267                if (ret)
1268                        return ret;
1269
1270                /*
1271                 * Re-populate bitmap to include all pinned pages which are
1272                 * considered as dirty but exclude pages which are unpinned and
1273                 * pages which are marked dirty by vfio_dma_rw()
1274                 */
1275                bitmap_clear(dma->bitmap, 0, dma->size >> pgshift);
1276                vfio_dma_populate_bitmap(dma, pgsize);
1277        }
1278        return 0;
1279}
1280
1281static int verify_bitmap_size(uint64_t npages, uint64_t bitmap_size)
1282{
1283        if (!npages || !bitmap_size || (bitmap_size > DIRTY_BITMAP_SIZE_MAX) ||
1284            (bitmap_size < DIRTY_BITMAP_BYTES(npages)))
1285                return -EINVAL;
1286
1287        return 0;
1288}
1289
1290static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
1291                             struct vfio_iommu_type1_dma_unmap *unmap,
1292                             struct vfio_bitmap *bitmap)
1293{
1294        struct vfio_dma *dma, *dma_last = NULL;
1295        size_t unmapped = 0, pgsize;
1296        int ret = -EINVAL, retries = 0;
1297        unsigned long pgshift;
1298        dma_addr_t iova = unmap->iova;
1299        u64 size = unmap->size;
1300        bool unmap_all = unmap->flags & VFIO_DMA_UNMAP_FLAG_ALL;
1301        bool invalidate_vaddr = unmap->flags & VFIO_DMA_UNMAP_FLAG_VADDR;
1302        struct rb_node *n, *first_n;
1303
1304        mutex_lock(&iommu->lock);
1305
1306        pgshift = __ffs(iommu->pgsize_bitmap);
1307        pgsize = (size_t)1 << pgshift;
1308
1309        if (iova & (pgsize - 1))
1310                goto unlock;
1311
1312        if (unmap_all) {
1313                if (iova || size)
1314                        goto unlock;
1315                size = U64_MAX;
1316        } else if (!size || size & (pgsize - 1) ||
1317                   iova + size - 1 < iova || size > SIZE_MAX) {
1318                goto unlock;
1319        }
1320
1321        /* When dirty tracking is enabled, allow only min supported pgsize */
1322        if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
1323            (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {
1324                goto unlock;
1325        }
1326
1327        WARN_ON((pgsize - 1) & PAGE_MASK);
1328again:
1329        /*
1330         * vfio-iommu-type1 (v1) - User mappings were coalesced together to
1331         * avoid tracking individual mappings.  This means that the granularity
1332         * of the original mapping was lost and the user was allowed to attempt
1333         * to unmap any range.  Depending on the contiguousness of physical
1334         * memory and page sizes supported by the IOMMU, arbitrary unmaps may
1335         * or may not have worked.  We only guaranteed unmap granularity
1336         * matching the original mapping; even though it was untracked here,
1337         * the original mappings are reflected in IOMMU mappings.  This
1338         * resulted in a couple unusual behaviors.  First, if a range is not
1339         * able to be unmapped, ex. a set of 4k pages that was mapped as a
1340         * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
1341         * a zero sized unmap.  Also, if an unmap request overlaps the first
1342         * address of a hugepage, the IOMMU will unmap the entire hugepage.
1343         * This also returns success and the returned unmap size reflects the
1344         * actual size unmapped.
1345         *
1346         * We attempt to maintain compatibility with this "v1" interface, but
1347         * we take control out of the hands of the IOMMU.  Therefore, an unmap
1348         * request offset from the beginning of the original mapping will
1349         * return success with zero sized unmap.  And an unmap request covering
1350         * the first iova of mapping will unmap the entire range.
1351         *
1352         * The v2 version of this interface intends to be more deterministic.
1353         * Unmap requests must fully cover previous mappings.  Multiple
1354         * mappings may still be unmaped by specifying large ranges, but there
1355         * must not be any previous mappings bisected by the range.  An error
1356         * will be returned if these conditions are not met.  The v2 interface
1357         * will only return success and a size of zero if there were no
1358         * mappings within the range.
1359         */
1360        if (iommu->v2 && !unmap_all) {
1361                dma = vfio_find_dma(iommu, iova, 1);
1362                if (dma && dma->iova != iova)
1363                        goto unlock;
1364
1365                dma = vfio_find_dma(iommu, iova + size - 1, 0);
1366                if (dma && dma->iova + dma->size != iova + size)
1367                        goto unlock;
1368        }
1369
1370        ret = 0;
1371        n = first_n = vfio_find_dma_first_node(iommu, iova, size);
1372
1373        while (n) {
1374                dma = rb_entry(n, struct vfio_dma, node);
1375                if (dma->iova >= iova + size)
1376                        break;
1377
1378                if (!iommu->v2 && iova > dma->iova)
1379                        break;
1380                /*
1381                 * Task with same address space who mapped this iova range is
1382                 * allowed to unmap the iova range.
1383                 */
1384                if (dma->task->mm != current->mm)
1385                        break;
1386
1387                if (invalidate_vaddr) {
1388                        if (dma->vaddr_invalid) {
1389                                struct rb_node *last_n = n;
1390
1391                                for (n = first_n; n != last_n; n = rb_next(n)) {
1392                                        dma = rb_entry(n,
1393                                                       struct vfio_dma, node);
1394                                        dma->vaddr_invalid = false;
1395                                        iommu->vaddr_invalid_count--;
1396                                }
1397                                ret = -EINVAL;
1398                                unmapped = 0;
1399                                break;
1400                        }
1401                        dma->vaddr_invalid = true;
1402                        iommu->vaddr_invalid_count++;
1403                        unmapped += dma->size;
1404                        n = rb_next(n);
1405                        continue;
1406                }
1407
1408                if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
1409                        struct vfio_iommu_type1_dma_unmap nb_unmap;
1410
1411                        if (dma_last == dma) {
1412                                BUG_ON(++retries > 10);
1413                        } else {
1414                                dma_last = dma;
1415                                retries = 0;
1416                        }
1417
1418                        nb_unmap.iova = dma->iova;
1419                        nb_unmap.size = dma->size;
1420
1421                        /*
1422                         * Notify anyone (mdev vendor drivers) to invalidate and
1423                         * unmap iovas within the range we're about to unmap.
1424                         * Vendor drivers MUST unpin pages in response to an
1425                         * invalidation.
1426                         */
1427                        mutex_unlock(&iommu->lock);
1428                        blocking_notifier_call_chain(&iommu->notifier,
1429                                                    VFIO_IOMMU_NOTIFY_DMA_UNMAP,
1430                                                    &nb_unmap);
1431                        mutex_lock(&iommu->lock);
1432                        goto again;
1433                }
1434
1435                if (unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
1436                        ret = update_user_bitmap(bitmap->data, iommu, dma,
1437                                                 iova, pgsize);
1438                        if (ret)
1439                                break;
1440                }
1441
1442                unmapped += dma->size;
1443                n = rb_next(n);
1444                vfio_remove_dma(iommu, dma);
1445        }
1446
1447unlock:
1448        mutex_unlock(&iommu->lock);
1449
1450        /* Report how much was unmapped */
1451        unmap->size = unmapped;
1452
1453        return ret;
1454}
1455
1456static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
1457                          unsigned long pfn, long npage, int prot)
1458{
1459        struct vfio_domain *d;
1460        int ret;
1461
1462        list_for_each_entry(d, &iommu->domain_list, next) {
1463                ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
1464                                npage << PAGE_SHIFT, prot | d->prot);
1465                if (ret)
1466                        goto unwind;
1467
1468                cond_resched();
1469        }
1470
1471        return 0;
1472
1473unwind:
1474        list_for_each_entry_continue_reverse(d, &iommu->domain_list, next) {
1475                iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
1476                cond_resched();
1477        }
1478
1479        return ret;
1480}
1481
1482static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1483                            size_t map_size)
1484{
1485        dma_addr_t iova = dma->iova;
1486        unsigned long vaddr = dma->vaddr;
1487        struct vfio_batch batch;
1488        size_t size = map_size;
1489        long npage;
1490        unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1491        int ret = 0;
1492
1493        vfio_batch_init(&batch);
1494
1495        while (size) {
1496                /* Pin a contiguous chunk of memory */
1497                npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1498                                              size >> PAGE_SHIFT, &pfn, limit,
1499                                              &batch);
1500                if (npage <= 0) {
1501                        WARN_ON(!npage);
1502                        ret = (int)npage;
1503                        break;
1504                }
1505
1506                /* Map it! */
1507                ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1508                                     dma->prot);
1509                if (ret) {
1510                        vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1511                                                npage, true);
1512                        vfio_batch_unpin(&batch, dma);
1513                        break;
1514                }
1515
1516                size -= npage << PAGE_SHIFT;
1517                dma->size += npage << PAGE_SHIFT;
1518        }
1519
1520        vfio_batch_fini(&batch);
1521        dma->iommu_mapped = true;
1522
1523        if (ret)
1524                vfio_remove_dma(iommu, dma);
1525
1526        return ret;
1527}
1528
1529/*
1530 * Check dma map request is within a valid iova range
1531 */
1532static bool vfio_iommu_iova_dma_valid(struct vfio_iommu *iommu,
1533                                      dma_addr_t start, dma_addr_t end)
1534{
1535        struct list_head *iova = &iommu->iova_list;
1536        struct vfio_iova *node;
1537
1538        list_for_each_entry(node, iova, list) {
1539                if (start >= node->start && end <= node->end)
1540                        return true;
1541        }
1542
1543        /*
1544         * Check for list_empty() as well since a container with
1545         * a single mdev device will have an empty list.
1546         */
1547        return list_empty(iova);
1548}
1549
1550static int vfio_dma_do_map(struct vfio_iommu *iommu,
1551                           struct vfio_iommu_type1_dma_map *map)
1552{
1553        bool set_vaddr = map->flags & VFIO_DMA_MAP_FLAG_VADDR;
1554        dma_addr_t iova = map->iova;
1555        unsigned long vaddr = map->vaddr;
1556        size_t size = map->size;
1557        int ret = 0, prot = 0;
1558        size_t pgsize;
1559        struct vfio_dma *dma;
1560
1561        /* Verify that none of our __u64 fields overflow */
1562        if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1563                return -EINVAL;
1564
1565        /* READ/WRITE from device perspective */
1566        if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1567                prot |= IOMMU_WRITE;
1568        if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1569                prot |= IOMMU_READ;
1570
1571        if ((prot && set_vaddr) || (!prot && !set_vaddr))
1572                return -EINVAL;
1573
1574        mutex_lock(&iommu->lock);
1575
1576        pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
1577
1578        WARN_ON((pgsize - 1) & PAGE_MASK);
1579
1580        if (!size || (size | iova | vaddr) & (pgsize - 1)) {
1581                ret = -EINVAL;
1582                goto out_unlock;
1583        }
1584
1585        /* Don't allow IOVA or virtual address wrap */
1586        if (iova + size - 1 < iova || vaddr + size - 1 < vaddr) {
1587                ret = -EINVAL;
1588                goto out_unlock;
1589        }
1590
1591        dma = vfio_find_dma(iommu, iova, size);
1592        if (set_vaddr) {
1593                if (!dma) {
1594                        ret = -ENOENT;
1595                } else if (!dma->vaddr_invalid || dma->iova != iova ||
1596                           dma->size != size) {
1597                        ret = -EINVAL;
1598                } else {
1599                        dma->vaddr = vaddr;
1600                        dma->vaddr_invalid = false;
1601                        iommu->vaddr_invalid_count--;
1602                        wake_up_all(&iommu->vaddr_wait);
1603                }
1604                goto out_unlock;
1605        } else if (dma) {
1606                ret = -EEXIST;
1607                goto out_unlock;
1608        }
1609
1610        if (!iommu->dma_avail) {
1611                ret = -ENOSPC;
1612                goto out_unlock;
1613        }
1614
1615        if (!vfio_iommu_iova_dma_valid(iommu, iova, iova + size - 1)) {
1616                ret = -EINVAL;
1617                goto out_unlock;
1618        }
1619
1620        dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1621        if (!dma) {
1622                ret = -ENOMEM;
1623                goto out_unlock;
1624        }
1625
1626        iommu->dma_avail--;
1627        dma->iova = iova;
1628        dma->vaddr = vaddr;
1629        dma->prot = prot;
1630
1631        /*
1632         * We need to be able to both add to a task's locked memory and test
1633         * against the locked memory limit and we need to be able to do both
1634         * outside of this call path as pinning can be asynchronous via the
1635         * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1636         * task_struct and VM locked pages requires an mm_struct, however
1637         * holding an indefinite mm reference is not recommended, therefore we
1638         * only hold a reference to a task.  We could hold a reference to
1639         * current, however QEMU uses this call path through vCPU threads,
1640         * which can be killed resulting in a NULL mm and failure in the unmap
1641         * path when called via a different thread.  Avoid this problem by
1642         * using the group_leader as threads within the same group require
1643         * both CLONE_THREAD and CLONE_VM and will therefore use the same
1644         * mm_struct.
1645         *
1646         * Previously we also used the task for testing CAP_IPC_LOCK at the
1647         * time of pinning and accounting, however has_capability() makes use
1648         * of real_cred, a copy-on-write field, so we can't guarantee that it
1649         * matches group_leader, or in fact that it might not change by the
1650         * time it's evaluated.  If a process were to call MAP_DMA with
1651         * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
1652         * possibly see different results for an iommu_mapped vfio_dma vs
1653         * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
1654         * time of calling MAP_DMA.
1655         */
1656        get_task_struct(current->group_leader);
1657        dma->task = current->group_leader;
1658        dma->lock_cap = capable(CAP_IPC_LOCK);
1659
1660        dma->pfn_list = RB_ROOT;
1661
1662        /* Insert zero-sized and grow as we map chunks of it */
1663        vfio_link_dma(iommu, dma);
1664
1665        /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1666        if (list_empty(&iommu->domain_list))
1667                dma->size = size;
1668        else
1669                ret = vfio_pin_map_dma(iommu, dma, size);
1670
1671        if (!ret && iommu->dirty_page_tracking) {
1672                ret = vfio_dma_bitmap_alloc(dma, pgsize);
1673                if (ret)
1674                        vfio_remove_dma(iommu, dma);
1675        }
1676
1677out_unlock:
1678        mutex_unlock(&iommu->lock);
1679        return ret;
1680}
1681
1682static int vfio_bus_type(struct device *dev, void *data)
1683{
1684        struct bus_type **bus = data;
1685
1686        if (*bus && *bus != dev->bus)
1687                return -EINVAL;
1688
1689        *bus = dev->bus;
1690
1691        return 0;
1692}
1693
1694static int vfio_iommu_replay(struct vfio_iommu *iommu,
1695                             struct vfio_domain *domain)
1696{
1697        struct vfio_batch batch;
1698        struct vfio_domain *d = NULL;
1699        struct rb_node *n;
1700        unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1701        int ret;
1702
1703        ret = vfio_wait_all_valid(iommu);
1704        if (ret < 0)
1705                return ret;
1706
1707        /* Arbitrarily pick the first domain in the list for lookups */
1708        if (!list_empty(&iommu->domain_list))
1709                d = list_first_entry(&iommu->domain_list,
1710                                     struct vfio_domain, next);
1711
1712        vfio_batch_init(&batch);
1713
1714        n = rb_first(&iommu->dma_list);
1715
1716        for (; n; n = rb_next(n)) {
1717                struct vfio_dma *dma;
1718                dma_addr_t iova;
1719
1720                dma = rb_entry(n, struct vfio_dma, node);
1721                iova = dma->iova;
1722
1723                while (iova < dma->iova + dma->size) {
1724                        phys_addr_t phys;
1725                        size_t size;
1726
1727                        if (dma->iommu_mapped) {
1728                                phys_addr_t p;
1729                                dma_addr_t i;
1730
1731                                if (WARN_ON(!d)) { /* mapped w/o a domain?! */
1732                                        ret = -EINVAL;
1733                                        goto unwind;
1734                                }
1735
1736                                phys = iommu_iova_to_phys(d->domain, iova);
1737
1738                                if (WARN_ON(!phys)) {
1739                                        iova += PAGE_SIZE;
1740                                        continue;
1741                                }
1742
1743                                size = PAGE_SIZE;
1744                                p = phys + size;
1745                                i = iova + size;
1746                                while (i < dma->iova + dma->size &&
1747                                       p == iommu_iova_to_phys(d->domain, i)) {
1748                                        size += PAGE_SIZE;
1749                                        p += PAGE_SIZE;
1750                                        i += PAGE_SIZE;
1751                                }
1752                        } else {
1753                                unsigned long pfn;
1754                                unsigned long vaddr = dma->vaddr +
1755                                                     (iova - dma->iova);
1756                                size_t n = dma->iova + dma->size - iova;
1757                                long npage;
1758
1759                                npage = vfio_pin_pages_remote(dma, vaddr,
1760                                                              n >> PAGE_SHIFT,
1761                                                              &pfn, limit,
1762                                                              &batch);
1763                                if (npage <= 0) {
1764                                        WARN_ON(!npage);
1765                                        ret = (int)npage;
1766                                        goto unwind;
1767                                }
1768
1769                                phys = pfn << PAGE_SHIFT;
1770                                size = npage << PAGE_SHIFT;
1771                        }
1772
1773                        ret = iommu_map(domain->domain, iova, phys,
1774                                        size, dma->prot | domain->prot);
1775                        if (ret) {
1776                                if (!dma->iommu_mapped) {
1777                                        vfio_unpin_pages_remote(dma, iova,
1778                                                        phys >> PAGE_SHIFT,
1779                                                        size >> PAGE_SHIFT,
1780                                                        true);
1781                                        vfio_batch_unpin(&batch, dma);
1782                                }
1783                                goto unwind;
1784                        }
1785
1786                        iova += size;
1787                }
1788        }
1789
1790        /* All dmas are now mapped, defer to second tree walk for unwind */
1791        for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1792                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1793
1794                dma->iommu_mapped = true;
1795        }
1796
1797        vfio_batch_fini(&batch);
1798        return 0;
1799
1800unwind:
1801        for (; n; n = rb_prev(n)) {
1802                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1803                dma_addr_t iova;
1804
1805                if (dma->iommu_mapped) {
1806                        iommu_unmap(domain->domain, dma->iova, dma->size);
1807                        continue;
1808                }
1809
1810                iova = dma->iova;
1811                while (iova < dma->iova + dma->size) {
1812                        phys_addr_t phys, p;
1813                        size_t size;
1814                        dma_addr_t i;
1815
1816                        phys = iommu_iova_to_phys(domain->domain, iova);
1817                        if (!phys) {
1818                                iova += PAGE_SIZE;
1819                                continue;
1820                        }
1821
1822                        size = PAGE_SIZE;
1823                        p = phys + size;
1824                        i = iova + size;
1825                        while (i < dma->iova + dma->size &&
1826                               p == iommu_iova_to_phys(domain->domain, i)) {
1827                                size += PAGE_SIZE;
1828                                p += PAGE_SIZE;
1829                                i += PAGE_SIZE;
1830                        }
1831
1832                        iommu_unmap(domain->domain, iova, size);
1833                        vfio_unpin_pages_remote(dma, iova, phys >> PAGE_SHIFT,
1834                                                size >> PAGE_SHIFT, true);
1835                }
1836        }
1837
1838        vfio_batch_fini(&batch);
1839        return ret;
1840}
1841
1842/*
1843 * We change our unmap behavior slightly depending on whether the IOMMU
1844 * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1845 * for practically any contiguous power-of-two mapping we give it.  This means
1846 * we don't need to look for contiguous chunks ourselves to make unmapping
1847 * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1848 * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1849 * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1850 * hugetlbfs is in use.
1851 */
1852static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1853{
1854        struct page *pages;
1855        int ret, order = get_order(PAGE_SIZE * 2);
1856
1857        pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1858        if (!pages)
1859                return;
1860
1861        ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1862                        IOMMU_READ | IOMMU_WRITE | domain->prot);
1863        if (!ret) {
1864                size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1865
1866                if (unmapped == PAGE_SIZE)
1867                        iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1868                else
1869                        domain->fgsp = true;
1870        }
1871
1872        __free_pages(pages, order);
1873}
1874
1875static struct vfio_iommu_group *find_iommu_group(struct vfio_domain *domain,
1876                                                 struct iommu_group *iommu_group)
1877{
1878        struct vfio_iommu_group *g;
1879
1880        list_for_each_entry(g, &domain->group_list, next) {
1881                if (g->iommu_group == iommu_group)
1882                        return g;
1883        }
1884
1885        return NULL;
1886}
1887
1888static struct vfio_iommu_group*
1889vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
1890                            struct iommu_group *iommu_group)
1891{
1892        struct vfio_iommu_group *group;
1893        struct vfio_domain *domain;
1894
1895        list_for_each_entry(domain, &iommu->domain_list, next) {
1896                group = find_iommu_group(domain, iommu_group);
1897                if (group)
1898                        return group;
1899        }
1900
1901        list_for_each_entry(group, &iommu->emulated_iommu_groups, next)
1902                if (group->iommu_group == iommu_group)
1903                        return group;
1904        return NULL;
1905}
1906
1907static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
1908                                  phys_addr_t *base)
1909{
1910        struct iommu_resv_region *region;
1911        bool ret = false;
1912
1913        list_for_each_entry(region, group_resv_regions, list) {
1914                /*
1915                 * The presence of any 'real' MSI regions should take
1916                 * precedence over the software-managed one if the
1917                 * IOMMU driver happens to advertise both types.
1918                 */
1919                if (region->type == IOMMU_RESV_MSI) {
1920                        ret = false;
1921                        break;
1922                }
1923
1924                if (region->type == IOMMU_RESV_SW_MSI) {
1925                        *base = region->start;
1926                        ret = true;
1927                }
1928        }
1929
1930        return ret;
1931}
1932
1933/*
1934 * This is a helper function to insert an address range to iova list.
1935 * The list is initially created with a single entry corresponding to
1936 * the IOMMU domain geometry to which the device group is attached.
1937 * The list aperture gets modified when a new domain is added to the
1938 * container if the new aperture doesn't conflict with the current one
1939 * or with any existing dma mappings. The list is also modified to
1940 * exclude any reserved regions associated with the device group.
1941 */
1942static int vfio_iommu_iova_insert(struct list_head *head,
1943                                  dma_addr_t start, dma_addr_t end)
1944{
1945        struct vfio_iova *region;
1946
1947        region = kmalloc(sizeof(*region), GFP_KERNEL);
1948        if (!region)
1949                return -ENOMEM;
1950
1951        INIT_LIST_HEAD(&region->list);
1952        region->start = start;
1953        region->end = end;
1954
1955        list_add_tail(&region->list, head);
1956        return 0;
1957}
1958
1959/*
1960 * Check the new iommu aperture conflicts with existing aper or with any
1961 * existing dma mappings.
1962 */
1963static bool vfio_iommu_aper_conflict(struct vfio_iommu *iommu,
1964                                     dma_addr_t start, dma_addr_t end)
1965{
1966        struct vfio_iova *first, *last;
1967        struct list_head *iova = &iommu->iova_list;
1968
1969        if (list_empty(iova))
1970                return false;
1971
1972        /* Disjoint sets, return conflict */
1973        first = list_first_entry(iova, struct vfio_iova, list);
1974        last = list_last_entry(iova, struct vfio_iova, list);
1975        if (start > last->end || end < first->start)
1976                return true;
1977
1978        /* Check for any existing dma mappings below the new start */
1979        if (start > first->start) {
1980                if (vfio_find_dma(iommu, first->start, start - first->start))
1981                        return true;
1982        }
1983
1984        /* Check for any existing dma mappings beyond the new end */
1985        if (end < last->end) {
1986                if (vfio_find_dma(iommu, end + 1, last->end - end))
1987                        return true;
1988        }
1989
1990        return false;
1991}
1992
1993/*
1994 * Resize iommu iova aperture window. This is called only if the new
1995 * aperture has no conflict with existing aperture and dma mappings.
1996 */
1997static int vfio_iommu_aper_resize(struct list_head *iova,
1998                                  dma_addr_t start, dma_addr_t end)
1999{
2000        struct vfio_iova *node, *next;
2001
2002        if (list_empty(iova))
2003                return vfio_iommu_iova_insert(iova, start, end);
2004
2005        /* Adjust iova list start */
2006        list_for_each_entry_safe(node, next, iova, list) {
2007                if (start < node->start)
2008                        break;
2009                if (start >= node->start && start < node->end) {
2010                        node->start = start;
2011                        break;
2012                }
2013                /* Delete nodes before new start */
2014                list_del(&node->list);
2015                kfree(node);
2016        }
2017
2018        /* Adjust iova list end */
2019        list_for_each_entry_safe(node, next, iova, list) {
2020                if (end > node->end)
2021                        continue;
2022                if (end > node->start && end <= node->end) {
2023                        node->end = end;
2024                        continue;
2025                }
2026                /* Delete nodes after new end */
2027                list_del(&node->list);
2028                kfree(node);
2029        }
2030
2031        return 0;
2032}
2033
2034/*
2035 * Check reserved region conflicts with existing dma mappings
2036 */
2037static bool vfio_iommu_resv_conflict(struct vfio_iommu *iommu,
2038                                     struct list_head *resv_regions)
2039{
2040        struct iommu_resv_region *region;
2041
2042        /* Check for conflict with existing dma mappings */
2043        list_for_each_entry(region, resv_regions, list) {
2044                if (region->type == IOMMU_RESV_DIRECT_RELAXABLE)
2045                        continue;
2046
2047                if (vfio_find_dma(iommu, region->start, region->length))
2048                        return true;
2049        }
2050
2051        return false;
2052}
2053
2054/*
2055 * Check iova region overlap with  reserved regions and
2056 * exclude them from the iommu iova range
2057 */
2058static int vfio_iommu_resv_exclude(struct list_head *iova,
2059                                   struct list_head *resv_regions)
2060{
2061        struct iommu_resv_region *resv;
2062        struct vfio_iova *n, *next;
2063
2064        list_for_each_entry(resv, resv_regions, list) {
2065                phys_addr_t start, end;
2066
2067                if (resv->type == IOMMU_RESV_DIRECT_RELAXABLE)
2068                        continue;
2069
2070                start = resv->start;
2071                end = resv->start + resv->length - 1;
2072
2073                list_for_each_entry_safe(n, next, iova, list) {
2074                        int ret = 0;
2075
2076                        /* No overlap */
2077                        if (start > n->end || end < n->start)
2078                                continue;
2079                        /*
2080                         * Insert a new node if current node overlaps with the
2081                         * reserve region to exclude that from valid iova range.
2082                         * Note that, new node is inserted before the current
2083                         * node and finally the current node is deleted keeping
2084                         * the list updated and sorted.
2085                         */
2086                        if (start > n->start)
2087                                ret = vfio_iommu_iova_insert(&n->list, n->start,
2088                                                             start - 1);
2089                        if (!ret && end < n->end)
2090                                ret = vfio_iommu_iova_insert(&n->list, end + 1,
2091                                                             n->end);
2092                        if (ret)
2093                                return ret;
2094
2095                        list_del(&n->list);
2096                        kfree(n);
2097                }
2098        }
2099
2100        if (list_empty(iova))
2101                return -EINVAL;
2102
2103        return 0;
2104}
2105
2106static void vfio_iommu_resv_free(struct list_head *resv_regions)
2107{
2108        struct iommu_resv_region *n, *next;
2109
2110        list_for_each_entry_safe(n, next, resv_regions, list) {
2111                list_del(&n->list);
2112                kfree(n);
2113        }
2114}
2115
2116static void vfio_iommu_iova_free(struct list_head *iova)
2117{
2118        struct vfio_iova *n, *next;
2119
2120        list_for_each_entry_safe(n, next, iova, list) {
2121                list_del(&n->list);
2122                kfree(n);
2123        }
2124}
2125
2126static int vfio_iommu_iova_get_copy(struct vfio_iommu *iommu,
2127                                    struct list_head *iova_copy)
2128{
2129        struct list_head *iova = &iommu->iova_list;
2130        struct vfio_iova *n;
2131        int ret;
2132
2133        list_for_each_entry(n, iova, list) {
2134                ret = vfio_iommu_iova_insert(iova_copy, n->start, n->end);
2135                if (ret)
2136                        goto out_free;
2137        }
2138
2139        return 0;
2140
2141out_free:
2142        vfio_iommu_iova_free(iova_copy);
2143        return ret;
2144}
2145
2146static void vfio_iommu_iova_insert_copy(struct vfio_iommu *iommu,
2147                                        struct list_head *iova_copy)
2148{
2149        struct list_head *iova = &iommu->iova_list;
2150
2151        vfio_iommu_iova_free(iova);
2152
2153        list_splice_tail(iova_copy, iova);
2154}
2155
2156static int vfio_iommu_type1_attach_group(void *iommu_data,
2157                struct iommu_group *iommu_group, enum vfio_group_type type)
2158{
2159        struct vfio_iommu *iommu = iommu_data;
2160        struct vfio_iommu_group *group;
2161        struct vfio_domain *domain, *d;
2162        struct bus_type *bus = NULL;
2163        bool resv_msi, msi_remap;
2164        phys_addr_t resv_msi_base = 0;
2165        struct iommu_domain_geometry *geo;
2166        LIST_HEAD(iova_copy);
2167        LIST_HEAD(group_resv_regions);
2168        int ret = -EINVAL;
2169
2170        mutex_lock(&iommu->lock);
2171
2172        /* Check for duplicates */
2173        if (vfio_iommu_find_iommu_group(iommu, iommu_group))
2174                goto out_unlock;
2175
2176        ret = -ENOMEM;
2177        group = kzalloc(sizeof(*group), GFP_KERNEL);
2178        if (!group)
2179                goto out_unlock;
2180        group->iommu_group = iommu_group;
2181
2182        if (type == VFIO_EMULATED_IOMMU) {
2183                list_add(&group->next, &iommu->emulated_iommu_groups);
2184                /*
2185                 * An emulated IOMMU group cannot dirty memory directly, it can
2186                 * only use interfaces that provide dirty tracking.
2187                 * The iommu scope can only be promoted with the addition of a
2188                 * dirty tracking group.
2189                 */
2190                group->pinned_page_dirty_scope = true;
2191                ret = 0;
2192                goto out_unlock;
2193        }
2194
2195        /* Determine bus_type in order to allocate a domain */
2196        ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
2197        if (ret)
2198                goto out_free_group;
2199
2200        ret = -ENOMEM;
2201        domain = kzalloc(sizeof(*domain), GFP_KERNEL);
2202        if (!domain)
2203                goto out_free_group;
2204
2205        ret = -EIO;
2206        domain->domain = iommu_domain_alloc(bus);
2207        if (!domain->domain)
2208                goto out_free_domain;
2209
2210        if (iommu->nesting) {
2211                ret = iommu_enable_nesting(domain->domain);
2212                if (ret)
2213                        goto out_domain;
2214        }
2215
2216        ret = iommu_attach_group(domain->domain, group->iommu_group);
2217        if (ret)
2218                goto out_domain;
2219
2220        /* Get aperture info */
2221        geo = &domain->domain->geometry;
2222        if (vfio_iommu_aper_conflict(iommu, geo->aperture_start,
2223                                     geo->aperture_end)) {
2224                ret = -EINVAL;
2225                goto out_detach;
2226        }
2227
2228        ret = iommu_get_group_resv_regions(iommu_group, &group_resv_regions);
2229        if (ret)
2230                goto out_detach;
2231
2232        if (vfio_iommu_resv_conflict(iommu, &group_resv_regions)) {
2233                ret = -EINVAL;
2234                goto out_detach;
2235        }
2236
2237        /*
2238         * We don't want to work on the original iova list as the list
2239         * gets modified and in case of failure we have to retain the
2240         * original list. Get a copy here.
2241         */
2242        ret = vfio_iommu_iova_get_copy(iommu, &iova_copy);
2243        if (ret)
2244                goto out_detach;
2245
2246        ret = vfio_iommu_aper_resize(&iova_copy, geo->aperture_start,
2247                                     geo->aperture_end);
2248        if (ret)
2249                goto out_detach;
2250
2251        ret = vfio_iommu_resv_exclude(&iova_copy, &group_resv_regions);
2252        if (ret)
2253                goto out_detach;
2254
2255        resv_msi = vfio_iommu_has_sw_msi(&group_resv_regions, &resv_msi_base);
2256
2257        INIT_LIST_HEAD(&domain->group_list);
2258        list_add(&group->next, &domain->group_list);
2259
2260        msi_remap = irq_domain_check_msi_remap() ||
2261                    iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
2262
2263        if (!allow_unsafe_interrupts && !msi_remap) {
2264                pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
2265                       __func__);
2266                ret = -EPERM;
2267                goto out_detach;
2268        }
2269
2270        if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
2271                domain->prot |= IOMMU_CACHE;
2272
2273        /*
2274         * Try to match an existing compatible domain.  We don't want to
2275         * preclude an IOMMU driver supporting multiple bus_types and being
2276         * able to include different bus_types in the same IOMMU domain, so
2277         * we test whether the domains use the same iommu_ops rather than
2278         * testing if they're on the same bus_type.
2279         */
2280        list_for_each_entry(d, &iommu->domain_list, next) {
2281                if (d->domain->ops == domain->domain->ops &&
2282                    d->prot == domain->prot) {
2283                        iommu_detach_group(domain->domain, group->iommu_group);
2284                        if (!iommu_attach_group(d->domain,
2285                                                group->iommu_group)) {
2286                                list_add(&group->next, &d->group_list);
2287                                iommu_domain_free(domain->domain);
2288                                kfree(domain);
2289                                goto done;
2290                        }
2291
2292                        ret = iommu_attach_group(domain->domain,
2293                                                 group->iommu_group);
2294                        if (ret)
2295                                goto out_domain;
2296                }
2297        }
2298
2299        vfio_test_domain_fgsp(domain);
2300
2301        /* replay mappings on new domains */
2302        ret = vfio_iommu_replay(iommu, domain);
2303        if (ret)
2304                goto out_detach;
2305
2306        if (resv_msi) {
2307                ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
2308                if (ret && ret != -ENODEV)
2309                        goto out_detach;
2310        }
2311
2312        list_add(&domain->next, &iommu->domain_list);
2313        vfio_update_pgsize_bitmap(iommu);
2314done:
2315        /* Delete the old one and insert new iova list */
2316        vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2317
2318        /*
2319         * An iommu backed group can dirty memory directly and therefore
2320         * demotes the iommu scope until it declares itself dirty tracking
2321         * capable via the page pinning interface.
2322         */
2323        iommu->num_non_pinned_groups++;
2324        mutex_unlock(&iommu->lock);
2325        vfio_iommu_resv_free(&group_resv_regions);
2326
2327        return 0;
2328
2329out_detach:
2330        iommu_detach_group(domain->domain, group->iommu_group);
2331out_domain:
2332        iommu_domain_free(domain->domain);
2333        vfio_iommu_iova_free(&iova_copy);
2334        vfio_iommu_resv_free(&group_resv_regions);
2335out_free_domain:
2336        kfree(domain);
2337out_free_group:
2338        kfree(group);
2339out_unlock:
2340        mutex_unlock(&iommu->lock);
2341        return ret;
2342}
2343
2344static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
2345{
2346        struct rb_node *node;
2347
2348        while ((node = rb_first(&iommu->dma_list)))
2349                vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
2350}
2351
2352static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
2353{
2354        struct rb_node *n, *p;
2355
2356        n = rb_first(&iommu->dma_list);
2357        for (; n; n = rb_next(n)) {
2358                struct vfio_dma *dma;
2359                long locked = 0, unlocked = 0;
2360
2361                dma = rb_entry(n, struct vfio_dma, node);
2362                unlocked += vfio_unmap_unpin(iommu, dma, false);
2363                p = rb_first(&dma->pfn_list);
2364                for (; p; p = rb_next(p)) {
2365                        struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
2366                                                         node);
2367
2368                        if (!is_invalid_reserved_pfn(vpfn->pfn))
2369                                locked++;
2370                }
2371                vfio_lock_acct(dma, locked - unlocked, true);
2372        }
2373}
2374
2375/*
2376 * Called when a domain is removed in detach. It is possible that
2377 * the removed domain decided the iova aperture window. Modify the
2378 * iova aperture with the smallest window among existing domains.
2379 */
2380static void vfio_iommu_aper_expand(struct vfio_iommu *iommu,
2381                                   struct list_head *iova_copy)
2382{
2383        struct vfio_domain *domain;
2384        struct vfio_iova *node;
2385        dma_addr_t start = 0;
2386        dma_addr_t end = (dma_addr_t)~0;
2387
2388        if (list_empty(iova_copy))
2389                return;
2390
2391        list_for_each_entry(domain, &iommu->domain_list, next) {
2392                struct iommu_domain_geometry *geo = &domain->domain->geometry;
2393
2394                if (geo->aperture_start > start)
2395                        start = geo->aperture_start;
2396                if (geo->aperture_end < end)
2397                        end = geo->aperture_end;
2398        }
2399
2400        /* Modify aperture limits. The new aper is either same or bigger */
2401        node = list_first_entry(iova_copy, struct vfio_iova, list);
2402        node->start = start;
2403        node = list_last_entry(iova_copy, struct vfio_iova, list);
2404        node->end = end;
2405}
2406
2407/*
2408 * Called when a group is detached. The reserved regions for that
2409 * group can be part of valid iova now. But since reserved regions
2410 * may be duplicated among groups, populate the iova valid regions
2411 * list again.
2412 */
2413static int vfio_iommu_resv_refresh(struct vfio_iommu *iommu,
2414                                   struct list_head *iova_copy)
2415{
2416        struct vfio_domain *d;
2417        struct vfio_iommu_group *g;
2418        struct vfio_iova *node;
2419        dma_addr_t start, end;
2420        LIST_HEAD(resv_regions);
2421        int ret;
2422
2423        if (list_empty(iova_copy))
2424                return -EINVAL;
2425
2426        list_for_each_entry(d, &iommu->domain_list, next) {
2427                list_for_each_entry(g, &d->group_list, next) {
2428                        ret = iommu_get_group_resv_regions(g->iommu_group,
2429                                                           &resv_regions);
2430                        if (ret)
2431                                goto done;
2432                }
2433        }
2434
2435        node = list_first_entry(iova_copy, struct vfio_iova, list);
2436        start = node->start;
2437        node = list_last_entry(iova_copy, struct vfio_iova, list);
2438        end = node->end;
2439
2440        /* purge the iova list and create new one */
2441        vfio_iommu_iova_free(iova_copy);
2442
2443        ret = vfio_iommu_aper_resize(iova_copy, start, end);
2444        if (ret)
2445                goto done;
2446
2447        /* Exclude current reserved regions from iova ranges */
2448        ret = vfio_iommu_resv_exclude(iova_copy, &resv_regions);
2449done:
2450        vfio_iommu_resv_free(&resv_regions);
2451        return ret;
2452}
2453
2454static void vfio_iommu_type1_detach_group(void *iommu_data,
2455                                          struct iommu_group *iommu_group)
2456{
2457        struct vfio_iommu *iommu = iommu_data;
2458        struct vfio_domain *domain;
2459        struct vfio_iommu_group *group;
2460        bool update_dirty_scope = false;
2461        LIST_HEAD(iova_copy);
2462
2463        mutex_lock(&iommu->lock);
2464        list_for_each_entry(group, &iommu->emulated_iommu_groups, next) {
2465                if (group->iommu_group != iommu_group)
2466                        continue;
2467                update_dirty_scope = !group->pinned_page_dirty_scope;
2468                list_del(&group->next);
2469                kfree(group);
2470
2471                if (list_empty(&iommu->emulated_iommu_groups) &&
2472                    list_empty(&iommu->domain_list)) {
2473                        WARN_ON(iommu->notifier.head);
2474                        vfio_iommu_unmap_unpin_all(iommu);
2475                }
2476                goto detach_group_done;
2477        }
2478
2479        /*
2480         * Get a copy of iova list. This will be used to update
2481         * and to replace the current one later. Please note that
2482         * we will leave the original list as it is if update fails.
2483         */
2484        vfio_iommu_iova_get_copy(iommu, &iova_copy);
2485
2486        list_for_each_entry(domain, &iommu->domain_list, next) {
2487                group = find_iommu_group(domain, iommu_group);
2488                if (!group)
2489                        continue;
2490
2491                iommu_detach_group(domain->domain, group->iommu_group);
2492                update_dirty_scope = !group->pinned_page_dirty_scope;
2493                list_del(&group->next);
2494                kfree(group);
2495                /*
2496                 * Group ownership provides privilege, if the group list is
2497                 * empty, the domain goes away. If it's the last domain with
2498                 * iommu and external domain doesn't exist, then all the
2499                 * mappings go away too. If it's the last domain with iommu and
2500                 * external domain exist, update accounting
2501                 */
2502                if (list_empty(&domain->group_list)) {
2503                        if (list_is_singular(&iommu->domain_list)) {
2504                                if (list_empty(&iommu->emulated_iommu_groups)) {
2505                                        WARN_ON(iommu->notifier.head);
2506                                        vfio_iommu_unmap_unpin_all(iommu);
2507                                } else {
2508                                        vfio_iommu_unmap_unpin_reaccount(iommu);
2509                                }
2510                        }
2511                        iommu_domain_free(domain->domain);
2512                        list_del(&domain->next);
2513                        kfree(domain);
2514                        vfio_iommu_aper_expand(iommu, &iova_copy);
2515                        vfio_update_pgsize_bitmap(iommu);
2516                }
2517                break;
2518        }
2519
2520        if (!vfio_iommu_resv_refresh(iommu, &iova_copy))
2521                vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2522        else
2523                vfio_iommu_iova_free(&iova_copy);
2524
2525detach_group_done:
2526        /*
2527         * Removal of a group without dirty tracking may allow the iommu scope
2528         * to be promoted.
2529         */
2530        if (update_dirty_scope) {
2531                iommu->num_non_pinned_groups--;
2532                if (iommu->dirty_page_tracking)
2533                        vfio_iommu_populate_bitmap_full(iommu);
2534        }
2535        mutex_unlock(&iommu->lock);
2536}
2537
2538static void *vfio_iommu_type1_open(unsigned long arg)
2539{
2540        struct vfio_iommu *iommu;
2541
2542        iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
2543        if (!iommu)
2544                return ERR_PTR(-ENOMEM);
2545
2546        switch (arg) {
2547        case VFIO_TYPE1_IOMMU:
2548                break;
2549        case VFIO_TYPE1_NESTING_IOMMU:
2550                iommu->nesting = true;
2551                fallthrough;
2552        case VFIO_TYPE1v2_IOMMU:
2553                iommu->v2 = true;
2554                break;
2555        default:
2556                kfree(iommu);
2557                return ERR_PTR(-EINVAL);
2558        }
2559
2560        INIT_LIST_HEAD(&iommu->domain_list);
2561        INIT_LIST_HEAD(&iommu->iova_list);
2562        iommu->dma_list = RB_ROOT;
2563        iommu->dma_avail = dma_entry_limit;
2564        iommu->container_open = true;
2565        mutex_init(&iommu->lock);
2566        BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
2567        init_waitqueue_head(&iommu->vaddr_wait);
2568        iommu->pgsize_bitmap = PAGE_MASK;
2569        INIT_LIST_HEAD(&iommu->emulated_iommu_groups);
2570
2571        return iommu;
2572}
2573
2574static void vfio_release_domain(struct vfio_domain *domain)
2575{
2576        struct vfio_iommu_group *group, *group_tmp;
2577
2578        list_for_each_entry_safe(group, group_tmp,
2579                                 &domain->group_list, next) {
2580                iommu_detach_group(domain->domain, group->iommu_group);
2581                list_del(&group->next);
2582                kfree(group);
2583        }
2584
2585        iommu_domain_free(domain->domain);
2586}
2587
2588static void vfio_iommu_type1_release(void *iommu_data)
2589{
2590        struct vfio_iommu *iommu = iommu_data;
2591        struct vfio_domain *domain, *domain_tmp;
2592        struct vfio_iommu_group *group, *next_group;
2593
2594        list_for_each_entry_safe(group, next_group,
2595                        &iommu->emulated_iommu_groups, next) {
2596                list_del(&group->next);
2597                kfree(group);
2598        }
2599
2600        vfio_iommu_unmap_unpin_all(iommu);
2601
2602        list_for_each_entry_safe(domain, domain_tmp,
2603                                 &iommu->domain_list, next) {
2604                vfio_release_domain(domain);
2605                list_del(&domain->next);
2606                kfree(domain);
2607        }
2608
2609        vfio_iommu_iova_free(&iommu->iova_list);
2610
2611        kfree(iommu);
2612}
2613
2614static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
2615{
2616        struct vfio_domain *domain;
2617        int ret = 1;
2618
2619        mutex_lock(&iommu->lock);
2620        list_for_each_entry(domain, &iommu->domain_list, next) {
2621                if (!(domain->prot & IOMMU_CACHE)) {
2622                        ret = 0;
2623                        break;
2624                }
2625        }
2626        mutex_unlock(&iommu->lock);
2627
2628        return ret;
2629}
2630
2631static int vfio_iommu_type1_check_extension(struct vfio_iommu *iommu,
2632                                            unsigned long arg)
2633{
2634        switch (arg) {
2635        case VFIO_TYPE1_IOMMU:
2636        case VFIO_TYPE1v2_IOMMU:
2637        case VFIO_TYPE1_NESTING_IOMMU:
2638        case VFIO_UNMAP_ALL:
2639        case VFIO_UPDATE_VADDR:
2640                return 1;
2641        case VFIO_DMA_CC_IOMMU:
2642                if (!iommu)
2643                        return 0;
2644                return vfio_domains_have_iommu_cache(iommu);
2645        default:
2646                return 0;
2647        }
2648}
2649
2650static int vfio_iommu_iova_add_cap(struct vfio_info_cap *caps,
2651                 struct vfio_iommu_type1_info_cap_iova_range *cap_iovas,
2652                 size_t size)
2653{
2654        struct vfio_info_cap_header *header;
2655        struct vfio_iommu_type1_info_cap_iova_range *iova_cap;
2656
2657        header = vfio_info_cap_add(caps, size,
2658                                   VFIO_IOMMU_TYPE1_INFO_CAP_IOVA_RANGE, 1);
2659        if (IS_ERR(header))
2660                return PTR_ERR(header);
2661
2662        iova_cap = container_of(header,
2663                                struct vfio_iommu_type1_info_cap_iova_range,
2664                                header);
2665        iova_cap->nr_iovas = cap_iovas->nr_iovas;
2666        memcpy(iova_cap->iova_ranges, cap_iovas->iova_ranges,
2667               cap_iovas->nr_iovas * sizeof(*cap_iovas->iova_ranges));
2668        return 0;
2669}
2670
2671static int vfio_iommu_iova_build_caps(struct vfio_iommu *iommu,
2672                                      struct vfio_info_cap *caps)
2673{
2674        struct vfio_iommu_type1_info_cap_iova_range *cap_iovas;
2675        struct vfio_iova *iova;
2676        size_t size;
2677        int iovas = 0, i = 0, ret;
2678
2679        list_for_each_entry(iova, &iommu->iova_list, list)
2680                iovas++;
2681
2682        if (!iovas) {
2683                /*
2684                 * Return 0 as a container with a single mdev device
2685                 * will have an empty list
2686                 */
2687                return 0;
2688        }
2689
2690        size = struct_size(cap_iovas, iova_ranges, iovas);
2691
2692        cap_iovas = kzalloc(size, GFP_KERNEL);
2693        if (!cap_iovas)
2694                return -ENOMEM;
2695
2696        cap_iovas->nr_iovas = iovas;
2697
2698        list_for_each_entry(iova, &iommu->iova_list, list) {
2699                cap_iovas->iova_ranges[i].start = iova->start;
2700                cap_iovas->iova_ranges[i].end = iova->end;
2701                i++;
2702        }
2703
2704        ret = vfio_iommu_iova_add_cap(caps, cap_iovas, size);
2705
2706        kfree(cap_iovas);
2707        return ret;
2708}
2709
2710static int vfio_iommu_migration_build_caps(struct vfio_iommu *iommu,
2711                                           struct vfio_info_cap *caps)
2712{
2713        struct vfio_iommu_type1_info_cap_migration cap_mig;
2714
2715        cap_mig.header.id = VFIO_IOMMU_TYPE1_INFO_CAP_MIGRATION;
2716        cap_mig.header.version = 1;
2717
2718        cap_mig.flags = 0;
2719        /* support minimum pgsize */
2720        cap_mig.pgsize_bitmap = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2721        cap_mig.max_dirty_bitmap_size = DIRTY_BITMAP_SIZE_MAX;
2722
2723        return vfio_info_add_capability(caps, &cap_mig.header, sizeof(cap_mig));
2724}
2725
2726static int vfio_iommu_dma_avail_build_caps(struct vfio_iommu *iommu,
2727                                           struct vfio_info_cap *caps)
2728{
2729        struct vfio_iommu_type1_info_dma_avail cap_dma_avail;
2730
2731        cap_dma_avail.header.id = VFIO_IOMMU_TYPE1_INFO_DMA_AVAIL;
2732        cap_dma_avail.header.version = 1;
2733
2734        cap_dma_avail.avail = iommu->dma_avail;
2735
2736        return vfio_info_add_capability(caps, &cap_dma_avail.header,
2737                                        sizeof(cap_dma_avail));
2738}
2739
2740static int vfio_iommu_type1_get_info(struct vfio_iommu *iommu,
2741                                     unsigned long arg)
2742{
2743        struct vfio_iommu_type1_info info;
2744        unsigned long minsz;
2745        struct vfio_info_cap caps = { .buf = NULL, .size = 0 };
2746        unsigned long capsz;
2747        int ret;
2748
2749        minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
2750
2751        /* For backward compatibility, cannot require this */
2752        capsz = offsetofend(struct vfio_iommu_type1_info, cap_offset);
2753
2754        if (copy_from_user(&info, (void __user *)arg, minsz))
2755                return -EFAULT;
2756
2757        if (info.argsz < minsz)
2758                return -EINVAL;
2759
2760        if (info.argsz >= capsz) {
2761                minsz = capsz;
2762                info.cap_offset = 0; /* output, no-recopy necessary */
2763        }
2764
2765        mutex_lock(&iommu->lock);
2766        info.flags = VFIO_IOMMU_INFO_PGSIZES;
2767
2768        info.iova_pgsizes = iommu->pgsize_bitmap;
2769
2770        ret = vfio_iommu_migration_build_caps(iommu, &caps);
2771
2772        if (!ret)
2773                ret = vfio_iommu_dma_avail_build_caps(iommu, &caps);
2774
2775        if (!ret)
2776                ret = vfio_iommu_iova_build_caps(iommu, &caps);
2777
2778        mutex_unlock(&iommu->lock);
2779
2780        if (ret)
2781                return ret;
2782
2783        if (caps.size) {
2784                info.flags |= VFIO_IOMMU_INFO_CAPS;
2785
2786                if (info.argsz < sizeof(info) + caps.size) {
2787                        info.argsz = sizeof(info) + caps.size;
2788                } else {
2789                        vfio_info_cap_shift(&caps, sizeof(info));
2790                        if (copy_to_user((void __user *)arg +
2791                                        sizeof(info), caps.buf,
2792                                        caps.size)) {
2793                                kfree(caps.buf);
2794                                return -EFAULT;
2795                        }
2796                        info.cap_offset = sizeof(info);
2797                }
2798
2799                kfree(caps.buf);
2800        }
2801
2802        return copy_to_user((void __user *)arg, &info, minsz) ?
2803                        -EFAULT : 0;
2804}
2805
2806static int vfio_iommu_type1_map_dma(struct vfio_iommu *iommu,
2807                                    unsigned long arg)
2808{
2809        struct vfio_iommu_type1_dma_map map;
2810        unsigned long minsz;
2811        uint32_t mask = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE |
2812                        VFIO_DMA_MAP_FLAG_VADDR;
2813
2814        minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
2815
2816        if (copy_from_user(&map, (void __user *)arg, minsz))
2817                return -EFAULT;
2818
2819        if (map.argsz < minsz || map.flags & ~mask)
2820                return -EINVAL;
2821
2822        return vfio_dma_do_map(iommu, &map);
2823}
2824
2825static int vfio_iommu_type1_unmap_dma(struct vfio_iommu *iommu,
2826                                      unsigned long arg)
2827{
2828        struct vfio_iommu_type1_dma_unmap unmap;
2829        struct vfio_bitmap bitmap = { 0 };
2830        uint32_t mask = VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP |
2831                        VFIO_DMA_UNMAP_FLAG_VADDR |
2832                        VFIO_DMA_UNMAP_FLAG_ALL;
2833        unsigned long minsz;
2834        int ret;
2835
2836        minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
2837
2838        if (copy_from_user(&unmap, (void __user *)arg, minsz))
2839                return -EFAULT;
2840
2841        if (unmap.argsz < minsz || unmap.flags & ~mask)
2842                return -EINVAL;
2843
2844        if ((unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
2845            (unmap.flags & (VFIO_DMA_UNMAP_FLAG_ALL |
2846                            VFIO_DMA_UNMAP_FLAG_VADDR)))
2847                return -EINVAL;
2848
2849        if (unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
2850                unsigned long pgshift;
2851
2852                if (unmap.argsz < (minsz + sizeof(bitmap)))
2853                        return -EINVAL;
2854
2855                if (copy_from_user(&bitmap,
2856                                   (void __user *)(arg + minsz),
2857                                   sizeof(bitmap)))
2858                        return -EFAULT;
2859
2860                if (!access_ok((void __user *)bitmap.data, bitmap.size))
2861                        return -EINVAL;
2862
2863                pgshift = __ffs(bitmap.pgsize);
2864                ret = verify_bitmap_size(unmap.size >> pgshift,
2865                                         bitmap.size);
2866                if (ret)
2867                        return ret;
2868        }
2869
2870        ret = vfio_dma_do_unmap(iommu, &unmap, &bitmap);
2871        if (ret)
2872                return ret;
2873
2874        return copy_to_user((void __user *)arg, &unmap, minsz) ?
2875                        -EFAULT : 0;
2876}
2877
2878static int vfio_iommu_type1_dirty_pages(struct vfio_iommu *iommu,
2879                                        unsigned long arg)
2880{
2881        struct vfio_iommu_type1_dirty_bitmap dirty;
2882        uint32_t mask = VFIO_IOMMU_DIRTY_PAGES_FLAG_START |
2883                        VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP |
2884                        VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP;
2885        unsigned long minsz;
2886        int ret = 0;
2887
2888        if (!iommu->v2)
2889                return -EACCES;
2890
2891        minsz = offsetofend(struct vfio_iommu_type1_dirty_bitmap, flags);
2892
2893        if (copy_from_user(&dirty, (void __user *)arg, minsz))
2894                return -EFAULT;
2895
2896        if (dirty.argsz < minsz || dirty.flags & ~mask)
2897                return -EINVAL;
2898
2899        /* only one flag should be set at a time */
2900        if (__ffs(dirty.flags) != __fls(dirty.flags))
2901                return -EINVAL;
2902
2903        if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_START) {
2904                size_t pgsize;
2905
2906                mutex_lock(&iommu->lock);
2907                pgsize = 1 << __ffs(iommu->pgsize_bitmap);
2908                if (!iommu->dirty_page_tracking) {
2909                        ret = vfio_dma_bitmap_alloc_all(iommu, pgsize);
2910                        if (!ret)
2911                                iommu->dirty_page_tracking = true;
2912                }
2913                mutex_unlock(&iommu->lock);
2914                return ret;
2915        } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP) {
2916                mutex_lock(&iommu->lock);
2917                if (iommu->dirty_page_tracking) {
2918                        iommu->dirty_page_tracking = false;
2919                        vfio_dma_bitmap_free_all(iommu);
2920                }
2921                mutex_unlock(&iommu->lock);
2922                return 0;
2923        } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP) {
2924                struct vfio_iommu_type1_dirty_bitmap_get range;
2925                unsigned long pgshift;
2926                size_t data_size = dirty.argsz - minsz;
2927                size_t iommu_pgsize;
2928
2929                if (!data_size || data_size < sizeof(range))
2930                        return -EINVAL;
2931
2932                if (copy_from_user(&range, (void __user *)(arg + minsz),
2933                                   sizeof(range)))
2934                        return -EFAULT;
2935
2936                if (range.iova + range.size < range.iova)
2937                        return -EINVAL;
2938                if (!access_ok((void __user *)range.bitmap.data,
2939                               range.bitmap.size))
2940                        return -EINVAL;
2941
2942                pgshift = __ffs(range.bitmap.pgsize);
2943                ret = verify_bitmap_size(range.size >> pgshift,
2944                                         range.bitmap.size);
2945                if (ret)
2946                        return ret;
2947
2948                mutex_lock(&iommu->lock);
2949
2950                iommu_pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2951
2952                /* allow only smallest supported pgsize */
2953                if (range.bitmap.pgsize != iommu_pgsize) {
2954                        ret = -EINVAL;
2955                        goto out_unlock;
2956                }
2957                if (range.iova & (iommu_pgsize - 1)) {
2958                        ret = -EINVAL;
2959                        goto out_unlock;
2960                }
2961                if (!range.size || range.size & (iommu_pgsize - 1)) {
2962                        ret = -EINVAL;
2963                        goto out_unlock;
2964                }
2965
2966                if (iommu->dirty_page_tracking)
2967                        ret = vfio_iova_dirty_bitmap(range.bitmap.data,
2968                                                     iommu, range.iova,
2969                                                     range.size,
2970                                                     range.bitmap.pgsize);
2971                else
2972                        ret = -EINVAL;
2973out_unlock:
2974                mutex_unlock(&iommu->lock);
2975
2976                return ret;
2977        }
2978
2979        return -EINVAL;
2980}
2981
2982static long vfio_iommu_type1_ioctl(void *iommu_data,
2983                                   unsigned int cmd, unsigned long arg)
2984{
2985        struct vfio_iommu *iommu = iommu_data;
2986
2987        switch (cmd) {
2988        case VFIO_CHECK_EXTENSION:
2989                return vfio_iommu_type1_check_extension(iommu, arg);
2990        case VFIO_IOMMU_GET_INFO:
2991                return vfio_iommu_type1_get_info(iommu, arg);
2992        case VFIO_IOMMU_MAP_DMA:
2993                return vfio_iommu_type1_map_dma(iommu, arg);
2994        case VFIO_IOMMU_UNMAP_DMA:
2995                return vfio_iommu_type1_unmap_dma(iommu, arg);
2996        case VFIO_IOMMU_DIRTY_PAGES:
2997                return vfio_iommu_type1_dirty_pages(iommu, arg);
2998        default:
2999                return -ENOTTY;
3000        }
3001}
3002
3003static int vfio_iommu_type1_register_notifier(void *iommu_data,
3004                                              unsigned long *events,
3005                                              struct notifier_block *nb)
3006{
3007        struct vfio_iommu *iommu = iommu_data;
3008
3009        /* clear known events */
3010        *events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
3011
3012        /* refuse to register if still events remaining */
3013        if (*events)
3014                return -EINVAL;
3015
3016        return blocking_notifier_chain_register(&iommu->notifier, nb);
3017}
3018
3019static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
3020                                                struct notifier_block *nb)
3021{
3022        struct vfio_iommu *iommu = iommu_data;
3023
3024        return blocking_notifier_chain_unregister(&iommu->notifier, nb);
3025}
3026
3027static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
3028                                         dma_addr_t user_iova, void *data,
3029                                         size_t count, bool write,
3030                                         size_t *copied)
3031{
3032        struct mm_struct *mm;
3033        unsigned long vaddr;
3034        struct vfio_dma *dma;
3035        bool kthread = current->mm == NULL;
3036        size_t offset;
3037        int ret;
3038
3039        *copied = 0;
3040
3041        ret = vfio_find_dma_valid(iommu, user_iova, 1, &dma);
3042        if (ret < 0)
3043                return ret;
3044
3045        if ((write && !(dma->prot & IOMMU_WRITE)) ||
3046                        !(dma->prot & IOMMU_READ))
3047                return -EPERM;
3048
3049        mm = get_task_mm(dma->task);
3050
3051        if (!mm)
3052                return -EPERM;
3053
3054        if (kthread)
3055                kthread_use_mm(mm);
3056        else if (current->mm != mm)
3057                goto out;
3058
3059        offset = user_iova - dma->iova;
3060
3061        if (count > dma->size - offset)
3062                count = dma->size - offset;
3063
3064        vaddr = dma->vaddr + offset;
3065
3066        if (write) {
3067                *copied = copy_to_user((void __user *)vaddr, data,
3068                                         count) ? 0 : count;
3069                if (*copied && iommu->dirty_page_tracking) {
3070                        unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
3071                        /*
3072                         * Bitmap populated with the smallest supported page
3073                         * size
3074                         */
3075                        bitmap_set(dma->bitmap, offset >> pgshift,
3076                                   ((offset + *copied - 1) >> pgshift) -
3077                                   (offset >> pgshift) + 1);
3078                }
3079        } else
3080                *copied = copy_from_user(data, (void __user *)vaddr,
3081                                           count) ? 0 : count;
3082        if (kthread)
3083                kthread_unuse_mm(mm);
3084out:
3085        mmput(mm);
3086        return *copied ? 0 : -EFAULT;
3087}
3088
3089static int vfio_iommu_type1_dma_rw(void *iommu_data, dma_addr_t user_iova,
3090                                   void *data, size_t count, bool write)
3091{
3092        struct vfio_iommu *iommu = iommu_data;
3093        int ret = 0;
3094        size_t done;
3095
3096        mutex_lock(&iommu->lock);
3097        while (count > 0) {
3098                ret = vfio_iommu_type1_dma_rw_chunk(iommu, user_iova, data,
3099                                                    count, write, &done);
3100                if (ret)
3101                        break;
3102
3103                count -= done;
3104                data += done;
3105                user_iova += done;
3106        }
3107
3108        mutex_unlock(&iommu->lock);
3109        return ret;
3110}
3111
3112static struct iommu_domain *
3113vfio_iommu_type1_group_iommu_domain(void *iommu_data,
3114                                    struct iommu_group *iommu_group)
3115{
3116        struct iommu_domain *domain = ERR_PTR(-ENODEV);
3117        struct vfio_iommu *iommu = iommu_data;
3118        struct vfio_domain *d;
3119
3120        if (!iommu || !iommu_group)
3121                return ERR_PTR(-EINVAL);
3122
3123        mutex_lock(&iommu->lock);
3124        list_for_each_entry(d, &iommu->domain_list, next) {
3125                if (find_iommu_group(d, iommu_group)) {
3126                        domain = d->domain;
3127                        break;
3128                }
3129        }
3130        mutex_unlock(&iommu->lock);
3131
3132        return domain;
3133}
3134
3135static void vfio_iommu_type1_notify(void *iommu_data,
3136                                    enum vfio_iommu_notify_type event)
3137{
3138        struct vfio_iommu *iommu = iommu_data;
3139
3140        if (event != VFIO_IOMMU_CONTAINER_CLOSE)
3141                return;
3142        mutex_lock(&iommu->lock);
3143        iommu->container_open = false;
3144        mutex_unlock(&iommu->lock);
3145        wake_up_all(&iommu->vaddr_wait);
3146}
3147
3148static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
3149        .name                   = "vfio-iommu-type1",
3150        .owner                  = THIS_MODULE,
3151        .open                   = vfio_iommu_type1_open,
3152        .release                = vfio_iommu_type1_release,
3153        .ioctl                  = vfio_iommu_type1_ioctl,
3154        .attach_group           = vfio_iommu_type1_attach_group,
3155        .detach_group           = vfio_iommu_type1_detach_group,
3156        .pin_pages              = vfio_iommu_type1_pin_pages,
3157        .unpin_pages            = vfio_iommu_type1_unpin_pages,
3158        .register_notifier      = vfio_iommu_type1_register_notifier,
3159        .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
3160        .dma_rw                 = vfio_iommu_type1_dma_rw,
3161        .group_iommu_domain     = vfio_iommu_type1_group_iommu_domain,
3162        .notify                 = vfio_iommu_type1_notify,
3163};
3164
3165static int __init vfio_iommu_type1_init(void)
3166{
3167        return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
3168}
3169
3170static void __exit vfio_iommu_type1_cleanup(void)
3171{
3172        vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
3173}
3174
3175module_init(vfio_iommu_type1_init);
3176module_exit(vfio_iommu_type1_cleanup);
3177
3178MODULE_VERSION(DRIVER_VERSION);
3179MODULE_LICENSE("GPL v2");
3180MODULE_AUTHOR(DRIVER_AUTHOR);
3181MODULE_DESCRIPTION(DRIVER_DESC);
3182