linux/drivers/iommu/amd/iommu_v2.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
   4 * Author: Joerg Roedel <jroedel@suse.de>
   5 */
   6
   7#define pr_fmt(fmt)     "AMD-Vi: " fmt
   8
   9#include <linux/refcount.h>
  10#include <linux/mmu_notifier.h>
  11#include <linux/amd-iommu.h>
  12#include <linux/mm_types.h>
  13#include <linux/profile.h>
  14#include <linux/module.h>
  15#include <linux/sched.h>
  16#include <linux/sched/mm.h>
  17#include <linux/wait.h>
  18#include <linux/pci.h>
  19#include <linux/gfp.h>
  20
  21#include "amd_iommu.h"
  22
  23MODULE_LICENSE("GPL v2");
  24MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
  25
  26#define MAX_DEVICES             0x10000
  27#define PRI_QUEUE_SIZE          512
  28
  29struct pri_queue {
  30        atomic_t inflight;
  31        bool finish;
  32        int status;
  33};
  34
  35struct pasid_state {
  36        struct list_head list;                  /* For global state-list */
  37        refcount_t count;                               /* Reference count */
  38        unsigned mmu_notifier_count;            /* Counting nested mmu_notifier
  39                                                   calls */
  40        struct mm_struct *mm;                   /* mm_struct for the faults */
  41        struct mmu_notifier mn;                 /* mmu_notifier handle */
  42        struct pri_queue pri[PRI_QUEUE_SIZE];   /* PRI tag states */
  43        struct device_state *device_state;      /* Link to our device_state */
  44        u32 pasid;                              /* PASID index */
  45        bool invalid;                           /* Used during setup and
  46                                                   teardown of the pasid */
  47        spinlock_t lock;                        /* Protect pri_queues and
  48                                                   mmu_notifer_count */
  49        wait_queue_head_t wq;                   /* To wait for count == 0 */
  50};
  51
  52struct device_state {
  53        struct list_head list;
  54        u16 devid;
  55        atomic_t count;
  56        struct pci_dev *pdev;
  57        struct pasid_state **states;
  58        struct iommu_domain *domain;
  59        int pasid_levels;
  60        int max_pasids;
  61        amd_iommu_invalid_ppr_cb inv_ppr_cb;
  62        amd_iommu_invalidate_ctx inv_ctx_cb;
  63        spinlock_t lock;
  64        wait_queue_head_t wq;
  65};
  66
  67struct fault {
  68        struct work_struct work;
  69        struct device_state *dev_state;
  70        struct pasid_state *state;
  71        struct mm_struct *mm;
  72        u64 address;
  73        u16 devid;
  74        u32 pasid;
  75        u16 tag;
  76        u16 finish;
  77        u16 flags;
  78};
  79
  80static LIST_HEAD(state_list);
  81static DEFINE_SPINLOCK(state_lock);
  82
  83static struct workqueue_struct *iommu_wq;
  84
  85static void free_pasid_states(struct device_state *dev_state);
  86
  87static u16 device_id(struct pci_dev *pdev)
  88{
  89        u16 devid;
  90
  91        devid = pdev->bus->number;
  92        devid = (devid << 8) | pdev->devfn;
  93
  94        return devid;
  95}
  96
  97static struct device_state *__get_device_state(u16 devid)
  98{
  99        struct device_state *dev_state;
 100
 101        list_for_each_entry(dev_state, &state_list, list) {
 102                if (dev_state->devid == devid)
 103                        return dev_state;
 104        }
 105
 106        return NULL;
 107}
 108
 109static struct device_state *get_device_state(u16 devid)
 110{
 111        struct device_state *dev_state;
 112        unsigned long flags;
 113
 114        spin_lock_irqsave(&state_lock, flags);
 115        dev_state = __get_device_state(devid);
 116        if (dev_state != NULL)
 117                atomic_inc(&dev_state->count);
 118        spin_unlock_irqrestore(&state_lock, flags);
 119
 120        return dev_state;
 121}
 122
 123static void free_device_state(struct device_state *dev_state)
 124{
 125        struct iommu_group *group;
 126
 127        /*
 128         * First detach device from domain - No more PRI requests will arrive
 129         * from that device after it is unbound from the IOMMUv2 domain.
 130         */
 131        group = iommu_group_get(&dev_state->pdev->dev);
 132        if (WARN_ON(!group))
 133                return;
 134
 135        iommu_detach_group(dev_state->domain, group);
 136
 137        iommu_group_put(group);
 138
 139        /* Everything is down now, free the IOMMUv2 domain */
 140        iommu_domain_free(dev_state->domain);
 141
 142        /* Finally get rid of the device-state */
 143        kfree(dev_state);
 144}
 145
 146static void put_device_state(struct device_state *dev_state)
 147{
 148        if (atomic_dec_and_test(&dev_state->count))
 149                wake_up(&dev_state->wq);
 150}
 151
 152/* Must be called under dev_state->lock */
 153static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
 154                                                  u32 pasid, bool alloc)
 155{
 156        struct pasid_state **root, **ptr;
 157        int level, index;
 158
 159        level = dev_state->pasid_levels;
 160        root  = dev_state->states;
 161
 162        while (true) {
 163
 164                index = (pasid >> (9 * level)) & 0x1ff;
 165                ptr   = &root[index];
 166
 167                if (level == 0)
 168                        break;
 169
 170                if (*ptr == NULL) {
 171                        if (!alloc)
 172                                return NULL;
 173
 174                        *ptr = (void *)get_zeroed_page(GFP_ATOMIC);
 175                        if (*ptr == NULL)
 176                                return NULL;
 177                }
 178
 179                root   = (struct pasid_state **)*ptr;
 180                level -= 1;
 181        }
 182
 183        return ptr;
 184}
 185
 186static int set_pasid_state(struct device_state *dev_state,
 187                           struct pasid_state *pasid_state,
 188                           u32 pasid)
 189{
 190        struct pasid_state **ptr;
 191        unsigned long flags;
 192        int ret;
 193
 194        spin_lock_irqsave(&dev_state->lock, flags);
 195        ptr = __get_pasid_state_ptr(dev_state, pasid, true);
 196
 197        ret = -ENOMEM;
 198        if (ptr == NULL)
 199                goto out_unlock;
 200
 201        ret = -ENOMEM;
 202        if (*ptr != NULL)
 203                goto out_unlock;
 204
 205        *ptr = pasid_state;
 206
 207        ret = 0;
 208
 209out_unlock:
 210        spin_unlock_irqrestore(&dev_state->lock, flags);
 211
 212        return ret;
 213}
 214
 215static void clear_pasid_state(struct device_state *dev_state, u32 pasid)
 216{
 217        struct pasid_state **ptr;
 218        unsigned long flags;
 219
 220        spin_lock_irqsave(&dev_state->lock, flags);
 221        ptr = __get_pasid_state_ptr(dev_state, pasid, true);
 222
 223        if (ptr == NULL)
 224                goto out_unlock;
 225
 226        *ptr = NULL;
 227
 228out_unlock:
 229        spin_unlock_irqrestore(&dev_state->lock, flags);
 230}
 231
 232static struct pasid_state *get_pasid_state(struct device_state *dev_state,
 233                                           u32 pasid)
 234{
 235        struct pasid_state **ptr, *ret = NULL;
 236        unsigned long flags;
 237
 238        spin_lock_irqsave(&dev_state->lock, flags);
 239        ptr = __get_pasid_state_ptr(dev_state, pasid, false);
 240
 241        if (ptr == NULL)
 242                goto out_unlock;
 243
 244        ret = *ptr;
 245        if (ret)
 246                refcount_inc(&ret->count);
 247
 248out_unlock:
 249        spin_unlock_irqrestore(&dev_state->lock, flags);
 250
 251        return ret;
 252}
 253
 254static void free_pasid_state(struct pasid_state *pasid_state)
 255{
 256        kfree(pasid_state);
 257}
 258
 259static void put_pasid_state(struct pasid_state *pasid_state)
 260{
 261        if (refcount_dec_and_test(&pasid_state->count))
 262                wake_up(&pasid_state->wq);
 263}
 264
 265static void put_pasid_state_wait(struct pasid_state *pasid_state)
 266{
 267        refcount_dec(&pasid_state->count);
 268        wait_event(pasid_state->wq, !refcount_read(&pasid_state->count));
 269        free_pasid_state(pasid_state);
 270}
 271
 272static void unbind_pasid(struct pasid_state *pasid_state)
 273{
 274        struct iommu_domain *domain;
 275
 276        domain = pasid_state->device_state->domain;
 277
 278        /*
 279         * Mark pasid_state as invalid, no more faults will we added to the
 280         * work queue after this is visible everywhere.
 281         */
 282        pasid_state->invalid = true;
 283
 284        /* Make sure this is visible */
 285        smp_wmb();
 286
 287        /* After this the device/pasid can't access the mm anymore */
 288        amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
 289
 290        /* Make sure no more pending faults are in the queue */
 291        flush_workqueue(iommu_wq);
 292}
 293
 294static void free_pasid_states_level1(struct pasid_state **tbl)
 295{
 296        int i;
 297
 298        for (i = 0; i < 512; ++i) {
 299                if (tbl[i] == NULL)
 300                        continue;
 301
 302                free_page((unsigned long)tbl[i]);
 303        }
 304}
 305
 306static void free_pasid_states_level2(struct pasid_state **tbl)
 307{
 308        struct pasid_state **ptr;
 309        int i;
 310
 311        for (i = 0; i < 512; ++i) {
 312                if (tbl[i] == NULL)
 313                        continue;
 314
 315                ptr = (struct pasid_state **)tbl[i];
 316                free_pasid_states_level1(ptr);
 317        }
 318}
 319
 320static void free_pasid_states(struct device_state *dev_state)
 321{
 322        struct pasid_state *pasid_state;
 323        int i;
 324
 325        for (i = 0; i < dev_state->max_pasids; ++i) {
 326                pasid_state = get_pasid_state(dev_state, i);
 327                if (pasid_state == NULL)
 328                        continue;
 329
 330                put_pasid_state(pasid_state);
 331
 332                /*
 333                 * This will call the mn_release function and
 334                 * unbind the PASID
 335                 */
 336                mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
 337
 338                put_pasid_state_wait(pasid_state); /* Reference taken in
 339                                                      amd_iommu_bind_pasid */
 340
 341                /* Drop reference taken in amd_iommu_bind_pasid */
 342                put_device_state(dev_state);
 343        }
 344
 345        if (dev_state->pasid_levels == 2)
 346                free_pasid_states_level2(dev_state->states);
 347        else if (dev_state->pasid_levels == 1)
 348                free_pasid_states_level1(dev_state->states);
 349        else
 350                BUG_ON(dev_state->pasid_levels != 0);
 351
 352        free_page((unsigned long)dev_state->states);
 353}
 354
 355static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
 356{
 357        return container_of(mn, struct pasid_state, mn);
 358}
 359
 360static void mn_invalidate_range(struct mmu_notifier *mn,
 361                                struct mm_struct *mm,
 362                                unsigned long start, unsigned long end)
 363{
 364        struct pasid_state *pasid_state;
 365        struct device_state *dev_state;
 366
 367        pasid_state = mn_to_state(mn);
 368        dev_state   = pasid_state->device_state;
 369
 370        if ((start ^ (end - 1)) < PAGE_SIZE)
 371                amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
 372                                     start);
 373        else
 374                amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
 375}
 376
 377static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
 378{
 379        struct pasid_state *pasid_state;
 380        struct device_state *dev_state;
 381        bool run_inv_ctx_cb;
 382
 383        might_sleep();
 384
 385        pasid_state    = mn_to_state(mn);
 386        dev_state      = pasid_state->device_state;
 387        run_inv_ctx_cb = !pasid_state->invalid;
 388
 389        if (run_inv_ctx_cb && dev_state->inv_ctx_cb)
 390                dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
 391
 392        unbind_pasid(pasid_state);
 393}
 394
 395static const struct mmu_notifier_ops iommu_mn = {
 396        .release                = mn_release,
 397        .invalidate_range       = mn_invalidate_range,
 398};
 399
 400static void set_pri_tag_status(struct pasid_state *pasid_state,
 401                               u16 tag, int status)
 402{
 403        unsigned long flags;
 404
 405        spin_lock_irqsave(&pasid_state->lock, flags);
 406        pasid_state->pri[tag].status = status;
 407        spin_unlock_irqrestore(&pasid_state->lock, flags);
 408}
 409
 410static void finish_pri_tag(struct device_state *dev_state,
 411                           struct pasid_state *pasid_state,
 412                           u16 tag)
 413{
 414        unsigned long flags;
 415
 416        spin_lock_irqsave(&pasid_state->lock, flags);
 417        if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
 418            pasid_state->pri[tag].finish) {
 419                amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
 420                                       pasid_state->pri[tag].status, tag);
 421                pasid_state->pri[tag].finish = false;
 422                pasid_state->pri[tag].status = PPR_SUCCESS;
 423        }
 424        spin_unlock_irqrestore(&pasid_state->lock, flags);
 425}
 426
 427static void handle_fault_error(struct fault *fault)
 428{
 429        int status;
 430
 431        if (!fault->dev_state->inv_ppr_cb) {
 432                set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
 433                return;
 434        }
 435
 436        status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
 437                                              fault->pasid,
 438                                              fault->address,
 439                                              fault->flags);
 440        switch (status) {
 441        case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
 442                set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
 443                break;
 444        case AMD_IOMMU_INV_PRI_RSP_INVALID:
 445                set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
 446                break;
 447        case AMD_IOMMU_INV_PRI_RSP_FAIL:
 448                set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
 449                break;
 450        default:
 451                BUG();
 452        }
 453}
 454
 455static bool access_error(struct vm_area_struct *vma, struct fault *fault)
 456{
 457        unsigned long requested = 0;
 458
 459        if (fault->flags & PPR_FAULT_EXEC)
 460                requested |= VM_EXEC;
 461
 462        if (fault->flags & PPR_FAULT_READ)
 463                requested |= VM_READ;
 464
 465        if (fault->flags & PPR_FAULT_WRITE)
 466                requested |= VM_WRITE;
 467
 468        return (requested & ~vma->vm_flags) != 0;
 469}
 470
 471static void do_fault(struct work_struct *work)
 472{
 473        struct fault *fault = container_of(work, struct fault, work);
 474        struct vm_area_struct *vma;
 475        vm_fault_t ret = VM_FAULT_ERROR;
 476        unsigned int flags = 0;
 477        struct mm_struct *mm;
 478        u64 address;
 479
 480        mm = fault->state->mm;
 481        address = fault->address;
 482
 483        if (fault->flags & PPR_FAULT_USER)
 484                flags |= FAULT_FLAG_USER;
 485        if (fault->flags & PPR_FAULT_WRITE)
 486                flags |= FAULT_FLAG_WRITE;
 487        flags |= FAULT_FLAG_REMOTE;
 488
 489        mmap_read_lock(mm);
 490        vma = find_extend_vma(mm, address);
 491        if (!vma || address < vma->vm_start)
 492                /* failed to get a vma in the right range */
 493                goto out;
 494
 495        /* Check if we have the right permissions on the vma */
 496        if (access_error(vma, fault))
 497                goto out;
 498
 499        ret = handle_mm_fault(vma, address, flags, NULL);
 500out:
 501        mmap_read_unlock(mm);
 502
 503        if (ret & VM_FAULT_ERROR)
 504                /* failed to service fault */
 505                handle_fault_error(fault);
 506
 507        finish_pri_tag(fault->dev_state, fault->state, fault->tag);
 508
 509        put_pasid_state(fault->state);
 510
 511        kfree(fault);
 512}
 513
 514static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
 515{
 516        struct amd_iommu_fault *iommu_fault;
 517        struct pasid_state *pasid_state;
 518        struct device_state *dev_state;
 519        struct pci_dev *pdev = NULL;
 520        unsigned long flags;
 521        struct fault *fault;
 522        bool finish;
 523        u16 tag, devid;
 524        int ret;
 525
 526        iommu_fault = data;
 527        tag         = iommu_fault->tag & 0x1ff;
 528        finish      = (iommu_fault->tag >> 9) & 1;
 529
 530        devid = iommu_fault->device_id;
 531        pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid),
 532                                           devid & 0xff);
 533        if (!pdev)
 534                return -ENODEV;
 535
 536        ret = NOTIFY_DONE;
 537
 538        /* In kdump kernel pci dev is not initialized yet -> send INVALID */
 539        if (amd_iommu_is_attach_deferred(NULL, &pdev->dev)) {
 540                amd_iommu_complete_ppr(pdev, iommu_fault->pasid,
 541                                       PPR_INVALID, tag);
 542                goto out;
 543        }
 544
 545        dev_state = get_device_state(iommu_fault->device_id);
 546        if (dev_state == NULL)
 547                goto out;
 548
 549        pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
 550        if (pasid_state == NULL || pasid_state->invalid) {
 551                /* We know the device but not the PASID -> send INVALID */
 552                amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
 553                                       PPR_INVALID, tag);
 554                goto out_drop_state;
 555        }
 556
 557        spin_lock_irqsave(&pasid_state->lock, flags);
 558        atomic_inc(&pasid_state->pri[tag].inflight);
 559        if (finish)
 560                pasid_state->pri[tag].finish = true;
 561        spin_unlock_irqrestore(&pasid_state->lock, flags);
 562
 563        fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
 564        if (fault == NULL) {
 565                /* We are OOM - send success and let the device re-fault */
 566                finish_pri_tag(dev_state, pasid_state, tag);
 567                goto out_drop_state;
 568        }
 569
 570        fault->dev_state = dev_state;
 571        fault->address   = iommu_fault->address;
 572        fault->state     = pasid_state;
 573        fault->tag       = tag;
 574        fault->finish    = finish;
 575        fault->pasid     = iommu_fault->pasid;
 576        fault->flags     = iommu_fault->flags;
 577        INIT_WORK(&fault->work, do_fault);
 578
 579        queue_work(iommu_wq, &fault->work);
 580
 581        ret = NOTIFY_OK;
 582
 583out_drop_state:
 584
 585        if (ret != NOTIFY_OK && pasid_state)
 586                put_pasid_state(pasid_state);
 587
 588        put_device_state(dev_state);
 589
 590out:
 591        return ret;
 592}
 593
 594static struct notifier_block ppr_nb = {
 595        .notifier_call = ppr_notifier,
 596};
 597
 598int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid,
 599                         struct task_struct *task)
 600{
 601        struct pasid_state *pasid_state;
 602        struct device_state *dev_state;
 603        struct mm_struct *mm;
 604        u16 devid;
 605        int ret;
 606
 607        might_sleep();
 608
 609        if (!amd_iommu_v2_supported())
 610                return -ENODEV;
 611
 612        devid     = device_id(pdev);
 613        dev_state = get_device_state(devid);
 614
 615        if (dev_state == NULL)
 616                return -EINVAL;
 617
 618        ret = -EINVAL;
 619        if (pasid >= dev_state->max_pasids)
 620                goto out;
 621
 622        ret = -ENOMEM;
 623        pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
 624        if (pasid_state == NULL)
 625                goto out;
 626
 627
 628        refcount_set(&pasid_state->count, 1);
 629        init_waitqueue_head(&pasid_state->wq);
 630        spin_lock_init(&pasid_state->lock);
 631
 632        mm                        = get_task_mm(task);
 633        pasid_state->mm           = mm;
 634        pasid_state->device_state = dev_state;
 635        pasid_state->pasid        = pasid;
 636        pasid_state->invalid      = true; /* Mark as valid only if we are
 637                                             done with setting up the pasid */
 638        pasid_state->mn.ops       = &iommu_mn;
 639
 640        if (pasid_state->mm == NULL)
 641                goto out_free;
 642
 643        mmu_notifier_register(&pasid_state->mn, mm);
 644
 645        ret = set_pasid_state(dev_state, pasid_state, pasid);
 646        if (ret)
 647                goto out_unregister;
 648
 649        ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
 650                                        __pa(pasid_state->mm->pgd));
 651        if (ret)
 652                goto out_clear_state;
 653
 654        /* Now we are ready to handle faults */
 655        pasid_state->invalid = false;
 656
 657        /*
 658         * Drop the reference to the mm_struct here. We rely on the
 659         * mmu_notifier release call-back to inform us when the mm
 660         * is going away.
 661         */
 662        mmput(mm);
 663
 664        return 0;
 665
 666out_clear_state:
 667        clear_pasid_state(dev_state, pasid);
 668
 669out_unregister:
 670        mmu_notifier_unregister(&pasid_state->mn, mm);
 671        mmput(mm);
 672
 673out_free:
 674        free_pasid_state(pasid_state);
 675
 676out:
 677        put_device_state(dev_state);
 678
 679        return ret;
 680}
 681EXPORT_SYMBOL(amd_iommu_bind_pasid);
 682
 683void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid)
 684{
 685        struct pasid_state *pasid_state;
 686        struct device_state *dev_state;
 687        u16 devid;
 688
 689        might_sleep();
 690
 691        if (!amd_iommu_v2_supported())
 692                return;
 693
 694        devid = device_id(pdev);
 695        dev_state = get_device_state(devid);
 696        if (dev_state == NULL)
 697                return;
 698
 699        if (pasid >= dev_state->max_pasids)
 700                goto out;
 701
 702        pasid_state = get_pasid_state(dev_state, pasid);
 703        if (pasid_state == NULL)
 704                goto out;
 705        /*
 706         * Drop reference taken here. We are safe because we still hold
 707         * the reference taken in the amd_iommu_bind_pasid function.
 708         */
 709        put_pasid_state(pasid_state);
 710
 711        /* Clear the pasid state so that the pasid can be re-used */
 712        clear_pasid_state(dev_state, pasid_state->pasid);
 713
 714        /*
 715         * Call mmu_notifier_unregister to drop our reference
 716         * to pasid_state->mm
 717         */
 718        mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
 719
 720        put_pasid_state_wait(pasid_state); /* Reference taken in
 721                                              amd_iommu_bind_pasid */
 722out:
 723        /* Drop reference taken in this function */
 724        put_device_state(dev_state);
 725
 726        /* Drop reference taken in amd_iommu_bind_pasid */
 727        put_device_state(dev_state);
 728}
 729EXPORT_SYMBOL(amd_iommu_unbind_pasid);
 730
 731int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
 732{
 733        struct device_state *dev_state;
 734        struct iommu_group *group;
 735        unsigned long flags;
 736        int ret, tmp;
 737        u16 devid;
 738
 739        might_sleep();
 740
 741        /*
 742         * When memory encryption is active the device is likely not in a
 743         * direct-mapped domain. Forbid using IOMMUv2 functionality for now.
 744         */
 745        if (mem_encrypt_active())
 746                return -ENODEV;
 747
 748        if (!amd_iommu_v2_supported())
 749                return -ENODEV;
 750
 751        if (pasids <= 0 || pasids > (PASID_MASK + 1))
 752                return -EINVAL;
 753
 754        devid = device_id(pdev);
 755
 756        dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
 757        if (dev_state == NULL)
 758                return -ENOMEM;
 759
 760        spin_lock_init(&dev_state->lock);
 761        init_waitqueue_head(&dev_state->wq);
 762        dev_state->pdev  = pdev;
 763        dev_state->devid = devid;
 764
 765        tmp = pasids;
 766        for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
 767                dev_state->pasid_levels += 1;
 768
 769        atomic_set(&dev_state->count, 1);
 770        dev_state->max_pasids = pasids;
 771
 772        ret = -ENOMEM;
 773        dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
 774        if (dev_state->states == NULL)
 775                goto out_free_dev_state;
 776
 777        dev_state->domain = iommu_domain_alloc(&pci_bus_type);
 778        if (dev_state->domain == NULL)
 779                goto out_free_states;
 780
 781        amd_iommu_domain_direct_map(dev_state->domain);
 782
 783        ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
 784        if (ret)
 785                goto out_free_domain;
 786
 787        group = iommu_group_get(&pdev->dev);
 788        if (!group) {
 789                ret = -EINVAL;
 790                goto out_free_domain;
 791        }
 792
 793        ret = iommu_attach_group(dev_state->domain, group);
 794        if (ret != 0)
 795                goto out_drop_group;
 796
 797        iommu_group_put(group);
 798
 799        spin_lock_irqsave(&state_lock, flags);
 800
 801        if (__get_device_state(devid) != NULL) {
 802                spin_unlock_irqrestore(&state_lock, flags);
 803                ret = -EBUSY;
 804                goto out_free_domain;
 805        }
 806
 807        list_add_tail(&dev_state->list, &state_list);
 808
 809        spin_unlock_irqrestore(&state_lock, flags);
 810
 811        return 0;
 812
 813out_drop_group:
 814        iommu_group_put(group);
 815
 816out_free_domain:
 817        iommu_domain_free(dev_state->domain);
 818
 819out_free_states:
 820        free_page((unsigned long)dev_state->states);
 821
 822out_free_dev_state:
 823        kfree(dev_state);
 824
 825        return ret;
 826}
 827EXPORT_SYMBOL(amd_iommu_init_device);
 828
 829void amd_iommu_free_device(struct pci_dev *pdev)
 830{
 831        struct device_state *dev_state;
 832        unsigned long flags;
 833        u16 devid;
 834
 835        if (!amd_iommu_v2_supported())
 836                return;
 837
 838        devid = device_id(pdev);
 839
 840        spin_lock_irqsave(&state_lock, flags);
 841
 842        dev_state = __get_device_state(devid);
 843        if (dev_state == NULL) {
 844                spin_unlock_irqrestore(&state_lock, flags);
 845                return;
 846        }
 847
 848        list_del(&dev_state->list);
 849
 850        spin_unlock_irqrestore(&state_lock, flags);
 851
 852        /* Get rid of any remaining pasid states */
 853        free_pasid_states(dev_state);
 854
 855        put_device_state(dev_state);
 856        /*
 857         * Wait until the last reference is dropped before freeing
 858         * the device state.
 859         */
 860        wait_event(dev_state->wq, !atomic_read(&dev_state->count));
 861        free_device_state(dev_state);
 862}
 863EXPORT_SYMBOL(amd_iommu_free_device);
 864
 865int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
 866                                 amd_iommu_invalid_ppr_cb cb)
 867{
 868        struct device_state *dev_state;
 869        unsigned long flags;
 870        u16 devid;
 871        int ret;
 872
 873        if (!amd_iommu_v2_supported())
 874                return -ENODEV;
 875
 876        devid = device_id(pdev);
 877
 878        spin_lock_irqsave(&state_lock, flags);
 879
 880        ret = -EINVAL;
 881        dev_state = __get_device_state(devid);
 882        if (dev_state == NULL)
 883                goto out_unlock;
 884
 885        dev_state->inv_ppr_cb = cb;
 886
 887        ret = 0;
 888
 889out_unlock:
 890        spin_unlock_irqrestore(&state_lock, flags);
 891
 892        return ret;
 893}
 894EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
 895
 896int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
 897                                    amd_iommu_invalidate_ctx cb)
 898{
 899        struct device_state *dev_state;
 900        unsigned long flags;
 901        u16 devid;
 902        int ret;
 903
 904        if (!amd_iommu_v2_supported())
 905                return -ENODEV;
 906
 907        devid = device_id(pdev);
 908
 909        spin_lock_irqsave(&state_lock, flags);
 910
 911        ret = -EINVAL;
 912        dev_state = __get_device_state(devid);
 913        if (dev_state == NULL)
 914                goto out_unlock;
 915
 916        dev_state->inv_ctx_cb = cb;
 917
 918        ret = 0;
 919
 920out_unlock:
 921        spin_unlock_irqrestore(&state_lock, flags);
 922
 923        return ret;
 924}
 925EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
 926
 927static int __init amd_iommu_v2_init(void)
 928{
 929        int ret;
 930
 931        pr_info("AMD IOMMUv2 driver by Joerg Roedel <jroedel@suse.de>\n");
 932
 933        if (!amd_iommu_v2_supported()) {
 934                pr_info("AMD IOMMUv2 functionality not available on this system\n");
 935                /*
 936                 * Load anyway to provide the symbols to other modules
 937                 * which may use AMD IOMMUv2 optionally.
 938                 */
 939                return 0;
 940        }
 941
 942        ret = -ENOMEM;
 943        iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0);
 944        if (iommu_wq == NULL)
 945                goto out;
 946
 947        amd_iommu_register_ppr_notifier(&ppr_nb);
 948
 949        return 0;
 950
 951out:
 952        return ret;
 953}
 954
 955static void __exit amd_iommu_v2_exit(void)
 956{
 957        struct device_state *dev_state;
 958        int i;
 959
 960        if (!amd_iommu_v2_supported())
 961                return;
 962
 963        amd_iommu_unregister_ppr_notifier(&ppr_nb);
 964
 965        flush_workqueue(iommu_wq);
 966
 967        /*
 968         * The loop below might call flush_workqueue(), so call
 969         * destroy_workqueue() after it
 970         */
 971        for (i = 0; i < MAX_DEVICES; ++i) {
 972                dev_state = get_device_state(i);
 973
 974                if (dev_state == NULL)
 975                        continue;
 976
 977                WARN_ON_ONCE(1);
 978
 979                put_device_state(dev_state);
 980                amd_iommu_free_device(dev_state->pdev);
 981        }
 982
 983        destroy_workqueue(iommu_wq);
 984}
 985
 986module_init(amd_iommu_v2_init);
 987module_exit(amd_iommu_v2_exit);
 988