linux/drivers/iommu/amd/iommu_v2.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
   4 * Author: Joerg Roedel <jroedel@suse.de>
   5 */
   6
   7#define pr_fmt(fmt)     "AMD-Vi: " fmt
   8
   9#include <linux/mmu_notifier.h>
  10#include <linux/amd-iommu.h>
  11#include <linux/mm_types.h>
  12#include <linux/profile.h>
  13#include <linux/module.h>
  14#include <linux/sched.h>
  15#include <linux/sched/mm.h>
  16#include <linux/wait.h>
  17#include <linux/pci.h>
  18#include <linux/gfp.h>
  19
  20#include "amd_iommu.h"
  21
  22MODULE_LICENSE("GPL v2");
  23MODULE_AUTHOR("Joerg Roedel <jroedel@suse.de>");
  24
  25#define MAX_DEVICES             0x10000
  26#define PRI_QUEUE_SIZE          512
  27
  28struct pri_queue {
  29        atomic_t inflight;
  30        bool finish;
  31        int status;
  32};
  33
  34struct pasid_state {
  35        struct list_head list;                  /* For global state-list */
  36        atomic_t count;                         /* Reference count */
  37        unsigned mmu_notifier_count;            /* Counting nested mmu_notifier
  38                                                   calls */
  39        struct mm_struct *mm;                   /* mm_struct for the faults */
  40        struct mmu_notifier mn;                 /* mmu_notifier handle */
  41        struct pri_queue pri[PRI_QUEUE_SIZE];   /* PRI tag states */
  42        struct device_state *device_state;      /* Link to our device_state */
  43        u32 pasid;                              /* PASID index */
  44        bool invalid;                           /* Used during setup and
  45                                                   teardown of the pasid */
  46        spinlock_t lock;                        /* Protect pri_queues and
  47                                                   mmu_notifer_count */
  48        wait_queue_head_t wq;                   /* To wait for count == 0 */
  49};
  50
  51struct device_state {
  52        struct list_head list;
  53        u16 devid;
  54        atomic_t count;
  55        struct pci_dev *pdev;
  56        struct pasid_state **states;
  57        struct iommu_domain *domain;
  58        int pasid_levels;
  59        int max_pasids;
  60        amd_iommu_invalid_ppr_cb inv_ppr_cb;
  61        amd_iommu_invalidate_ctx inv_ctx_cb;
  62        spinlock_t lock;
  63        wait_queue_head_t wq;
  64};
  65
  66struct fault {
  67        struct work_struct work;
  68        struct device_state *dev_state;
  69        struct pasid_state *state;
  70        struct mm_struct *mm;
  71        u64 address;
  72        u16 devid;
  73        u32 pasid;
  74        u16 tag;
  75        u16 finish;
  76        u16 flags;
  77};
  78
  79static LIST_HEAD(state_list);
  80static DEFINE_SPINLOCK(state_lock);
  81
  82static struct workqueue_struct *iommu_wq;
  83
  84static void free_pasid_states(struct device_state *dev_state);
  85
  86static u16 device_id(struct pci_dev *pdev)
  87{
  88        u16 devid;
  89
  90        devid = pdev->bus->number;
  91        devid = (devid << 8) | pdev->devfn;
  92
  93        return devid;
  94}
  95
  96static struct device_state *__get_device_state(u16 devid)
  97{
  98        struct device_state *dev_state;
  99
 100        list_for_each_entry(dev_state, &state_list, list) {
 101                if (dev_state->devid == devid)
 102                        return dev_state;
 103        }
 104
 105        return NULL;
 106}
 107
 108static struct device_state *get_device_state(u16 devid)
 109{
 110        struct device_state *dev_state;
 111        unsigned long flags;
 112
 113        spin_lock_irqsave(&state_lock, flags);
 114        dev_state = __get_device_state(devid);
 115        if (dev_state != NULL)
 116                atomic_inc(&dev_state->count);
 117        spin_unlock_irqrestore(&state_lock, flags);
 118
 119        return dev_state;
 120}
 121
 122static void free_device_state(struct device_state *dev_state)
 123{
 124        struct iommu_group *group;
 125
 126        /*
 127         * First detach device from domain - No more PRI requests will arrive
 128         * from that device after it is unbound from the IOMMUv2 domain.
 129         */
 130        group = iommu_group_get(&dev_state->pdev->dev);
 131        if (WARN_ON(!group))
 132                return;
 133
 134        iommu_detach_group(dev_state->domain, group);
 135
 136        iommu_group_put(group);
 137
 138        /* Everything is down now, free the IOMMUv2 domain */
 139        iommu_domain_free(dev_state->domain);
 140
 141        /* Finally get rid of the device-state */
 142        kfree(dev_state);
 143}
 144
 145static void put_device_state(struct device_state *dev_state)
 146{
 147        if (atomic_dec_and_test(&dev_state->count))
 148                wake_up(&dev_state->wq);
 149}
 150
 151/* Must be called under dev_state->lock */
 152static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
 153                                                  u32 pasid, bool alloc)
 154{
 155        struct pasid_state **root, **ptr;
 156        int level, index;
 157
 158        level = dev_state->pasid_levels;
 159        root  = dev_state->states;
 160
 161        while (true) {
 162
 163                index = (pasid >> (9 * level)) & 0x1ff;
 164                ptr   = &root[index];
 165
 166                if (level == 0)
 167                        break;
 168
 169                if (*ptr == NULL) {
 170                        if (!alloc)
 171                                return NULL;
 172
 173                        *ptr = (void *)get_zeroed_page(GFP_ATOMIC);
 174                        if (*ptr == NULL)
 175                                return NULL;
 176                }
 177
 178                root   = (struct pasid_state **)*ptr;
 179                level -= 1;
 180        }
 181
 182        return ptr;
 183}
 184
 185static int set_pasid_state(struct device_state *dev_state,
 186                           struct pasid_state *pasid_state,
 187                           u32 pasid)
 188{
 189        struct pasid_state **ptr;
 190        unsigned long flags;
 191        int ret;
 192
 193        spin_lock_irqsave(&dev_state->lock, flags);
 194        ptr = __get_pasid_state_ptr(dev_state, pasid, true);
 195
 196        ret = -ENOMEM;
 197        if (ptr == NULL)
 198                goto out_unlock;
 199
 200        ret = -ENOMEM;
 201        if (*ptr != NULL)
 202                goto out_unlock;
 203
 204        *ptr = pasid_state;
 205
 206        ret = 0;
 207
 208out_unlock:
 209        spin_unlock_irqrestore(&dev_state->lock, flags);
 210
 211        return ret;
 212}
 213
 214static void clear_pasid_state(struct device_state *dev_state, u32 pasid)
 215{
 216        struct pasid_state **ptr;
 217        unsigned long flags;
 218
 219        spin_lock_irqsave(&dev_state->lock, flags);
 220        ptr = __get_pasid_state_ptr(dev_state, pasid, true);
 221
 222        if (ptr == NULL)
 223                goto out_unlock;
 224
 225        *ptr = NULL;
 226
 227out_unlock:
 228        spin_unlock_irqrestore(&dev_state->lock, flags);
 229}
 230
 231static struct pasid_state *get_pasid_state(struct device_state *dev_state,
 232                                           u32 pasid)
 233{
 234        struct pasid_state **ptr, *ret = NULL;
 235        unsigned long flags;
 236
 237        spin_lock_irqsave(&dev_state->lock, flags);
 238        ptr = __get_pasid_state_ptr(dev_state, pasid, false);
 239
 240        if (ptr == NULL)
 241                goto out_unlock;
 242
 243        ret = *ptr;
 244        if (ret)
 245                atomic_inc(&ret->count);
 246
 247out_unlock:
 248        spin_unlock_irqrestore(&dev_state->lock, flags);
 249
 250        return ret;
 251}
 252
 253static void free_pasid_state(struct pasid_state *pasid_state)
 254{
 255        kfree(pasid_state);
 256}
 257
 258static void put_pasid_state(struct pasid_state *pasid_state)
 259{
 260        if (atomic_dec_and_test(&pasid_state->count))
 261                wake_up(&pasid_state->wq);
 262}
 263
 264static void put_pasid_state_wait(struct pasid_state *pasid_state)
 265{
 266        atomic_dec(&pasid_state->count);
 267        wait_event(pasid_state->wq, !atomic_read(&pasid_state->count));
 268        free_pasid_state(pasid_state);
 269}
 270
 271static void unbind_pasid(struct pasid_state *pasid_state)
 272{
 273        struct iommu_domain *domain;
 274
 275        domain = pasid_state->device_state->domain;
 276
 277        /*
 278         * Mark pasid_state as invalid, no more faults will we added to the
 279         * work queue after this is visible everywhere.
 280         */
 281        pasid_state->invalid = true;
 282
 283        /* Make sure this is visible */
 284        smp_wmb();
 285
 286        /* After this the device/pasid can't access the mm anymore */
 287        amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
 288
 289        /* Make sure no more pending faults are in the queue */
 290        flush_workqueue(iommu_wq);
 291}
 292
 293static void free_pasid_states_level1(struct pasid_state **tbl)
 294{
 295        int i;
 296
 297        for (i = 0; i < 512; ++i) {
 298                if (tbl[i] == NULL)
 299                        continue;
 300
 301                free_page((unsigned long)tbl[i]);
 302        }
 303}
 304
 305static void free_pasid_states_level2(struct pasid_state **tbl)
 306{
 307        struct pasid_state **ptr;
 308        int i;
 309
 310        for (i = 0; i < 512; ++i) {
 311                if (tbl[i] == NULL)
 312                        continue;
 313
 314                ptr = (struct pasid_state **)tbl[i];
 315                free_pasid_states_level1(ptr);
 316        }
 317}
 318
 319static void free_pasid_states(struct device_state *dev_state)
 320{
 321        struct pasid_state *pasid_state;
 322        int i;
 323
 324        for (i = 0; i < dev_state->max_pasids; ++i) {
 325                pasid_state = get_pasid_state(dev_state, i);
 326                if (pasid_state == NULL)
 327                        continue;
 328
 329                put_pasid_state(pasid_state);
 330
 331                /*
 332                 * This will call the mn_release function and
 333                 * unbind the PASID
 334                 */
 335                mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
 336
 337                put_pasid_state_wait(pasid_state); /* Reference taken in
 338                                                      amd_iommu_bind_pasid */
 339
 340                /* Drop reference taken in amd_iommu_bind_pasid */
 341                put_device_state(dev_state);
 342        }
 343
 344        if (dev_state->pasid_levels == 2)
 345                free_pasid_states_level2(dev_state->states);
 346        else if (dev_state->pasid_levels == 1)
 347                free_pasid_states_level1(dev_state->states);
 348        else
 349                BUG_ON(dev_state->pasid_levels != 0);
 350
 351        free_page((unsigned long)dev_state->states);
 352}
 353
 354static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
 355{
 356        return container_of(mn, struct pasid_state, mn);
 357}
 358
 359static void mn_invalidate_range(struct mmu_notifier *mn,
 360                                struct mm_struct *mm,
 361                                unsigned long start, unsigned long end)
 362{
 363        struct pasid_state *pasid_state;
 364        struct device_state *dev_state;
 365
 366        pasid_state = mn_to_state(mn);
 367        dev_state   = pasid_state->device_state;
 368
 369        if ((start ^ (end - 1)) < PAGE_SIZE)
 370                amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
 371                                     start);
 372        else
 373                amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
 374}
 375
 376static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
 377{
 378        struct pasid_state *pasid_state;
 379        struct device_state *dev_state;
 380        bool run_inv_ctx_cb;
 381
 382        might_sleep();
 383
 384        pasid_state    = mn_to_state(mn);
 385        dev_state      = pasid_state->device_state;
 386        run_inv_ctx_cb = !pasid_state->invalid;
 387
 388        if (run_inv_ctx_cb && dev_state->inv_ctx_cb)
 389                dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
 390
 391        unbind_pasid(pasid_state);
 392}
 393
 394static const struct mmu_notifier_ops iommu_mn = {
 395        .release                = mn_release,
 396        .invalidate_range       = mn_invalidate_range,
 397};
 398
 399static void set_pri_tag_status(struct pasid_state *pasid_state,
 400                               u16 tag, int status)
 401{
 402        unsigned long flags;
 403
 404        spin_lock_irqsave(&pasid_state->lock, flags);
 405        pasid_state->pri[tag].status = status;
 406        spin_unlock_irqrestore(&pasid_state->lock, flags);
 407}
 408
 409static void finish_pri_tag(struct device_state *dev_state,
 410                           struct pasid_state *pasid_state,
 411                           u16 tag)
 412{
 413        unsigned long flags;
 414
 415        spin_lock_irqsave(&pasid_state->lock, flags);
 416        if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
 417            pasid_state->pri[tag].finish) {
 418                amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
 419                                       pasid_state->pri[tag].status, tag);
 420                pasid_state->pri[tag].finish = false;
 421                pasid_state->pri[tag].status = PPR_SUCCESS;
 422        }
 423        spin_unlock_irqrestore(&pasid_state->lock, flags);
 424}
 425
 426static void handle_fault_error(struct fault *fault)
 427{
 428        int status;
 429
 430        if (!fault->dev_state->inv_ppr_cb) {
 431                set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
 432                return;
 433        }
 434
 435        status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
 436                                              fault->pasid,
 437                                              fault->address,
 438                                              fault->flags);
 439        switch (status) {
 440        case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
 441                set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
 442                break;
 443        case AMD_IOMMU_INV_PRI_RSP_INVALID:
 444                set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
 445                break;
 446        case AMD_IOMMU_INV_PRI_RSP_FAIL:
 447                set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
 448                break;
 449        default:
 450                BUG();
 451        }
 452}
 453
 454static bool access_error(struct vm_area_struct *vma, struct fault *fault)
 455{
 456        unsigned long requested = 0;
 457
 458        if (fault->flags & PPR_FAULT_EXEC)
 459                requested |= VM_EXEC;
 460
 461        if (fault->flags & PPR_FAULT_READ)
 462                requested |= VM_READ;
 463
 464        if (fault->flags & PPR_FAULT_WRITE)
 465                requested |= VM_WRITE;
 466
 467        return (requested & ~vma->vm_flags) != 0;
 468}
 469
 470static void do_fault(struct work_struct *work)
 471{
 472        struct fault *fault = container_of(work, struct fault, work);
 473        struct vm_area_struct *vma;
 474        vm_fault_t ret = VM_FAULT_ERROR;
 475        unsigned int flags = 0;
 476        struct mm_struct *mm;
 477        u64 address;
 478
 479        mm = fault->state->mm;
 480        address = fault->address;
 481
 482        if (fault->flags & PPR_FAULT_USER)
 483                flags |= FAULT_FLAG_USER;
 484        if (fault->flags & PPR_FAULT_WRITE)
 485                flags |= FAULT_FLAG_WRITE;
 486        flags |= FAULT_FLAG_REMOTE;
 487
 488        mmap_read_lock(mm);
 489        vma = find_extend_vma(mm, address);
 490        if (!vma || address < vma->vm_start)
 491                /* failed to get a vma in the right range */
 492                goto out;
 493
 494        /* Check if we have the right permissions on the vma */
 495        if (access_error(vma, fault))
 496                goto out;
 497
 498        ret = handle_mm_fault(vma, address, flags, NULL);
 499out:
 500        mmap_read_unlock(mm);
 501
 502        if (ret & VM_FAULT_ERROR)
 503                /* failed to service fault */
 504                handle_fault_error(fault);
 505
 506        finish_pri_tag(fault->dev_state, fault->state, fault->tag);
 507
 508        put_pasid_state(fault->state);
 509
 510        kfree(fault);
 511}
 512
 513static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
 514{
 515        struct amd_iommu_fault *iommu_fault;
 516        struct pasid_state *pasid_state;
 517        struct device_state *dev_state;
 518        struct pci_dev *pdev = NULL;
 519        unsigned long flags;
 520        struct fault *fault;
 521        bool finish;
 522        u16 tag, devid;
 523        int ret;
 524
 525        iommu_fault = data;
 526        tag         = iommu_fault->tag & 0x1ff;
 527        finish      = (iommu_fault->tag >> 9) & 1;
 528
 529        devid = iommu_fault->device_id;
 530        pdev = pci_get_domain_bus_and_slot(0, PCI_BUS_NUM(devid),
 531                                           devid & 0xff);
 532        if (!pdev)
 533                return -ENODEV;
 534
 535        ret = NOTIFY_DONE;
 536
 537        /* In kdump kernel pci dev is not initialized yet -> send INVALID */
 538        if (amd_iommu_is_attach_deferred(NULL, &pdev->dev)) {
 539                amd_iommu_complete_ppr(pdev, iommu_fault->pasid,
 540                                       PPR_INVALID, tag);
 541                goto out;
 542        }
 543
 544        dev_state = get_device_state(iommu_fault->device_id);
 545        if (dev_state == NULL)
 546                goto out;
 547
 548        pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
 549        if (pasid_state == NULL || pasid_state->invalid) {
 550                /* We know the device but not the PASID -> send INVALID */
 551                amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
 552                                       PPR_INVALID, tag);
 553                goto out_drop_state;
 554        }
 555
 556        spin_lock_irqsave(&pasid_state->lock, flags);
 557        atomic_inc(&pasid_state->pri[tag].inflight);
 558        if (finish)
 559                pasid_state->pri[tag].finish = true;
 560        spin_unlock_irqrestore(&pasid_state->lock, flags);
 561
 562        fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
 563        if (fault == NULL) {
 564                /* We are OOM - send success and let the device re-fault */
 565                finish_pri_tag(dev_state, pasid_state, tag);
 566                goto out_drop_state;
 567        }
 568
 569        fault->dev_state = dev_state;
 570        fault->address   = iommu_fault->address;
 571        fault->state     = pasid_state;
 572        fault->tag       = tag;
 573        fault->finish    = finish;
 574        fault->pasid     = iommu_fault->pasid;
 575        fault->flags     = iommu_fault->flags;
 576        INIT_WORK(&fault->work, do_fault);
 577
 578        queue_work(iommu_wq, &fault->work);
 579
 580        ret = NOTIFY_OK;
 581
 582out_drop_state:
 583
 584        if (ret != NOTIFY_OK && pasid_state)
 585                put_pasid_state(pasid_state);
 586
 587        put_device_state(dev_state);
 588
 589out:
 590        return ret;
 591}
 592
 593static struct notifier_block ppr_nb = {
 594        .notifier_call = ppr_notifier,
 595};
 596
 597int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid,
 598                         struct task_struct *task)
 599{
 600        struct pasid_state *pasid_state;
 601        struct device_state *dev_state;
 602        struct mm_struct *mm;
 603        u16 devid;
 604        int ret;
 605
 606        might_sleep();
 607
 608        if (!amd_iommu_v2_supported())
 609                return -ENODEV;
 610
 611        devid     = device_id(pdev);
 612        dev_state = get_device_state(devid);
 613
 614        if (dev_state == NULL)
 615                return -EINVAL;
 616
 617        ret = -EINVAL;
 618        if (pasid >= dev_state->max_pasids)
 619                goto out;
 620
 621        ret = -ENOMEM;
 622        pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
 623        if (pasid_state == NULL)
 624                goto out;
 625
 626
 627        atomic_set(&pasid_state->count, 1);
 628        init_waitqueue_head(&pasid_state->wq);
 629        spin_lock_init(&pasid_state->lock);
 630
 631        mm                        = get_task_mm(task);
 632        pasid_state->mm           = mm;
 633        pasid_state->device_state = dev_state;
 634        pasid_state->pasid        = pasid;
 635        pasid_state->invalid      = true; /* Mark as valid only if we are
 636                                             done with setting up the pasid */
 637        pasid_state->mn.ops       = &iommu_mn;
 638
 639        if (pasid_state->mm == NULL)
 640                goto out_free;
 641
 642        mmu_notifier_register(&pasid_state->mn, mm);
 643
 644        ret = set_pasid_state(dev_state, pasid_state, pasid);
 645        if (ret)
 646                goto out_unregister;
 647
 648        ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
 649                                        __pa(pasid_state->mm->pgd));
 650        if (ret)
 651                goto out_clear_state;
 652
 653        /* Now we are ready to handle faults */
 654        pasid_state->invalid = false;
 655
 656        /*
 657         * Drop the reference to the mm_struct here. We rely on the
 658         * mmu_notifier release call-back to inform us when the mm
 659         * is going away.
 660         */
 661        mmput(mm);
 662
 663        return 0;
 664
 665out_clear_state:
 666        clear_pasid_state(dev_state, pasid);
 667
 668out_unregister:
 669        mmu_notifier_unregister(&pasid_state->mn, mm);
 670        mmput(mm);
 671
 672out_free:
 673        free_pasid_state(pasid_state);
 674
 675out:
 676        put_device_state(dev_state);
 677
 678        return ret;
 679}
 680EXPORT_SYMBOL(amd_iommu_bind_pasid);
 681
 682void amd_iommu_unbind_pasid(struct pci_dev *pdev, u32 pasid)
 683{
 684        struct pasid_state *pasid_state;
 685        struct device_state *dev_state;
 686        u16 devid;
 687
 688        might_sleep();
 689
 690        if (!amd_iommu_v2_supported())
 691                return;
 692
 693        devid = device_id(pdev);
 694        dev_state = get_device_state(devid);
 695        if (dev_state == NULL)
 696                return;
 697
 698        if (pasid >= dev_state->max_pasids)
 699                goto out;
 700
 701        pasid_state = get_pasid_state(dev_state, pasid);
 702        if (pasid_state == NULL)
 703                goto out;
 704        /*
 705         * Drop reference taken here. We are safe because we still hold
 706         * the reference taken in the amd_iommu_bind_pasid function.
 707         */
 708        put_pasid_state(pasid_state);
 709
 710        /* Clear the pasid state so that the pasid can be re-used */
 711        clear_pasid_state(dev_state, pasid_state->pasid);
 712
 713        /*
 714         * Call mmu_notifier_unregister to drop our reference
 715         * to pasid_state->mm
 716         */
 717        mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
 718
 719        put_pasid_state_wait(pasid_state); /* Reference taken in
 720                                              amd_iommu_bind_pasid */
 721out:
 722        /* Drop reference taken in this function */
 723        put_device_state(dev_state);
 724
 725        /* Drop reference taken in amd_iommu_bind_pasid */
 726        put_device_state(dev_state);
 727}
 728EXPORT_SYMBOL(amd_iommu_unbind_pasid);
 729
 730int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
 731{
 732        struct device_state *dev_state;
 733        struct iommu_group *group;
 734        unsigned long flags;
 735        int ret, tmp;
 736        u16 devid;
 737
 738        might_sleep();
 739
 740        /*
 741         * When memory encryption is active the device is likely not in a
 742         * direct-mapped domain. Forbid using IOMMUv2 functionality for now.
 743         */
 744        if (mem_encrypt_active())
 745                return -ENODEV;
 746
 747        if (!amd_iommu_v2_supported())
 748                return -ENODEV;
 749
 750        if (pasids <= 0 || pasids > (PASID_MASK + 1))
 751                return -EINVAL;
 752
 753        devid = device_id(pdev);
 754
 755        dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
 756        if (dev_state == NULL)
 757                return -ENOMEM;
 758
 759        spin_lock_init(&dev_state->lock);
 760        init_waitqueue_head(&dev_state->wq);
 761        dev_state->pdev  = pdev;
 762        dev_state->devid = devid;
 763
 764        tmp = pasids;
 765        for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
 766                dev_state->pasid_levels += 1;
 767
 768        atomic_set(&dev_state->count, 1);
 769        dev_state->max_pasids = pasids;
 770
 771        ret = -ENOMEM;
 772        dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
 773        if (dev_state->states == NULL)
 774                goto out_free_dev_state;
 775
 776        dev_state->domain = iommu_domain_alloc(&pci_bus_type);
 777        if (dev_state->domain == NULL)
 778                goto out_free_states;
 779
 780        amd_iommu_domain_direct_map(dev_state->domain);
 781
 782        ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
 783        if (ret)
 784                goto out_free_domain;
 785
 786        group = iommu_group_get(&pdev->dev);
 787        if (!group) {
 788                ret = -EINVAL;
 789                goto out_free_domain;
 790        }
 791
 792        ret = iommu_attach_group(dev_state->domain, group);
 793        if (ret != 0)
 794                goto out_drop_group;
 795
 796        iommu_group_put(group);
 797
 798        spin_lock_irqsave(&state_lock, flags);
 799
 800        if (__get_device_state(devid) != NULL) {
 801                spin_unlock_irqrestore(&state_lock, flags);
 802                ret = -EBUSY;
 803                goto out_free_domain;
 804        }
 805
 806        list_add_tail(&dev_state->list, &state_list);
 807
 808        spin_unlock_irqrestore(&state_lock, flags);
 809
 810        return 0;
 811
 812out_drop_group:
 813        iommu_group_put(group);
 814
 815out_free_domain:
 816        iommu_domain_free(dev_state->domain);
 817
 818out_free_states:
 819        free_page((unsigned long)dev_state->states);
 820
 821out_free_dev_state:
 822        kfree(dev_state);
 823
 824        return ret;
 825}
 826EXPORT_SYMBOL(amd_iommu_init_device);
 827
 828void amd_iommu_free_device(struct pci_dev *pdev)
 829{
 830        struct device_state *dev_state;
 831        unsigned long flags;
 832        u16 devid;
 833
 834        if (!amd_iommu_v2_supported())
 835                return;
 836
 837        devid = device_id(pdev);
 838
 839        spin_lock_irqsave(&state_lock, flags);
 840
 841        dev_state = __get_device_state(devid);
 842        if (dev_state == NULL) {
 843                spin_unlock_irqrestore(&state_lock, flags);
 844                return;
 845        }
 846
 847        list_del(&dev_state->list);
 848
 849        spin_unlock_irqrestore(&state_lock, flags);
 850
 851        /* Get rid of any remaining pasid states */
 852        free_pasid_states(dev_state);
 853
 854        put_device_state(dev_state);
 855        /*
 856         * Wait until the last reference is dropped before freeing
 857         * the device state.
 858         */
 859        wait_event(dev_state->wq, !atomic_read(&dev_state->count));
 860        free_device_state(dev_state);
 861}
 862EXPORT_SYMBOL(amd_iommu_free_device);
 863
 864int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
 865                                 amd_iommu_invalid_ppr_cb cb)
 866{
 867        struct device_state *dev_state;
 868        unsigned long flags;
 869        u16 devid;
 870        int ret;
 871
 872        if (!amd_iommu_v2_supported())
 873                return -ENODEV;
 874
 875        devid = device_id(pdev);
 876
 877        spin_lock_irqsave(&state_lock, flags);
 878
 879        ret = -EINVAL;
 880        dev_state = __get_device_state(devid);
 881        if (dev_state == NULL)
 882                goto out_unlock;
 883
 884        dev_state->inv_ppr_cb = cb;
 885
 886        ret = 0;
 887
 888out_unlock:
 889        spin_unlock_irqrestore(&state_lock, flags);
 890
 891        return ret;
 892}
 893EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
 894
 895int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
 896                                    amd_iommu_invalidate_ctx cb)
 897{
 898        struct device_state *dev_state;
 899        unsigned long flags;
 900        u16 devid;
 901        int ret;
 902
 903        if (!amd_iommu_v2_supported())
 904                return -ENODEV;
 905
 906        devid = device_id(pdev);
 907
 908        spin_lock_irqsave(&state_lock, flags);
 909
 910        ret = -EINVAL;
 911        dev_state = __get_device_state(devid);
 912        if (dev_state == NULL)
 913                goto out_unlock;
 914
 915        dev_state->inv_ctx_cb = cb;
 916
 917        ret = 0;
 918
 919out_unlock:
 920        spin_unlock_irqrestore(&state_lock, flags);
 921
 922        return ret;
 923}
 924EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
 925
 926static int __init amd_iommu_v2_init(void)
 927{
 928        int ret;
 929
 930        pr_info("AMD IOMMUv2 driver by Joerg Roedel <jroedel@suse.de>\n");
 931
 932        if (!amd_iommu_v2_supported()) {
 933                pr_info("AMD IOMMUv2 functionality not available on this system\n");
 934                /*
 935                 * Load anyway to provide the symbols to other modules
 936                 * which may use AMD IOMMUv2 optionally.
 937                 */
 938                return 0;
 939        }
 940
 941        ret = -ENOMEM;
 942        iommu_wq = alloc_workqueue("amd_iommu_v2", WQ_MEM_RECLAIM, 0);
 943        if (iommu_wq == NULL)
 944                goto out;
 945
 946        amd_iommu_register_ppr_notifier(&ppr_nb);
 947
 948        return 0;
 949
 950out:
 951        return ret;
 952}
 953
 954static void __exit amd_iommu_v2_exit(void)
 955{
 956        struct device_state *dev_state;
 957        int i;
 958
 959        if (!amd_iommu_v2_supported())
 960                return;
 961
 962        amd_iommu_unregister_ppr_notifier(&ppr_nb);
 963
 964        flush_workqueue(iommu_wq);
 965
 966        /*
 967         * The loop below might call flush_workqueue(), so call
 968         * destroy_workqueue() after it
 969         */
 970        for (i = 0; i < MAX_DEVICES; ++i) {
 971                dev_state = get_device_state(i);
 972
 973                if (dev_state == NULL)
 974                        continue;
 975
 976                WARN_ON_ONCE(1);
 977
 978                put_device_state(dev_state);
 979                amd_iommu_free_device(dev_state->pdev);
 980        }
 981
 982        destroy_workqueue(iommu_wq);
 983}
 984
 985module_init(amd_iommu_v2_init);
 986module_exit(amd_iommu_v2_exit);
 987