linux/drivers/iommu/amd_iommu_v2.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
   4 * Author: Joerg Roedel <jroedel@suse.de>
   5 */
   6
   7#define pr_fmt(fmt)     "AMD-Vi: " fmt
   8
   9#include <linux/mmu_notifier.h>
  10#include <linux/amd-iommu.h>
  11#include <linux/mm_types.h>
  12#include <linux/profile.h>
  13#include <linux/module.h>
  14#include <linux/sched.h>
  15#include <linux/sched/mm.h>
  16#include <linux/iommu.h>
  17#include <linux/wait.h>
  18#include <linux/pci.h>
  19#include <linux/gfp.h>
  20
  21#include "amd_iommu_types.h"
  22#include "amd_iommu_proto.h"
  23
  24MODULE_LICENSE("GPL v2");
  25MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
  26
  27#define MAX_DEVICES             0x10000
  28#define PRI_QUEUE_SIZE          512
  29
  30struct pri_queue {
  31        atomic_t inflight;
  32        bool finish;
  33        int status;
  34};
  35
  36struct pasid_state {
  37        struct list_head list;                  /* For global state-list */
  38        atomic_t count;                         /* Reference count */
  39        unsigned mmu_notifier_count;            /* Counting nested mmu_notifier
  40                                                   calls */
  41        struct mm_struct *mm;                   /* mm_struct for the faults */
  42        struct mmu_notifier mn;                 /* mmu_notifier handle */
  43        struct pri_queue pri[PRI_QUEUE_SIZE];   /* PRI tag states */
  44        struct device_state *device_state;      /* Link to our device_state */
  45        int pasid;                              /* PASID index */
  46        bool invalid;                           /* Used during setup and
  47                                                   teardown of the pasid */
  48        spinlock_t lock;                        /* Protect pri_queues and
  49                                                   mmu_notifer_count */
  50        wait_queue_head_t wq;                   /* To wait for count == 0 */
  51};
  52
  53struct device_state {
  54        struct list_head list;
  55        u16 devid;
  56        atomic_t count;
  57        struct pci_dev *pdev;
  58        struct pasid_state **states;
  59        struct iommu_domain *domain;
  60        int pasid_levels;
  61        int max_pasids;
  62        amd_iommu_invalid_ppr_cb inv_ppr_cb;
  63        amd_iommu_invalidate_ctx inv_ctx_cb;
  64        spinlock_t lock;
  65        wait_queue_head_t wq;
  66};
  67
  68struct fault {
  69        struct work_struct work;
  70        struct device_state *dev_state;
  71        struct pasid_state *state;
  72        struct mm_struct *mm;
  73        u64 address;
  74        u16 devid;
  75        u16 pasid;
  76        u16 tag;
  77        u16 finish;
  78        u16 flags;
  79};
  80
  81static LIST_HEAD(state_list);
  82static spinlock_t state_lock;
  83
  84static struct workqueue_struct *iommu_wq;
  85
  86static void free_pasid_states(struct device_state *dev_state);
  87
  88static u16 device_id(struct pci_dev *pdev)
  89{
  90        u16 devid;
  91
  92        devid = pdev->bus->number;
  93        devid = (devid << 8) | pdev->devfn;
  94
  95        return devid;
  96}
  97
  98static struct device_state *__get_device_state(u16 devid)
  99{
 100        struct device_state *dev_state;
 101
 102        list_for_each_entry(dev_state, &state_list, list) {
 103                if (dev_state->devid == devid)
 104                        return dev_state;
 105        }
 106
 107        return NULL;
 108}
 109
 110static struct device_state *get_device_state(u16 devid)
 111{
 112        struct device_state *dev_state;
 113        unsigned long flags;
 114
 115        spin_lock_irqsave(&state_lock, flags);
 116        dev_state = __get_device_state(devid);
 117        if (dev_state != NULL)
 118                atomic_inc(&dev_state->count);
 119        spin_unlock_irqrestore(&state_lock, flags);
 120
 121        return dev_state;
 122}
 123
 124static void free_device_state(struct device_state *dev_state)
 125{
 126        struct iommu_group *group;
 127
 128        /*
 129         * First detach device from domain - No more PRI requests will arrive
 130         * from that device after it is unbound from the IOMMUv2 domain.
 131         */
 132        group = iommu_group_get(&dev_state->pdev->dev);
 133        if (WARN_ON(!group))
 134                return;
 135
 136        iommu_detach_group(dev_state->domain, group);
 137
 138        iommu_group_put(group);
 139
 140        /* Everything is down now, free the IOMMUv2 domain */
 141        iommu_domain_free(dev_state->domain);
 142
 143        /* Finally get rid of the device-state */
 144        kfree(dev_state);
 145}
 146
 147static void put_device_state(struct device_state *dev_state)
 148{
 149        if (atomic_dec_and_test(&dev_state->count))
 150                wake_up(&dev_state->wq);
 151}
 152
 153/* Must be called under dev_state->lock */
 154static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
 155                                                  int pasid, bool alloc)
 156{
 157        struct pasid_state **root, **ptr;
 158        int level, index;
 159
 160        level = dev_state->pasid_levels;
 161        root  = dev_state->states;
 162
 163        while (true) {
 164
 165                index = (pasid >> (9 * level)) & 0x1ff;
 166                ptr   = &root[index];
 167
 168                if (level == 0)
 169                        break;
 170
 171                if (*ptr == NULL) {
 172                        if (!alloc)
 173                                return NULL;
 174
 175                        *ptr = (void *)get_zeroed_page(GFP_ATOMIC);
 176                        if (*ptr == NULL)
 177                                return NULL;
 178                }
 179
 180                root   = (struct pasid_state **)*ptr;
 181                level -= 1;
 182        }
 183
 184        return ptr;
 185}
 186
 187static int set_pasid_state(struct device_state *dev_state,
 188                           struct pasid_state *pasid_state,
 189                           int pasid)
 190{
 191        struct pasid_state **ptr;
 192        unsigned long flags;
 193        int ret;
 194
 195        spin_lock_irqsave(&dev_state->lock, flags);
 196        ptr = __get_pasid_state_ptr(dev_state, pasid, true);
 197
 198        ret = -ENOMEM;
 199        if (ptr == NULL)
 200                goto out_unlock;
 201
 202        ret = -ENOMEM;
 203        if (*ptr != NULL)
 204                goto out_unlock;
 205
 206        *ptr = pasid_state;
 207
 208        ret = 0;
 209
 210out_unlock:
 211        spin_unlock_irqrestore(&dev_state->lock, flags);
 212
 213        return ret;
 214}
 215
 216static void clear_pasid_state(struct device_state *dev_state, int pasid)
 217{
 218        struct pasid_state **ptr;
 219        unsigned long flags;
 220
 221        spin_lock_irqsave(&dev_state->lock, flags);
 222        ptr = __get_pasid_state_ptr(dev_state, pasid, true);
 223
 224        if (ptr == NULL)
 225                goto out_unlock;
 226
 227        *ptr = NULL;
 228
 229out_unlock:
 230        spin_unlock_irqrestore(&dev_state->lock, flags);
 231}
 232
 233static struct pasid_state *get_pasid_state(struct device_state *dev_state,
 234                                           int pasid)
 235{
 236        struct pasid_state **ptr, *ret = NULL;
 237        unsigned long flags;
 238
 239        spin_lock_irqsave(&dev_state->lock, flags);
 240        ptr = __get_pasid_state_ptr(dev_state, pasid, false);
 241
 242        if (ptr == NULL)
 243                goto out_unlock;
 244
 245        ret = *ptr;
 246        if (ret)
 247                atomic_inc(&ret->count);
 248
 249out_unlock:
 250        spin_unlock_irqrestore(&dev_state->lock, flags);
 251
 252        return ret;
 253}
 254
 255static void free_pasid_state(struct pasid_state *pasid_state)
 256{
 257        kfree(pasid_state);
 258}
 259
 260static void put_pasid_state(struct pasid_state *pasid_state)
 261{
 262        if (atomic_dec_and_test(&pasid_state->count))
 263                wake_up(&pasid_state->wq);
 264}
 265
 266static void put_pasid_state_wait(struct pasid_state *pasid_state)
 267{
 268        atomic_dec(&pasid_state->count);
 269        wait_event(pasid_state->wq, !atomic_read(&pasid_state->count));
 270        free_pasid_state(pasid_state);
 271}
 272
 273static void unbind_pasid(struct pasid_state *pasid_state)
 274{
 275        struct iommu_domain *domain;
 276
 277        domain = pasid_state->device_state->domain;
 278
 279        /*
 280         * Mark pasid_state as invalid, no more faults will we added to the
 281         * work queue after this is visible everywhere.
 282         */
 283        pasid_state->invalid = true;
 284
 285        /* Make sure this is visible */
 286        smp_wmb();
 287
 288        /* After this the device/pasid can't access the mm anymore */
 289        amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
 290
 291        /* Make sure no more pending faults are in the queue */
 292        flush_workqueue(iommu_wq);
 293}
 294
 295static void free_pasid_states_level1(struct pasid_state **tbl)
 296{
 297        int i;
 298
 299        for (i = 0; i < 512; ++i) {
 300                if (tbl[i] == NULL)
 301                        continue;
 302
 303                free_page((unsigned long)tbl[i]);
 304        }
 305}
 306
 307static void free_pasid_states_level2(struct pasid_state **tbl)
 308{
 309        struct pasid_state **ptr;
 310        int i;
 311
 312        for (i = 0; i < 512; ++i) {
 313                if (tbl[i] == NULL)
 314                        continue;
 315
 316                ptr = (struct pasid_state **)tbl[i];
 317                free_pasid_states_level1(ptr);
 318        }
 319}
 320
 321static void free_pasid_states(struct device_state *dev_state)
 322{
 323        struct pasid_state *pasid_state;
 324        int i;
 325
 326        for (i = 0; i < dev_state->max_pasids; ++i) {
 327                pasid_state = get_pasid_state(dev_state, i);
 328                if (pasid_state == NULL)
 329                        continue;
 330
 331                put_pasid_state(pasid_state);
 332
 333                /*
 334                 * This will call the mn_release function and
 335                 * unbind the PASID
 336                 */
 337                mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
 338
 339                put_pasid_state_wait(pasid_state); /* Reference taken in
 340                                                      amd_iommu_bind_pasid */
 341
 342                /* Drop reference taken in amd_iommu_bind_pasid */
 343                put_device_state(dev_state);
 344        }
 345
 346        if (dev_state->pasid_levels == 2)
 347                free_pasid_states_level2(dev_state->states);
 348        else if (dev_state->pasid_levels == 1)
 349                free_pasid_states_level1(dev_state->states);
 350        else
 351                BUG_ON(dev_state->pasid_levels != 0);
 352
 353        free_page((unsigned long)dev_state->states);
 354}
 355
 356static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
 357{
 358        return container_of(mn, struct pasid_state, mn);
 359}
 360
 361static void mn_invalidate_range(struct mmu_notifier *mn,
 362                                struct mm_struct *mm,
 363                                unsigned long start, unsigned long end)
 364{
 365        struct pasid_state *pasid_state;
 366        struct device_state *dev_state;
 367
 368        pasid_state = mn_to_state(mn);
 369        dev_state   = pasid_state->device_state;
 370
 371        if ((start ^ (end - 1)) < PAGE_SIZE)
 372                amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
 373                                     start);
 374        else
 375                amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
 376}
 377
 378static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
 379{
 380        struct pasid_state *pasid_state;
 381        struct device_state *dev_state;
 382        bool run_inv_ctx_cb;
 383
 384        might_sleep();
 385
 386        pasid_state    = mn_to_state(mn);
 387        dev_state      = pasid_state->device_state;
 388        run_inv_ctx_cb = !pasid_state->invalid;
 389
 390        if (run_inv_ctx_cb && dev_state->inv_ctx_cb)
 391                dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
 392
 393        unbind_pasid(pasid_state);
 394}
 395
 396static const struct mmu_notifier_ops iommu_mn = {
 397        .release                = mn_release,
 398        .invalidate_range       = mn_invalidate_range,
 399};
 400
 401static void set_pri_tag_status(struct pasid_state *pasid_state,
 402                               u16 tag, int status)
 403{
 404        unsigned long flags;
 405
 406        spin_lock_irqsave(&pasid_state->lock, flags);
 407        pasid_state->pri[tag].status = status;
 408        spin_unlock_irqrestore(&pasid_state->lock, flags);
 409}
 410
 411static void finish_pri_tag(struct device_state *dev_state,
 412                           struct pasid_state *pasid_state,
 413                           u16 tag)
 414{
 415        unsigned long flags;
 416
 417        spin_lock_irqsave(&pasid_state->lock, flags);
 418        if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
 419            pasid_state->pri[tag].finish) {
 420                amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
 421                                       pasid_state->pri[tag].status, tag);
 422                pasid_state->pri[tag].finish = false;
 423                pasid_state->pri[tag].status = PPR_SUCCESS;
 424        }
 425        spin_unlock_irqrestore(&pasid_state->lock, flags);
 426}
 427
 428static void handle_fault_error(struct fault *fault)
 429{
 430        int status;
 431
 432        if (!fault->dev_state->inv_ppr_cb) {
 433                set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
 434                return;
 435        }
 436
 437        status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
 438                                              fault->pasid,
 439                                              fault->address,
 440                                              fault->flags);
 441        switch (status) {
 442        case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
 443                set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
 444                break;
 445        case AMD_IOMMU_INV_PRI_RSP_INVALID:
 446                set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
 447                break;
 448        case AMD_IOMMU_INV_PRI_RSP_FAIL:
 449                set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
 450                break;
 451        default:
 452                BUG();
 453        }
 454}
 455
 456static bool access_error(struct vm_area_struct *vma, struct fault *fault)
 457{
 458        unsigned long requested = 0;
 459
 460        if (fault->flags & PPR_FAULT_EXEC)
 461                requested |= VM_EXEC;
 462
 463        if (fault->flags & PPR_FAULT_READ)
 464                requested |= VM_READ;
 465
 466        if (fault->flags & PPR_FAULT_WRITE)
 467                requested |= VM_WRITE;
 468
 469        return (requested & ~vma->vm_flags) != 0;
 470}
 471
 472static void do_fault(struct work_struct *work)
 473{
 474        struct fault *fault = container_of(work, struct fault, work);
 475        struct vm_area_struct *vma;
 476        vm_fault_t ret = VM_FAULT_ERROR;
 477        unsigned int flags = 0;
 478        struct mm_struct *mm;
 479        u64 address;
 480
 481        mm = fault->state->mm;
 482        address = fault->address;
 483
 484        if (fault->flags & PPR_FAULT_USER)
 485                flags |= FAULT_FLAG_USER;
 486        if (fault->flags & PPR_FAULT_WRITE)
 487                flags |= FAULT_FLAG_WRITE;
 488        flags |= FAULT_FLAG_REMOTE;
 489
 490        down_read(&mm->mmap_sem);
 491        vma = find_extend_vma(mm, address);
 492        if (!vma || address < vma->vm_start)
 493                /* failed to get a vma in the right range */
 494                goto out;
 495
 496        /* Check if we have the right permissions on the vma */
 497        if (access_error(vma, fault))
 498                goto out;
 499
 500        ret = handle_mm_fault(vma, address, flags);
 501out:
 502        up_read(&mm->mmap_sem);
 503
 504        if (ret & VM_FAULT_ERROR)
 505                /* failed to service fault */
 506                handle_fault_error(fault);
 507
 508        finish_pri_tag(fault->dev_state, fault->state, fault->tag);
 509
 510        put_pasid_state(fault->state);
 511
 512        kfree(fault);
 513}
 514
 515static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
 516{
 517        struct amd_iommu_fault *iommu_fault;
 518        struct pasid_state *pasid_state;
 519        struct device_state *dev_state;
 520        unsigned long flags;
 521        struct fault *fault;
 522        bool finish;
 523        u16 tag, devid;
 524        int ret;
 525        struct iommu_dev_data *dev_data;
 526        struct pci_dev *pdev = NULL;
 527
 528        iommu_fault = data;
 529        tag         = iommu_fault->tag & 0x1ff;
 530        finish      = (iommu_fault->tag >> 9) & 1;
 531
 532        devid = iommu_fault->device_id;
 533        pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid),
 534                                           devid & 0xff);
 535        if (!pdev)
 536                return -ENODEV;
 537        dev_data = get_dev_data(&pdev->dev);
 538
 539        /* In kdump kernel pci dev is not initialized yet -> send INVALID */
 540        ret = NOTIFY_DONE;
 541        if (translation_pre_enabled(amd_iommu_rlookup_table[devid])
 542                && dev_data->defer_attach) {
 543                amd_iommu_complete_ppr(pdev, iommu_fault->pasid,
 544                                       PPR_INVALID, tag);
 545                goto out;
 546        }
 547
 548        dev_state = get_device_state(iommu_fault->device_id);
 549        if (dev_state == NULL)
 550                goto out;
 551
 552        pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
 553        if (pasid_state == NULL || pasid_state->invalid) {
 554                /* We know the device but not the PASID -> send INVALID */
 555                amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
 556                                       PPR_INVALID, tag);
 557                goto out_drop_state;
 558        }
 559
 560        spin_lock_irqsave(&pasid_state->lock, flags);
 561        atomic_inc(&pasid_state->pri[tag].inflight);
 562        if (finish)
 563                pasid_state->pri[tag].finish = true;
 564        spin_unlock_irqrestore(&pasid_state->lock, flags);
 565
 566        fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
 567        if (fault == NULL) {
 568                /* We are OOM - send success and let the device re-fault */
 569                finish_pri_tag(dev_state, pasid_state, tag);
 570                goto out_drop_state;
 571        }
 572
 573        fault->dev_state = dev_state;
 574        fault->address   = iommu_fault->address;
 575        fault->state     = pasid_state;
 576        fault->tag       = tag;
 577        fault->finish    = finish;
 578        fault->pasid     = iommu_fault->pasid;
 579        fault->flags     = iommu_fault->flags;
 580        INIT_WORK(&fault->work, do_fault);
 581
 582        queue_work(iommu_wq, &fault->work);
 583
 584        ret = NOTIFY_OK;
 585
 586out_drop_state:
 587
 588        if (ret != NOTIFY_OK && pasid_state)
 589                put_pasid_state(pasid_state);
 590
 591        put_device_state(dev_state);
 592
 593out:
 594        return ret;
 595}
 596
 597static struct notifier_block ppr_nb = {
 598        .notifier_call = ppr_notifier,
 599};
 600
 601int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
 602                         struct task_struct *task)
 603{
 604        struct pasid_state *pasid_state;
 605        struct device_state *dev_state;
 606        struct mm_struct *mm;
 607        u16 devid;
 608        int ret;
 609
 610        might_sleep();
 611
 612        if (!amd_iommu_v2_supported())
 613                return -ENODEV;
 614
 615        devid     = device_id(pdev);
 616        dev_state = get_device_state(devid);
 617
 618        if (dev_state == NULL)
 619                return -EINVAL;
 620
 621        ret = -EINVAL;
 622        if (pasid < 0 || pasid >= dev_state->max_pasids)
 623                goto out;
 624
 625        ret = -ENOMEM;
 626        pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
 627        if (pasid_state == NULL)
 628                goto out;
 629
 630
 631        atomic_set(&pasid_state->count, 1);
 632        init_waitqueue_head(&pasid_state->wq);
 633        spin_lock_init(&pasid_state->lock);
 634
 635        mm                        = get_task_mm(task);
 636        pasid_state->mm           = mm;
 637        pasid_state->device_state = dev_state;
 638        pasid_state->pasid        = pasid;
 639        pasid_state->invalid      = true; /* Mark as valid only if we are
 640                                             done with setting up the pasid */
 641        pasid_state->mn.ops       = &iommu_mn;
 642
 643        if (pasid_state->mm == NULL)
 644                goto out_free;
 645
 646        mmu_notifier_register(&pasid_state->mn, mm);
 647
 648        ret = set_pasid_state(dev_state, pasid_state, pasid);
 649        if (ret)
 650                goto out_unregister;
 651
 652        ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
 653                                        __pa(pasid_state->mm->pgd));
 654        if (ret)
 655                goto out_clear_state;
 656
 657        /* Now we are ready to handle faults */
 658        pasid_state->invalid = false;
 659
 660        /*
 661         * Drop the reference to the mm_struct here. We rely on the
 662         * mmu_notifier release call-back to inform us when the mm
 663         * is going away.
 664         */
 665        mmput(mm);
 666
 667        return 0;
 668
 669out_clear_state:
 670        clear_pasid_state(dev_state, pasid);
 671
 672out_unregister:
 673        mmu_notifier_unregister(&pasid_state->mn, mm);
 674        mmput(mm);
 675
 676out_free:
 677        free_pasid_state(pasid_state);
 678
 679out:
 680        put_device_state(dev_state);
 681
 682        return ret;
 683}
 684EXPORT_SYMBOL(amd_iommu_bind_pasid);
 685
 686void amd_iommu_unbind_pasid(struct pci_dev *pdev, int pasid)
 687{
 688        struct pasid_state *pasid_state;
 689        struct device_state *dev_state;
 690        u16 devid;
 691
 692        might_sleep();
 693
 694        if (!amd_iommu_v2_supported())
 695                return;
 696
 697        devid = device_id(pdev);
 698        dev_state = get_device_state(devid);
 699        if (dev_state == NULL)
 700                return;
 701
 702        if (pasid < 0 || pasid >= dev_state->max_pasids)
 703                goto out;
 704
 705        pasid_state = get_pasid_state(dev_state, pasid);
 706        if (pasid_state == NULL)
 707                goto out;
 708        /*
 709         * Drop reference taken here. We are safe because we still hold
 710         * the reference taken in the amd_iommu_bind_pasid function.
 711         */
 712        put_pasid_state(pasid_state);
 713
 714        /* Clear the pasid state so that the pasid can be re-used */
 715        clear_pasid_state(dev_state, pasid_state->pasid);
 716
 717        /*
 718         * Call mmu_notifier_unregister to drop our reference
 719         * to pasid_state->mm
 720         */
 721        mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
 722
 723        put_pasid_state_wait(pasid_state); /* Reference taken in
 724                                              amd_iommu_bind_pasid */
 725out:
 726        /* Drop reference taken in this function */
 727        put_device_state(dev_state);
 728
 729        /* Drop reference taken in amd_iommu_bind_pasid */
 730        put_device_state(dev_state);
 731}
 732EXPORT_SYMBOL(amd_iommu_unbind_pasid);
 733
 734int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
 735{
 736        struct device_state *dev_state;
 737        struct iommu_group *group;
 738        unsigned long flags;
 739        int ret, tmp;
 740        u16 devid;
 741
 742        might_sleep();
 743
 744        if (!amd_iommu_v2_supported())
 745                return -ENODEV;
 746
 747        if (pasids <= 0 || pasids > (PASID_MASK + 1))
 748                return -EINVAL;
 749
 750        devid = device_id(pdev);
 751
 752        dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
 753        if (dev_state == NULL)
 754                return -ENOMEM;
 755
 756        spin_lock_init(&dev_state->lock);
 757        init_waitqueue_head(&dev_state->wq);
 758        dev_state->pdev  = pdev;
 759        dev_state->devid = devid;
 760
 761        tmp = pasids;
 762        for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
 763                dev_state->pasid_levels += 1;
 764
 765        atomic_set(&dev_state->count, 1);
 766        dev_state->max_pasids = pasids;
 767
 768        ret = -ENOMEM;
 769        dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
 770        if (dev_state->states == NULL)
 771                goto out_free_dev_state;
 772
 773        dev_state->domain = iommu_domain_alloc(&pci_bus_type);
 774        if (dev_state->domain == NULL)
 775                goto out_free_states;
 776
 777        amd_iommu_domain_direct_map(dev_state->domain);
 778
 779        ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
 780        if (ret)
 781                goto out_free_domain;
 782
 783        group = iommu_group_get(&pdev->dev);
 784        if (!group) {
 785                ret = -EINVAL;
 786                goto out_free_domain;
 787        }
 788
 789        ret = iommu_attach_group(dev_state->domain, group);
 790        if (ret != 0)
 791                goto out_drop_group;
 792
 793        iommu_group_put(group);
 794
 795        spin_lock_irqsave(&state_lock, flags);
 796
 797        if (__get_device_state(devid) != NULL) {
 798                spin_unlock_irqrestore(&state_lock, flags);
 799                ret = -EBUSY;
 800                goto out_free_domain;
 801        }
 802
 803        list_add_tail(&dev_state->list, &state_list);
 804
 805        spin_unlock_irqrestore(&state_lock, flags);
 806
 807        return 0;
 808
 809out_drop_group:
 810        iommu_group_put(group);
 811
 812out_free_domain:
 813        iommu_domain_free(dev_state->domain);
 814
 815out_free_states:
 816        free_page((unsigned long)dev_state->states);
 817
 818out_free_dev_state:
 819        kfree(dev_state);
 820
 821        return ret;
 822}
 823EXPORT_SYMBOL(amd_iommu_init_device);
 824
 825void amd_iommu_free_device(struct pci_dev *pdev)
 826{
 827        struct device_state *dev_state;
 828        unsigned long flags;
 829        u16 devid;
 830
 831        if (!amd_iommu_v2_supported())
 832                return;
 833
 834        devid = device_id(pdev);
 835
 836        spin_lock_irqsave(&state_lock, flags);
 837
 838        dev_state = __get_device_state(devid);
 839        if (dev_state == NULL) {
 840                spin_unlock_irqrestore(&state_lock, flags);
 841                return;
 842        }
 843
 844        list_del(&dev_state->list);
 845
 846        spin_unlock_irqrestore(&state_lock, flags);
 847
 848        /* Get rid of any remaining pasid states */
 849        free_pasid_states(dev_state);
 850
 851        put_device_state(dev_state);
 852        /*
 853         * Wait until the last reference is dropped before freeing
 854         * the device state.
 855         */
 856        wait_event(dev_state->wq, !atomic_read(&dev_state->count));
 857        free_device_state(dev_state);
 858}
 859EXPORT_SYMBOL(amd_iommu_free_device);
 860
 861int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
 862                                 amd_iommu_invalid_ppr_cb cb)
 863{
 864        struct device_state *dev_state;
 865        unsigned long flags;
 866        u16 devid;
 867        int ret;
 868
 869        if (!amd_iommu_v2_supported())
 870                return -ENODEV;
 871
 872        devid = device_id(pdev);
 873
 874        spin_lock_irqsave(&state_lock, flags);
 875
 876        ret = -EINVAL;
 877        dev_state = __get_device_state(devid);
 878        if (dev_state == NULL)
 879                goto out_unlock;
 880
 881        dev_state->inv_ppr_cb = cb;
 882
 883        ret = 0;
 884
 885out_unlock:
 886        spin_unlock_irqrestore(&state_lock, flags);
 887
 888        return ret;
 889}
 890EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
 891
 892int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
 893                                    amd_iommu_invalidate_ctx cb)
 894{
 895        struct device_state *dev_state;
 896        unsigned long flags;
 897        u16 devid;
 898        int ret;
 899
 900        if (!amd_iommu_v2_supported())
 901                return -ENODEV;
 902
 903        devid = device_id(pdev);
 904
 905        spin_lock_irqsave(&state_lock, flags);
 906
 907        ret = -EINVAL;
 908        dev_state = __get_device_state(devid);
 909        if (dev_state == NULL)
 910                goto out_unlock;
 911
 912        dev_state->inv_ctx_cb = cb;
 913
 914        ret = 0;
 915
 916out_unlock:
 917        spin_unlock_irqrestore(&state_lock, flags);
 918
 919        return ret;
 920}
 921EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
 922
 923static int __init amd_iommu_v2_init(void)
 924{
 925        int ret;
 926
 927        pr_info("AMD IOMMUv2 driver by Joerg Roedel <jroedel@suse.de>\n");
 928
 929        if (!amd_iommu_v2_supported()) {
 930                pr_info("AMD IOMMUv2 functionality not available on this system\n");
 931                /*
 932                 * Load anyway to provide the symbols to other modules
 933                 * which may use AMD IOMMUv2 optionally.
 934                 */
 935                return 0;
 936        }
 937
 938        spin_lock_init(&state_lock);
 939
 940        ret = -ENOMEM;
 941        iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0);
 942        if (iommu_wq == NULL)
 943                goto out;
 944
 945        amd_iommu_register_ppr_notifier(&ppr_nb);
 946
 947        return 0;
 948
 949out:
 950        return ret;
 951}
 952
 953static void __exit amd_iommu_v2_exit(void)
 954{
 955        struct device_state *dev_state;
 956        int i;
 957
 958        if (!amd_iommu_v2_supported())
 959                return;
 960
 961        amd_iommu_unregister_ppr_notifier(&ppr_nb);
 962
 963        flush_workqueue(iommu_wq);
 964
 965        /*
 966         * The loop below might call flush_workqueue(), so call
 967         * destroy_workqueue() after it
 968         */
 969        for (i = 0; i < MAX_DEVICES; ++i) {
 970                dev_state = get_device_state(i);
 971
 972                if (dev_state == NULL)
 973                        continue;
 974
 975                WARN_ON_ONCE(1);
 976
 977                put_device_state(dev_state);
 978                amd_iommu_free_device(dev_state->pdev);
 979        }
 980
 981        destroy_workqueue(iommu_wq);
 982}
 983
 984module_init(amd_iommu_v2_init);
 985module_exit(amd_iommu_v2_exit);
 986