linux/drivers/vfio/vfio_iommu_type1.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * VFIO: IOMMU DMA mapping support for Type1 IOMMU
   4 *
   5 * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
   6 *     Author: Alex Williamson <alex.williamson@redhat.com>
   7 *
   8 * Derived from original vfio:
   9 * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
  10 * Author: Tom Lyon, pugs@cisco.com
  11 *
  12 * We arbitrarily define a Type1 IOMMU as one matching the below code.
  13 * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
  14 * VT-d, but that makes it harder to re-use as theoretically anyone
  15 * implementing a similar IOMMU could make use of this.  We expect the
  16 * IOMMU to support the IOMMU API and have few to no restrictions around
  17 * the IOVA range that can be mapped.  The Type1 IOMMU is currently
  18 * optimized for relatively static mappings of a userspace process with
  19 * userspace pages pinned into memory.  We also assume devices and IOMMU
  20 * domains are PCI based as the IOMMU API is still centered around a
  21 * device/bus interface rather than a group interface.
  22 */
  23
  24#include <linux/compat.h>
  25#include <linux/device.h>
  26#include <linux/fs.h>
  27#include <linux/highmem.h>
  28#include <linux/iommu.h>
  29#include <linux/module.h>
  30#include <linux/mm.h>
  31#include <linux/kthread.h>
  32#include <linux/rbtree.h>
  33#include <linux/sched/signal.h>
  34#include <linux/sched/mm.h>
  35#include <linux/slab.h>
  36#include <linux/uaccess.h>
  37#include <linux/vfio.h>
  38#include <linux/workqueue.h>
  39#include <linux/mdev.h>
  40#include <linux/notifier.h>
  41#include <linux/dma-iommu.h>
  42#include <linux/irqdomain.h>
  43
  44#define DRIVER_VERSION  "0.2"
  45#define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
  46#define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
  47
  48static bool allow_unsafe_interrupts;
  49module_param_named(allow_unsafe_interrupts,
  50                   allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
  51MODULE_PARM_DESC(allow_unsafe_interrupts,
  52                 "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
  53
  54static bool disable_hugepages;
  55module_param_named(disable_hugepages,
  56                   disable_hugepages, bool, S_IRUGO | S_IWUSR);
  57MODULE_PARM_DESC(disable_hugepages,
  58                 "Disable VFIO IOMMU support for IOMMU hugepages.");
  59
  60static unsigned int dma_entry_limit __read_mostly = U16_MAX;
  61module_param_named(dma_entry_limit, dma_entry_limit, uint, 0644);
  62MODULE_PARM_DESC(dma_entry_limit,
  63                 "Maximum number of user DMA mappings per container (65535).");
  64
  65struct vfio_iommu {
  66        struct list_head        domain_list;
  67        struct list_head        iova_list;
  68        struct vfio_domain      *external_domain; /* domain for external user */
  69        struct mutex            lock;
  70        struct rb_root          dma_list;
  71        struct blocking_notifier_head notifier;
  72        unsigned int            dma_avail;
  73        unsigned int            vaddr_invalid_count;
  74        uint64_t                pgsize_bitmap;
  75        uint64_t                num_non_pinned_groups;
  76        wait_queue_head_t       vaddr_wait;
  77        bool                    v2;
  78        bool                    nesting;
  79        bool                    dirty_page_tracking;
  80        bool                    container_open;
  81};
  82
  83struct vfio_domain {
  84        struct iommu_domain     *domain;
  85        struct list_head        next;
  86        struct list_head        group_list;
  87        int                     prot;           /* IOMMU_CACHE */
  88        bool                    fgsp;           /* Fine-grained super pages */
  89};
  90
  91struct vfio_dma {
  92        struct rb_node          node;
  93        dma_addr_t              iova;           /* Device address */
  94        unsigned long           vaddr;          /* Process virtual addr */
  95        size_t                  size;           /* Map size (bytes) */
  96        int                     prot;           /* IOMMU_READ/WRITE */
  97        bool                    iommu_mapped;
  98        bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
  99        bool                    vaddr_invalid;
 100        struct task_struct      *task;
 101        struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
 102        unsigned long           *bitmap;
 103};
 104
 105struct vfio_batch {
 106        struct page             **pages;        /* for pin_user_pages_remote */
 107        struct page             *fallback_page; /* if pages alloc fails */
 108        int                     capacity;       /* length of pages array */
 109        int                     size;           /* of batch currently */
 110        int                     offset;         /* of next entry in pages */
 111};
 112
 113struct vfio_group {
 114        struct iommu_group      *iommu_group;
 115        struct list_head        next;
 116        bool                    mdev_group;     /* An mdev group */
 117        bool                    pinned_page_dirty_scope;
 118};
 119
 120struct vfio_iova {
 121        struct list_head        list;
 122        dma_addr_t              start;
 123        dma_addr_t              end;
 124};
 125
 126/*
 127 * Guest RAM pinning working set or DMA target
 128 */
 129struct vfio_pfn {
 130        struct rb_node          node;
 131        dma_addr_t              iova;           /* Device address */
 132        unsigned long           pfn;            /* Host pfn */
 133        unsigned int            ref_count;
 134};
 135
 136struct vfio_regions {
 137        struct list_head list;
 138        dma_addr_t iova;
 139        phys_addr_t phys;
 140        size_t len;
 141};
 142
 143#define IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu) \
 144                                        (!list_empty(&iommu->domain_list))
 145
 146#define DIRTY_BITMAP_BYTES(n)   (ALIGN(n, BITS_PER_TYPE(u64)) / BITS_PER_BYTE)
 147
 148/*
 149 * Input argument of number of bits to bitmap_set() is unsigned integer, which
 150 * further casts to signed integer for unaligned multi-bit operation,
 151 * __bitmap_set().
 152 * Then maximum bitmap size supported is 2^31 bits divided by 2^3 bits/byte,
 153 * that is 2^28 (256 MB) which maps to 2^31 * 2^12 = 2^43 (8TB) on 4K page
 154 * system.
 155 */
 156#define DIRTY_BITMAP_PAGES_MAX   ((u64)INT_MAX)
 157#define DIRTY_BITMAP_SIZE_MAX    DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
 158
 159#define WAITED 1
 160
 161static int put_pfn(unsigned long pfn, int prot);
 162
 163static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
 164                                               struct iommu_group *iommu_group);
 165
 166/*
 167 * This code handles mapping and unmapping of user data buffers
 168 * into DMA'ble space using the IOMMU
 169 */
 170
 171static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
 172                                      dma_addr_t start, size_t size)
 173{
 174        struct rb_node *node = iommu->dma_list.rb_node;
 175
 176        while (node) {
 177                struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
 178
 179                if (start + size <= dma->iova)
 180                        node = node->rb_left;
 181                else if (start >= dma->iova + dma->size)
 182                        node = node->rb_right;
 183                else
 184                        return dma;
 185        }
 186
 187        return NULL;
 188}
 189
 190static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
 191                                                dma_addr_t start, u64 size)
 192{
 193        struct rb_node *res = NULL;
 194        struct rb_node *node = iommu->dma_list.rb_node;
 195        struct vfio_dma *dma_res = NULL;
 196
 197        while (node) {
 198                struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
 199
 200                if (start < dma->iova + dma->size) {
 201                        res = node;
 202                        dma_res = dma;
 203                        if (start >= dma->iova)
 204                                break;
 205                        node = node->rb_left;
 206                } else {
 207                        node = node->rb_right;
 208                }
 209        }
 210        if (res && size && dma_res->iova >= start + size)
 211                res = NULL;
 212        return res;
 213}
 214
 215static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
 216{
 217        struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
 218        struct vfio_dma *dma;
 219
 220        while (*link) {
 221                parent = *link;
 222                dma = rb_entry(parent, struct vfio_dma, node);
 223
 224                if (new->iova + new->size <= dma->iova)
 225                        link = &(*link)->rb_left;
 226                else
 227                        link = &(*link)->rb_right;
 228        }
 229
 230        rb_link_node(&new->node, parent, link);
 231        rb_insert_color(&new->node, &iommu->dma_list);
 232}
 233
 234static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
 235{
 236        rb_erase(&old->node, &iommu->dma_list);
 237}
 238
 239
 240static int vfio_dma_bitmap_alloc(struct vfio_dma *dma, size_t pgsize)
 241{
 242        uint64_t npages = dma->size / pgsize;
 243
 244        if (npages > DIRTY_BITMAP_PAGES_MAX)
 245                return -EINVAL;
 246
 247        /*
 248         * Allocate extra 64 bits that are used to calculate shift required for
 249         * bitmap_shift_left() to manipulate and club unaligned number of pages
 250         * in adjacent vfio_dma ranges.
 251         */
 252        dma->bitmap = kvzalloc(DIRTY_BITMAP_BYTES(npages) + sizeof(u64),
 253                               GFP_KERNEL);
 254        if (!dma->bitmap)
 255                return -ENOMEM;
 256
 257        return 0;
 258}
 259
 260static void vfio_dma_bitmap_free(struct vfio_dma *dma)
 261{
 262        kfree(dma->bitmap);
 263        dma->bitmap = NULL;
 264}
 265
 266static void vfio_dma_populate_bitmap(struct vfio_dma *dma, size_t pgsize)
 267{
 268        struct rb_node *p;
 269        unsigned long pgshift = __ffs(pgsize);
 270
 271        for (p = rb_first(&dma->pfn_list); p; p = rb_next(p)) {
 272                struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn, node);
 273
 274                bitmap_set(dma->bitmap, (vpfn->iova - dma->iova) >> pgshift, 1);
 275        }
 276}
 277
 278static void vfio_iommu_populate_bitmap_full(struct vfio_iommu *iommu)
 279{
 280        struct rb_node *n;
 281        unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
 282
 283        for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
 284                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
 285
 286                bitmap_set(dma->bitmap, 0, dma->size >> pgshift);
 287        }
 288}
 289
 290static int vfio_dma_bitmap_alloc_all(struct vfio_iommu *iommu, size_t pgsize)
 291{
 292        struct rb_node *n;
 293
 294        for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
 295                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
 296                int ret;
 297
 298                ret = vfio_dma_bitmap_alloc(dma, pgsize);
 299                if (ret) {
 300                        struct rb_node *p;
 301
 302                        for (p = rb_prev(n); p; p = rb_prev(p)) {
 303                                struct vfio_dma *dma = rb_entry(n,
 304                                                        struct vfio_dma, node);
 305
 306                                vfio_dma_bitmap_free(dma);
 307                        }
 308                        return ret;
 309                }
 310                vfio_dma_populate_bitmap(dma, pgsize);
 311        }
 312        return 0;
 313}
 314
 315static void vfio_dma_bitmap_free_all(struct vfio_iommu *iommu)
 316{
 317        struct rb_node *n;
 318
 319        for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
 320                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
 321
 322                vfio_dma_bitmap_free(dma);
 323        }
 324}
 325
 326/*
 327 * Helper Functions for host iova-pfn list
 328 */
 329static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
 330{
 331        struct vfio_pfn *vpfn;
 332        struct rb_node *node = dma->pfn_list.rb_node;
 333
 334        while (node) {
 335                vpfn = rb_entry(node, struct vfio_pfn, node);
 336
 337                if (iova < vpfn->iova)
 338                        node = node->rb_left;
 339                else if (iova > vpfn->iova)
 340                        node = node->rb_right;
 341                else
 342                        return vpfn;
 343        }
 344        return NULL;
 345}
 346
 347static void vfio_link_pfn(struct vfio_dma *dma,
 348                          struct vfio_pfn *new)
 349{
 350        struct rb_node **link, *parent = NULL;
 351        struct vfio_pfn *vpfn;
 352
 353        link = &dma->pfn_list.rb_node;
 354        while (*link) {
 355                parent = *link;
 356                vpfn = rb_entry(parent, struct vfio_pfn, node);
 357
 358                if (new->iova < vpfn->iova)
 359                        link = &(*link)->rb_left;
 360                else
 361                        link = &(*link)->rb_right;
 362        }
 363
 364        rb_link_node(&new->node, parent, link);
 365        rb_insert_color(&new->node, &dma->pfn_list);
 366}
 367
 368static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
 369{
 370        rb_erase(&old->node, &dma->pfn_list);
 371}
 372
 373static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
 374                                unsigned long pfn)
 375{
 376        struct vfio_pfn *vpfn;
 377
 378        vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
 379        if (!vpfn)
 380                return -ENOMEM;
 381
 382        vpfn->iova = iova;
 383        vpfn->pfn = pfn;
 384        vpfn->ref_count = 1;
 385        vfio_link_pfn(dma, vpfn);
 386        return 0;
 387}
 388
 389static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
 390                                      struct vfio_pfn *vpfn)
 391{
 392        vfio_unlink_pfn(dma, vpfn);
 393        kfree(vpfn);
 394}
 395
 396static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
 397                                               unsigned long iova)
 398{
 399        struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
 400
 401        if (vpfn)
 402                vpfn->ref_count++;
 403        return vpfn;
 404}
 405
 406static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
 407{
 408        int ret = 0;
 409
 410        vpfn->ref_count--;
 411        if (!vpfn->ref_count) {
 412                ret = put_pfn(vpfn->pfn, dma->prot);
 413                vfio_remove_from_pfn_list(dma, vpfn);
 414        }
 415        return ret;
 416}
 417
 418static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
 419{
 420        struct mm_struct *mm;
 421        int ret;
 422
 423        if (!npage)
 424                return 0;
 425
 426        mm = async ? get_task_mm(dma->task) : dma->task->mm;
 427        if (!mm)
 428                return -ESRCH; /* process exited */
 429
 430        ret = mmap_write_lock_killable(mm);
 431        if (!ret) {
 432                ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task,
 433                                          dma->lock_cap);
 434                mmap_write_unlock(mm);
 435        }
 436
 437        if (async)
 438                mmput(mm);
 439
 440        return ret;
 441}
 442
 443/*
 444 * Some mappings aren't backed by a struct page, for example an mmap'd
 445 * MMIO range for our own or another device.  These use a different
 446 * pfn conversion and shouldn't be tracked as locked pages.
 447 * For compound pages, any driver that sets the reserved bit in head
 448 * page needs to set the reserved bit in all subpages to be safe.
 449 */
 450static bool is_invalid_reserved_pfn(unsigned long pfn)
 451{
 452        if (pfn_valid(pfn))
 453                return PageReserved(pfn_to_page(pfn));
 454
 455        return true;
 456}
 457
 458static int put_pfn(unsigned long pfn, int prot)
 459{
 460        if (!is_invalid_reserved_pfn(pfn)) {
 461                struct page *page = pfn_to_page(pfn);
 462
 463                unpin_user_pages_dirty_lock(&page, 1, prot & IOMMU_WRITE);
 464                return 1;
 465        }
 466        return 0;
 467}
 468
 469#define VFIO_BATCH_MAX_CAPACITY (PAGE_SIZE / sizeof(struct page *))
 470
 471static void vfio_batch_init(struct vfio_batch *batch)
 472{
 473        batch->size = 0;
 474        batch->offset = 0;
 475
 476        if (unlikely(disable_hugepages))
 477                goto fallback;
 478
 479        batch->pages = (struct page **) __get_free_page(GFP_KERNEL);
 480        if (!batch->pages)
 481                goto fallback;
 482
 483        batch->capacity = VFIO_BATCH_MAX_CAPACITY;
 484        return;
 485
 486fallback:
 487        batch->pages = &batch->fallback_page;
 488        batch->capacity = 1;
 489}
 490
 491static void vfio_batch_unpin(struct vfio_batch *batch, struct vfio_dma *dma)
 492{
 493        while (batch->size) {
 494                unsigned long pfn = page_to_pfn(batch->pages[batch->offset]);
 495
 496                put_pfn(pfn, dma->prot);
 497                batch->offset++;
 498                batch->size--;
 499        }
 500}
 501
 502static void vfio_batch_fini(struct vfio_batch *batch)
 503{
 504        if (batch->capacity == VFIO_BATCH_MAX_CAPACITY)
 505                free_page((unsigned long)batch->pages);
 506}
 507
 508static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
 509                            unsigned long vaddr, unsigned long *pfn,
 510                            bool write_fault)
 511{
 512        pte_t *ptep;
 513        spinlock_t *ptl;
 514        int ret;
 515
 516        ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
 517        if (ret) {
 518                bool unlocked = false;
 519
 520                ret = fixup_user_fault(mm, vaddr,
 521                                       FAULT_FLAG_REMOTE |
 522                                       (write_fault ?  FAULT_FLAG_WRITE : 0),
 523                                       &unlocked);
 524                if (unlocked)
 525                        return -EAGAIN;
 526
 527                if (ret)
 528                        return ret;
 529
 530                ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
 531                if (ret)
 532                        return ret;
 533        }
 534
 535        if (write_fault && !pte_write(*ptep))
 536                ret = -EFAULT;
 537        else
 538                *pfn = pte_pfn(*ptep);
 539
 540        pte_unmap_unlock(ptep, ptl);
 541        return ret;
 542}
 543
 544/*
 545 * Returns the positive number of pfns successfully obtained or a negative
 546 * error code.
 547 */
 548static int vaddr_get_pfns(struct mm_struct *mm, unsigned long vaddr,
 549                          long npages, int prot, unsigned long *pfn,
 550                          struct page **pages)
 551{
 552        struct vm_area_struct *vma;
 553        unsigned int flags = 0;
 554        int ret;
 555
 556        if (prot & IOMMU_WRITE)
 557                flags |= FOLL_WRITE;
 558
 559        mmap_read_lock(mm);
 560        ret = pin_user_pages_remote(mm, vaddr, npages, flags | FOLL_LONGTERM,
 561                                    pages, NULL, NULL);
 562        if (ret > 0) {
 563                *pfn = page_to_pfn(pages[0]);
 564                goto done;
 565        }
 566
 567        vaddr = untagged_addr(vaddr);
 568
 569retry:
 570        vma = find_vma_intersection(mm, vaddr, vaddr + 1);
 571
 572        if (vma && vma->vm_flags & VM_PFNMAP) {
 573                ret = follow_fault_pfn(vma, mm, vaddr, pfn, prot & IOMMU_WRITE);
 574                if (ret == -EAGAIN)
 575                        goto retry;
 576
 577                if (!ret) {
 578                        if (is_invalid_reserved_pfn(*pfn))
 579                                ret = 1;
 580                        else
 581                                ret = -EFAULT;
 582                }
 583        }
 584done:
 585        mmap_read_unlock(mm);
 586        return ret;
 587}
 588
 589static int vfio_wait(struct vfio_iommu *iommu)
 590{
 591        DEFINE_WAIT(wait);
 592
 593        prepare_to_wait(&iommu->vaddr_wait, &wait, TASK_KILLABLE);
 594        mutex_unlock(&iommu->lock);
 595        schedule();
 596        mutex_lock(&iommu->lock);
 597        finish_wait(&iommu->vaddr_wait, &wait);
 598        if (kthread_should_stop() || !iommu->container_open ||
 599            fatal_signal_pending(current)) {
 600                return -EFAULT;
 601        }
 602        return WAITED;
 603}
 604
 605/*
 606 * Find dma struct and wait for its vaddr to be valid.  iommu lock is dropped
 607 * if the task waits, but is re-locked on return.  Return result in *dma_p.
 608 * Return 0 on success with no waiting, WAITED on success if waited, and -errno
 609 * on error.
 610 */
 611static int vfio_find_dma_valid(struct vfio_iommu *iommu, dma_addr_t start,
 612                               size_t size, struct vfio_dma **dma_p)
 613{
 614        int ret;
 615
 616        do {
 617                *dma_p = vfio_find_dma(iommu, start, size);
 618                if (!*dma_p)
 619                        ret = -EINVAL;
 620                else if (!(*dma_p)->vaddr_invalid)
 621                        ret = 0;
 622                else
 623                        ret = vfio_wait(iommu);
 624        } while (ret > 0);
 625
 626        return ret;
 627}
 628
 629/*
 630 * Wait for all vaddr in the dma_list to become valid.  iommu lock is dropped
 631 * if the task waits, but is re-locked on return.  Return 0 on success with no
 632 * waiting, WAITED on success if waited, and -errno on error.
 633 */
 634static int vfio_wait_all_valid(struct vfio_iommu *iommu)
 635{
 636        int ret = 0;
 637
 638        while (iommu->vaddr_invalid_count && ret >= 0)
 639                ret = vfio_wait(iommu);
 640
 641        return ret;
 642}
 643
 644/*
 645 * Attempt to pin pages.  We really don't want to track all the pfns and
 646 * the iommu can only map chunks of consecutive pfns anyway, so get the
 647 * first page and all consecutive pages with the same locking.
 648 */
 649static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
 650                                  long npage, unsigned long *pfn_base,
 651                                  unsigned long limit, struct vfio_batch *batch)
 652{
 653        unsigned long pfn;
 654        struct mm_struct *mm = current->mm;
 655        long ret, pinned = 0, lock_acct = 0;
 656        bool rsvd;
 657        dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
 658
 659        /* This code path is only user initiated */
 660        if (!mm)
 661                return -ENODEV;
 662
 663        if (batch->size) {
 664                /* Leftover pages in batch from an earlier call. */
 665                *pfn_base = page_to_pfn(batch->pages[batch->offset]);
 666                pfn = *pfn_base;
 667                rsvd = is_invalid_reserved_pfn(*pfn_base);
 668        } else {
 669                *pfn_base = 0;
 670        }
 671
 672        while (npage) {
 673                if (!batch->size) {
 674                        /* Empty batch, so refill it. */
 675                        long req_pages = min_t(long, npage, batch->capacity);
 676
 677                        ret = vaddr_get_pfns(mm, vaddr, req_pages, dma->prot,
 678                                             &pfn, batch->pages);
 679                        if (ret < 0)
 680                                goto unpin_out;
 681
 682                        batch->size = ret;
 683                        batch->offset = 0;
 684
 685                        if (!*pfn_base) {
 686                                *pfn_base = pfn;
 687                                rsvd = is_invalid_reserved_pfn(*pfn_base);
 688                        }
 689                }
 690
 691                /*
 692                 * pfn is preset for the first iteration of this inner loop and
 693                 * updated at the end to handle a VM_PFNMAP pfn.  In that case,
 694                 * batch->pages isn't valid (there's no struct page), so allow
 695                 * batch->pages to be touched only when there's more than one
 696                 * pfn to check, which guarantees the pfns are from a
 697                 * !VM_PFNMAP vma.
 698                 */
 699                while (true) {
 700                        if (pfn != *pfn_base + pinned ||
 701                            rsvd != is_invalid_reserved_pfn(pfn))
 702                                goto out;
 703
 704                        /*
 705                         * Reserved pages aren't counted against the user,
 706                         * externally pinned pages are already counted against
 707                         * the user.
 708                         */
 709                        if (!rsvd && !vfio_find_vpfn(dma, iova)) {
 710                                if (!dma->lock_cap &&
 711                                    mm->locked_vm + lock_acct + 1 > limit) {
 712                                        pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
 713                                                __func__, limit << PAGE_SHIFT);
 714                                        ret = -ENOMEM;
 715                                        goto unpin_out;
 716                                }
 717                                lock_acct++;
 718                        }
 719
 720                        pinned++;
 721                        npage--;
 722                        vaddr += PAGE_SIZE;
 723                        iova += PAGE_SIZE;
 724                        batch->offset++;
 725                        batch->size--;
 726
 727                        if (!batch->size)
 728                                break;
 729
 730                        pfn = page_to_pfn(batch->pages[batch->offset]);
 731                }
 732
 733                if (unlikely(disable_hugepages))
 734                        break;
 735        }
 736
 737out:
 738        ret = vfio_lock_acct(dma, lock_acct, false);
 739
 740unpin_out:
 741        if (batch->size == 1 && !batch->offset) {
 742                /* May be a VM_PFNMAP pfn, which the batch can't remember. */
 743                put_pfn(pfn, dma->prot);
 744                batch->size = 0;
 745        }
 746
 747        if (ret < 0) {
 748                if (pinned && !rsvd) {
 749                        for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
 750                                put_pfn(pfn, dma->prot);
 751                }
 752                vfio_batch_unpin(batch, dma);
 753
 754                return ret;
 755        }
 756
 757        return pinned;
 758}
 759
 760static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
 761                                    unsigned long pfn, long npage,
 762                                    bool do_accounting)
 763{
 764        long unlocked = 0, locked = 0;
 765        long i;
 766
 767        for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
 768                if (put_pfn(pfn++, dma->prot)) {
 769                        unlocked++;
 770                        if (vfio_find_vpfn(dma, iova))
 771                                locked++;
 772                }
 773        }
 774
 775        if (do_accounting)
 776                vfio_lock_acct(dma, locked - unlocked, true);
 777
 778        return unlocked;
 779}
 780
 781static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
 782                                  unsigned long *pfn_base, bool do_accounting)
 783{
 784        struct page *pages[1];
 785        struct mm_struct *mm;
 786        int ret;
 787
 788        mm = get_task_mm(dma->task);
 789        if (!mm)
 790                return -ENODEV;
 791
 792        ret = vaddr_get_pfns(mm, vaddr, 1, dma->prot, pfn_base, pages);
 793        if (ret != 1)
 794                goto out;
 795
 796        ret = 0;
 797
 798        if (do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
 799                ret = vfio_lock_acct(dma, 1, true);
 800                if (ret) {
 801                        put_pfn(*pfn_base, dma->prot);
 802                        if (ret == -ENOMEM)
 803                                pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
 804                                        "(%ld) exceeded\n", __func__,
 805                                        dma->task->comm, task_pid_nr(dma->task),
 806                                        task_rlimit(dma->task, RLIMIT_MEMLOCK));
 807                }
 808        }
 809
 810out:
 811        mmput(mm);
 812        return ret;
 813}
 814
 815static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
 816                                    bool do_accounting)
 817{
 818        int unlocked;
 819        struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
 820
 821        if (!vpfn)
 822                return 0;
 823
 824        unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
 825
 826        if (do_accounting)
 827                vfio_lock_acct(dma, -unlocked, true);
 828
 829        return unlocked;
 830}
 831
 832static int vfio_iommu_type1_pin_pages(void *iommu_data,
 833                                      struct iommu_group *iommu_group,
 834                                      unsigned long *user_pfn,
 835                                      int npage, int prot,
 836                                      unsigned long *phys_pfn)
 837{
 838        struct vfio_iommu *iommu = iommu_data;
 839        struct vfio_group *group;
 840        int i, j, ret;
 841        unsigned long remote_vaddr;
 842        struct vfio_dma *dma;
 843        bool do_accounting;
 844        dma_addr_t iova;
 845
 846        if (!iommu || !user_pfn || !phys_pfn)
 847                return -EINVAL;
 848
 849        /* Supported for v2 version only */
 850        if (!iommu->v2)
 851                return -EACCES;
 852
 853        mutex_lock(&iommu->lock);
 854
 855        /*
 856         * Wait for all necessary vaddr's to be valid so they can be used in
 857         * the main loop without dropping the lock, to avoid racing vs unmap.
 858         */
 859again:
 860        if (iommu->vaddr_invalid_count) {
 861                for (i = 0; i < npage; i++) {
 862                        iova = user_pfn[i] << PAGE_SHIFT;
 863                        ret = vfio_find_dma_valid(iommu, iova, PAGE_SIZE, &dma);
 864                        if (ret < 0)
 865                                goto pin_done;
 866                        if (ret == WAITED)
 867                                goto again;
 868                }
 869        }
 870
 871        /* Fail if notifier list is empty */
 872        if (!iommu->notifier.head) {
 873                ret = -EINVAL;
 874                goto pin_done;
 875        }
 876
 877        /*
 878         * If iommu capable domain exist in the container then all pages are
 879         * already pinned and accounted. Accounting should be done if there is no
 880         * iommu capable domain in the container.
 881         */
 882        do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
 883
 884        for (i = 0; i < npage; i++) {
 885                struct vfio_pfn *vpfn;
 886
 887                iova = user_pfn[i] << PAGE_SHIFT;
 888                dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
 889                if (!dma) {
 890                        ret = -EINVAL;
 891                        goto pin_unwind;
 892                }
 893
 894                if ((dma->prot & prot) != prot) {
 895                        ret = -EPERM;
 896                        goto pin_unwind;
 897                }
 898
 899                vpfn = vfio_iova_get_vfio_pfn(dma, iova);
 900                if (vpfn) {
 901                        phys_pfn[i] = vpfn->pfn;
 902                        continue;
 903                }
 904
 905                remote_vaddr = dma->vaddr + (iova - dma->iova);
 906                ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
 907                                             do_accounting);
 908                if (ret)
 909                        goto pin_unwind;
 910
 911                ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
 912                if (ret) {
 913                        if (put_pfn(phys_pfn[i], dma->prot) && do_accounting)
 914                                vfio_lock_acct(dma, -1, true);
 915                        goto pin_unwind;
 916                }
 917
 918                if (iommu->dirty_page_tracking) {
 919                        unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
 920
 921                        /*
 922                         * Bitmap populated with the smallest supported page
 923                         * size
 924                         */
 925                        bitmap_set(dma->bitmap,
 926                                   (iova - dma->iova) >> pgshift, 1);
 927                }
 928        }
 929        ret = i;
 930
 931        group = vfio_iommu_find_iommu_group(iommu, iommu_group);
 932        if (!group->pinned_page_dirty_scope) {
 933                group->pinned_page_dirty_scope = true;
 934                iommu->num_non_pinned_groups--;
 935        }
 936
 937        goto pin_done;
 938
 939pin_unwind:
 940        phys_pfn[i] = 0;
 941        for (j = 0; j < i; j++) {
 942                dma_addr_t iova;
 943
 944                iova = user_pfn[j] << PAGE_SHIFT;
 945                dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
 946                vfio_unpin_page_external(dma, iova, do_accounting);
 947                phys_pfn[j] = 0;
 948        }
 949pin_done:
 950        mutex_unlock(&iommu->lock);
 951        return ret;
 952}
 953
 954static int vfio_iommu_type1_unpin_pages(void *iommu_data,
 955                                        unsigned long *user_pfn,
 956                                        int npage)
 957{
 958        struct vfio_iommu *iommu = iommu_data;
 959        bool do_accounting;
 960        int i;
 961
 962        if (!iommu || !user_pfn || npage <= 0)
 963                return -EINVAL;
 964
 965        /* Supported for v2 version only */
 966        if (!iommu->v2)
 967                return -EACCES;
 968
 969        mutex_lock(&iommu->lock);
 970
 971        do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
 972        for (i = 0; i < npage; i++) {
 973                struct vfio_dma *dma;
 974                dma_addr_t iova;
 975
 976                iova = user_pfn[i] << PAGE_SHIFT;
 977                dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
 978                if (!dma)
 979                        break;
 980
 981                vfio_unpin_page_external(dma, iova, do_accounting);
 982        }
 983
 984        mutex_unlock(&iommu->lock);
 985        return i > 0 ? i : -EINVAL;
 986}
 987
 988static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
 989                            struct list_head *regions,
 990                            struct iommu_iotlb_gather *iotlb_gather)
 991{
 992        long unlocked = 0;
 993        struct vfio_regions *entry, *next;
 994
 995        iommu_iotlb_sync(domain->domain, iotlb_gather);
 996
 997        list_for_each_entry_safe(entry, next, regions, list) {
 998                unlocked += vfio_unpin_pages_remote(dma,
 999                                                    entry->iova,
1000                                                    entry->phys >> PAGE_SHIFT,
1001                                                    entry->len >> PAGE_SHIFT,
1002                                                    false);
1003                list_del(&entry->list);
1004                kfree(entry);
1005        }
1006
1007        cond_resched();
1008
1009        return unlocked;
1010}
1011
1012/*
1013 * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
1014 * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
1015 * of these regions (currently using a list).
1016 *
1017 * This value specifies maximum number of regions for each IOTLB flush sync.
1018 */
1019#define VFIO_IOMMU_TLB_SYNC_MAX         512
1020
1021static size_t unmap_unpin_fast(struct vfio_domain *domain,
1022                               struct vfio_dma *dma, dma_addr_t *iova,
1023                               size_t len, phys_addr_t phys, long *unlocked,
1024                               struct list_head *unmapped_list,
1025                               int *unmapped_cnt,
1026                               struct iommu_iotlb_gather *iotlb_gather)
1027{
1028        size_t unmapped = 0;
1029        struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
1030
1031        if (entry) {
1032                unmapped = iommu_unmap_fast(domain->domain, *iova, len,
1033                                            iotlb_gather);
1034
1035                if (!unmapped) {
1036                        kfree(entry);
1037                } else {
1038                        entry->iova = *iova;
1039                        entry->phys = phys;
1040                        entry->len  = unmapped;
1041                        list_add_tail(&entry->list, unmapped_list);
1042
1043                        *iova += unmapped;
1044                        (*unmapped_cnt)++;
1045                }
1046        }
1047
1048        /*
1049         * Sync if the number of fast-unmap regions hits the limit
1050         * or in case of errors.
1051         */
1052        if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
1053                *unlocked += vfio_sync_unpin(dma, domain, unmapped_list,
1054                                             iotlb_gather);
1055                *unmapped_cnt = 0;
1056        }
1057
1058        return unmapped;
1059}
1060
1061static size_t unmap_unpin_slow(struct vfio_domain *domain,
1062                               struct vfio_dma *dma, dma_addr_t *iova,
1063                               size_t len, phys_addr_t phys,
1064                               long *unlocked)
1065{
1066        size_t unmapped = iommu_unmap(domain->domain, *iova, len);
1067
1068        if (unmapped) {
1069                *unlocked += vfio_unpin_pages_remote(dma, *iova,
1070                                                     phys >> PAGE_SHIFT,
1071                                                     unmapped >> PAGE_SHIFT,
1072                                                     false);
1073                *iova += unmapped;
1074                cond_resched();
1075        }
1076        return unmapped;
1077}
1078
1079static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
1080                             bool do_accounting)
1081{
1082        dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
1083        struct vfio_domain *domain, *d;
1084        LIST_HEAD(unmapped_region_list);
1085        struct iommu_iotlb_gather iotlb_gather;
1086        int unmapped_region_cnt = 0;
1087        long unlocked = 0;
1088
1089        if (!dma->size)
1090                return 0;
1091
1092        if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1093                return 0;
1094
1095        /*
1096         * We use the IOMMU to track the physical addresses, otherwise we'd
1097         * need a much more complicated tracking system.  Unfortunately that
1098         * means we need to use one of the iommu domains to figure out the
1099         * pfns to unpin.  The rest need to be unmapped in advance so we have
1100         * no iommu translations remaining when the pages are unpinned.
1101         */
1102        domain = d = list_first_entry(&iommu->domain_list,
1103                                      struct vfio_domain, next);
1104
1105        list_for_each_entry_continue(d, &iommu->domain_list, next) {
1106                iommu_unmap(d->domain, dma->iova, dma->size);
1107                cond_resched();
1108        }
1109
1110        iommu_iotlb_gather_init(&iotlb_gather);
1111        while (iova < end) {
1112                size_t unmapped, len;
1113                phys_addr_t phys, next;
1114
1115                phys = iommu_iova_to_phys(domain->domain, iova);
1116                if (WARN_ON(!phys)) {
1117                        iova += PAGE_SIZE;
1118                        continue;
1119                }
1120
1121                /*
1122                 * To optimize for fewer iommu_unmap() calls, each of which
1123                 * may require hardware cache flushing, try to find the
1124                 * largest contiguous physical memory chunk to unmap.
1125                 */
1126                for (len = PAGE_SIZE;
1127                     !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
1128                        next = iommu_iova_to_phys(domain->domain, iova + len);
1129                        if (next != phys + len)
1130                                break;
1131                }
1132
1133                /*
1134                 * First, try to use fast unmap/unpin. In case of failure,
1135                 * switch to slow unmap/unpin path.
1136                 */
1137                unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
1138                                            &unlocked, &unmapped_region_list,
1139                                            &unmapped_region_cnt,
1140                                            &iotlb_gather);
1141                if (!unmapped) {
1142                        unmapped = unmap_unpin_slow(domain, dma, &iova, len,
1143                                                    phys, &unlocked);
1144                        if (WARN_ON(!unmapped))
1145                                break;
1146                }
1147        }
1148
1149        dma->iommu_mapped = false;
1150
1151        if (unmapped_region_cnt) {
1152                unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list,
1153                                            &iotlb_gather);
1154        }
1155
1156        if (do_accounting) {
1157                vfio_lock_acct(dma, -unlocked, true);
1158                return 0;
1159        }
1160        return unlocked;
1161}
1162
1163static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
1164{
1165        WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list));
1166        vfio_unmap_unpin(iommu, dma, true);
1167        vfio_unlink_dma(iommu, dma);
1168        put_task_struct(dma->task);
1169        vfio_dma_bitmap_free(dma);
1170        if (dma->vaddr_invalid) {
1171                iommu->vaddr_invalid_count--;
1172                wake_up_all(&iommu->vaddr_wait);
1173        }
1174        kfree(dma);
1175        iommu->dma_avail++;
1176}
1177
1178static void vfio_update_pgsize_bitmap(struct vfio_iommu *iommu)
1179{
1180        struct vfio_domain *domain;
1181
1182        iommu->pgsize_bitmap = ULONG_MAX;
1183
1184        list_for_each_entry(domain, &iommu->domain_list, next)
1185                iommu->pgsize_bitmap &= domain->domain->pgsize_bitmap;
1186
1187        /*
1188         * In case the IOMMU supports page sizes smaller than PAGE_SIZE
1189         * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
1190         * That way the user will be able to map/unmap buffers whose size/
1191         * start address is aligned with PAGE_SIZE. Pinning code uses that
1192         * granularity while iommu driver can use the sub-PAGE_SIZE size
1193         * to map the buffer.
1194         */
1195        if (iommu->pgsize_bitmap & ~PAGE_MASK) {
1196                iommu->pgsize_bitmap &= PAGE_MASK;
1197                iommu->pgsize_bitmap |= PAGE_SIZE;
1198        }
1199}
1200
1201static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1202                              struct vfio_dma *dma, dma_addr_t base_iova,
1203                              size_t pgsize)
1204{
1205        unsigned long pgshift = __ffs(pgsize);
1206        unsigned long nbits = dma->size >> pgshift;
1207        unsigned long bit_offset = (dma->iova - base_iova) >> pgshift;
1208        unsigned long copy_offset = bit_offset / BITS_PER_LONG;
1209        unsigned long shift = bit_offset % BITS_PER_LONG;
1210        unsigned long leftover;
1211
1212        /*
1213         * mark all pages dirty if any IOMMU capable device is not able
1214         * to report dirty pages and all pages are pinned and mapped.
1215         */
1216        if (iommu->num_non_pinned_groups && dma->iommu_mapped)
1217                bitmap_set(dma->bitmap, 0, nbits);
1218
1219        if (shift) {
1220                bitmap_shift_left(dma->bitmap, dma->bitmap, shift,
1221                                  nbits + shift);
1222
1223                if (copy_from_user(&leftover,
1224                                   (void __user *)(bitmap + copy_offset),
1225                                   sizeof(leftover)))
1226                        return -EFAULT;
1227
1228                bitmap_or(dma->bitmap, dma->bitmap, &leftover, shift);
1229        }
1230
1231        if (copy_to_user((void __user *)(bitmap + copy_offset), dma->bitmap,
1232                         DIRTY_BITMAP_BYTES(nbits + shift)))
1233                return -EFAULT;
1234
1235        return 0;
1236}
1237
1238static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1239                                  dma_addr_t iova, size_t size, size_t pgsize)
1240{
1241        struct vfio_dma *dma;
1242        struct rb_node *n;
1243        unsigned long pgshift = __ffs(pgsize);
1244        int ret;
1245
1246        /*
1247         * GET_BITMAP request must fully cover vfio_dma mappings.  Multiple
1248         * vfio_dma mappings may be clubbed by specifying large ranges, but
1249         * there must not be any previous mappings bisected by the range.
1250         * An error will be returned if these conditions are not met.
1251         */
1252        dma = vfio_find_dma(iommu, iova, 1);
1253        if (dma && dma->iova != iova)
1254                return -EINVAL;
1255
1256        dma = vfio_find_dma(iommu, iova + size - 1, 0);
1257        if (dma && dma->iova + dma->size != iova + size)
1258                return -EINVAL;
1259
1260        for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1261                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1262
1263                if (dma->iova < iova)
1264                        continue;
1265
1266                if (dma->iova > iova + size - 1)
1267                        break;
1268
1269                ret = update_user_bitmap(bitmap, iommu, dma, iova, pgsize);
1270                if (ret)
1271                        return ret;
1272
1273                /*
1274                 * Re-populate bitmap to include all pinned pages which are
1275                 * considered as dirty but exclude pages which are unpinned and
1276                 * pages which are marked dirty by vfio_dma_rw()
1277                 */
1278                bitmap_clear(dma->bitmap, 0, dma->size >> pgshift);
1279                vfio_dma_populate_bitmap(dma, pgsize);
1280        }
1281        return 0;
1282}
1283
1284static int verify_bitmap_size(uint64_t npages, uint64_t bitmap_size)
1285{
1286        if (!npages || !bitmap_size || (bitmap_size > DIRTY_BITMAP_SIZE_MAX) ||
1287            (bitmap_size < DIRTY_BITMAP_BYTES(npages)))
1288                return -EINVAL;
1289
1290        return 0;
1291}
1292
1293static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
1294                             struct vfio_iommu_type1_dma_unmap *unmap,
1295                             struct vfio_bitmap *bitmap)
1296{
1297        struct vfio_dma *dma, *dma_last = NULL;
1298        size_t unmapped = 0, pgsize;
1299        int ret = -EINVAL, retries = 0;
1300        unsigned long pgshift;
1301        dma_addr_t iova = unmap->iova;
1302        u64 size = unmap->size;
1303        bool unmap_all = unmap->flags & VFIO_DMA_UNMAP_FLAG_ALL;
1304        bool invalidate_vaddr = unmap->flags & VFIO_DMA_UNMAP_FLAG_VADDR;
1305        struct rb_node *n, *first_n;
1306
1307        mutex_lock(&iommu->lock);
1308
1309        pgshift = __ffs(iommu->pgsize_bitmap);
1310        pgsize = (size_t)1 << pgshift;
1311
1312        if (iova & (pgsize - 1))
1313                goto unlock;
1314
1315        if (unmap_all) {
1316                if (iova || size)
1317                        goto unlock;
1318                size = U64_MAX;
1319        } else if (!size || size & (pgsize - 1) ||
1320                   iova + size - 1 < iova || size > SIZE_MAX) {
1321                goto unlock;
1322        }
1323
1324        /* When dirty tracking is enabled, allow only min supported pgsize */
1325        if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
1326            (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {
1327                goto unlock;
1328        }
1329
1330        WARN_ON((pgsize - 1) & PAGE_MASK);
1331again:
1332        /*
1333         * vfio-iommu-type1 (v1) - User mappings were coalesced together to
1334         * avoid tracking individual mappings.  This means that the granularity
1335         * of the original mapping was lost and the user was allowed to attempt
1336         * to unmap any range.  Depending on the contiguousness of physical
1337         * memory and page sizes supported by the IOMMU, arbitrary unmaps may
1338         * or may not have worked.  We only guaranteed unmap granularity
1339         * matching the original mapping; even though it was untracked here,
1340         * the original mappings are reflected in IOMMU mappings.  This
1341         * resulted in a couple unusual behaviors.  First, if a range is not
1342         * able to be unmapped, ex. a set of 4k pages that was mapped as a
1343         * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
1344         * a zero sized unmap.  Also, if an unmap request overlaps the first
1345         * address of a hugepage, the IOMMU will unmap the entire hugepage.
1346         * This also returns success and the returned unmap size reflects the
1347         * actual size unmapped.
1348         *
1349         * We attempt to maintain compatibility with this "v1" interface, but
1350         * we take control out of the hands of the IOMMU.  Therefore, an unmap
1351         * request offset from the beginning of the original mapping will
1352         * return success with zero sized unmap.  And an unmap request covering
1353         * the first iova of mapping will unmap the entire range.
1354         *
1355         * The v2 version of this interface intends to be more deterministic.
1356         * Unmap requests must fully cover previous mappings.  Multiple
1357         * mappings may still be unmaped by specifying large ranges, but there
1358         * must not be any previous mappings bisected by the range.  An error
1359         * will be returned if these conditions are not met.  The v2 interface
1360         * will only return success and a size of zero if there were no
1361         * mappings within the range.
1362         */
1363        if (iommu->v2 && !unmap_all) {
1364                dma = vfio_find_dma(iommu, iova, 1);
1365                if (dma && dma->iova != iova)
1366                        goto unlock;
1367
1368                dma = vfio_find_dma(iommu, iova + size - 1, 0);
1369                if (dma && dma->iova + dma->size != iova + size)
1370                        goto unlock;
1371        }
1372
1373        ret = 0;
1374        n = first_n = vfio_find_dma_first_node(iommu, iova, size);
1375
1376        while (n) {
1377                dma = rb_entry(n, struct vfio_dma, node);
1378                if (dma->iova >= iova + size)
1379                        break;
1380
1381                if (!iommu->v2 && iova > dma->iova)
1382                        break;
1383                /*
1384                 * Task with same address space who mapped this iova range is
1385                 * allowed to unmap the iova range.
1386                 */
1387                if (dma->task->mm != current->mm)
1388                        break;
1389
1390                if (invalidate_vaddr) {
1391                        if (dma->vaddr_invalid) {
1392                                struct rb_node *last_n = n;
1393
1394                                for (n = first_n; n != last_n; n = rb_next(n)) {
1395                                        dma = rb_entry(n,
1396                                                       struct vfio_dma, node);
1397                                        dma->vaddr_invalid = false;
1398                                        iommu->vaddr_invalid_count--;
1399                                }
1400                                ret = -EINVAL;
1401                                unmapped = 0;
1402                                break;
1403                        }
1404                        dma->vaddr_invalid = true;
1405                        iommu->vaddr_invalid_count++;
1406                        unmapped += dma->size;
1407                        n = rb_next(n);
1408                        continue;
1409                }
1410
1411                if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
1412                        struct vfio_iommu_type1_dma_unmap nb_unmap;
1413
1414                        if (dma_last == dma) {
1415                                BUG_ON(++retries > 10);
1416                        } else {
1417                                dma_last = dma;
1418                                retries = 0;
1419                        }
1420
1421                        nb_unmap.iova = dma->iova;
1422                        nb_unmap.size = dma->size;
1423
1424                        /*
1425                         * Notify anyone (mdev vendor drivers) to invalidate and
1426                         * unmap iovas within the range we're about to unmap.
1427                         * Vendor drivers MUST unpin pages in response to an
1428                         * invalidation.
1429                         */
1430                        mutex_unlock(&iommu->lock);
1431                        blocking_notifier_call_chain(&iommu->notifier,
1432                                                    VFIO_IOMMU_NOTIFY_DMA_UNMAP,
1433                                                    &nb_unmap);
1434                        mutex_lock(&iommu->lock);
1435                        goto again;
1436                }
1437
1438                if (unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
1439                        ret = update_user_bitmap(bitmap->data, iommu, dma,
1440                                                 iova, pgsize);
1441                        if (ret)
1442                                break;
1443                }
1444
1445                unmapped += dma->size;
1446                n = rb_next(n);
1447                vfio_remove_dma(iommu, dma);
1448        }
1449
1450unlock:
1451        mutex_unlock(&iommu->lock);
1452
1453        /* Report how much was unmapped */
1454        unmap->size = unmapped;
1455
1456        return ret;
1457}
1458
1459static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
1460                          unsigned long pfn, long npage, int prot)
1461{
1462        struct vfio_domain *d;
1463        int ret;
1464
1465        list_for_each_entry(d, &iommu->domain_list, next) {
1466                ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
1467                                npage << PAGE_SHIFT, prot | d->prot);
1468                if (ret)
1469                        goto unwind;
1470
1471                cond_resched();
1472        }
1473
1474        return 0;
1475
1476unwind:
1477        list_for_each_entry_continue_reverse(d, &iommu->domain_list, next) {
1478                iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
1479                cond_resched();
1480        }
1481
1482        return ret;
1483}
1484
1485static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1486                            size_t map_size)
1487{
1488        dma_addr_t iova = dma->iova;
1489        unsigned long vaddr = dma->vaddr;
1490        struct vfio_batch batch;
1491        size_t size = map_size;
1492        long npage;
1493        unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1494        int ret = 0;
1495
1496        vfio_batch_init(&batch);
1497
1498        while (size) {
1499                /* Pin a contiguous chunk of memory */
1500                npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1501                                              size >> PAGE_SHIFT, &pfn, limit,
1502                                              &batch);
1503                if (npage <= 0) {
1504                        WARN_ON(!npage);
1505                        ret = (int)npage;
1506                        break;
1507                }
1508
1509                /* Map it! */
1510                ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1511                                     dma->prot);
1512                if (ret) {
1513                        vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1514                                                npage, true);
1515                        vfio_batch_unpin(&batch, dma);
1516                        break;
1517                }
1518
1519                size -= npage << PAGE_SHIFT;
1520                dma->size += npage << PAGE_SHIFT;
1521        }
1522
1523        vfio_batch_fini(&batch);
1524        dma->iommu_mapped = true;
1525
1526        if (ret)
1527                vfio_remove_dma(iommu, dma);
1528
1529        return ret;
1530}
1531
1532/*
1533 * Check dma map request is within a valid iova range
1534 */
1535static bool vfio_iommu_iova_dma_valid(struct vfio_iommu *iommu,
1536                                      dma_addr_t start, dma_addr_t end)
1537{
1538        struct list_head *iova = &iommu->iova_list;
1539        struct vfio_iova *node;
1540
1541        list_for_each_entry(node, iova, list) {
1542                if (start >= node->start && end <= node->end)
1543                        return true;
1544        }
1545
1546        /*
1547         * Check for list_empty() as well since a container with
1548         * a single mdev device will have an empty list.
1549         */
1550        return list_empty(iova);
1551}
1552
1553static int vfio_dma_do_map(struct vfio_iommu *iommu,
1554                           struct vfio_iommu_type1_dma_map *map)
1555{
1556        bool set_vaddr = map->flags & VFIO_DMA_MAP_FLAG_VADDR;
1557        dma_addr_t iova = map->iova;
1558        unsigned long vaddr = map->vaddr;
1559        size_t size = map->size;
1560        int ret = 0, prot = 0;
1561        size_t pgsize;
1562        struct vfio_dma *dma;
1563
1564        /* Verify that none of our __u64 fields overflow */
1565        if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1566                return -EINVAL;
1567
1568        /* READ/WRITE from device perspective */
1569        if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1570                prot |= IOMMU_WRITE;
1571        if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1572                prot |= IOMMU_READ;
1573
1574        if ((prot && set_vaddr) || (!prot && !set_vaddr))
1575                return -EINVAL;
1576
1577        mutex_lock(&iommu->lock);
1578
1579        pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
1580
1581        WARN_ON((pgsize - 1) & PAGE_MASK);
1582
1583        if (!size || (size | iova | vaddr) & (pgsize - 1)) {
1584                ret = -EINVAL;
1585                goto out_unlock;
1586        }
1587
1588        /* Don't allow IOVA or virtual address wrap */
1589        if (iova + size - 1 < iova || vaddr + size - 1 < vaddr) {
1590                ret = -EINVAL;
1591                goto out_unlock;
1592        }
1593
1594        dma = vfio_find_dma(iommu, iova, size);
1595        if (set_vaddr) {
1596                if (!dma) {
1597                        ret = -ENOENT;
1598                } else if (!dma->vaddr_invalid || dma->iova != iova ||
1599                           dma->size != size) {
1600                        ret = -EINVAL;
1601                } else {
1602                        dma->vaddr = vaddr;
1603                        dma->vaddr_invalid = false;
1604                        iommu->vaddr_invalid_count--;
1605                        wake_up_all(&iommu->vaddr_wait);
1606                }
1607                goto out_unlock;
1608        } else if (dma) {
1609                ret = -EEXIST;
1610                goto out_unlock;
1611        }
1612
1613        if (!iommu->dma_avail) {
1614                ret = -ENOSPC;
1615                goto out_unlock;
1616        }
1617
1618        if (!vfio_iommu_iova_dma_valid(iommu, iova, iova + size - 1)) {
1619                ret = -EINVAL;
1620                goto out_unlock;
1621        }
1622
1623        dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1624        if (!dma) {
1625                ret = -ENOMEM;
1626                goto out_unlock;
1627        }
1628
1629        iommu->dma_avail--;
1630        dma->iova = iova;
1631        dma->vaddr = vaddr;
1632        dma->prot = prot;
1633
1634        /*
1635         * We need to be able to both add to a task's locked memory and test
1636         * against the locked memory limit and we need to be able to do both
1637         * outside of this call path as pinning can be asynchronous via the
1638         * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1639         * task_struct and VM locked pages requires an mm_struct, however
1640         * holding an indefinite mm reference is not recommended, therefore we
1641         * only hold a reference to a task.  We could hold a reference to
1642         * current, however QEMU uses this call path through vCPU threads,
1643         * which can be killed resulting in a NULL mm and failure in the unmap
1644         * path when called via a different thread.  Avoid this problem by
1645         * using the group_leader as threads within the same group require
1646         * both CLONE_THREAD and CLONE_VM and will therefore use the same
1647         * mm_struct.
1648         *
1649         * Previously we also used the task for testing CAP_IPC_LOCK at the
1650         * time of pinning and accounting, however has_capability() makes use
1651         * of real_cred, a copy-on-write field, so we can't guarantee that it
1652         * matches group_leader, or in fact that it might not change by the
1653         * time it's evaluated.  If a process were to call MAP_DMA with
1654         * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
1655         * possibly see different results for an iommu_mapped vfio_dma vs
1656         * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
1657         * time of calling MAP_DMA.
1658         */
1659        get_task_struct(current->group_leader);
1660        dma->task = current->group_leader;
1661        dma->lock_cap = capable(CAP_IPC_LOCK);
1662
1663        dma->pfn_list = RB_ROOT;
1664
1665        /* Insert zero-sized and grow as we map chunks of it */
1666        vfio_link_dma(iommu, dma);
1667
1668        /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1669        if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1670                dma->size = size;
1671        else
1672                ret = vfio_pin_map_dma(iommu, dma, size);
1673
1674        if (!ret && iommu->dirty_page_tracking) {
1675                ret = vfio_dma_bitmap_alloc(dma, pgsize);
1676                if (ret)
1677                        vfio_remove_dma(iommu, dma);
1678        }
1679
1680out_unlock:
1681        mutex_unlock(&iommu->lock);
1682        return ret;
1683}
1684
1685static int vfio_bus_type(struct device *dev, void *data)
1686{
1687        struct bus_type **bus = data;
1688
1689        if (*bus && *bus != dev->bus)
1690                return -EINVAL;
1691
1692        *bus = dev->bus;
1693
1694        return 0;
1695}
1696
1697static int vfio_iommu_replay(struct vfio_iommu *iommu,
1698                             struct vfio_domain *domain)
1699{
1700        struct vfio_batch batch;
1701        struct vfio_domain *d = NULL;
1702        struct rb_node *n;
1703        unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1704        int ret;
1705
1706        ret = vfio_wait_all_valid(iommu);
1707        if (ret < 0)
1708                return ret;
1709
1710        /* Arbitrarily pick the first domain in the list for lookups */
1711        if (!list_empty(&iommu->domain_list))
1712                d = list_first_entry(&iommu->domain_list,
1713                                     struct vfio_domain, next);
1714
1715        vfio_batch_init(&batch);
1716
1717        n = rb_first(&iommu->dma_list);
1718
1719        for (; n; n = rb_next(n)) {
1720                struct vfio_dma *dma;
1721                dma_addr_t iova;
1722
1723                dma = rb_entry(n, struct vfio_dma, node);
1724                iova = dma->iova;
1725
1726                while (iova < dma->iova + dma->size) {
1727                        phys_addr_t phys;
1728                        size_t size;
1729
1730                        if (dma->iommu_mapped) {
1731                                phys_addr_t p;
1732                                dma_addr_t i;
1733
1734                                if (WARN_ON(!d)) { /* mapped w/o a domain?! */
1735                                        ret = -EINVAL;
1736                                        goto unwind;
1737                                }
1738
1739                                phys = iommu_iova_to_phys(d->domain, iova);
1740
1741                                if (WARN_ON(!phys)) {
1742                                        iova += PAGE_SIZE;
1743                                        continue;
1744                                }
1745
1746                                size = PAGE_SIZE;
1747                                p = phys + size;
1748                                i = iova + size;
1749                                while (i < dma->iova + dma->size &&
1750                                       p == iommu_iova_to_phys(d->domain, i)) {
1751                                        size += PAGE_SIZE;
1752                                        p += PAGE_SIZE;
1753                                        i += PAGE_SIZE;
1754                                }
1755                        } else {
1756                                unsigned long pfn;
1757                                unsigned long vaddr = dma->vaddr +
1758                                                     (iova - dma->iova);
1759                                size_t n = dma->iova + dma->size - iova;
1760                                long npage;
1761
1762                                npage = vfio_pin_pages_remote(dma, vaddr,
1763                                                              n >> PAGE_SHIFT,
1764                                                              &pfn, limit,
1765                                                              &batch);
1766                                if (npage <= 0) {
1767                                        WARN_ON(!npage);
1768                                        ret = (int)npage;
1769                                        goto unwind;
1770                                }
1771
1772                                phys = pfn << PAGE_SHIFT;
1773                                size = npage << PAGE_SHIFT;
1774                        }
1775
1776                        ret = iommu_map(domain->domain, iova, phys,
1777                                        size, dma->prot | domain->prot);
1778                        if (ret) {
1779                                if (!dma->iommu_mapped) {
1780                                        vfio_unpin_pages_remote(dma, iova,
1781                                                        phys >> PAGE_SHIFT,
1782                                                        size >> PAGE_SHIFT,
1783                                                        true);
1784                                        vfio_batch_unpin(&batch, dma);
1785                                }
1786                                goto unwind;
1787                        }
1788
1789                        iova += size;
1790                }
1791        }
1792
1793        /* All dmas are now mapped, defer to second tree walk for unwind */
1794        for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1795                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1796
1797                dma->iommu_mapped = true;
1798        }
1799
1800        vfio_batch_fini(&batch);
1801        return 0;
1802
1803unwind:
1804        for (; n; n = rb_prev(n)) {
1805                struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1806                dma_addr_t iova;
1807
1808                if (dma->iommu_mapped) {
1809                        iommu_unmap(domain->domain, dma->iova, dma->size);
1810                        continue;
1811                }
1812
1813                iova = dma->iova;
1814                while (iova < dma->iova + dma->size) {
1815                        phys_addr_t phys, p;
1816                        size_t size;
1817                        dma_addr_t i;
1818
1819                        phys = iommu_iova_to_phys(domain->domain, iova);
1820                        if (!phys) {
1821                                iova += PAGE_SIZE;
1822                                continue;
1823                        }
1824
1825                        size = PAGE_SIZE;
1826                        p = phys + size;
1827                        i = iova + size;
1828                        while (i < dma->iova + dma->size &&
1829                               p == iommu_iova_to_phys(domain->domain, i)) {
1830                                size += PAGE_SIZE;
1831                                p += PAGE_SIZE;
1832                                i += PAGE_SIZE;
1833                        }
1834
1835                        iommu_unmap(domain->domain, iova, size);
1836                        vfio_unpin_pages_remote(dma, iova, phys >> PAGE_SHIFT,
1837                                                size >> PAGE_SHIFT, true);
1838                }
1839        }
1840
1841        vfio_batch_fini(&batch);
1842        return ret;
1843}
1844
1845/*
1846 * We change our unmap behavior slightly depending on whether the IOMMU
1847 * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1848 * for practically any contiguous power-of-two mapping we give it.  This means
1849 * we don't need to look for contiguous chunks ourselves to make unmapping
1850 * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1851 * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1852 * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1853 * hugetlbfs is in use.
1854 */
1855static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1856{
1857        struct page *pages;
1858        int ret, order = get_order(PAGE_SIZE * 2);
1859
1860        pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1861        if (!pages)
1862                return;
1863
1864        ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1865                        IOMMU_READ | IOMMU_WRITE | domain->prot);
1866        if (!ret) {
1867                size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1868
1869                if (unmapped == PAGE_SIZE)
1870                        iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1871                else
1872                        domain->fgsp = true;
1873        }
1874
1875        __free_pages(pages, order);
1876}
1877
1878static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
1879                                           struct iommu_group *iommu_group)
1880{
1881        struct vfio_group *g;
1882
1883        list_for_each_entry(g, &domain->group_list, next) {
1884                if (g->iommu_group == iommu_group)
1885                        return g;
1886        }
1887
1888        return NULL;
1889}
1890
1891static struct vfio_group *vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
1892                                               struct iommu_group *iommu_group)
1893{
1894        struct vfio_domain *domain;
1895        struct vfio_group *group = NULL;
1896
1897        list_for_each_entry(domain, &iommu->domain_list, next) {
1898                group = find_iommu_group(domain, iommu_group);
1899                if (group)
1900                        return group;
1901        }
1902
1903        if (iommu->external_domain)
1904                group = find_iommu_group(iommu->external_domain, iommu_group);
1905
1906        return group;
1907}
1908
1909static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
1910                                  phys_addr_t *base)
1911{
1912        struct iommu_resv_region *region;
1913        bool ret = false;
1914
1915        list_for_each_entry(region, group_resv_regions, list) {
1916                /*
1917                 * The presence of any 'real' MSI regions should take
1918                 * precedence over the software-managed one if the
1919                 * IOMMU driver happens to advertise both types.
1920                 */
1921                if (region->type == IOMMU_RESV_MSI) {
1922                        ret = false;
1923                        break;
1924                }
1925
1926                if (region->type == IOMMU_RESV_SW_MSI) {
1927                        *base = region->start;
1928                        ret = true;
1929                }
1930        }
1931
1932        return ret;
1933}
1934
1935static int vfio_mdev_attach_domain(struct device *dev, void *data)
1936{
1937        struct mdev_device *mdev = to_mdev_device(dev);
1938        struct iommu_domain *domain = data;
1939        struct device *iommu_device;
1940
1941        iommu_device = mdev_get_iommu_device(mdev);
1942        if (iommu_device) {
1943                if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1944                        return iommu_aux_attach_device(domain, iommu_device);
1945                else
1946                        return iommu_attach_device(domain, iommu_device);
1947        }
1948
1949        return -EINVAL;
1950}
1951
1952static int vfio_mdev_detach_domain(struct device *dev, void *data)
1953{
1954        struct mdev_device *mdev = to_mdev_device(dev);
1955        struct iommu_domain *domain = data;
1956        struct device *iommu_device;
1957
1958        iommu_device = mdev_get_iommu_device(mdev);
1959        if (iommu_device) {
1960                if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1961                        iommu_aux_detach_device(domain, iommu_device);
1962                else
1963                        iommu_detach_device(domain, iommu_device);
1964        }
1965
1966        return 0;
1967}
1968
1969static int vfio_iommu_attach_group(struct vfio_domain *domain,
1970                                   struct vfio_group *group)
1971{
1972        if (group->mdev_group)
1973                return iommu_group_for_each_dev(group->iommu_group,
1974                                                domain->domain,
1975                                                vfio_mdev_attach_domain);
1976        else
1977                return iommu_attach_group(domain->domain, group->iommu_group);
1978}
1979
1980static void vfio_iommu_detach_group(struct vfio_domain *domain,
1981                                    struct vfio_group *group)
1982{
1983        if (group->mdev_group)
1984                iommu_group_for_each_dev(group->iommu_group, domain->domain,
1985                                         vfio_mdev_detach_domain);
1986        else
1987                iommu_detach_group(domain->domain, group->iommu_group);
1988}
1989
1990static bool vfio_bus_is_mdev(struct bus_type *bus)
1991{
1992        struct bus_type *mdev_bus;
1993        bool ret = false;
1994
1995        mdev_bus = symbol_get(mdev_bus_type);
1996        if (mdev_bus) {
1997                ret = (bus == mdev_bus);
1998                symbol_put(mdev_bus_type);
1999        }
2000
2001        return ret;
2002}
2003
2004static int vfio_mdev_iommu_device(struct device *dev, void *data)
2005{
2006        struct mdev_device *mdev = to_mdev_device(dev);
2007        struct device **old = data, *new;
2008
2009        new = mdev_get_iommu_device(mdev);
2010        if (!new || (*old && *old != new))
2011                return -EINVAL;
2012
2013        *old = new;
2014
2015        return 0;
2016}
2017
2018/*
2019 * This is a helper function to insert an address range to iova list.
2020 * The list is initially created with a single entry corresponding to
2021 * the IOMMU domain geometry to which the device group is attached.
2022 * The list aperture gets modified when a new domain is added to the
2023 * container if the new aperture doesn't conflict with the current one
2024 * or with any existing dma mappings. The list is also modified to
2025 * exclude any reserved regions associated with the device group.
2026 */
2027static int vfio_iommu_iova_insert(struct list_head *head,
2028                                  dma_addr_t start, dma_addr_t end)
2029{
2030        struct vfio_iova *region;
2031
2032        region = kmalloc(sizeof(*region), GFP_KERNEL);
2033        if (!region)
2034                return -ENOMEM;
2035
2036        INIT_LIST_HEAD(&region->list);
2037        region->start = start;
2038        region->end = end;
2039
2040        list_add_tail(&region->list, head);
2041        return 0;
2042}
2043
2044/*
2045 * Check the new iommu aperture conflicts with existing aper or with any
2046 * existing dma mappings.
2047 */
2048static bool vfio_iommu_aper_conflict(struct vfio_iommu *iommu,
2049                                     dma_addr_t start, dma_addr_t end)
2050{
2051        struct vfio_iova *first, *last;
2052        struct list_head *iova = &iommu->iova_list;
2053
2054        if (list_empty(iova))
2055                return false;
2056
2057        /* Disjoint sets, return conflict */
2058        first = list_first_entry(iova, struct vfio_iova, list);
2059        last = list_last_entry(iova, struct vfio_iova, list);
2060        if (start > last->end || end < first->start)
2061                return true;
2062
2063        /* Check for any existing dma mappings below the new start */
2064        if (start > first->start) {
2065                if (vfio_find_dma(iommu, first->start, start - first->start))
2066                        return true;
2067        }
2068
2069        /* Check for any existing dma mappings beyond the new end */
2070        if (end < last->end) {
2071                if (vfio_find_dma(iommu, end + 1, last->end - end))
2072                        return true;
2073        }
2074
2075        return false;
2076}
2077
2078/*
2079 * Resize iommu iova aperture window. This is called only if the new
2080 * aperture has no conflict with existing aperture and dma mappings.
2081 */
2082static int vfio_iommu_aper_resize(struct list_head *iova,
2083                                  dma_addr_t start, dma_addr_t end)
2084{
2085        struct vfio_iova *node, *next;
2086
2087        if (list_empty(iova))
2088                return vfio_iommu_iova_insert(iova, start, end);
2089
2090        /* Adjust iova list start */
2091        list_for_each_entry_safe(node, next, iova, list) {
2092                if (start < node->start)
2093                        break;
2094                if (start >= node->start && start < node->end) {
2095                        node->start = start;
2096                        break;
2097                }
2098                /* Delete nodes before new start */
2099                list_del(&node->list);
2100                kfree(node);
2101        }
2102
2103        /* Adjust iova list end */
2104        list_for_each_entry_safe(node, next, iova, list) {
2105                if (end > node->end)
2106                        continue;
2107                if (end > node->start && end <= node->end) {
2108                        node->end = end;
2109                        continue;
2110                }
2111                /* Delete nodes after new end */
2112                list_del(&node->list);
2113                kfree(node);
2114        }
2115
2116        return 0;
2117}
2118
2119/*
2120 * Check reserved region conflicts with existing dma mappings
2121 */
2122static bool vfio_iommu_resv_conflict(struct vfio_iommu *iommu,
2123                                     struct list_head *resv_regions)
2124{
2125        struct iommu_resv_region *region;
2126
2127        /* Check for conflict with existing dma mappings */
2128        list_for_each_entry(region, resv_regions, list) {
2129                if (region->type == IOMMU_RESV_DIRECT_RELAXABLE)
2130                        continue;
2131
2132                if (vfio_find_dma(iommu, region->start, region->length))
2133                        return true;
2134        }
2135
2136        return false;
2137}
2138
2139/*
2140 * Check iova region overlap with  reserved regions and
2141 * exclude them from the iommu iova range
2142 */
2143static int vfio_iommu_resv_exclude(struct list_head *iova,
2144                                   struct list_head *resv_regions)
2145{
2146        struct iommu_resv_region *resv;
2147        struct vfio_iova *n, *next;
2148
2149        list_for_each_entry(resv, resv_regions, list) {
2150                phys_addr_t start, end;
2151
2152                if (resv->type == IOMMU_RESV_DIRECT_RELAXABLE)
2153                        continue;
2154
2155                start = resv->start;
2156                end = resv->start + resv->length - 1;
2157
2158                list_for_each_entry_safe(n, next, iova, list) {
2159                        int ret = 0;
2160
2161                        /* No overlap */
2162                        if (start > n->end || end < n->start)
2163                                continue;
2164                        /*
2165                         * Insert a new node if current node overlaps with the
2166                         * reserve region to exclude that from valid iova range.
2167                         * Note that, new node is inserted before the current
2168                         * node and finally the current node is deleted keeping
2169                         * the list updated and sorted.
2170                         */
2171                        if (start > n->start)
2172                                ret = vfio_iommu_iova_insert(&n->list, n->start,
2173                                                             start - 1);
2174                        if (!ret && end < n->end)
2175                                ret = vfio_iommu_iova_insert(&n->list, end + 1,
2176                                                             n->end);
2177                        if (ret)
2178                                return ret;
2179
2180                        list_del(&n->list);
2181                        kfree(n);
2182                }
2183        }
2184
2185        if (list_empty(iova))
2186                return -EINVAL;
2187
2188        return 0;
2189}
2190
2191static void vfio_iommu_resv_free(struct list_head *resv_regions)
2192{
2193        struct iommu_resv_region *n, *next;
2194
2195        list_for_each_entry_safe(n, next, resv_regions, list) {
2196                list_del(&n->list);
2197                kfree(n);
2198        }
2199}
2200
2201static void vfio_iommu_iova_free(struct list_head *iova)
2202{
2203        struct vfio_iova *n, *next;
2204
2205        list_for_each_entry_safe(n, next, iova, list) {
2206                list_del(&n->list);
2207                kfree(n);
2208        }
2209}
2210
2211static int vfio_iommu_iova_get_copy(struct vfio_iommu *iommu,
2212                                    struct list_head *iova_copy)
2213{
2214        struct list_head *iova = &iommu->iova_list;
2215        struct vfio_iova *n;
2216        int ret;
2217
2218        list_for_each_entry(n, iova, list) {
2219                ret = vfio_iommu_iova_insert(iova_copy, n->start, n->end);
2220                if (ret)
2221                        goto out_free;
2222        }
2223
2224        return 0;
2225
2226out_free:
2227        vfio_iommu_iova_free(iova_copy);
2228        return ret;
2229}
2230
2231static void vfio_iommu_iova_insert_copy(struct vfio_iommu *iommu,
2232                                        struct list_head *iova_copy)
2233{
2234        struct list_head *iova = &iommu->iova_list;
2235
2236        vfio_iommu_iova_free(iova);
2237
2238        list_splice_tail(iova_copy, iova);
2239}
2240
2241static int vfio_iommu_type1_attach_group(void *iommu_data,
2242                                         struct iommu_group *iommu_group)
2243{
2244        struct vfio_iommu *iommu = iommu_data;
2245        struct vfio_group *group;
2246        struct vfio_domain *domain, *d;
2247        struct bus_type *bus = NULL;
2248        int ret;
2249        bool resv_msi, msi_remap;
2250        phys_addr_t resv_msi_base = 0;
2251        struct iommu_domain_geometry *geo;
2252        LIST_HEAD(iova_copy);
2253        LIST_HEAD(group_resv_regions);
2254
2255        mutex_lock(&iommu->lock);
2256
2257        /* Check for duplicates */
2258        if (vfio_iommu_find_iommu_group(iommu, iommu_group)) {
2259                mutex_unlock(&iommu->lock);
2260                return -EINVAL;
2261        }
2262
2263        group = kzalloc(sizeof(*group), GFP_KERNEL);
2264        domain = kzalloc(sizeof(*domain), GFP_KERNEL);
2265        if (!group || !domain) {
2266                ret = -ENOMEM;
2267                goto out_free;
2268        }
2269
2270        group->iommu_group = iommu_group;
2271
2272        /* Determine bus_type in order to allocate a domain */
2273        ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
2274        if (ret)
2275                goto out_free;
2276
2277        if (vfio_bus_is_mdev(bus)) {
2278                struct device *iommu_device = NULL;
2279
2280                group->mdev_group = true;
2281
2282                /* Determine the isolation type */
2283                ret = iommu_group_for_each_dev(iommu_group, &iommu_device,
2284                                               vfio_mdev_iommu_device);
2285                if (ret || !iommu_device) {
2286                        if (!iommu->external_domain) {
2287                                INIT_LIST_HEAD(&domain->group_list);
2288                                iommu->external_domain = domain;
2289                                vfio_update_pgsize_bitmap(iommu);
2290                        } else {
2291                                kfree(domain);
2292                        }
2293
2294                        list_add(&group->next,
2295                                 &iommu->external_domain->group_list);
2296                        /*
2297                         * Non-iommu backed group cannot dirty memory directly,
2298                         * it can only use interfaces that provide dirty
2299                         * tracking.
2300                         * The iommu scope can only be promoted with the
2301                         * addition of a dirty tracking group.
2302                         */
2303                        group->pinned_page_dirty_scope = true;
2304                        mutex_unlock(&iommu->lock);
2305
2306                        return 0;
2307                }
2308
2309                bus = iommu_device->bus;
2310        }
2311
2312        domain->domain = iommu_domain_alloc(bus);
2313        if (!domain->domain) {
2314                ret = -EIO;
2315                goto out_free;
2316        }
2317
2318        if (iommu->nesting) {
2319                ret = iommu_enable_nesting(domain->domain);
2320                if (ret)
2321                        goto out_domain;
2322        }
2323
2324        ret = vfio_iommu_attach_group(domain, group);
2325        if (ret)
2326                goto out_domain;
2327
2328        /* Get aperture info */
2329        geo = &domain->domain->geometry;
2330        if (vfio_iommu_aper_conflict(iommu, geo->aperture_start,
2331                                     geo->aperture_end)) {
2332                ret = -EINVAL;
2333                goto out_detach;
2334        }
2335
2336        ret = iommu_get_group_resv_regions(iommu_group, &group_resv_regions);
2337        if (ret)
2338                goto out_detach;
2339
2340        if (vfio_iommu_resv_conflict(iommu, &group_resv_regions)) {
2341                ret = -EINVAL;
2342                goto out_detach;
2343        }
2344
2345        /*
2346         * We don't want to work on the original iova list as the list
2347         * gets modified and in case of failure we have to retain the
2348         * original list. Get a copy here.
2349         */
2350        ret = vfio_iommu_iova_get_copy(iommu, &iova_copy);
2351        if (ret)
2352                goto out_detach;
2353
2354        ret = vfio_iommu_aper_resize(&iova_copy, geo->aperture_start,
2355                                     geo->aperture_end);
2356        if (ret)
2357                goto out_detach;
2358
2359        ret = vfio_iommu_resv_exclude(&iova_copy, &group_resv_regions);
2360        if (ret)
2361                goto out_detach;
2362
2363        resv_msi = vfio_iommu_has_sw_msi(&group_resv_regions, &resv_msi_base);
2364
2365        INIT_LIST_HEAD(&domain->group_list);
2366        list_add(&group->next, &domain->group_list);
2367
2368        msi_remap = irq_domain_check_msi_remap() ||
2369                    iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
2370
2371        if (!allow_unsafe_interrupts && !msi_remap) {
2372                pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
2373                       __func__);
2374                ret = -EPERM;
2375                goto out_detach;
2376        }
2377
2378        if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
2379                domain->prot |= IOMMU_CACHE;
2380
2381        /*
2382         * Try to match an existing compatible domain.  We don't want to
2383         * preclude an IOMMU driver supporting multiple bus_types and being
2384         * able to include different bus_types in the same IOMMU domain, so
2385         * we test whether the domains use the same iommu_ops rather than
2386         * testing if they're on the same bus_type.
2387         */
2388        list_for_each_entry(d, &iommu->domain_list, next) {
2389                if (d->domain->ops == domain->domain->ops &&
2390                    d->prot == domain->prot) {
2391                        vfio_iommu_detach_group(domain, group);
2392                        if (!vfio_iommu_attach_group(d, group)) {
2393                                list_add(&group->next, &d->group_list);
2394                                iommu_domain_free(domain->domain);
2395                                kfree(domain);
2396                                goto done;
2397                        }
2398
2399                        ret = vfio_iommu_attach_group(domain, group);
2400                        if (ret)
2401                                goto out_domain;
2402                }
2403        }
2404
2405        vfio_test_domain_fgsp(domain);
2406
2407        /* replay mappings on new domains */
2408        ret = vfio_iommu_replay(iommu, domain);
2409        if (ret)
2410                goto out_detach;
2411
2412        if (resv_msi) {
2413                ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
2414                if (ret && ret != -ENODEV)
2415                        goto out_detach;
2416        }
2417
2418        list_add(&domain->next, &iommu->domain_list);
2419        vfio_update_pgsize_bitmap(iommu);
2420done:
2421        /* Delete the old one and insert new iova list */
2422        vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2423
2424        /*
2425         * An iommu backed group can dirty memory directly and therefore
2426         * demotes the iommu scope until it declares itself dirty tracking
2427         * capable via the page pinning interface.
2428         */
2429        iommu->num_non_pinned_groups++;
2430        mutex_unlock(&iommu->lock);
2431        vfio_iommu_resv_free(&group_resv_regions);
2432
2433        return 0;
2434
2435out_detach:
2436        vfio_iommu_detach_group(domain, group);
2437out_domain:
2438        iommu_domain_free(domain->domain);
2439        vfio_iommu_iova_free(&iova_copy);
2440        vfio_iommu_resv_free(&group_resv_regions);
2441out_free:
2442        kfree(domain);
2443        kfree(group);
2444        mutex_unlock(&iommu->lock);
2445        return ret;
2446}
2447
2448static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
2449{
2450        struct rb_node *node;
2451
2452        while ((node = rb_first(&iommu->dma_list)))
2453                vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
2454}
2455
2456static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
2457{
2458        struct rb_node *n, *p;
2459
2460        n = rb_first(&iommu->dma_list);
2461        for (; n; n = rb_next(n)) {
2462                struct vfio_dma *dma;
2463                long locked = 0, unlocked = 0;
2464
2465                dma = rb_entry(n, struct vfio_dma, node);
2466                unlocked += vfio_unmap_unpin(iommu, dma, false);
2467                p = rb_first(&dma->pfn_list);
2468                for (; p; p = rb_next(p)) {
2469                        struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
2470                                                         node);
2471
2472                        if (!is_invalid_reserved_pfn(vpfn->pfn))
2473                                locked++;
2474                }
2475                vfio_lock_acct(dma, locked - unlocked, true);
2476        }
2477}
2478
2479/*
2480 * Called when a domain is removed in detach. It is possible that
2481 * the removed domain decided the iova aperture window. Modify the
2482 * iova aperture with the smallest window among existing domains.
2483 */
2484static void vfio_iommu_aper_expand(struct vfio_iommu *iommu,
2485                                   struct list_head *iova_copy)
2486{
2487        struct vfio_domain *domain;
2488        struct vfio_iova *node;
2489        dma_addr_t start = 0;
2490        dma_addr_t end = (dma_addr_t)~0;
2491
2492        if (list_empty(iova_copy))
2493                return;
2494
2495        list_for_each_entry(domain, &iommu->domain_list, next) {
2496                struct iommu_domain_geometry *geo = &domain->domain->geometry;
2497
2498                if (geo->aperture_start > start)
2499                        start = geo->aperture_start;
2500                if (geo->aperture_end < end)
2501                        end = geo->aperture_end;
2502        }
2503
2504        /* Modify aperture limits. The new aper is either same or bigger */
2505        node = list_first_entry(iova_copy, struct vfio_iova, list);
2506        node->start = start;
2507        node = list_last_entry(iova_copy, struct vfio_iova, list);
2508        node->end = end;
2509}
2510
2511/*
2512 * Called when a group is detached. The reserved regions for that
2513 * group can be part of valid iova now. But since reserved regions
2514 * may be duplicated among groups, populate the iova valid regions
2515 * list again.
2516 */
2517static int vfio_iommu_resv_refresh(struct vfio_iommu *iommu,
2518                                   struct list_head *iova_copy)
2519{
2520        struct vfio_domain *d;
2521        struct vfio_group *g;
2522        struct vfio_iova *node;
2523        dma_addr_t start, end;
2524        LIST_HEAD(resv_regions);
2525        int ret;
2526
2527        if (list_empty(iova_copy))
2528                return -EINVAL;
2529
2530        list_for_each_entry(d, &iommu->domain_list, next) {
2531                list_for_each_entry(g, &d->group_list, next) {
2532                        ret = iommu_get_group_resv_regions(g->iommu_group,
2533                                                           &resv_regions);
2534                        if (ret)
2535                                goto done;
2536                }
2537        }
2538
2539        node = list_first_entry(iova_copy, struct vfio_iova, list);
2540        start = node->start;
2541        node = list_last_entry(iova_copy, struct vfio_iova, list);
2542        end = node->end;
2543
2544        /* purge the iova list and create new one */
2545        vfio_iommu_iova_free(iova_copy);
2546
2547        ret = vfio_iommu_aper_resize(iova_copy, start, end);
2548        if (ret)
2549                goto done;
2550
2551        /* Exclude current reserved regions from iova ranges */
2552        ret = vfio_iommu_resv_exclude(iova_copy, &resv_regions);
2553done:
2554        vfio_iommu_resv_free(&resv_regions);
2555        return ret;
2556}
2557
2558static void vfio_iommu_type1_detach_group(void *iommu_data,
2559                                          struct iommu_group *iommu_group)
2560{
2561        struct vfio_iommu *iommu = iommu_data;
2562        struct vfio_domain *domain;
2563        struct vfio_group *group;
2564        bool update_dirty_scope = false;
2565        LIST_HEAD(iova_copy);
2566
2567        mutex_lock(&iommu->lock);
2568
2569        if (iommu->external_domain) {
2570                group = find_iommu_group(iommu->external_domain, iommu_group);
2571                if (group) {
2572                        update_dirty_scope = !group->pinned_page_dirty_scope;
2573                        list_del(&group->next);
2574                        kfree(group);
2575
2576                        if (list_empty(&iommu->external_domain->group_list)) {
2577                                if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu)) {
2578                                        WARN_ON(iommu->notifier.head);
2579                                        vfio_iommu_unmap_unpin_all(iommu);
2580                                }
2581
2582                                kfree(iommu->external_domain);
2583                                iommu->external_domain = NULL;
2584                        }
2585                        goto detach_group_done;
2586                }
2587        }
2588
2589        /*
2590         * Get a copy of iova list. This will be used to update
2591         * and to replace the current one later. Please note that
2592         * we will leave the original list as it is if update fails.
2593         */
2594        vfio_iommu_iova_get_copy(iommu, &iova_copy);
2595
2596        list_for_each_entry(domain, &iommu->domain_list, next) {
2597                group = find_iommu_group(domain, iommu_group);
2598                if (!group)
2599                        continue;
2600
2601                vfio_iommu_detach_group(domain, group);
2602                update_dirty_scope = !group->pinned_page_dirty_scope;
2603                list_del(&group->next);
2604                kfree(group);
2605                /*
2606                 * Group ownership provides privilege, if the group list is
2607                 * empty, the domain goes away. If it's the last domain with
2608                 * iommu and external domain doesn't exist, then all the
2609                 * mappings go away too. If it's the last domain with iommu and
2610                 * external domain exist, update accounting
2611                 */
2612                if (list_empty(&domain->group_list)) {
2613                        if (list_is_singular(&iommu->domain_list)) {
2614                                if (!iommu->external_domain) {
2615                                        WARN_ON(iommu->notifier.head);
2616                                        vfio_iommu_unmap_unpin_all(iommu);
2617                                } else {
2618                                        vfio_iommu_unmap_unpin_reaccount(iommu);
2619                                }
2620                        }
2621                        iommu_domain_free(domain->domain);
2622                        list_del(&domain->next);
2623                        kfree(domain);
2624                        vfio_iommu_aper_expand(iommu, &iova_copy);
2625                        vfio_update_pgsize_bitmap(iommu);
2626                }
2627                break;
2628        }
2629
2630        if (!vfio_iommu_resv_refresh(iommu, &iova_copy))
2631                vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2632        else
2633                vfio_iommu_iova_free(&iova_copy);
2634
2635detach_group_done:
2636        /*
2637         * Removal of a group without dirty tracking may allow the iommu scope
2638         * to be promoted.
2639         */
2640        if (update_dirty_scope) {
2641                iommu->num_non_pinned_groups--;
2642                if (iommu->dirty_page_tracking)
2643                        vfio_iommu_populate_bitmap_full(iommu);
2644        }
2645        mutex_unlock(&iommu->lock);
2646}
2647
2648static void *vfio_iommu_type1_open(unsigned long arg)
2649{
2650        struct vfio_iommu *iommu;
2651
2652        iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
2653        if (!iommu)
2654                return ERR_PTR(-ENOMEM);
2655
2656        switch (arg) {
2657        case VFIO_TYPE1_IOMMU:
2658                break;
2659        case VFIO_TYPE1_NESTING_IOMMU:
2660                iommu->nesting = true;
2661                fallthrough;
2662        case VFIO_TYPE1v2_IOMMU:
2663                iommu->v2 = true;
2664                break;
2665        default:
2666                kfree(iommu);
2667                return ERR_PTR(-EINVAL);
2668        }
2669
2670        INIT_LIST_HEAD(&iommu->domain_list);
2671        INIT_LIST_HEAD(&iommu->iova_list);
2672        iommu->dma_list = RB_ROOT;
2673        iommu->dma_avail = dma_entry_limit;
2674        iommu->container_open = true;
2675        mutex_init(&iommu->lock);
2676        BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
2677        init_waitqueue_head(&iommu->vaddr_wait);
2678
2679        return iommu;
2680}
2681
2682static void vfio_release_domain(struct vfio_domain *domain, bool external)
2683{
2684        struct vfio_group *group, *group_tmp;
2685
2686        list_for_each_entry_safe(group, group_tmp,
2687                                 &domain->group_list, next) {
2688                if (!external)
2689                        vfio_iommu_detach_group(domain, group);
2690                list_del(&group->next);
2691                kfree(group);
2692        }
2693
2694        if (!external)
2695                iommu_domain_free(domain->domain);
2696}
2697
2698static void vfio_iommu_type1_release(void *iommu_data)
2699{
2700        struct vfio_iommu *iommu = iommu_data;
2701        struct vfio_domain *domain, *domain_tmp;
2702
2703        if (iommu->external_domain) {
2704                vfio_release_domain(iommu->external_domain, true);
2705                kfree(iommu->external_domain);
2706        }
2707
2708        vfio_iommu_unmap_unpin_all(iommu);
2709
2710        list_for_each_entry_safe(domain, domain_tmp,
2711                                 &iommu->domain_list, next) {
2712                vfio_release_domain(domain, false);
2713                list_del(&domain->next);
2714                kfree(domain);
2715        }
2716
2717        vfio_iommu_iova_free(&iommu->iova_list);
2718
2719        kfree(iommu);
2720}
2721
2722static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
2723{
2724        struct vfio_domain *domain;
2725        int ret = 1;
2726
2727        mutex_lock(&iommu->lock);
2728        list_for_each_entry(domain, &iommu->domain_list, next) {
2729                if (!(domain->prot & IOMMU_CACHE)) {
2730                        ret = 0;
2731                        break;
2732                }
2733        }
2734        mutex_unlock(&iommu->lock);
2735
2736        return ret;
2737}
2738
2739static int vfio_iommu_type1_check_extension(struct vfio_iommu *iommu,
2740                                            unsigned long arg)
2741{
2742        switch (arg) {
2743        case VFIO_TYPE1_IOMMU:
2744        case VFIO_TYPE1v2_IOMMU:
2745        case VFIO_TYPE1_NESTING_IOMMU:
2746        case VFIO_UNMAP_ALL:
2747        case VFIO_UPDATE_VADDR:
2748                return 1;
2749        case VFIO_DMA_CC_IOMMU:
2750                if (!iommu)
2751                        return 0;
2752                return vfio_domains_have_iommu_cache(iommu);
2753        default:
2754                return 0;
2755        }
2756}
2757
2758static int vfio_iommu_iova_add_cap(struct vfio_info_cap *caps,
2759                 struct vfio_iommu_type1_info_cap_iova_range *cap_iovas,
2760                 size_t size)
2761{
2762        struct vfio_info_cap_header *header;
2763        struct vfio_iommu_type1_info_cap_iova_range *iova_cap;
2764
2765        header = vfio_info_cap_add(caps, size,
2766                                   VFIO_IOMMU_TYPE1_INFO_CAP_IOVA_RANGE, 1);
2767        if (IS_ERR(header))
2768                return PTR_ERR(header);
2769
2770        iova_cap = container_of(header,
2771                                struct vfio_iommu_type1_info_cap_iova_range,
2772                                header);
2773        iova_cap->nr_iovas = cap_iovas->nr_iovas;
2774        memcpy(iova_cap->iova_ranges, cap_iovas->iova_ranges,
2775               cap_iovas->nr_iovas * sizeof(*cap_iovas->iova_ranges));
2776        return 0;
2777}
2778
2779static int vfio_iommu_iova_build_caps(struct vfio_iommu *iommu,
2780                                      struct vfio_info_cap *caps)
2781{
2782        struct vfio_iommu_type1_info_cap_iova_range *cap_iovas;
2783        struct vfio_iova *iova;
2784        size_t size;
2785        int iovas = 0, i = 0, ret;
2786
2787        list_for_each_entry(iova, &iommu->iova_list, list)
2788                iovas++;
2789
2790        if (!iovas) {
2791                /*
2792                 * Return 0 as a container with a single mdev device
2793                 * will have an empty list
2794                 */
2795                return 0;
2796        }
2797
2798        size = struct_size(cap_iovas, iova_ranges, iovas);
2799
2800        cap_iovas = kzalloc(size, GFP_KERNEL);
2801        if (!cap_iovas)
2802                return -ENOMEM;
2803
2804        cap_iovas->nr_iovas = iovas;
2805
2806        list_for_each_entry(iova, &iommu->iova_list, list) {
2807                cap_iovas->iova_ranges[i].start = iova->start;
2808                cap_iovas->iova_ranges[i].end = iova->end;
2809                i++;
2810        }
2811
2812        ret = vfio_iommu_iova_add_cap(caps, cap_iovas, size);
2813
2814        kfree(cap_iovas);
2815        return ret;
2816}
2817
2818static int vfio_iommu_migration_build_caps(struct vfio_iommu *iommu,
2819                                           struct vfio_info_cap *caps)
2820{
2821        struct vfio_iommu_type1_info_cap_migration cap_mig;
2822
2823        cap_mig.header.id = VFIO_IOMMU_TYPE1_INFO_CAP_MIGRATION;
2824        cap_mig.header.version = 1;
2825
2826        cap_mig.flags = 0;
2827        /* support minimum pgsize */
2828        cap_mig.pgsize_bitmap = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2829        cap_mig.max_dirty_bitmap_size = DIRTY_BITMAP_SIZE_MAX;
2830
2831        return vfio_info_add_capability(caps, &cap_mig.header, sizeof(cap_mig));
2832}
2833
2834static int vfio_iommu_dma_avail_build_caps(struct vfio_iommu *iommu,
2835                                           struct vfio_info_cap *caps)
2836{
2837        struct vfio_iommu_type1_info_dma_avail cap_dma_avail;
2838
2839        cap_dma_avail.header.id = VFIO_IOMMU_TYPE1_INFO_DMA_AVAIL;
2840        cap_dma_avail.header.version = 1;
2841
2842        cap_dma_avail.avail = iommu->dma_avail;
2843
2844        return vfio_info_add_capability(caps, &cap_dma_avail.header,
2845                                        sizeof(cap_dma_avail));
2846}
2847
2848static int vfio_iommu_type1_get_info(struct vfio_iommu *iommu,
2849                                     unsigned long arg)
2850{
2851        struct vfio_iommu_type1_info info;
2852        unsigned long minsz;
2853        struct vfio_info_cap caps = { .buf = NULL, .size = 0 };
2854        unsigned long capsz;
2855        int ret;
2856
2857        minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
2858
2859        /* For backward compatibility, cannot require this */
2860        capsz = offsetofend(struct vfio_iommu_type1_info, cap_offset);
2861
2862        if (copy_from_user(&info, (void __user *)arg, minsz))
2863                return -EFAULT;
2864
2865        if (info.argsz < minsz)
2866                return -EINVAL;
2867
2868        if (info.argsz >= capsz) {
2869                minsz = capsz;
2870                info.cap_offset = 0; /* output, no-recopy necessary */
2871        }
2872
2873        mutex_lock(&iommu->lock);
2874        info.flags = VFIO_IOMMU_INFO_PGSIZES;
2875
2876        info.iova_pgsizes = iommu->pgsize_bitmap;
2877
2878        ret = vfio_iommu_migration_build_caps(iommu, &caps);
2879
2880        if (!ret)
2881                ret = vfio_iommu_dma_avail_build_caps(iommu, &caps);
2882
2883        if (!ret)
2884                ret = vfio_iommu_iova_build_caps(iommu, &caps);
2885
2886        mutex_unlock(&iommu->lock);
2887
2888        if (ret)
2889                return ret;
2890
2891        if (caps.size) {
2892                info.flags |= VFIO_IOMMU_INFO_CAPS;
2893
2894                if (info.argsz < sizeof(info) + caps.size) {
2895                        info.argsz = sizeof(info) + caps.size;
2896                } else {
2897                        vfio_info_cap_shift(&caps, sizeof(info));
2898                        if (copy_to_user((void __user *)arg +
2899                                        sizeof(info), caps.buf,
2900                                        caps.size)) {
2901                                kfree(caps.buf);
2902                                return -EFAULT;
2903                        }
2904                        info.cap_offset = sizeof(info);
2905                }
2906
2907                kfree(caps.buf);
2908        }
2909
2910        return copy_to_user((void __user *)arg, &info, minsz) ?
2911                        -EFAULT : 0;
2912}
2913
2914static int vfio_iommu_type1_map_dma(struct vfio_iommu *iommu,
2915                                    unsigned long arg)
2916{
2917        struct vfio_iommu_type1_dma_map map;
2918        unsigned long minsz;
2919        uint32_t mask = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE |
2920                        VFIO_DMA_MAP_FLAG_VADDR;
2921
2922        minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
2923
2924        if (copy_from_user(&map, (void __user *)arg, minsz))
2925                return -EFAULT;
2926
2927        if (map.argsz < minsz || map.flags & ~mask)
2928                return -EINVAL;
2929
2930        return vfio_dma_do_map(iommu, &map);
2931}
2932
2933static int vfio_iommu_type1_unmap_dma(struct vfio_iommu *iommu,
2934                                      unsigned long arg)
2935{
2936        struct vfio_iommu_type1_dma_unmap unmap;
2937        struct vfio_bitmap bitmap = { 0 };
2938        uint32_t mask = VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP |
2939                        VFIO_DMA_UNMAP_FLAG_VADDR |
2940                        VFIO_DMA_UNMAP_FLAG_ALL;
2941        unsigned long minsz;
2942        int ret;
2943
2944        minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
2945
2946        if (copy_from_user(&unmap, (void __user *)arg, minsz))
2947                return -EFAULT;
2948
2949        if (unmap.argsz < minsz || unmap.flags & ~mask)
2950                return -EINVAL;
2951
2952        if ((unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
2953            (unmap.flags & (VFIO_DMA_UNMAP_FLAG_ALL |
2954                            VFIO_DMA_UNMAP_FLAG_VADDR)))
2955                return -EINVAL;
2956
2957        if (unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
2958                unsigned long pgshift;
2959
2960                if (unmap.argsz < (minsz + sizeof(bitmap)))
2961                        return -EINVAL;
2962
2963                if (copy_from_user(&bitmap,
2964                                   (void __user *)(arg + minsz),
2965                                   sizeof(bitmap)))
2966                        return -EFAULT;
2967
2968                if (!access_ok((void __user *)bitmap.data, bitmap.size))
2969                        return -EINVAL;
2970
2971                pgshift = __ffs(bitmap.pgsize);
2972                ret = verify_bitmap_size(unmap.size >> pgshift,
2973                                         bitmap.size);
2974                if (ret)
2975                        return ret;
2976        }
2977
2978        ret = vfio_dma_do_unmap(iommu, &unmap, &bitmap);
2979        if (ret)
2980                return ret;
2981
2982        return copy_to_user((void __user *)arg, &unmap, minsz) ?
2983                        -EFAULT : 0;
2984}
2985
2986static int vfio_iommu_type1_dirty_pages(struct vfio_iommu *iommu,
2987                                        unsigned long arg)
2988{
2989        struct vfio_iommu_type1_dirty_bitmap dirty;
2990        uint32_t mask = VFIO_IOMMU_DIRTY_PAGES_FLAG_START |
2991                        VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP |
2992                        VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP;
2993        unsigned long minsz;
2994        int ret = 0;
2995
2996        if (!iommu->v2)
2997                return -EACCES;
2998
2999        minsz = offsetofend(struct vfio_iommu_type1_dirty_bitmap, flags);
3000
3001        if (copy_from_user(&dirty, (void __user *)arg, minsz))
3002                return -EFAULT;
3003
3004        if (dirty.argsz < minsz || dirty.flags & ~mask)
3005                return -EINVAL;
3006
3007        /* only one flag should be set at a time */
3008        if (__ffs(dirty.flags) != __fls(dirty.flags))
3009                return -EINVAL;
3010
3011        if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_START) {
3012                size_t pgsize;
3013
3014                mutex_lock(&iommu->lock);
3015                pgsize = 1 << __ffs(iommu->pgsize_bitmap);
3016                if (!iommu->dirty_page_tracking) {
3017                        ret = vfio_dma_bitmap_alloc_all(iommu, pgsize);
3018                        if (!ret)
3019                                iommu->dirty_page_tracking = true;
3020                }
3021                mutex_unlock(&iommu->lock);
3022                return ret;
3023        } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP) {
3024                mutex_lock(&iommu->lock);
3025                if (iommu->dirty_page_tracking) {
3026                        iommu->dirty_page_tracking = false;
3027                        vfio_dma_bitmap_free_all(iommu);
3028                }
3029                mutex_unlock(&iommu->lock);
3030                return 0;
3031        } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP) {
3032                struct vfio_iommu_type1_dirty_bitmap_get range;
3033                unsigned long pgshift;
3034                size_t data_size = dirty.argsz - minsz;
3035                size_t iommu_pgsize;
3036
3037                if (!data_size || data_size < sizeof(range))
3038                        return -EINVAL;
3039
3040                if (copy_from_user(&range, (void __user *)(arg + minsz),
3041                                   sizeof(range)))
3042                        return -EFAULT;
3043
3044                if (range.iova + range.size < range.iova)
3045                        return -EINVAL;
3046                if (!access_ok((void __user *)range.bitmap.data,
3047                               range.bitmap.size))
3048                        return -EINVAL;
3049
3050                pgshift = __ffs(range.bitmap.pgsize);
3051                ret = verify_bitmap_size(range.size >> pgshift,
3052                                         range.bitmap.size);
3053                if (ret)
3054                        return ret;
3055
3056                mutex_lock(&iommu->lock);
3057
3058                iommu_pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
3059
3060                /* allow only smallest supported pgsize */
3061                if (range.bitmap.pgsize != iommu_pgsize) {
3062                        ret = -EINVAL;
3063                        goto out_unlock;
3064                }
3065                if (range.iova & (iommu_pgsize - 1)) {
3066                        ret = -EINVAL;
3067                        goto out_unlock;
3068                }
3069                if (!range.size || range.size & (iommu_pgsize - 1)) {
3070                        ret = -EINVAL;
3071                        goto out_unlock;
3072                }
3073
3074                if (iommu->dirty_page_tracking)
3075                        ret = vfio_iova_dirty_bitmap(range.bitmap.data,
3076                                                     iommu, range.iova,
3077                                                     range.size,
3078                                                     range.bitmap.pgsize);
3079                else
3080                        ret = -EINVAL;
3081out_unlock:
3082                mutex_unlock(&iommu->lock);
3083
3084                return ret;
3085        }
3086
3087        return -EINVAL;
3088}
3089
3090static long vfio_iommu_type1_ioctl(void *iommu_data,
3091                                   unsigned int cmd, unsigned long arg)
3092{
3093        struct vfio_iommu *iommu = iommu_data;
3094
3095        switch (cmd) {
3096        case VFIO_CHECK_EXTENSION:
3097                return vfio_iommu_type1_check_extension(iommu, arg);
3098        case VFIO_IOMMU_GET_INFO:
3099                return vfio_iommu_type1_get_info(iommu, arg);
3100        case VFIO_IOMMU_MAP_DMA:
3101                return vfio_iommu_type1_map_dma(iommu, arg);
3102        case VFIO_IOMMU_UNMAP_DMA:
3103                return vfio_iommu_type1_unmap_dma(iommu, arg);
3104        case VFIO_IOMMU_DIRTY_PAGES:
3105                return vfio_iommu_type1_dirty_pages(iommu, arg);
3106        default:
3107                return -ENOTTY;
3108        }
3109}
3110
3111static int vfio_iommu_type1_register_notifier(void *iommu_data,
3112                                              unsigned long *events,
3113                                              struct notifier_block *nb)
3114{
3115        struct vfio_iommu *iommu = iommu_data;
3116
3117        /* clear known events */
3118        *events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
3119
3120        /* refuse to register if still events remaining */
3121        if (*events)
3122                return -EINVAL;
3123
3124        return blocking_notifier_chain_register(&iommu->notifier, nb);
3125}
3126
3127static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
3128                                                struct notifier_block *nb)
3129{
3130        struct vfio_iommu *iommu = iommu_data;
3131
3132        return blocking_notifier_chain_unregister(&iommu->notifier, nb);
3133}
3134
3135static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
3136                                         dma_addr_t user_iova, void *data,
3137                                         size_t count, bool write,
3138                                         size_t *copied)
3139{
3140        struct mm_struct *mm;
3141        unsigned long vaddr;
3142        struct vfio_dma *dma;
3143        bool kthread = current->mm == NULL;
3144        size_t offset;
3145        int ret;
3146
3147        *copied = 0;
3148
3149        ret = vfio_find_dma_valid(iommu, user_iova, 1, &dma);
3150        if (ret < 0)
3151                return ret;
3152
3153        if ((write && !(dma->prot & IOMMU_WRITE)) ||
3154                        !(dma->prot & IOMMU_READ))
3155                return -EPERM;
3156
3157        mm = get_task_mm(dma->task);
3158
3159        if (!mm)
3160                return -EPERM;
3161
3162        if (kthread)
3163                kthread_use_mm(mm);
3164        else if (current->mm != mm)
3165                goto out;
3166
3167        offset = user_iova - dma->iova;
3168
3169        if (count > dma->size - offset)
3170                count = dma->size - offset;
3171
3172        vaddr = dma->vaddr + offset;
3173
3174        if (write) {
3175                *copied = copy_to_user((void __user *)vaddr, data,
3176                                         count) ? 0 : count;
3177                if (*copied && iommu->dirty_page_tracking) {
3178                        unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
3179                        /*
3180                         * Bitmap populated with the smallest supported page
3181                         * size
3182                         */
3183                        bitmap_set(dma->bitmap, offset >> pgshift,
3184                                   ((offset + *copied - 1) >> pgshift) -
3185                                   (offset >> pgshift) + 1);
3186                }
3187        } else
3188                *copied = copy_from_user(data, (void __user *)vaddr,
3189                                           count) ? 0 : count;
3190        if (kthread)
3191                kthread_unuse_mm(mm);
3192out:
3193        mmput(mm);
3194        return *copied ? 0 : -EFAULT;
3195}
3196
3197static int vfio_iommu_type1_dma_rw(void *iommu_data, dma_addr_t user_iova,
3198                                   void *data, size_t count, bool write)
3199{
3200        struct vfio_iommu *iommu = iommu_data;
3201        int ret = 0;
3202        size_t done;
3203
3204        mutex_lock(&iommu->lock);
3205        while (count > 0) {
3206                ret = vfio_iommu_type1_dma_rw_chunk(iommu, user_iova, data,
3207                                                    count, write, &done);
3208                if (ret)
3209                        break;
3210
3211                count -= done;
3212                data += done;
3213                user_iova += done;
3214        }
3215
3216        mutex_unlock(&iommu->lock);
3217        return ret;
3218}
3219
3220static struct iommu_domain *
3221vfio_iommu_type1_group_iommu_domain(void *iommu_data,
3222                                    struct iommu_group *iommu_group)
3223{
3224        struct iommu_domain *domain = ERR_PTR(-ENODEV);
3225        struct vfio_iommu *iommu = iommu_data;
3226        struct vfio_domain *d;
3227
3228        if (!iommu || !iommu_group)
3229                return ERR_PTR(-EINVAL);
3230
3231        mutex_lock(&iommu->lock);
3232        list_for_each_entry(d, &iommu->domain_list, next) {
3233                if (find_iommu_group(d, iommu_group)) {
3234                        domain = d->domain;
3235                        break;
3236                }
3237        }
3238        mutex_unlock(&iommu->lock);
3239
3240        return domain;
3241}
3242
3243static void vfio_iommu_type1_notify(void *iommu_data,
3244                                    enum vfio_iommu_notify_type event)
3245{
3246        struct vfio_iommu *iommu = iommu_data;
3247
3248        if (event != VFIO_IOMMU_CONTAINER_CLOSE)
3249                return;
3250        mutex_lock(&iommu->lock);
3251        iommu->container_open = false;
3252        mutex_unlock(&iommu->lock);
3253        wake_up_all(&iommu->vaddr_wait);
3254}
3255
3256static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
3257        .name                   = "vfio-iommu-type1",
3258        .owner                  = THIS_MODULE,
3259        .open                   = vfio_iommu_type1_open,
3260        .release                = vfio_iommu_type1_release,
3261        .ioctl                  = vfio_iommu_type1_ioctl,
3262        .attach_group           = vfio_iommu_type1_attach_group,
3263        .detach_group           = vfio_iommu_type1_detach_group,
3264        .pin_pages              = vfio_iommu_type1_pin_pages,
3265        .unpin_pages            = vfio_iommu_type1_unpin_pages,
3266        .register_notifier      = vfio_iommu_type1_register_notifier,
3267        .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
3268        .dma_rw                 = vfio_iommu_type1_dma_rw,
3269        .group_iommu_domain     = vfio_iommu_type1_group_iommu_domain,
3270        .notify                 = vfio_iommu_type1_notify,
3271};
3272
3273static int __init vfio_iommu_type1_init(void)
3274{
3275        return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
3276}
3277
3278static void __exit vfio_iommu_type1_cleanup(void)
3279{
3280        vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
3281}
3282
3283module_init(vfio_iommu_type1_init);
3284module_exit(vfio_iommu_type1_cleanup);
3285
3286MODULE_VERSION(DRIVER_VERSION);
3287MODULE_LICENSE("GPL v2");
3288MODULE_AUTHOR(DRIVER_AUTHOR);
3289MODULE_DESCRIPTION(DRIVER_DESC);
3290