linux/drivers/iommu/amd_iommu_v2.c
<<
>>
Prefs
   1/*
   2 * Copyright (C) 2010-2012 Advanced Micro Devices, Inc.
   3 * Author: Joerg Roedel <joerg.roedel@amd.com>
   4 *
   5 * This program is free software; you can redistribute it and/or modify it
   6 * under the terms of the GNU General Public License version 2 as published
   7 * by the Free Software Foundation.
   8 *
   9 * This program is distributed in the hope that it will be useful,
  10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  12 * GNU General Public License for more details.
  13 *
  14 * You should have received a copy of the GNU General Public License
  15 * along with this program; if not, write to the Free Software
  16 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
  17 */
  18
  19#include <linux/mmu_notifier.h>
  20#include <linux/amd-iommu.h>
  21#include <linux/mm_types.h>
  22#include <linux/profile.h>
  23#include <linux/module.h>
  24#include <linux/sched.h>
  25#include <linux/iommu.h>
  26#include <linux/wait.h>
  27#include <linux/pci.h>
  28#include <linux/gfp.h>
  29
  30#include "amd_iommu_types.h"
  31#include "amd_iommu_proto.h"
  32
  33MODULE_LICENSE("GPL v2");
  34MODULE_AUTHOR("Joerg Roedel <joerg.roedel@amd.com>");
  35
  36#define MAX_DEVICES             0x10000
  37#define PRI_QUEUE_SIZE          512
  38
  39struct pri_queue {
  40        atomic_t inflight;
  41        bool finish;
  42        int status;
  43};
  44
  45struct pasid_state {
  46        struct list_head list;                  /* For global state-list */
  47        atomic_t count;                         /* Reference count */
  48        unsigned mmu_notifier_count;            /* Counting nested mmu_notifier
  49                                                   calls */
  50        struct mm_struct *mm;                   /* mm_struct for the faults */
  51        struct mmu_notifier mn;                 /* mmu_notifier handle */
  52        struct pri_queue pri[PRI_QUEUE_SIZE];   /* PRI tag states */
  53        struct device_state *device_state;      /* Link to our device_state */
  54        int pasid;                              /* PASID index */
  55        bool invalid;                           /* Used during setup and
  56                                                   teardown of the pasid */
  57        spinlock_t lock;                        /* Protect pri_queues and
  58                                                   mmu_notifer_count */
  59        wait_queue_head_t wq;                   /* To wait for count == 0 */
  60};
  61
  62struct device_state {
  63        struct list_head list;
  64        u16 devid;
  65        atomic_t count;
  66        struct pci_dev *pdev;
  67        struct pasid_state **states;
  68        struct iommu_domain *domain;
  69        int pasid_levels;
  70        int max_pasids;
  71        amd_iommu_invalid_ppr_cb inv_ppr_cb;
  72        amd_iommu_invalidate_ctx inv_ctx_cb;
  73        spinlock_t lock;
  74        wait_queue_head_t wq;
  75};
  76
  77struct fault {
  78        struct work_struct work;
  79        struct device_state *dev_state;
  80        struct pasid_state *state;
  81        struct mm_struct *mm;
  82        u64 address;
  83        u16 devid;
  84        u16 pasid;
  85        u16 tag;
  86        u16 finish;
  87        u16 flags;
  88};
  89
  90static LIST_HEAD(state_list);
  91static spinlock_t state_lock;
  92
  93static struct workqueue_struct *iommu_wq;
  94
  95/*
  96 * Empty page table - Used between
  97 * mmu_notifier_invalidate_range_start and
  98 * mmu_notifier_invalidate_range_end
  99 */
 100static u64 *empty_page_table;
 101
 102static void free_pasid_states(struct device_state *dev_state);
 103
 104static u16 device_id(struct pci_dev *pdev)
 105{
 106        u16 devid;
 107
 108        devid = pdev->bus->number;
 109        devid = (devid << 8) | pdev->devfn;
 110
 111        return devid;
 112}
 113
 114static struct device_state *__get_device_state(u16 devid)
 115{
 116        struct device_state *dev_state;
 117
 118        list_for_each_entry(dev_state, &state_list, list) {
 119                if (dev_state->devid == devid)
 120                        return dev_state;
 121        }
 122
 123        return NULL;
 124}
 125
 126static struct device_state *get_device_state(u16 devid)
 127{
 128        struct device_state *dev_state;
 129        unsigned long flags;
 130
 131        spin_lock_irqsave(&state_lock, flags);
 132        dev_state = __get_device_state(devid);
 133        if (dev_state != NULL)
 134                atomic_inc(&dev_state->count);
 135        spin_unlock_irqrestore(&state_lock, flags);
 136
 137        return dev_state;
 138}
 139
 140static void free_device_state(struct device_state *dev_state)
 141{
 142        /*
 143         * First detach device from domain - No more PRI requests will arrive
 144         * from that device after it is unbound from the IOMMUv2 domain.
 145         */
 146        iommu_detach_device(dev_state->domain, &dev_state->pdev->dev);
 147
 148        /* Everything is down now, free the IOMMUv2 domain */
 149        iommu_domain_free(dev_state->domain);
 150
 151        /* Finally get rid of the device-state */
 152        kfree(dev_state);
 153}
 154
 155static void put_device_state(struct device_state *dev_state)
 156{
 157        if (atomic_dec_and_test(&dev_state->count))
 158                wake_up(&dev_state->wq);
 159}
 160
 161static void put_device_state_wait(struct device_state *dev_state)
 162{
 163        DEFINE_WAIT(wait);
 164
 165        prepare_to_wait(&dev_state->wq, &wait, TASK_UNINTERRUPTIBLE);
 166        if (!atomic_dec_and_test(&dev_state->count))
 167                schedule();
 168        finish_wait(&dev_state->wq, &wait);
 169
 170        free_device_state(dev_state);
 171}
 172
 173/* Must be called under dev_state->lock */
 174static struct pasid_state **__get_pasid_state_ptr(struct device_state *dev_state,
 175                                                  int pasid, bool alloc)
 176{
 177        struct pasid_state **root, **ptr;
 178        int level, index;
 179
 180        level = dev_state->pasid_levels;
 181        root  = dev_state->states;
 182
 183        while (true) {
 184
 185                index = (pasid >> (9 * level)) & 0x1ff;
 186                ptr   = &root[index];
 187
 188                if (level == 0)
 189                        break;
 190
 191                if (*ptr == NULL) {
 192                        if (!alloc)
 193                                return NULL;
 194
 195                        *ptr = (void *)get_zeroed_page(GFP_ATOMIC);
 196                        if (*ptr == NULL)
 197                                return NULL;
 198                }
 199
 200                root   = (struct pasid_state **)*ptr;
 201                level -= 1;
 202        }
 203
 204        return ptr;
 205}
 206
 207static int set_pasid_state(struct device_state *dev_state,
 208                           struct pasid_state *pasid_state,
 209                           int pasid)
 210{
 211        struct pasid_state **ptr;
 212        unsigned long flags;
 213        int ret;
 214
 215        spin_lock_irqsave(&dev_state->lock, flags);
 216        ptr = __get_pasid_state_ptr(dev_state, pasid, true);
 217
 218        ret = -ENOMEM;
 219        if (ptr == NULL)
 220                goto out_unlock;
 221
 222        ret = -ENOMEM;
 223        if (*ptr != NULL)
 224                goto out_unlock;
 225
 226        *ptr = pasid_state;
 227
 228        ret = 0;
 229
 230out_unlock:
 231        spin_unlock_irqrestore(&dev_state->lock, flags);
 232
 233        return ret;
 234}
 235
 236static void clear_pasid_state(struct device_state *dev_state, int pasid)
 237{
 238        struct pasid_state **ptr;
 239        unsigned long flags;
 240
 241        spin_lock_irqsave(&dev_state->lock, flags);
 242        ptr = __get_pasid_state_ptr(dev_state, pasid, true);
 243
 244        if (ptr == NULL)
 245                goto out_unlock;
 246
 247        *ptr = NULL;
 248
 249out_unlock:
 250        spin_unlock_irqrestore(&dev_state->lock, flags);
 251}
 252
 253static struct pasid_state *get_pasid_state(struct device_state *dev_state,
 254                                           int pasid)
 255{
 256        struct pasid_state **ptr, *ret = NULL;
 257        unsigned long flags;
 258
 259        spin_lock_irqsave(&dev_state->lock, flags);
 260        ptr = __get_pasid_state_ptr(dev_state, pasid, false);
 261
 262        if (ptr == NULL)
 263                goto out_unlock;
 264
 265        ret = *ptr;
 266        if (ret)
 267                atomic_inc(&ret->count);
 268
 269out_unlock:
 270        spin_unlock_irqrestore(&dev_state->lock, flags);
 271
 272        return ret;
 273}
 274
 275static void free_pasid_state(struct pasid_state *pasid_state)
 276{
 277        kfree(pasid_state);
 278}
 279
 280static void put_pasid_state(struct pasid_state *pasid_state)
 281{
 282        if (atomic_dec_and_test(&pasid_state->count)) {
 283                put_device_state(pasid_state->device_state);
 284                wake_up(&pasid_state->wq);
 285        }
 286}
 287
 288static void put_pasid_state_wait(struct pasid_state *pasid_state)
 289{
 290        DEFINE_WAIT(wait);
 291
 292        prepare_to_wait(&pasid_state->wq, &wait, TASK_UNINTERRUPTIBLE);
 293
 294        if (atomic_dec_and_test(&pasid_state->count))
 295                put_device_state(pasid_state->device_state);
 296        else
 297                schedule();
 298
 299        finish_wait(&pasid_state->wq, &wait);
 300        free_pasid_state(pasid_state);
 301}
 302
 303static void unbind_pasid(struct pasid_state *pasid_state)
 304{
 305        struct iommu_domain *domain;
 306
 307        domain = pasid_state->device_state->domain;
 308
 309        /*
 310         * Mark pasid_state as invalid, no more faults will we added to the
 311         * work queue after this is visible everywhere.
 312         */
 313        pasid_state->invalid = true;
 314
 315        /* Make sure this is visible */
 316        smp_wmb();
 317
 318        /* After this the device/pasid can't access the mm anymore */
 319        amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
 320
 321        /* Make sure no more pending faults are in the queue */
 322        flush_workqueue(iommu_wq);
 323}
 324
 325static void free_pasid_states_level1(struct pasid_state **tbl)
 326{
 327        int i;
 328
 329        for (i = 0; i < 512; ++i) {
 330                if (tbl[i] == NULL)
 331                        continue;
 332
 333                free_page((unsigned long)tbl[i]);
 334        }
 335}
 336
 337static void free_pasid_states_level2(struct pasid_state **tbl)
 338{
 339        struct pasid_state **ptr;
 340        int i;
 341
 342        for (i = 0; i < 512; ++i) {
 343                if (tbl[i] == NULL)
 344                        continue;
 345
 346                ptr = (struct pasid_state **)tbl[i];
 347                free_pasid_states_level1(ptr);
 348        }
 349}
 350
 351static void free_pasid_states(struct device_state *dev_state)
 352{
 353        struct pasid_state *pasid_state;
 354        int i;
 355
 356        for (i = 0; i < dev_state->max_pasids; ++i) {
 357                pasid_state = get_pasid_state(dev_state, i);
 358                if (pasid_state == NULL)
 359                        continue;
 360
 361                put_pasid_state(pasid_state);
 362
 363                /*
 364                 * This will call the mn_release function and
 365                 * unbind the PASID
 366                 */
 367                mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
 368
 369                put_pasid_state_wait(pasid_state); /* Reference taken in
 370                                                      amd_iommu_bind_pasid */
 371
 372                /* Drop reference taken in amd_iommu_bind_pasid */
 373                put_device_state(dev_state);
 374        }
 375
 376        if (dev_state->pasid_levels == 2)
 377                free_pasid_states_level2(dev_state->states);
 378        else if (dev_state->pasid_levels == 1)
 379                free_pasid_states_level1(dev_state->states);
 380        else if (dev_state->pasid_levels != 0)
 381                BUG();
 382
 383        free_page((unsigned long)dev_state->states);
 384}
 385
 386static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
 387{
 388        return container_of(mn, struct pasid_state, mn);
 389}
 390
 391static void __mn_flush_page(struct mmu_notifier *mn,
 392                            unsigned long address)
 393{
 394        struct pasid_state *pasid_state;
 395        struct device_state *dev_state;
 396
 397        pasid_state = mn_to_state(mn);
 398        dev_state   = pasid_state->device_state;
 399
 400        amd_iommu_flush_page(dev_state->domain, pasid_state->pasid, address);
 401}
 402
 403static int mn_clear_flush_young(struct mmu_notifier *mn,
 404                                struct mm_struct *mm,
 405                                unsigned long address)
 406{
 407        __mn_flush_page(mn, address);
 408
 409        return 0;
 410}
 411
 412static void mn_invalidate_page(struct mmu_notifier *mn,
 413                               struct mm_struct *mm,
 414                               unsigned long address)
 415{
 416        __mn_flush_page(mn, address);
 417}
 418
 419static void mn_invalidate_range_start(struct mmu_notifier *mn,
 420                                      struct mm_struct *mm,
 421                                      unsigned long start, unsigned long end)
 422{
 423        struct pasid_state *pasid_state;
 424        struct device_state *dev_state;
 425        unsigned long flags;
 426
 427        pasid_state = mn_to_state(mn);
 428        dev_state   = pasid_state->device_state;
 429
 430        spin_lock_irqsave(&pasid_state->lock, flags);
 431        if (pasid_state->mmu_notifier_count == 0) {
 432                amd_iommu_domain_set_gcr3(dev_state->domain,
 433                                          pasid_state->pasid,
 434                                          __pa(empty_page_table));
 435        }
 436        pasid_state->mmu_notifier_count += 1;
 437        spin_unlock_irqrestore(&pasid_state->lock, flags);
 438}
 439
 440static void mn_invalidate_range_end(struct mmu_notifier *mn,
 441                                    struct mm_struct *mm,
 442                                    unsigned long start, unsigned long end)
 443{
 444        struct pasid_state *pasid_state;
 445        struct device_state *dev_state;
 446        unsigned long flags;
 447
 448        pasid_state = mn_to_state(mn);
 449        dev_state   = pasid_state->device_state;
 450
 451        spin_lock_irqsave(&pasid_state->lock, flags);
 452        pasid_state->mmu_notifier_count -= 1;
 453        if (pasid_state->mmu_notifier_count == 0) {
 454                amd_iommu_domain_set_gcr3(dev_state->domain,
 455                                          pasid_state->pasid,
 456                                          __pa(pasid_state->mm->pgd));
 457        }
 458        spin_unlock_irqrestore(&pasid_state->lock, flags);
 459}
 460
 461static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
 462{
 463        struct pasid_state *pasid_state;
 464        struct device_state *dev_state;
 465        bool run_inv_ctx_cb;
 466
 467        might_sleep();
 468
 469        pasid_state    = mn_to_state(mn);
 470        dev_state      = pasid_state->device_state;
 471        run_inv_ctx_cb = !pasid_state->invalid;
 472
 473        if (run_inv_ctx_cb && pasid_state->device_state->inv_ctx_cb)
 474                dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
 475
 476        unbind_pasid(pasid_state);
 477}
 478
 479static struct mmu_notifier_ops iommu_mn = {
 480        .release                = mn_release,
 481        .clear_flush_young      = mn_clear_flush_young,
 482        .invalidate_page        = mn_invalidate_page,
 483        .invalidate_range_start = mn_invalidate_range_start,
 484        .invalidate_range_end   = mn_invalidate_range_end,
 485};
 486
 487static void set_pri_tag_status(struct pasid_state *pasid_state,
 488                               u16 tag, int status)
 489{
 490        unsigned long flags;
 491
 492        spin_lock_irqsave(&pasid_state->lock, flags);
 493        pasid_state->pri[tag].status = status;
 494        spin_unlock_irqrestore(&pasid_state->lock, flags);
 495}
 496
 497static void finish_pri_tag(struct device_state *dev_state,
 498                           struct pasid_state *pasid_state,
 499                           u16 tag)
 500{
 501        unsigned long flags;
 502
 503        spin_lock_irqsave(&pasid_state->lock, flags);
 504        if (atomic_dec_and_test(&pasid_state->pri[tag].inflight) &&
 505            pasid_state->pri[tag].finish) {
 506                amd_iommu_complete_ppr(dev_state->pdev, pasid_state->pasid,
 507                                       pasid_state->pri[tag].status, tag);
 508                pasid_state->pri[tag].finish = false;
 509                pasid_state->pri[tag].status = PPR_SUCCESS;
 510        }
 511        spin_unlock_irqrestore(&pasid_state->lock, flags);
 512}
 513
 514static void do_fault(struct work_struct *work)
 515{
 516        struct fault *fault = container_of(work, struct fault, work);
 517        int npages, write;
 518        struct page *page;
 519
 520        write = !!(fault->flags & PPR_FAULT_WRITE);
 521
 522        down_read(&fault->state->mm->mmap_sem);
 523        npages = get_user_pages(NULL, fault->state->mm,
 524                                fault->address, 1, write, 0, &page, NULL);
 525        up_read(&fault->state->mm->mmap_sem);
 526
 527        if (npages == 1) {
 528                put_page(page);
 529        } else if (fault->dev_state->inv_ppr_cb) {
 530                int status;
 531
 532                status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
 533                                                      fault->pasid,
 534                                                      fault->address,
 535                                                      fault->flags);
 536                switch (status) {
 537                case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
 538                        set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
 539                        break;
 540                case AMD_IOMMU_INV_PRI_RSP_INVALID:
 541                        set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
 542                        break;
 543                case AMD_IOMMU_INV_PRI_RSP_FAIL:
 544                        set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
 545                        break;
 546                default:
 547                        BUG();
 548                }
 549        } else {
 550                set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
 551        }
 552
 553        finish_pri_tag(fault->dev_state, fault->state, fault->tag);
 554
 555        put_pasid_state(fault->state);
 556
 557        kfree(fault);
 558}
 559
 560static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
 561{
 562        struct amd_iommu_fault *iommu_fault;
 563        struct pasid_state *pasid_state;
 564        struct device_state *dev_state;
 565        unsigned long flags;
 566        struct fault *fault;
 567        bool finish;
 568        u16 tag;
 569        int ret;
 570
 571        iommu_fault = data;
 572        tag         = iommu_fault->tag & 0x1ff;
 573        finish      = (iommu_fault->tag >> 9) & 1;
 574
 575        ret = NOTIFY_DONE;
 576        dev_state = get_device_state(iommu_fault->device_id);
 577        if (dev_state == NULL)
 578                goto out;
 579
 580        pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
 581        if (pasid_state == NULL || pasid_state->invalid) {
 582                /* We know the device but not the PASID -> send INVALID */
 583                amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
 584                                       PPR_INVALID, tag);
 585                goto out_drop_state;
 586        }
 587
 588        spin_lock_irqsave(&pasid_state->lock, flags);
 589        atomic_inc(&pasid_state->pri[tag].inflight);
 590        if (finish)
 591                pasid_state->pri[tag].finish = true;
 592        spin_unlock_irqrestore(&pasid_state->lock, flags);
 593
 594        fault = kzalloc(sizeof(*fault), GFP_ATOMIC);
 595        if (fault == NULL) {
 596                /* We are OOM - send success and let the device re-fault */
 597                finish_pri_tag(dev_state, pasid_state, tag);
 598                goto out_drop_state;
 599        }
 600
 601        fault->dev_state = dev_state;
 602        fault->address   = iommu_fault->address;
 603        fault->state     = pasid_state;
 604        fault->tag       = tag;
 605        fault->finish    = finish;
 606        fault->pasid     = iommu_fault->pasid;
 607        fault->flags     = iommu_fault->flags;
 608        INIT_WORK(&fault->work, do_fault);
 609
 610        queue_work(iommu_wq, &fault->work);
 611
 612        ret = NOTIFY_OK;
 613
 614out_drop_state:
 615
 616        if (ret != NOTIFY_OK && pasid_state)
 617                put_pasid_state(pasid_state);
 618
 619        put_device_state(dev_state);
 620
 621out:
 622        return ret;
 623}
 624
 625static struct notifier_block ppr_nb = {
 626        .notifier_call = ppr_notifier,
 627};
 628
 629int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
 630                         struct task_struct *task)
 631{
 632        struct pasid_state *pasid_state;
 633        struct device_state *dev_state;
 634        struct mm_struct *mm;
 635        u16 devid;
 636        int ret;
 637
 638        might_sleep();
 639
 640        if (!amd_iommu_v2_supported())
 641                return -ENODEV;
 642
 643        devid     = device_id(pdev);
 644        dev_state = get_device_state(devid);
 645
 646        if (dev_state == NULL)
 647                return -EINVAL;
 648
 649        ret = -EINVAL;
 650        if (pasid < 0 || pasid >= dev_state->max_pasids)
 651                goto out;
 652
 653        ret = -ENOMEM;
 654        pasid_state = kzalloc(sizeof(*pasid_state), GFP_KERNEL);
 655        if (pasid_state == NULL)
 656                goto out;
 657
 658
 659        atomic_set(&pasid_state->count, 1);
 660        init_waitqueue_head(&pasid_state->wq);
 661        spin_lock_init(&pasid_state->lock);
 662
 663        mm                        = get_task_mm(task);
 664        pasid_state->mm           = mm;
 665        pasid_state->device_state = dev_state;
 666        pasid_state->pasid        = pasid;
 667        pasid_state->invalid      = true; /* Mark as valid only if we are
 668                                             done with setting up the pasid */
 669        pasid_state->mn.ops       = &iommu_mn;
 670
 671        if (pasid_state->mm == NULL)
 672                goto out_free;
 673
 674        mmu_notifier_register(&pasid_state->mn, mm);
 675
 676        ret = set_pasid_state(dev_state, pasid_state, pasid);
 677        if (ret)
 678                goto out_unregister;
 679
 680        ret = amd_iommu_domain_set_gcr3(dev_state->domain, pasid,
 681                                        __pa(pasid_state->mm->pgd));
 682        if (ret)
 683                goto out_clear_state;
 684
 685        /* Now we are ready to handle faults */
 686        pasid_state->invalid = false;
 687
 688        /*
 689         * Drop the reference to the mm_struct here. We rely on the
 690         * mmu_notifier release call-back to inform us when the mm
 691         * is going away.
 692         */
 693        mmput(mm);
 694
 695        return 0;
 696
 697out_clear_state:
 698        clear_pasid_state(dev_state, pasid);
 699
 700out_unregister:
 701        mmu_notifier_unregister(&pasid_state->mn, mm);
 702
 703out_free:
 704        mmput(mm);
 705        free_pasid_state(pasid_state);
 706
 707out:
 708        put_device_state(dev_state);
 709
 710        return ret;
 711}
 712EXPORT_SYMBOL(amd_iommu_bind_pasid);
 713
 714void amd_iommu_unbind_pasid(struct pci_dev *pdev, int pasid)
 715{
 716        struct pasid_state *pasid_state;
 717        struct device_state *dev_state;
 718        u16 devid;
 719
 720        might_sleep();
 721
 722        if (!amd_iommu_v2_supported())
 723                return;
 724
 725        devid = device_id(pdev);
 726        dev_state = get_device_state(devid);
 727        if (dev_state == NULL)
 728                return;
 729
 730        if (pasid < 0 || pasid >= dev_state->max_pasids)
 731                goto out;
 732
 733        pasid_state = get_pasid_state(dev_state, pasid);
 734        if (pasid_state == NULL)
 735                goto out;
 736        /*
 737         * Drop reference taken here. We are safe because we still hold
 738         * the reference taken in the amd_iommu_bind_pasid function.
 739         */
 740        put_pasid_state(pasid_state);
 741
 742        /* Clear the pasid state so that the pasid can be re-used */
 743        clear_pasid_state(dev_state, pasid_state->pasid);
 744
 745        /*
 746         * Call mmu_notifier_unregister to drop our reference
 747         * to pasid_state->mm
 748         */
 749        mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
 750
 751        put_pasid_state_wait(pasid_state); /* Reference taken in
 752                                              amd_iommu_bind_pasid */
 753out:
 754        /* Drop reference taken in this function */
 755        put_device_state(dev_state);
 756
 757        /* Drop reference taken in amd_iommu_bind_pasid */
 758        put_device_state(dev_state);
 759}
 760EXPORT_SYMBOL(amd_iommu_unbind_pasid);
 761
 762int amd_iommu_init_device(struct pci_dev *pdev, int pasids)
 763{
 764        struct device_state *dev_state;
 765        unsigned long flags;
 766        int ret, tmp;
 767        u16 devid;
 768
 769        might_sleep();
 770
 771        if (!amd_iommu_v2_supported())
 772                return -ENODEV;
 773
 774        if (pasids <= 0 || pasids > (PASID_MASK + 1))
 775                return -EINVAL;
 776
 777        devid = device_id(pdev);
 778
 779        dev_state = kzalloc(sizeof(*dev_state), GFP_KERNEL);
 780        if (dev_state == NULL)
 781                return -ENOMEM;
 782
 783        spin_lock_init(&dev_state->lock);
 784        init_waitqueue_head(&dev_state->wq);
 785        dev_state->pdev  = pdev;
 786        dev_state->devid = devid;
 787
 788        tmp = pasids;
 789        for (dev_state->pasid_levels = 0; (tmp - 1) & ~0x1ff; tmp >>= 9)
 790                dev_state->pasid_levels += 1;
 791
 792        atomic_set(&dev_state->count, 1);
 793        dev_state->max_pasids = pasids;
 794
 795        ret = -ENOMEM;
 796        dev_state->states = (void *)get_zeroed_page(GFP_KERNEL);
 797        if (dev_state->states == NULL)
 798                goto out_free_dev_state;
 799
 800        dev_state->domain = iommu_domain_alloc(&pci_bus_type);
 801        if (dev_state->domain == NULL)
 802                goto out_free_states;
 803
 804        amd_iommu_domain_direct_map(dev_state->domain);
 805
 806        ret = amd_iommu_domain_enable_v2(dev_state->domain, pasids);
 807        if (ret)
 808                goto out_free_domain;
 809
 810        ret = iommu_attach_device(dev_state->domain, &pdev->dev);
 811        if (ret != 0)
 812                goto out_free_domain;
 813
 814        spin_lock_irqsave(&state_lock, flags);
 815
 816        if (__get_device_state(devid) != NULL) {
 817                spin_unlock_irqrestore(&state_lock, flags);
 818                ret = -EBUSY;
 819                goto out_free_domain;
 820        }
 821
 822        list_add_tail(&dev_state->list, &state_list);
 823
 824        spin_unlock_irqrestore(&state_lock, flags);
 825
 826        return 0;
 827
 828out_free_domain:
 829        iommu_domain_free(dev_state->domain);
 830
 831out_free_states:
 832        free_page((unsigned long)dev_state->states);
 833
 834out_free_dev_state:
 835        kfree(dev_state);
 836
 837        return ret;
 838}
 839EXPORT_SYMBOL(amd_iommu_init_device);
 840
 841void amd_iommu_free_device(struct pci_dev *pdev)
 842{
 843        struct device_state *dev_state;
 844        unsigned long flags;
 845        u16 devid;
 846
 847        if (!amd_iommu_v2_supported())
 848                return;
 849
 850        devid = device_id(pdev);
 851
 852        spin_lock_irqsave(&state_lock, flags);
 853
 854        dev_state = __get_device_state(devid);
 855        if (dev_state == NULL) {
 856                spin_unlock_irqrestore(&state_lock, flags);
 857                return;
 858        }
 859
 860        list_del(&dev_state->list);
 861
 862        spin_unlock_irqrestore(&state_lock, flags);
 863
 864        /* Get rid of any remaining pasid states */
 865        free_pasid_states(dev_state);
 866
 867        put_device_state_wait(dev_state);
 868}
 869EXPORT_SYMBOL(amd_iommu_free_device);
 870
 871int amd_iommu_set_invalid_ppr_cb(struct pci_dev *pdev,
 872                                 amd_iommu_invalid_ppr_cb cb)
 873{
 874        struct device_state *dev_state;
 875        unsigned long flags;
 876        u16 devid;
 877        int ret;
 878
 879        if (!amd_iommu_v2_supported())
 880                return -ENODEV;
 881
 882        devid = device_id(pdev);
 883
 884        spin_lock_irqsave(&state_lock, flags);
 885
 886        ret = -EINVAL;
 887        dev_state = __get_device_state(devid);
 888        if (dev_state == NULL)
 889                goto out_unlock;
 890
 891        dev_state->inv_ppr_cb = cb;
 892
 893        ret = 0;
 894
 895out_unlock:
 896        spin_unlock_irqrestore(&state_lock, flags);
 897
 898        return ret;
 899}
 900EXPORT_SYMBOL(amd_iommu_set_invalid_ppr_cb);
 901
 902int amd_iommu_set_invalidate_ctx_cb(struct pci_dev *pdev,
 903                                    amd_iommu_invalidate_ctx cb)
 904{
 905        struct device_state *dev_state;
 906        unsigned long flags;
 907        u16 devid;
 908        int ret;
 909
 910        if (!amd_iommu_v2_supported())
 911                return -ENODEV;
 912
 913        devid = device_id(pdev);
 914
 915        spin_lock_irqsave(&state_lock, flags);
 916
 917        ret = -EINVAL;
 918        dev_state = __get_device_state(devid);
 919        if (dev_state == NULL)
 920                goto out_unlock;
 921
 922        dev_state->inv_ctx_cb = cb;
 923
 924        ret = 0;
 925
 926out_unlock:
 927        spin_unlock_irqrestore(&state_lock, flags);
 928
 929        return ret;
 930}
 931EXPORT_SYMBOL(amd_iommu_set_invalidate_ctx_cb);
 932
 933static int __init amd_iommu_v2_init(void)
 934{
 935        int ret;
 936
 937        pr_info("AMD IOMMUv2 driver by Joerg Roedel <joerg.roedel@amd.com>\n");
 938
 939        if (!amd_iommu_v2_supported()) {
 940                pr_info("AMD IOMMUv2 functionality not available on this system\n");
 941                /*
 942                 * Load anyway to provide the symbols to other modules
 943                 * which may use AMD IOMMUv2 optionally.
 944                 */
 945                return 0;
 946        }
 947
 948        spin_lock_init(&state_lock);
 949
 950        ret = -ENOMEM;
 951        iommu_wq = create_workqueue("amd_iommu_v2");
 952        if (iommu_wq == NULL)
 953                goto out;
 954
 955        ret = -ENOMEM;
 956        empty_page_table = (u64 *)get_zeroed_page(GFP_KERNEL);
 957        if (empty_page_table == NULL)
 958                goto out_destroy_wq;
 959
 960        amd_iommu_register_ppr_notifier(&ppr_nb);
 961
 962        return 0;
 963
 964out_destroy_wq:
 965        destroy_workqueue(iommu_wq);
 966
 967out:
 968        return ret;
 969}
 970
 971static void __exit amd_iommu_v2_exit(void)
 972{
 973        struct device_state *dev_state;
 974        int i;
 975
 976        if (!amd_iommu_v2_supported())
 977                return;
 978
 979        amd_iommu_unregister_ppr_notifier(&ppr_nb);
 980
 981        flush_workqueue(iommu_wq);
 982
 983        /*
 984         * The loop below might call flush_workqueue(), so call
 985         * destroy_workqueue() after it
 986         */
 987        for (i = 0; i < MAX_DEVICES; ++i) {
 988                dev_state = get_device_state(i);
 989
 990                if (dev_state == NULL)
 991                        continue;
 992
 993                WARN_ON_ONCE(1);
 994
 995                put_device_state(dev_state);
 996                amd_iommu_free_device(dev_state->pdev);
 997        }
 998
 999        destroy_workqueue(iommu_wq);
1000
1001        free_page((unsigned long)empty_page_table);
1002}
1003
1004module_init(amd_iommu_v2_init);
1005module_exit(amd_iommu_v2_exit);
1006