linux/drivers/vfio/vfio_iommu_type1.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * VFIO: IOMMU DMA mapping support for Type1 IOMMU
   4 *
   5 * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
   6 *     Author: Alex Williamson <alex.williamson@redhat.com>
   7 *
   8 * Derived from original vfio:
   9 * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
  10 * Author: Tom Lyon, pugs@cisco.com
  11 *
  12 * We arbitrarily define a Type1 IOMMU as one matching the below code.
  13 * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
  14 * VT-d, but that makes it harder to re-use as theoretically anyone
  15 * implementing a similar IOMMU could make use of this.  We expect the
  16 * IOMMU to support the IOMMU API and have few to no restrictions around
  17 * the IOVA range that can be mapped.  The Type1 IOMMU is currently
  18 * optimized for relatively static mappings of a userspace process with
  19 * userpsace pages pinned into memory.  We also assume devices and IOMMU
  20 * domains are PCI based as the IOMMU API is still centered around a
  21 * device/bus interface rather than a group interface.
  22 */
  23
  24#include <linux/compat.h>
  25#include <linux/device.h>
  26#include <linux/fs.h>
  27#include <linux/iommu.h>
  28#include <linux/module.h>
  29#include <linux/mm.h>
  30#include <linux/rbtree.h>
  31#include <linux/sched/signal.h>
  32#include <linux/sched/mm.h>
  33#include <linux/slab.h>
  34#include <linux/uaccess.h>
  35#include <linux/vfio.h>
  36#include <linux/workqueue.h>
  37#include <linux/mdev.h>
  38#include <linux/notifier.h>
  39#include <linux/dma-iommu.h>
  40#include <linux/irqdomain.h>
  41
  42#define DRIVER_VERSION  "0.2"
  43#define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
  44#define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
  45
  46static bool allow_unsafe_interrupts;
  47module_param_named(allow_unsafe_interrupts,
  48                   allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
  49MODULE_PARM_DESC(allow_unsafe_interrupts,
  50                 "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
  51
  52static bool disable_hugepages;
  53module_param_named(disable_hugepages,
  54                   disable_hugepages, bool, S_IRUGO | S_IWUSR);
  55MODULE_PARM_DESC(disable_hugepages,
  56                 "Disable VFIO IOMMU support for IOMMU hugepages.");
  57
  58static unsigned int dma_entry_limit __read_mostly = U16_MAX;
  59module_param_named(dma_entry_limit, dma_entry_limit, uint, 0644);
  60MODULE_PARM_DESC(dma_entry_limit,
  61                 "Maximum number of user DMA mappings per container (65535).");
  62
  63struct vfio_iommu {
  64        struct list_head        domain_list;
  65        struct vfio_domain      *external_domain; /* domain for external user */
  66        struct mutex            lock;
  67        struct rb_root          dma_list;
  68        struct blocking_notifier_head notifier;
  69        unsigned int            dma_avail;
  70        bool                    v2;
  71        bool                    nesting;
  72};
  73
  74struct vfio_domain {
  75        struct iommu_domain     *domain;
  76        struct list_head        next;
  77        struct list_head        group_list;
  78        int                     prot;           /* IOMMU_CACHE */
  79        bool                    fgsp;           /* Fine-grained super pages */
  80};
  81
  82struct vfio_dma {
  83        struct rb_node          node;
  84        dma_addr_t              iova;           /* Device address */
  85        unsigned long           vaddr;          /* Process virtual addr */
  86        size_t                  size;           /* Map size (bytes) */
  87        int                     prot;           /* IOMMU_READ/WRITE */
  88        bool                    iommu_mapped;
  89        bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
  90        struct task_struct      *task;
  91        struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
  92};
  93
  94struct vfio_group {
  95        struct iommu_group      *iommu_group;
  96        struct list_head        next;
  97        bool                    mdev_group;     /* An mdev group */
  98};
  99
 100/*
 101 * Guest RAM pinning working set or DMA target
 102 */
 103struct vfio_pfn {
 104        struct rb_node          node;
 105        dma_addr_t              iova;           /* Device address */
 106        unsigned long           pfn;            /* Host pfn */
 107        atomic_t                ref_count;
 108};
 109
 110struct vfio_regions {
 111        struct list_head list;
 112        dma_addr_t iova;
 113        phys_addr_t phys;
 114        size_t len;
 115};
 116
 117#define IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu) \
 118                                        (!list_empty(&iommu->domain_list))
 119
 120static int put_pfn(unsigned long pfn, int prot);
 121
 122/*
 123 * This code handles mapping and unmapping of user data buffers
 124 * into DMA'ble space using the IOMMU
 125 */
 126
 127static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
 128                                      dma_addr_t start, size_t size)
 129{
 130        struct rb_node *node = iommu->dma_list.rb_node;
 131
 132        while (node) {
 133                struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
 134
 135                if (start + size <= dma->iova)
 136                        node = node->rb_left;
 137                else if (start >= dma->iova + dma->size)
 138                        node = node->rb_right;
 139                else
 140                        return dma;
 141        }
 142
 143        return NULL;
 144}
 145
 146static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
 147{
 148        struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
 149        struct vfio_dma *dma;
 150
 151        while (*link) {
 152                parent = *link;
 153                dma = rb_entry(parent, struct vfio_dma, node);
 154
 155                if (new->iova + new->size <= dma->iova)
 156                        link = &(*link)->rb_left;
 157                else
 158                        link = &(*link)->rb_right;
 159        }
 160
 161        rb_link_node(&new->node, parent, link);
 162        rb_insert_color(&new->node, &iommu->dma_list);
 163}
 164
 165static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
 166{
 167        rb_erase(&old->node, &iommu->dma_list);
 168}
 169
 170/*
 171 * Helper Functions for host iova-pfn list
 172 */
 173static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
 174{
 175        struct vfio_pfn *vpfn;
 176        struct rb_node *node = dma->pfn_list.rb_node;
 177
 178        while (node) {
 179                vpfn = rb_entry(node, struct vfio_pfn, node);
 180
 181                if (iova < vpfn->iova)
 182                        node = node->rb_left;
 183                else if (iova > vpfn->iova)
 184                        node = node->rb_right;
 185                else
 186                        return vpfn;
 187        }
 188        return NULL;
 189}
 190
 191static void vfio_link_pfn(struct vfio_dma *dma,
 192                          struct vfio_pfn *new)
 193{
 194        struct rb_node **link, *parent = NULL;
 195        struct vfio_pfn *vpfn;
 196
 197        link = &dma->pfn_list.rb_node;
 198        while (*link) {
 199                parent = *link;
 200                vpfn = rb_entry(parent, struct vfio_pfn, node);
 201
 202                if (new->iova < vpfn->iova)
 203                        link = &(*link)->rb_left;
 204                else
 205                        link = &(*link)->rb_right;
 206        }
 207
 208        rb_link_node(&new->node, parent, link);
 209        rb_insert_color(&new->node, &dma->pfn_list);
 210}
 211
 212static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
 213{
 214        rb_erase(&old->node, &dma->pfn_list);
 215}
 216
 217static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
 218                                unsigned long pfn)
 219{
 220        struct vfio_pfn *vpfn;
 221
 222        vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
 223        if (!vpfn)
 224                return -ENOMEM;
 225
 226        vpfn->iova = iova;
 227        vpfn->pfn = pfn;
 228        atomic_set(&vpfn->ref_count, 1);
 229        vfio_link_pfn(dma, vpfn);
 230        return 0;
 231}
 232
 233static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
 234                                      struct vfio_pfn *vpfn)
 235{
 236        vfio_unlink_pfn(dma, vpfn);
 237        kfree(vpfn);
 238}
 239
 240static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
 241                                               unsigned long iova)
 242{
 243        struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
 244
 245        if (vpfn)
 246                atomic_inc(&vpfn->ref_count);
 247        return vpfn;
 248}
 249
 250static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
 251{
 252        int ret = 0;
 253
 254        if (atomic_dec_and_test(&vpfn->ref_count)) {
 255                ret = put_pfn(vpfn->pfn, dma->prot);
 256                vfio_remove_from_pfn_list(dma, vpfn);
 257        }
 258        return ret;
 259}
 260
 261static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
 262{
 263        struct mm_struct *mm;
 264        int ret;
 265
 266        if (!npage)
 267                return 0;
 268
 269        mm = async ? get_task_mm(dma->task) : dma->task->mm;
 270        if (!mm)
 271                return -ESRCH; /* process exited */
 272
 273        ret = down_write_killable(&mm->mmap_sem);
 274        if (!ret) {
 275                ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task,
 276                                          dma->lock_cap);
 277                up_write(&mm->mmap_sem);
 278        }
 279
 280        if (async)
 281                mmput(mm);
 282
 283        return ret;
 284}
 285
 286/*
 287 * Some mappings aren't backed by a struct page, for example an mmap'd
 288 * MMIO range for our own or another device.  These use a different
 289 * pfn conversion and shouldn't be tracked as locked pages.
 290 */
 291static bool is_invalid_reserved_pfn(unsigned long pfn)
 292{
 293        if (pfn_valid(pfn)) {
 294                bool reserved;
 295                struct page *tail = pfn_to_page(pfn);
 296                struct page *head = compound_head(tail);
 297                reserved = !!(PageReserved(head));
 298                if (head != tail) {
 299                        /*
 300                         * "head" is not a dangling pointer
 301                         * (compound_head takes care of that)
 302                         * but the hugepage may have been split
 303                         * from under us (and we may not hold a
 304                         * reference count on the head page so it can
 305                         * be reused before we run PageReferenced), so
 306                         * we've to check PageTail before returning
 307                         * what we just read.
 308                         */
 309                        smp_rmb();
 310                        if (PageTail(tail))
 311                                return reserved;
 312                }
 313                return PageReserved(tail);
 314        }
 315
 316        return true;
 317}
 318
 319static int put_pfn(unsigned long pfn, int prot)
 320{
 321        if (!is_invalid_reserved_pfn(pfn)) {
 322                struct page *page = pfn_to_page(pfn);
 323                if (prot & IOMMU_WRITE)
 324                        SetPageDirty(page);
 325                put_page(page);
 326                return 1;
 327        }
 328        return 0;
 329}
 330
 331static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
 332                         int prot, unsigned long *pfn)
 333{
 334        struct page *page[1];
 335        struct vm_area_struct *vma;
 336        struct vm_area_struct *vmas[1];
 337        unsigned int flags = 0;
 338        int ret;
 339
 340        if (prot & IOMMU_WRITE)
 341                flags |= FOLL_WRITE;
 342
 343        down_read(&mm->mmap_sem);
 344        if (mm == current->mm) {
 345                ret = get_user_pages(vaddr, 1, flags | FOLL_LONGTERM, page,
 346                                     vmas);
 347        } else {
 348                ret = get_user_pages_remote(NULL, mm, vaddr, 1, flags, page,
 349                                            vmas, NULL);
 350                /*
 351                 * The lifetime of a vaddr_get_pfn() page pin is
 352                 * userspace-controlled. In the fs-dax case this could
 353                 * lead to indefinite stalls in filesystem operations.
 354                 * Disallow attempts to pin fs-dax pages via this
 355                 * interface.
 356                 */
 357                if (ret > 0 && vma_is_fsdax(vmas[0])) {
 358                        ret = -EOPNOTSUPP;
 359                        put_page(page[0]);
 360                }
 361        }
 362        up_read(&mm->mmap_sem);
 363
 364        if (ret == 1) {
 365                *pfn = page_to_pfn(page[0]);
 366                return 0;
 367        }
 368
 369        down_read(&mm->mmap_sem);
 370
 371        vma = find_vma_intersection(mm, vaddr, vaddr + 1);
 372
 373        if (vma && vma->vm_flags & VM_PFNMAP) {
 374                *pfn = ((vaddr - vma->vm_start) >> PAGE_SHIFT) + vma->vm_pgoff;
 375                if (is_invalid_reserved_pfn(*pfn))
 376                        ret = 0;
 377        }
 378
 379        up_read(&mm->mmap_sem);
 380        return ret;
 381}
 382
 383/*
 384 * Attempt to pin pages.  We really don't want to track all the pfns and
 385 * the iommu can only map chunks of consecutive pfns anyway, so get the
 386 * first page and all consecutive pages with the same locking.
 387 */
 388static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
 389                                  long npage, unsigned long *pfn_base,
 390                                  unsigned long limit)
 391{
 392        unsigned long pfn = 0;
 393        long ret, pinned = 0, lock_acct = 0;
 394        bool rsvd;
 395        dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
 396
 397        /* This code path is only user initiated */
 398        if (!current->mm)
 399                return -ENODEV;
 400
 401        ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
 402        if (ret)
 403                return ret;
 404
 405        pinned++;
 406        rsvd = is_invalid_reserved_pfn(*pfn_base);
 407
 408        /*
 409         * Reserved pages aren't counted against the user, externally pinned
 410         * pages are already counted against the user.
 411         */
 412        if (!rsvd && !vfio_find_vpfn(dma, iova)) {
 413                if (!dma->lock_cap && current->mm->locked_vm + 1 > limit) {
 414                        put_pfn(*pfn_base, dma->prot);
 415                        pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
 416                                        limit << PAGE_SHIFT);
 417                        return -ENOMEM;
 418                }
 419                lock_acct++;
 420        }
 421
 422        if (unlikely(disable_hugepages))
 423                goto out;
 424
 425        /* Lock all the consecutive pages from pfn_base */
 426        for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
 427             pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
 428                ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
 429                if (ret)
 430                        break;
 431
 432                if (pfn != *pfn_base + pinned ||
 433                    rsvd != is_invalid_reserved_pfn(pfn)) {
 434                        put_pfn(pfn, dma->prot);
 435                        break;
 436                }
 437
 438                if (!rsvd && !vfio_find_vpfn(dma, iova)) {
 439                        if (!dma->lock_cap &&
 440                            current->mm->locked_vm + lock_acct + 1 > limit) {
 441                                put_pfn(pfn, dma->prot);
 442                                pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
 443                                        __func__, limit << PAGE_SHIFT);
 444                                ret = -ENOMEM;
 445                                goto unpin_out;
 446                        }
 447                        lock_acct++;
 448                }
 449        }
 450
 451out:
 452        ret = vfio_lock_acct(dma, lock_acct, false);
 453
 454unpin_out:
 455        if (ret) {
 456                if (!rsvd) {
 457                        for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
 458                                put_pfn(pfn, dma->prot);
 459                }
 460
 461                return ret;
 462        }
 463
 464        return pinned;
 465}
 466
 467static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
 468                                    unsigned long pfn, long npage,
 469                                    bool do_accounting)
 470{
 471        long unlocked = 0, locked = 0;
 472        long i;
 473
 474        for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
 475                if (put_pfn(pfn++, dma->prot)) {
 476                        unlocked++;
 477                        if (vfio_find_vpfn(dma, iova))
 478                                locked++;
 479                }
 480        }
 481
 482        if (do_accounting)
 483                vfio_lock_acct(dma, locked - unlocked, true);
 484
 485        return unlocked;
 486}
 487
 488static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
 489                                  unsigned long *pfn_base, bool do_accounting)
 490{
 491        struct mm_struct *mm;
 492        int ret;
 493
 494        mm = get_task_mm(dma->task);
 495        if (!mm)
 496                return -ENODEV;
 497
 498        ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
 499        if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
 500                ret = vfio_lock_acct(dma, 1, true);
 501                if (ret) {
 502                        put_pfn(*pfn_base, dma->prot);
 503                        if (ret == -ENOMEM)
 504                                pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
 505                                        "(%ld) exceeded\n", __func__,
 506                                        dma->task->comm, task_pid_nr(dma->task),
 507                                        task_rlimit(dma->task, RLIMIT_MEMLOCK));
 508                }
 509        }
 510
 511        mmput(mm);
 512        return ret;
 513}
 514
 515static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
 516                                    bool do_accounting)
 517{
 518        int unlocked;
 519        struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
 520
 521        if (!vpfn)
 522                return 0;
 523
 524        unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
 525
 526        if (do_accounting)
 527                vfio_lock_acct(dma, -unlocked, true);
 528
 529        return unlocked;
 530}
 531
 532static int vfio_iommu_type1_pin_pages(void *iommu_data,
 533                                      unsigned long *user_pfn,
 534                                      int npage, int prot,
 535                                      unsigned long *phys_pfn)
 536{
 537        struct vfio_iommu *iommu = iommu_data;
 538        int i, j, ret;
 539        unsigned long remote_vaddr;
 540        struct vfio_dma *dma;
 541        bool do_accounting;
 542
 543        if (!iommu || !user_pfn || !phys_pfn)
 544                return -EINVAL;
 545
 546        /* Supported for v2 version only */
 547        if (!iommu->v2)
 548                return -EACCES;
 549
 550        mutex_lock(&iommu->lock);
 551
 552        /* Fail if notifier list is empty */
 553        if (!iommu->notifier.head) {
 554                ret = -EINVAL;
 555                goto pin_done;
 556        }
 557
 558        /*
 559         * If iommu capable domain exist in the container then all pages are
 560         * already pinned and accounted. Accouting should be done if there is no
 561         * iommu capable domain in the container.
 562         */
 563        do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
 564
 565        for (i = 0; i < npage; i++) {
 566                dma_addr_t iova;
 567                struct vfio_pfn *vpfn;
 568
 569                iova = user_pfn[i] << PAGE_SHIFT;
 570                dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
 571                if (!dma) {
 572                        ret = -EINVAL;
 573                        goto pin_unwind;
 574                }
 575
 576                if ((dma->prot & prot) != prot) {
 577                        ret = -EPERM;
 578                        goto pin_unwind;
 579                }
 580
 581                vpfn = vfio_iova_get_vfio_pfn(dma, iova);
 582                if (vpfn) {
 583                        phys_pfn[i] = vpfn->pfn;
 584                        continue;
 585                }
 586
 587                remote_vaddr = dma->vaddr + iova - dma->iova;
 588                ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
 589                                             do_accounting);
 590                if (ret)
 591                        goto pin_unwind;
 592
 593                ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
 594                if (ret) {
 595                        vfio_unpin_page_external(dma, iova, do_accounting);
 596                        goto pin_unwind;
 597                }
 598        }
 599
 600        ret = i;
 601        goto pin_done;
 602
 603pin_unwind:
 604        phys_pfn[i] = 0;
 605        for (j = 0; j < i; j++) {
 606                dma_addr_t iova;
 607
 608                iova = user_pfn[j] << PAGE_SHIFT;
 609                dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
 610                vfio_unpin_page_external(dma, iova, do_accounting);
 611                phys_pfn[j] = 0;
 612        }
 613pin_done:
 614        mutex_unlock(&iommu->lock);
 615        return ret;
 616}
 617
 618static int vfio_iommu_type1_unpin_pages(void *iommu_data,
 619                                        unsigned long *user_pfn,
 620                                        int npage)
 621{
 622        struct vfio_iommu *iommu = iommu_data;
 623        bool do_accounting;
 624        int i;
 625
 626        if (!iommu || !user_pfn)
 627                return -EINVAL;
 628
 629        /* Supported for v2 version only */
 630        if (!iommu->v2)
 631                return -EACCES;
 632
 633        mutex_lock(&iommu->lock);
 634
 635        do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
 636        for (i = 0; i < npage; i++) {
 637                struct vfio_dma *dma;
 638                dma_addr_t iova;
 639
 640                iova = user_pfn[i] << PAGE_SHIFT;
 641                dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
 642                if (!dma)
 643                        goto unpin_exit;
 644                vfio_unpin_page_external(dma, iova, do_accounting);
 645        }
 646
 647unpin_exit:
 648        mutex_unlock(&iommu->lock);
 649        return i > npage ? npage : (i > 0 ? i : -EINVAL);
 650}
 651
 652static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
 653                                struct list_head *regions)
 654{
 655        long unlocked = 0;
 656        struct vfio_regions *entry, *next;
 657
 658        iommu_tlb_sync(domain->domain);
 659
 660        list_for_each_entry_safe(entry, next, regions, list) {
 661                unlocked += vfio_unpin_pages_remote(dma,
 662                                                    entry->iova,
 663                                                    entry->phys >> PAGE_SHIFT,
 664                                                    entry->len >> PAGE_SHIFT,
 665                                                    false);
 666                list_del(&entry->list);
 667                kfree(entry);
 668        }
 669
 670        cond_resched();
 671
 672        return unlocked;
 673}
 674
 675/*
 676 * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
 677 * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
 678 * of these regions (currently using a list).
 679 *
 680 * This value specifies maximum number of regions for each IOTLB flush sync.
 681 */
 682#define VFIO_IOMMU_TLB_SYNC_MAX         512
 683
 684static size_t unmap_unpin_fast(struct vfio_domain *domain,
 685                               struct vfio_dma *dma, dma_addr_t *iova,
 686                               size_t len, phys_addr_t phys, long *unlocked,
 687                               struct list_head *unmapped_list,
 688                               int *unmapped_cnt)
 689{
 690        size_t unmapped = 0;
 691        struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
 692
 693        if (entry) {
 694                unmapped = iommu_unmap_fast(domain->domain, *iova, len);
 695
 696                if (!unmapped) {
 697                        kfree(entry);
 698                } else {
 699                        iommu_tlb_range_add(domain->domain, *iova, unmapped);
 700                        entry->iova = *iova;
 701                        entry->phys = phys;
 702                        entry->len  = unmapped;
 703                        list_add_tail(&entry->list, unmapped_list);
 704
 705                        *iova += unmapped;
 706                        (*unmapped_cnt)++;
 707                }
 708        }
 709
 710        /*
 711         * Sync if the number of fast-unmap regions hits the limit
 712         * or in case of errors.
 713         */
 714        if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
 715                *unlocked += vfio_sync_unpin(dma, domain,
 716                                             unmapped_list);
 717                *unmapped_cnt = 0;
 718        }
 719
 720        return unmapped;
 721}
 722
 723static size_t unmap_unpin_slow(struct vfio_domain *domain,
 724                               struct vfio_dma *dma, dma_addr_t *iova,
 725                               size_t len, phys_addr_t phys,
 726                               long *unlocked)
 727{
 728        size_t unmapped = iommu_unmap(domain->domain, *iova, len);
 729
 730        if (unmapped) {
 731                *unlocked += vfio_unpin_pages_remote(dma, *iova,
 732                                                     phys >> PAGE_SHIFT,
 733                                                     unmapped >> PAGE_SHIFT,
 734                                                     false);
 735                *iova += unmapped;
 736                cond_resched();
 737        }
 738        return unmapped;
 739}
 740
 741static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
 742                             bool do_accounting)
 743{
 744        dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
 745        struct vfio_domain *domain, *d;
 746        LIST_HEAD(unmapped_region_list);
 747        int unmapped_region_cnt = 0;
 748        long unlocked = 0;
 749
 750        if (!dma->size)
 751                return 0;
 752
 753        if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
 754                return 0;
 755
 756        /*
 757         * We use the IOMMU to track the physical addresses, otherwise we'd
 758         * need a much more complicated tracking system.  Unfortunately that
 759         * means we need to use one of the iommu domains to figure out the
 760         * pfns to unpin.  The rest need to be unmapped in advance so we have
 761         * no iommu translations remaining when the pages are unpinned.
 762         */
 763        domain = d = list_first_entry(&iommu->domain_list,
 764                                      struct vfio_domain, next);
 765
 766        list_for_each_entry_continue(d, &iommu->domain_list, next) {
 767                iommu_unmap(d->domain, dma->iova, dma->size);
 768                cond_resched();
 769        }
 770
 771        while (iova < end) {
 772                size_t unmapped, len;
 773                phys_addr_t phys, next;
 774
 775                phys = iommu_iova_to_phys(domain->domain, iova);
 776                if (WARN_ON(!phys)) {
 777                        iova += PAGE_SIZE;
 778                        continue;
 779                }
 780
 781                /*
 782                 * To optimize for fewer iommu_unmap() calls, each of which
 783                 * may require hardware cache flushing, try to find the
 784                 * largest contiguous physical memory chunk to unmap.
 785                 */
 786                for (len = PAGE_SIZE;
 787                     !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
 788                        next = iommu_iova_to_phys(domain->domain, iova + len);
 789                        if (next != phys + len)
 790                                break;
 791                }
 792
 793                /*
 794                 * First, try to use fast unmap/unpin. In case of failure,
 795                 * switch to slow unmap/unpin path.
 796                 */
 797                unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
 798                                            &unlocked, &unmapped_region_list,
 799                                            &unmapped_region_cnt);
 800                if (!unmapped) {
 801                        unmapped = unmap_unpin_slow(domain, dma, &iova, len,
 802                                                    phys, &unlocked);
 803                        if (WARN_ON(!unmapped))
 804                                break;
 805                }
 806        }
 807
 808        dma->iommu_mapped = false;
 809
 810        if (unmapped_region_cnt)
 811                unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list);
 812
 813        if (do_accounting) {
 814                vfio_lock_acct(dma, -unlocked, true);
 815                return 0;
 816        }
 817        return unlocked;
 818}
 819
 820static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
 821{
 822        vfio_unmap_unpin(iommu, dma, true);
 823        vfio_unlink_dma(iommu, dma);
 824        put_task_struct(dma->task);
 825        kfree(dma);
 826        iommu->dma_avail++;
 827}
 828
 829static unsigned long vfio_pgsize_bitmap(struct vfio_iommu *iommu)
 830{
 831        struct vfio_domain *domain;
 832        unsigned long bitmap = ULONG_MAX;
 833
 834        mutex_lock(&iommu->lock);
 835        list_for_each_entry(domain, &iommu->domain_list, next)
 836                bitmap &= domain->domain->pgsize_bitmap;
 837        mutex_unlock(&iommu->lock);
 838
 839        /*
 840         * In case the IOMMU supports page sizes smaller than PAGE_SIZE
 841         * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
 842         * That way the user will be able to map/unmap buffers whose size/
 843         * start address is aligned with PAGE_SIZE. Pinning code uses that
 844         * granularity while iommu driver can use the sub-PAGE_SIZE size
 845         * to map the buffer.
 846         */
 847        if (bitmap & ~PAGE_MASK) {
 848                bitmap &= PAGE_MASK;
 849                bitmap |= PAGE_SIZE;
 850        }
 851
 852        return bitmap;
 853}
 854
 855static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
 856                             struct vfio_iommu_type1_dma_unmap *unmap)
 857{
 858        uint64_t mask;
 859        struct vfio_dma *dma, *dma_last = NULL;
 860        size_t unmapped = 0;
 861        int ret = 0, retries = 0;
 862
 863        mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1;
 864
 865        if (unmap->iova & mask)
 866                return -EINVAL;
 867        if (!unmap->size || unmap->size & mask)
 868                return -EINVAL;
 869        if (unmap->iova + unmap->size - 1 < unmap->iova ||
 870            unmap->size > SIZE_MAX)
 871                return -EINVAL;
 872
 873        WARN_ON(mask & PAGE_MASK);
 874again:
 875        mutex_lock(&iommu->lock);
 876
 877        /*
 878         * vfio-iommu-type1 (v1) - User mappings were coalesced together to
 879         * avoid tracking individual mappings.  This means that the granularity
 880         * of the original mapping was lost and the user was allowed to attempt
 881         * to unmap any range.  Depending on the contiguousness of physical
 882         * memory and page sizes supported by the IOMMU, arbitrary unmaps may
 883         * or may not have worked.  We only guaranteed unmap granularity
 884         * matching the original mapping; even though it was untracked here,
 885         * the original mappings are reflected in IOMMU mappings.  This
 886         * resulted in a couple unusual behaviors.  First, if a range is not
 887         * able to be unmapped, ex. a set of 4k pages that was mapped as a
 888         * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
 889         * a zero sized unmap.  Also, if an unmap request overlaps the first
 890         * address of a hugepage, the IOMMU will unmap the entire hugepage.
 891         * This also returns success and the returned unmap size reflects the
 892         * actual size unmapped.
 893         *
 894         * We attempt to maintain compatibility with this "v1" interface, but
 895         * we take control out of the hands of the IOMMU.  Therefore, an unmap
 896         * request offset from the beginning of the original mapping will
 897         * return success with zero sized unmap.  And an unmap request covering
 898         * the first iova of mapping will unmap the entire range.
 899         *
 900         * The v2 version of this interface intends to be more deterministic.
 901         * Unmap requests must fully cover previous mappings.  Multiple
 902         * mappings may still be unmaped by specifying large ranges, but there
 903         * must not be any previous mappings bisected by the range.  An error
 904         * will be returned if these conditions are not met.  The v2 interface
 905         * will only return success and a size of zero if there were no
 906         * mappings within the range.
 907         */
 908        if (iommu->v2) {
 909                dma = vfio_find_dma(iommu, unmap->iova, 1);
 910                if (dma && dma->iova != unmap->iova) {
 911                        ret = -EINVAL;
 912                        goto unlock;
 913                }
 914                dma = vfio_find_dma(iommu, unmap->iova + unmap->size - 1, 0);
 915                if (dma && dma->iova + dma->size != unmap->iova + unmap->size) {
 916                        ret = -EINVAL;
 917                        goto unlock;
 918                }
 919        }
 920
 921        while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
 922                if (!iommu->v2 && unmap->iova > dma->iova)
 923                        break;
 924                /*
 925                 * Task with same address space who mapped this iova range is
 926                 * allowed to unmap the iova range.
 927                 */
 928                if (dma->task->mm != current->mm)
 929                        break;
 930
 931                if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
 932                        struct vfio_iommu_type1_dma_unmap nb_unmap;
 933
 934                        if (dma_last == dma) {
 935                                BUG_ON(++retries > 10);
 936                        } else {
 937                                dma_last = dma;
 938                                retries = 0;
 939                        }
 940
 941                        nb_unmap.iova = dma->iova;
 942                        nb_unmap.size = dma->size;
 943
 944                        /*
 945                         * Notify anyone (mdev vendor drivers) to invalidate and
 946                         * unmap iovas within the range we're about to unmap.
 947                         * Vendor drivers MUST unpin pages in response to an
 948                         * invalidation.
 949                         */
 950                        mutex_unlock(&iommu->lock);
 951                        blocking_notifier_call_chain(&iommu->notifier,
 952                                                    VFIO_IOMMU_NOTIFY_DMA_UNMAP,
 953                                                    &nb_unmap);
 954                        goto again;
 955                }
 956                unmapped += dma->size;
 957                vfio_remove_dma(iommu, dma);
 958        }
 959
 960unlock:
 961        mutex_unlock(&iommu->lock);
 962
 963        /* Report how much was unmapped */
 964        unmap->size = unmapped;
 965
 966        return ret;
 967}
 968
 969static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
 970                          unsigned long pfn, long npage, int prot)
 971{
 972        struct vfio_domain *d;
 973        int ret;
 974
 975        list_for_each_entry(d, &iommu->domain_list, next) {
 976                ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
 977                                npage << PAGE_SHIFT, prot | d->prot);
 978                if (ret)
 979                        goto unwind;
 980
 981                cond_resched();
 982        }
 983
 984        return 0;
 985
 986unwind:
 987        list_for_each_entry_continue_reverse(d, &iommu->domain_list, next)
 988                iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
 989
 990        return ret;
 991}
 992
 993static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
 994                            size_t map_size)
 995{
 996        dma_addr_t iova = dma->iova;
 997        unsigned long vaddr = dma->vaddr;
 998        size_t size = map_size;
 999        long npage;
1000        unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1001        int ret = 0;
1002
1003        while (size) {
1004                /* Pin a contiguous chunk of memory */
1005                npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1006                                              size >> PAGE_SHIFT, &pfn, limit);
1007                if (npage <= 0) {
1008                        WARN_ON(!npage);
1009                        ret = (int)npage;
1010                        break;
1011                }
1012
1013                /* Map it! */
1014                ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1015                                     dma->prot);
1016                if (ret) {
1017                        vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1018                                                npage, true);
1019                        break;
1020                }
1021
1022                size -= npage << PAGE_SHIFT;
1023                dma->size += npage << PAGE_SHIFT;
1024        }
1025
1026        dma->iommu_mapped = true;
1027
1028        if (ret)
1029                vfio_remove_dma(iommu, dma);
1030
1031        return ret;
1032}
1033
1034static int vfio_dma_do_map(struct vfio_iommu *iommu,
1035                           struct vfio_iommu_type1_dma_map *map)
1036{
1037        dma_addr_t iova = map->iova;
1038        unsigned long vaddr = map->vaddr;
1039        size_t size = map->size;
1040        int ret = 0, prot = 0;
1041        uint64_t mask;
1042        struct vfio_dma *dma;
1043
1044        /* Verify that none of our __u64 fields overflow */
1045        if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1046                return -EINVAL;
1047
1048        mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1;
1049
1050        WARN_ON(mask & PAGE_MASK);
1051
1052        /* READ/WRITE from device perspective */
1053        if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1054                prot |= IOMMU_WRITE;
1055        if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1056                prot |= IOMMU_READ;
1057
1058        if (!prot || !size || (size | iova | vaddr) & mask)
1059                return -EINVAL;
1060
1061        /* Don't allow IOVA or virtual address wrap */
1062        if (iova + size - 1 < iova || vaddr + size - 1 < vaddr)
1063                return -EINVAL;
1064
1065        mutex_lock(&iommu->lock);
1066
1067        if (vfio_find_dma(iommu, iova, size)) {
1068                ret = -EEXIST;
1069                goto out_unlock;
1070        }
1071
1072        if (!iommu->dma_avail) {
1073                ret = -ENOSPC;
1074                goto out_unlock;
1075        }
1076
1077        dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1078        if (!dma) {
1079                ret = -ENOMEM;
1080                goto out_unlock;
1081        }
1082
1083        iommu->dma_avail--;
1084        dma->iova = iova;
1085        dma->vaddr = vaddr;
1086        dma->prot = prot;
1087
1088        /*
1089         * We need to be able to both add to a task's locked memory and test
1090         * against the locked memory limit and we need to be able to do both
1091         * outside of this call path as pinning can be asynchronous via the
1092         * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1093         * task_struct and VM locked pages requires an mm_struct, however
1094         * holding an indefinite mm reference is not recommended, therefore we
1095         * only hold a reference to a task.  We could hold a reference to
1096         * current, however QEMU uses this call path through vCPU threads,
1097         * which can be killed resulting in a NULL mm and failure in the unmap
1098         * path when called via a different thread.  Avoid this problem by
1099         * using the group_leader as threads within the same group require
1100         * both CLONE_THREAD and CLONE_VM and will therefore use the same
1101         * mm_struct.
1102         *
1103         * Previously we also used the task for testing CAP_IPC_LOCK at the
1104         * time of pinning and accounting, however has_capability() makes use
1105         * of real_cred, a copy-on-write field, so we can't guarantee that it
1106         * matches group_leader, or in fact that it might not change by the
1107         * time it's evaluated.  If a process were to call MAP_DMA with
1108         * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
1109         * possibly see different results for an iommu_mapped vfio_dma vs
1110         * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
1111         * time of calling MAP_DMA.
1112         */
1113        get_task_struct(current->group_leader);
1114        dma->task = current->group_leader;
1115        dma->lock_cap = capable(CAP_IPC_LOCK);
1116
1117        dma->pfn_list = RB_ROOT;
1118
1119        /* Insert zero-sized and grow as we map chunks of it */
1120        vfio_link_dma(iommu, dma);
1121
1122        /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1123        if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1124                dma->size = size;
1125        else
1126                ret = vfio_pin_map_dma(iommu, dma, size);
1127
1128out_unlock:
1129        mutex_unlock(&iommu->lock);
1130        return ret;
1131}
1132
1133static int vfio_bus_type(struct device *dev, void *data)
1134{
1135        struct bus_type **bus = data;
1136
1137        if (*bus && *bus != dev->bus)
1138                return -EINVAL;
1139
1140        *bus = dev->bus;
1141
1142        return 0;
1143}
1144
1145static int vfio_iommu_replay(struct vfio_iommu *iommu,
1146                             struct vfio_domain *domain)
1147{
1148        struct vfio_domain *d;
1149        struct rb_node *n;
1150        unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1151        int ret;
1152
1153        /* Arbitrarily pick the first domain in the list for lookups */
1154        d = list_first_entry(&iommu->domain_list, struct vfio_domain, next);
1155        n = rb_first(&iommu->dma_list);
1156
1157        for (; n; n = rb_next(n)) {
1158                struct vfio_dma *dma;
1159                dma_addr_t iova;
1160
1161                dma = rb_entry(n, struct vfio_dma, node);
1162                iova = dma->iova;
1163
1164                while (iova < dma->iova + dma->size) {
1165                        phys_addr_t phys;
1166                        size_t size;
1167
1168                        if (dma->iommu_mapped) {
1169                                phys_addr_t p;
1170                                dma_addr_t i;
1171
1172                                phys = iommu_iova_to_phys(d->domain, iova);
1173
1174                                if (WARN_ON(!phys)) {
1175                                        iova += PAGE_SIZE;
1176                                        continue;
1177                                }
1178
1179                                size = PAGE_SIZE;
1180                                p = phys + size;
1181                                i = iova + size;
1182                                while (i < dma->iova + dma->size &&
1183                                       p == iommu_iova_to_phys(d->domain, i)) {
1184                                        size += PAGE_SIZE;
1185                                        p += PAGE_SIZE;
1186                                        i += PAGE_SIZE;
1187                                }
1188                        } else {
1189                                unsigned long pfn;
1190                                unsigned long vaddr = dma->vaddr +
1191                                                     (iova - dma->iova);
1192                                size_t n = dma->iova + dma->size - iova;
1193                                long npage;
1194
1195                                npage = vfio_pin_pages_remote(dma, vaddr,
1196                                                              n >> PAGE_SHIFT,
1197                                                              &pfn, limit);
1198                                if (npage <= 0) {
1199                                        WARN_ON(!npage);
1200                                        ret = (int)npage;
1201                                        return ret;
1202                                }
1203
1204                                phys = pfn << PAGE_SHIFT;
1205                                size = npage << PAGE_SHIFT;
1206                        }
1207
1208                        ret = iommu_map(domain->domain, iova, phys,
1209                                        size, dma->prot | domain->prot);
1210                        if (ret)
1211                                return ret;
1212
1213                        iova += size;
1214                }
1215                dma->iommu_mapped = true;
1216        }
1217        return 0;
1218}
1219
1220/*
1221 * We change our unmap behavior slightly depending on whether the IOMMU
1222 * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1223 * for practically any contiguous power-of-two mapping we give it.  This means
1224 * we don't need to look for contiguous chunks ourselves to make unmapping
1225 * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1226 * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1227 * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1228 * hugetlbfs is in use.
1229 */
1230static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1231{
1232        struct page *pages;
1233        int ret, order = get_order(PAGE_SIZE * 2);
1234
1235        pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1236        if (!pages)
1237                return;
1238
1239        ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1240                        IOMMU_READ | IOMMU_WRITE | domain->prot);
1241        if (!ret) {
1242                size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1243
1244                if (unmapped == PAGE_SIZE)
1245                        iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1246                else
1247                        domain->fgsp = true;
1248        }
1249
1250        __free_pages(pages, order);
1251}
1252
1253static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
1254                                           struct iommu_group *iommu_group)
1255{
1256        struct vfio_group *g;
1257
1258        list_for_each_entry(g, &domain->group_list, next) {
1259                if (g->iommu_group == iommu_group)
1260                        return g;
1261        }
1262
1263        return NULL;
1264}
1265
1266static bool vfio_iommu_has_sw_msi(struct iommu_group *group, phys_addr_t *base)
1267{
1268        struct list_head group_resv_regions;
1269        struct iommu_resv_region *region, *next;
1270        bool ret = false;
1271
1272        INIT_LIST_HEAD(&group_resv_regions);
1273        iommu_get_group_resv_regions(group, &group_resv_regions);
1274        list_for_each_entry(region, &group_resv_regions, list) {
1275                /*
1276                 * The presence of any 'real' MSI regions should take
1277                 * precedence over the software-managed one if the
1278                 * IOMMU driver happens to advertise both types.
1279                 */
1280                if (region->type == IOMMU_RESV_MSI) {
1281                        ret = false;
1282                        break;
1283                }
1284
1285                if (region->type == IOMMU_RESV_SW_MSI) {
1286                        *base = region->start;
1287                        ret = true;
1288                }
1289        }
1290        list_for_each_entry_safe(region, next, &group_resv_regions, list)
1291                kfree(region);
1292        return ret;
1293}
1294
1295static struct device *vfio_mdev_get_iommu_device(struct device *dev)
1296{
1297        struct device *(*fn)(struct device *dev);
1298        struct device *iommu_device;
1299
1300        fn = symbol_get(mdev_get_iommu_device);
1301        if (fn) {
1302                iommu_device = fn(dev);
1303                symbol_put(mdev_get_iommu_device);
1304
1305                return iommu_device;
1306        }
1307
1308        return NULL;
1309}
1310
1311static int vfio_mdev_attach_domain(struct device *dev, void *data)
1312{
1313        struct iommu_domain *domain = data;
1314        struct device *iommu_device;
1315
1316        iommu_device = vfio_mdev_get_iommu_device(dev);
1317        if (iommu_device) {
1318                if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1319                        return iommu_aux_attach_device(domain, iommu_device);
1320                else
1321                        return iommu_attach_device(domain, iommu_device);
1322        }
1323
1324        return -EINVAL;
1325}
1326
1327static int vfio_mdev_detach_domain(struct device *dev, void *data)
1328{
1329        struct iommu_domain *domain = data;
1330        struct device *iommu_device;
1331
1332        iommu_device = vfio_mdev_get_iommu_device(dev);
1333        if (iommu_device) {
1334                if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1335                        iommu_aux_detach_device(domain, iommu_device);
1336                else
1337                        iommu_detach_device(domain, iommu_device);
1338        }
1339
1340        return 0;
1341}
1342
1343static int vfio_iommu_attach_group(struct vfio_domain *domain,
1344                                   struct vfio_group *group)
1345{
1346        if (group->mdev_group)
1347                return iommu_group_for_each_dev(group->iommu_group,
1348                                                domain->domain,
1349                                                vfio_mdev_attach_domain);
1350        else
1351                return iommu_attach_group(domain->domain, group->iommu_group);
1352}
1353
1354static void vfio_iommu_detach_group(struct vfio_domain *domain,
1355                                    struct vfio_group *group)
1356{
1357        if (group->mdev_group)
1358                iommu_group_for_each_dev(group->iommu_group, domain->domain,
1359                                         vfio_mdev_detach_domain);
1360        else
1361                iommu_detach_group(domain->domain, group->iommu_group);
1362}
1363
1364static bool vfio_bus_is_mdev(struct bus_type *bus)
1365{
1366        struct bus_type *mdev_bus;
1367        bool ret = false;
1368
1369        mdev_bus = symbol_get(mdev_bus_type);
1370        if (mdev_bus) {
1371                ret = (bus == mdev_bus);
1372                symbol_put(mdev_bus_type);
1373        }
1374
1375        return ret;
1376}
1377
1378static int vfio_mdev_iommu_device(struct device *dev, void *data)
1379{
1380        struct device **old = data, *new;
1381
1382        new = vfio_mdev_get_iommu_device(dev);
1383        if (!new || (*old && *old != new))
1384                return -EINVAL;
1385
1386        *old = new;
1387
1388        return 0;
1389}
1390
1391static int vfio_iommu_type1_attach_group(void *iommu_data,
1392                                         struct iommu_group *iommu_group)
1393{
1394        struct vfio_iommu *iommu = iommu_data;
1395        struct vfio_group *group;
1396        struct vfio_domain *domain, *d;
1397        struct bus_type *bus = NULL;
1398        int ret;
1399        bool resv_msi, msi_remap;
1400        phys_addr_t resv_msi_base;
1401
1402        mutex_lock(&iommu->lock);
1403
1404        list_for_each_entry(d, &iommu->domain_list, next) {
1405                if (find_iommu_group(d, iommu_group)) {
1406                        mutex_unlock(&iommu->lock);
1407                        return -EINVAL;
1408                }
1409        }
1410
1411        if (iommu->external_domain) {
1412                if (find_iommu_group(iommu->external_domain, iommu_group)) {
1413                        mutex_unlock(&iommu->lock);
1414                        return -EINVAL;
1415                }
1416        }
1417
1418        group = kzalloc(sizeof(*group), GFP_KERNEL);
1419        domain = kzalloc(sizeof(*domain), GFP_KERNEL);
1420        if (!group || !domain) {
1421                ret = -ENOMEM;
1422                goto out_free;
1423        }
1424
1425        group->iommu_group = iommu_group;
1426
1427        /* Determine bus_type in order to allocate a domain */
1428        ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
1429        if (ret)
1430                goto out_free;
1431
1432        if (vfio_bus_is_mdev(bus)) {
1433                struct device *iommu_device = NULL;
1434
1435                group->mdev_group = true;
1436
1437                /* Determine the isolation type */
1438                ret = iommu_group_for_each_dev(iommu_group, &iommu_device,
1439                                               vfio_mdev_iommu_device);
1440                if (ret || !iommu_device) {
1441                        if (!iommu->external_domain) {
1442                                INIT_LIST_HEAD(&domain->group_list);
1443                                iommu->external_domain = domain;
1444                        } else {
1445                                kfree(domain);
1446                        }
1447
1448                        list_add(&group->next,
1449                                 &iommu->external_domain->group_list);
1450                        mutex_unlock(&iommu->lock);
1451
1452                        return 0;
1453                }
1454
1455                bus = iommu_device->bus;
1456        }
1457
1458        domain->domain = iommu_domain_alloc(bus);
1459        if (!domain->domain) {
1460                ret = -EIO;
1461                goto out_free;
1462        }
1463
1464        if (iommu->nesting) {
1465                int attr = 1;
1466
1467                ret = iommu_domain_set_attr(domain->domain, DOMAIN_ATTR_NESTING,
1468                                            &attr);
1469                if (ret)
1470                        goto out_domain;
1471        }
1472
1473        ret = vfio_iommu_attach_group(domain, group);
1474        if (ret)
1475                goto out_domain;
1476
1477        resv_msi = vfio_iommu_has_sw_msi(iommu_group, &resv_msi_base);
1478
1479        INIT_LIST_HEAD(&domain->group_list);
1480        list_add(&group->next, &domain->group_list);
1481
1482        msi_remap = irq_domain_check_msi_remap() ||
1483                    iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
1484
1485        if (!allow_unsafe_interrupts && !msi_remap) {
1486                pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
1487                       __func__);
1488                ret = -EPERM;
1489                goto out_detach;
1490        }
1491
1492        if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
1493                domain->prot |= IOMMU_CACHE;
1494
1495        /*
1496         * Try to match an existing compatible domain.  We don't want to
1497         * preclude an IOMMU driver supporting multiple bus_types and being
1498         * able to include different bus_types in the same IOMMU domain, so
1499         * we test whether the domains use the same iommu_ops rather than
1500         * testing if they're on the same bus_type.
1501         */
1502        list_for_each_entry(d, &iommu->domain_list, next) {
1503                if (d->domain->ops == domain->domain->ops &&
1504                    d->prot == domain->prot) {
1505                        vfio_iommu_detach_group(domain, group);
1506                        if (!vfio_iommu_attach_group(d, group)) {
1507                                list_add(&group->next, &d->group_list);
1508                                iommu_domain_free(domain->domain);
1509                                kfree(domain);
1510                                mutex_unlock(&iommu->lock);
1511                                return 0;
1512                        }
1513
1514                        ret = vfio_iommu_attach_group(domain, group);
1515                        if (ret)
1516                                goto out_domain;
1517                }
1518        }
1519
1520        vfio_test_domain_fgsp(domain);
1521
1522        /* replay mappings on new domains */
1523        ret = vfio_iommu_replay(iommu, domain);
1524        if (ret)
1525                goto out_detach;
1526
1527        if (resv_msi) {
1528                ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
1529                if (ret)
1530                        goto out_detach;
1531        }
1532
1533        list_add(&domain->next, &iommu->domain_list);
1534
1535        mutex_unlock(&iommu->lock);
1536
1537        return 0;
1538
1539out_detach:
1540        vfio_iommu_detach_group(domain, group);
1541out_domain:
1542        iommu_domain_free(domain->domain);
1543out_free:
1544        kfree(domain);
1545        kfree(group);
1546        mutex_unlock(&iommu->lock);
1547        return ret;
1548}
1549
1550static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
1551{
1552        struct rb_node *node;
1553
1554        while ((node = rb_first(&iommu->dma_list)))
1555                vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
1556}
1557
1558static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
1559{
1560        struct rb_node *n, *p;
1561
1562        n = rb_first(&iommu->dma_list);
1563        for (; n; n = rb_next(n)) {
1564                struct vfio_dma *dma;
1565                long locked = 0, unlocked = 0;
1566
1567                dma = rb_entry(n, struct vfio_dma, node);
1568                unlocked += vfio_unmap_unpin(iommu, dma, false);
1569                p = rb_first(&dma->pfn_list);
1570                for (; p; p = rb_next(p)) {
1571                        struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
1572                                                         node);
1573
1574                        if (!is_invalid_reserved_pfn(vpfn->pfn))
1575                                locked++;
1576                }
1577                vfio_lock_acct(dma, locked - unlocked, true);
1578        }
1579}
1580
1581static void vfio_sanity_check_pfn_list(struct vfio_iommu *iommu)
1582{
1583        struct rb_node *n;
1584
1585        n = rb_first(&iommu->dma_list);
1586        for (; n; n = rb_next(n)) {
1587                struct vfio_dma *dma;
1588
1589                dma = rb_entry(n, struct vfio_dma, node);
1590
1591                if (WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list)))
1592                        break;
1593        }
1594        /* mdev vendor driver must unregister notifier */
1595        WARN_ON(iommu->notifier.head);
1596}
1597
1598static void vfio_iommu_type1_detach_group(void *iommu_data,
1599                                          struct iommu_group *iommu_group)
1600{
1601        struct vfio_iommu *iommu = iommu_data;
1602        struct vfio_domain *domain;
1603        struct vfio_group *group;
1604
1605        mutex_lock(&iommu->lock);
1606
1607        if (iommu->external_domain) {
1608                group = find_iommu_group(iommu->external_domain, iommu_group);
1609                if (group) {
1610                        list_del(&group->next);
1611                        kfree(group);
1612
1613                        if (list_empty(&iommu->external_domain->group_list)) {
1614                                vfio_sanity_check_pfn_list(iommu);
1615
1616                                if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1617                                        vfio_iommu_unmap_unpin_all(iommu);
1618
1619                                kfree(iommu->external_domain);
1620                                iommu->external_domain = NULL;
1621                        }
1622                        goto detach_group_done;
1623                }
1624        }
1625
1626        list_for_each_entry(domain, &iommu->domain_list, next) {
1627                group = find_iommu_group(domain, iommu_group);
1628                if (!group)
1629                        continue;
1630
1631                vfio_iommu_detach_group(domain, group);
1632                list_del(&group->next);
1633                kfree(group);
1634                /*
1635                 * Group ownership provides privilege, if the group list is
1636                 * empty, the domain goes away. If it's the last domain with
1637                 * iommu and external domain doesn't exist, then all the
1638                 * mappings go away too. If it's the last domain with iommu and
1639                 * external domain exist, update accounting
1640                 */
1641                if (list_empty(&domain->group_list)) {
1642                        if (list_is_singular(&iommu->domain_list)) {
1643                                if (!iommu->external_domain)
1644                                        vfio_iommu_unmap_unpin_all(iommu);
1645                                else
1646                                        vfio_iommu_unmap_unpin_reaccount(iommu);
1647                        }
1648                        iommu_domain_free(domain->domain);
1649                        list_del(&domain->next);
1650                        kfree(domain);
1651                }
1652                break;
1653        }
1654
1655detach_group_done:
1656        mutex_unlock(&iommu->lock);
1657}
1658
1659static void *vfio_iommu_type1_open(unsigned long arg)
1660{
1661        struct vfio_iommu *iommu;
1662
1663        iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
1664        if (!iommu)
1665                return ERR_PTR(-ENOMEM);
1666
1667        switch (arg) {
1668        case VFIO_TYPE1_IOMMU:
1669                break;
1670        case VFIO_TYPE1_NESTING_IOMMU:
1671                iommu->nesting = true;
1672                /* fall through */
1673        case VFIO_TYPE1v2_IOMMU:
1674                iommu->v2 = true;
1675                break;
1676        default:
1677                kfree(iommu);
1678                return ERR_PTR(-EINVAL);
1679        }
1680
1681        INIT_LIST_HEAD(&iommu->domain_list);
1682        iommu->dma_list = RB_ROOT;
1683        iommu->dma_avail = dma_entry_limit;
1684        mutex_init(&iommu->lock);
1685        BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
1686
1687        return iommu;
1688}
1689
1690static void vfio_release_domain(struct vfio_domain *domain, bool external)
1691{
1692        struct vfio_group *group, *group_tmp;
1693
1694        list_for_each_entry_safe(group, group_tmp,
1695                                 &domain->group_list, next) {
1696                if (!external)
1697                        vfio_iommu_detach_group(domain, group);
1698                list_del(&group->next);
1699                kfree(group);
1700        }
1701
1702        if (!external)
1703                iommu_domain_free(domain->domain);
1704}
1705
1706static void vfio_iommu_type1_release(void *iommu_data)
1707{
1708        struct vfio_iommu *iommu = iommu_data;
1709        struct vfio_domain *domain, *domain_tmp;
1710
1711        if (iommu->external_domain) {
1712                vfio_release_domain(iommu->external_domain, true);
1713                vfio_sanity_check_pfn_list(iommu);
1714                kfree(iommu->external_domain);
1715        }
1716
1717        vfio_iommu_unmap_unpin_all(iommu);
1718
1719        list_for_each_entry_safe(domain, domain_tmp,
1720                                 &iommu->domain_list, next) {
1721                vfio_release_domain(domain, false);
1722                list_del(&domain->next);
1723                kfree(domain);
1724        }
1725        kfree(iommu);
1726}
1727
1728static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
1729{
1730        struct vfio_domain *domain;
1731        int ret = 1;
1732
1733        mutex_lock(&iommu->lock);
1734        list_for_each_entry(domain, &iommu->domain_list, next) {
1735                if (!(domain->prot & IOMMU_CACHE)) {
1736                        ret = 0;
1737                        break;
1738                }
1739        }
1740        mutex_unlock(&iommu->lock);
1741
1742        return ret;
1743}
1744
1745static long vfio_iommu_type1_ioctl(void *iommu_data,
1746                                   unsigned int cmd, unsigned long arg)
1747{
1748        struct vfio_iommu *iommu = iommu_data;
1749        unsigned long minsz;
1750
1751        if (cmd == VFIO_CHECK_EXTENSION) {
1752                switch (arg) {
1753                case VFIO_TYPE1_IOMMU:
1754                case VFIO_TYPE1v2_IOMMU:
1755                case VFIO_TYPE1_NESTING_IOMMU:
1756                        return 1;
1757                case VFIO_DMA_CC_IOMMU:
1758                        if (!iommu)
1759                                return 0;
1760                        return vfio_domains_have_iommu_cache(iommu);
1761                default:
1762                        return 0;
1763                }
1764        } else if (cmd == VFIO_IOMMU_GET_INFO) {
1765                struct vfio_iommu_type1_info info;
1766
1767                minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
1768
1769                if (copy_from_user(&info, (void __user *)arg, minsz))
1770                        return -EFAULT;
1771
1772                if (info.argsz < minsz)
1773                        return -EINVAL;
1774
1775                info.flags = VFIO_IOMMU_INFO_PGSIZES;
1776
1777                info.iova_pgsizes = vfio_pgsize_bitmap(iommu);
1778
1779                return copy_to_user((void __user *)arg, &info, minsz) ?
1780                        -EFAULT : 0;
1781
1782        } else if (cmd == VFIO_IOMMU_MAP_DMA) {
1783                struct vfio_iommu_type1_dma_map map;
1784                uint32_t mask = VFIO_DMA_MAP_FLAG_READ |
1785                                VFIO_DMA_MAP_FLAG_WRITE;
1786
1787                minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
1788
1789                if (copy_from_user(&map, (void __user *)arg, minsz))
1790                        return -EFAULT;
1791
1792                if (map.argsz < minsz || map.flags & ~mask)
1793                        return -EINVAL;
1794
1795                return vfio_dma_do_map(iommu, &map);
1796
1797        } else if (cmd == VFIO_IOMMU_UNMAP_DMA) {
1798                struct vfio_iommu_type1_dma_unmap unmap;
1799                long ret;
1800
1801                minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
1802
1803                if (copy_from_user(&unmap, (void __user *)arg, minsz))
1804                        return -EFAULT;
1805
1806                if (unmap.argsz < minsz || unmap.flags)
1807                        return -EINVAL;
1808
1809                ret = vfio_dma_do_unmap(iommu, &unmap);
1810                if (ret)
1811                        return ret;
1812
1813                return copy_to_user((void __user *)arg, &unmap, minsz) ?
1814                        -EFAULT : 0;
1815        }
1816
1817        return -ENOTTY;
1818}
1819
1820static int vfio_iommu_type1_register_notifier(void *iommu_data,
1821                                              unsigned long *events,
1822                                              struct notifier_block *nb)
1823{
1824        struct vfio_iommu *iommu = iommu_data;
1825
1826        /* clear known events */
1827        *events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
1828
1829        /* refuse to register if still events remaining */
1830        if (*events)
1831                return -EINVAL;
1832
1833        return blocking_notifier_chain_register(&iommu->notifier, nb);
1834}
1835
1836static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
1837                                                struct notifier_block *nb)
1838{
1839        struct vfio_iommu *iommu = iommu_data;
1840
1841        return blocking_notifier_chain_unregister(&iommu->notifier, nb);
1842}
1843
1844static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
1845        .name                   = "vfio-iommu-type1",
1846        .owner                  = THIS_MODULE,
1847        .open                   = vfio_iommu_type1_open,
1848        .release                = vfio_iommu_type1_release,
1849        .ioctl                  = vfio_iommu_type1_ioctl,
1850        .attach_group           = vfio_iommu_type1_attach_group,
1851        .detach_group           = vfio_iommu_type1_detach_group,
1852        .pin_pages              = vfio_iommu_type1_pin_pages,
1853        .unpin_pages            = vfio_iommu_type1_unpin_pages,
1854        .register_notifier      = vfio_iommu_type1_register_notifier,
1855        .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
1856};
1857
1858static int __init vfio_iommu_type1_init(void)
1859{
1860        return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
1861}
1862
1863static void __exit vfio_iommu_type1_cleanup(void)
1864{
1865        vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
1866}
1867
1868module_init(vfio_iommu_type1_init);
1869module_exit(vfio_iommu_type1_cleanup);
1870
1871MODULE_VERSION(DRIVER_VERSION);
1872MODULE_LICENSE("GPL v2");
1873MODULE_AUTHOR(DRIVER_AUTHOR);
1874MODULE_DESCRIPTION(DRIVER_DESC);
1875