linux/lib/test_hmm.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/*
   3 * This is a module to test the HMM (Heterogeneous Memory Management)
   4 * mirror and zone device private memory migration APIs of the kernel.
   5 * Userspace programs can register with the driver to mirror their own address
   6 * space and can use the device to read/write any valid virtual address.
   7 */
   8#include <linux/init.h>
   9#include <linux/fs.h>
  10#include <linux/mm.h>
  11#include <linux/module.h>
  12#include <linux/kernel.h>
  13#include <linux/cdev.h>
  14#include <linux/device.h>
  15#include <linux/mutex.h>
  16#include <linux/rwsem.h>
  17#include <linux/sched.h>
  18#include <linux/slab.h>
  19#include <linux/highmem.h>
  20#include <linux/delay.h>
  21#include <linux/pagemap.h>
  22#include <linux/hmm.h>
  23#include <linux/vmalloc.h>
  24#include <linux/swap.h>
  25#include <linux/swapops.h>
  26#include <linux/sched/mm.h>
  27#include <linux/platform_device.h>
  28#include <linux/rmap.h>
  29
  30#include "test_hmm_uapi.h"
  31
  32#define DMIRROR_NDEVICES                2
  33#define DMIRROR_RANGE_FAULT_TIMEOUT     1000
  34#define DEVMEM_CHUNK_SIZE               (256 * 1024 * 1024U)
  35#define DEVMEM_CHUNKS_RESERVE           16
  36
  37static const struct dev_pagemap_ops dmirror_devmem_ops;
  38static const struct mmu_interval_notifier_ops dmirror_min_ops;
  39static dev_t dmirror_dev;
  40
  41struct dmirror_device;
  42
  43struct dmirror_bounce {
  44        void                    *ptr;
  45        unsigned long           size;
  46        unsigned long           addr;
  47        unsigned long           cpages;
  48};
  49
  50#define DPT_XA_TAG_ATOMIC 1UL
  51#define DPT_XA_TAG_WRITE 3UL
  52
  53/*
  54 * Data structure to track address ranges and register for mmu interval
  55 * notifier updates.
  56 */
  57struct dmirror_interval {
  58        struct mmu_interval_notifier    notifier;
  59        struct dmirror                  *dmirror;
  60};
  61
  62/*
  63 * Data attached to the open device file.
  64 * Note that it might be shared after a fork().
  65 */
  66struct dmirror {
  67        struct dmirror_device           *mdevice;
  68        struct xarray                   pt;
  69        struct mmu_interval_notifier    notifier;
  70        struct mutex                    mutex;
  71};
  72
  73/*
  74 * ZONE_DEVICE pages for migration and simulating device memory.
  75 */
  76struct dmirror_chunk {
  77        struct dev_pagemap      pagemap;
  78        struct dmirror_device   *mdevice;
  79};
  80
  81/*
  82 * Per device data.
  83 */
  84struct dmirror_device {
  85        struct cdev             cdevice;
  86        struct hmm_devmem       *devmem;
  87
  88        unsigned int            devmem_capacity;
  89        unsigned int            devmem_count;
  90        struct dmirror_chunk    **devmem_chunks;
  91        struct mutex            devmem_lock;    /* protects the above */
  92
  93        unsigned long           calloc;
  94        unsigned long           cfree;
  95        struct page             *free_pages;
  96        spinlock_t              lock;           /* protects the above */
  97};
  98
  99static struct dmirror_device dmirror_devices[DMIRROR_NDEVICES];
 100
 101static int dmirror_bounce_init(struct dmirror_bounce *bounce,
 102                               unsigned long addr,
 103                               unsigned long size)
 104{
 105        bounce->addr = addr;
 106        bounce->size = size;
 107        bounce->cpages = 0;
 108        bounce->ptr = vmalloc(size);
 109        if (!bounce->ptr)
 110                return -ENOMEM;
 111        return 0;
 112}
 113
 114static void dmirror_bounce_fini(struct dmirror_bounce *bounce)
 115{
 116        vfree(bounce->ptr);
 117}
 118
 119static int dmirror_fops_open(struct inode *inode, struct file *filp)
 120{
 121        struct cdev *cdev = inode->i_cdev;
 122        struct dmirror *dmirror;
 123        int ret;
 124
 125        /* Mirror this process address space */
 126        dmirror = kzalloc(sizeof(*dmirror), GFP_KERNEL);
 127        if (dmirror == NULL)
 128                return -ENOMEM;
 129
 130        dmirror->mdevice = container_of(cdev, struct dmirror_device, cdevice);
 131        mutex_init(&dmirror->mutex);
 132        xa_init(&dmirror->pt);
 133
 134        ret = mmu_interval_notifier_insert(&dmirror->notifier, current->mm,
 135                                0, ULONG_MAX & PAGE_MASK, &dmirror_min_ops);
 136        if (ret) {
 137                kfree(dmirror);
 138                return ret;
 139        }
 140
 141        filp->private_data = dmirror;
 142        return 0;
 143}
 144
 145static int dmirror_fops_release(struct inode *inode, struct file *filp)
 146{
 147        struct dmirror *dmirror = filp->private_data;
 148
 149        mmu_interval_notifier_remove(&dmirror->notifier);
 150        xa_destroy(&dmirror->pt);
 151        kfree(dmirror);
 152        return 0;
 153}
 154
 155static struct dmirror_device *dmirror_page_to_device(struct page *page)
 156
 157{
 158        return container_of(page->pgmap, struct dmirror_chunk,
 159                            pagemap)->mdevice;
 160}
 161
 162static int dmirror_do_fault(struct dmirror *dmirror, struct hmm_range *range)
 163{
 164        unsigned long *pfns = range->hmm_pfns;
 165        unsigned long pfn;
 166
 167        for (pfn = (range->start >> PAGE_SHIFT);
 168             pfn < (range->end >> PAGE_SHIFT);
 169             pfn++, pfns++) {
 170                struct page *page;
 171                void *entry;
 172
 173                /*
 174                 * Since we asked for hmm_range_fault() to populate pages,
 175                 * it shouldn't return an error entry on success.
 176                 */
 177                WARN_ON(*pfns & HMM_PFN_ERROR);
 178                WARN_ON(!(*pfns & HMM_PFN_VALID));
 179
 180                page = hmm_pfn_to_page(*pfns);
 181                WARN_ON(!page);
 182
 183                entry = page;
 184                if (*pfns & HMM_PFN_WRITE)
 185                        entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
 186                else if (WARN_ON(range->default_flags & HMM_PFN_WRITE))
 187                        return -EFAULT;
 188                entry = xa_store(&dmirror->pt, pfn, entry, GFP_ATOMIC);
 189                if (xa_is_err(entry))
 190                        return xa_err(entry);
 191        }
 192
 193        return 0;
 194}
 195
 196static void dmirror_do_update(struct dmirror *dmirror, unsigned long start,
 197                              unsigned long end)
 198{
 199        unsigned long pfn;
 200        void *entry;
 201
 202        /*
 203         * The XArray doesn't hold references to pages since it relies on
 204         * the mmu notifier to clear page pointers when they become stale.
 205         * Therefore, it is OK to just clear the entry.
 206         */
 207        xa_for_each_range(&dmirror->pt, pfn, entry, start >> PAGE_SHIFT,
 208                          end >> PAGE_SHIFT)
 209                xa_erase(&dmirror->pt, pfn);
 210}
 211
 212static bool dmirror_interval_invalidate(struct mmu_interval_notifier *mni,
 213                                const struct mmu_notifier_range *range,
 214                                unsigned long cur_seq)
 215{
 216        struct dmirror *dmirror = container_of(mni, struct dmirror, notifier);
 217
 218        /*
 219         * Ignore invalidation callbacks for device private pages since
 220         * the invalidation is handled as part of the migration process.
 221         */
 222        if (range->event == MMU_NOTIFY_MIGRATE &&
 223            range->owner == dmirror->mdevice)
 224                return true;
 225
 226        if (mmu_notifier_range_blockable(range))
 227                mutex_lock(&dmirror->mutex);
 228        else if (!mutex_trylock(&dmirror->mutex))
 229                return false;
 230
 231        mmu_interval_set_seq(mni, cur_seq);
 232        dmirror_do_update(dmirror, range->start, range->end);
 233
 234        mutex_unlock(&dmirror->mutex);
 235        return true;
 236}
 237
 238static const struct mmu_interval_notifier_ops dmirror_min_ops = {
 239        .invalidate = dmirror_interval_invalidate,
 240};
 241
 242static int dmirror_range_fault(struct dmirror *dmirror,
 243                                struct hmm_range *range)
 244{
 245        struct mm_struct *mm = dmirror->notifier.mm;
 246        unsigned long timeout =
 247                jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
 248        int ret;
 249
 250        while (true) {
 251                if (time_after(jiffies, timeout)) {
 252                        ret = -EBUSY;
 253                        goto out;
 254                }
 255
 256                range->notifier_seq = mmu_interval_read_begin(range->notifier);
 257                mmap_read_lock(mm);
 258                ret = hmm_range_fault(range);
 259                mmap_read_unlock(mm);
 260                if (ret) {
 261                        if (ret == -EBUSY)
 262                                continue;
 263                        goto out;
 264                }
 265
 266                mutex_lock(&dmirror->mutex);
 267                if (mmu_interval_read_retry(range->notifier,
 268                                            range->notifier_seq)) {
 269                        mutex_unlock(&dmirror->mutex);
 270                        continue;
 271                }
 272                break;
 273        }
 274
 275        ret = dmirror_do_fault(dmirror, range);
 276
 277        mutex_unlock(&dmirror->mutex);
 278out:
 279        return ret;
 280}
 281
 282static int dmirror_fault(struct dmirror *dmirror, unsigned long start,
 283                         unsigned long end, bool write)
 284{
 285        struct mm_struct *mm = dmirror->notifier.mm;
 286        unsigned long addr;
 287        unsigned long pfns[64];
 288        struct hmm_range range = {
 289                .notifier = &dmirror->notifier,
 290                .hmm_pfns = pfns,
 291                .pfn_flags_mask = 0,
 292                .default_flags =
 293                        HMM_PFN_REQ_FAULT | (write ? HMM_PFN_REQ_WRITE : 0),
 294                .dev_private_owner = dmirror->mdevice,
 295        };
 296        int ret = 0;
 297
 298        /* Since the mm is for the mirrored process, get a reference first. */
 299        if (!mmget_not_zero(mm))
 300                return 0;
 301
 302        for (addr = start; addr < end; addr = range.end) {
 303                range.start = addr;
 304                range.end = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
 305
 306                ret = dmirror_range_fault(dmirror, &range);
 307                if (ret)
 308                        break;
 309        }
 310
 311        mmput(mm);
 312        return ret;
 313}
 314
 315static int dmirror_do_read(struct dmirror *dmirror, unsigned long start,
 316                           unsigned long end, struct dmirror_bounce *bounce)
 317{
 318        unsigned long pfn;
 319        void *ptr;
 320
 321        ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
 322
 323        for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
 324                void *entry;
 325                struct page *page;
 326                void *tmp;
 327
 328                entry = xa_load(&dmirror->pt, pfn);
 329                page = xa_untag_pointer(entry);
 330                if (!page)
 331                        return -ENOENT;
 332
 333                tmp = kmap(page);
 334                memcpy(ptr, tmp, PAGE_SIZE);
 335                kunmap(page);
 336
 337                ptr += PAGE_SIZE;
 338                bounce->cpages++;
 339        }
 340
 341        return 0;
 342}
 343
 344static int dmirror_read(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
 345{
 346        struct dmirror_bounce bounce;
 347        unsigned long start, end;
 348        unsigned long size = cmd->npages << PAGE_SHIFT;
 349        int ret;
 350
 351        start = cmd->addr;
 352        end = start + size;
 353        if (end < start)
 354                return -EINVAL;
 355
 356        ret = dmirror_bounce_init(&bounce, start, size);
 357        if (ret)
 358                return ret;
 359
 360        while (1) {
 361                mutex_lock(&dmirror->mutex);
 362                ret = dmirror_do_read(dmirror, start, end, &bounce);
 363                mutex_unlock(&dmirror->mutex);
 364                if (ret != -ENOENT)
 365                        break;
 366
 367                start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
 368                ret = dmirror_fault(dmirror, start, end, false);
 369                if (ret)
 370                        break;
 371                cmd->faults++;
 372        }
 373
 374        if (ret == 0) {
 375                if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
 376                                 bounce.size))
 377                        ret = -EFAULT;
 378        }
 379        cmd->cpages = bounce.cpages;
 380        dmirror_bounce_fini(&bounce);
 381        return ret;
 382}
 383
 384static int dmirror_do_write(struct dmirror *dmirror, unsigned long start,
 385                            unsigned long end, struct dmirror_bounce *bounce)
 386{
 387        unsigned long pfn;
 388        void *ptr;
 389
 390        ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
 391
 392        for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
 393                void *entry;
 394                struct page *page;
 395                void *tmp;
 396
 397                entry = xa_load(&dmirror->pt, pfn);
 398                page = xa_untag_pointer(entry);
 399                if (!page || xa_pointer_tag(entry) != DPT_XA_TAG_WRITE)
 400                        return -ENOENT;
 401
 402                tmp = kmap(page);
 403                memcpy(tmp, ptr, PAGE_SIZE);
 404                kunmap(page);
 405
 406                ptr += PAGE_SIZE;
 407                bounce->cpages++;
 408        }
 409
 410        return 0;
 411}
 412
 413static int dmirror_write(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
 414{
 415        struct dmirror_bounce bounce;
 416        unsigned long start, end;
 417        unsigned long size = cmd->npages << PAGE_SHIFT;
 418        int ret;
 419
 420        start = cmd->addr;
 421        end = start + size;
 422        if (end < start)
 423                return -EINVAL;
 424
 425        ret = dmirror_bounce_init(&bounce, start, size);
 426        if (ret)
 427                return ret;
 428        if (copy_from_user(bounce.ptr, u64_to_user_ptr(cmd->ptr),
 429                           bounce.size)) {
 430                ret = -EFAULT;
 431                goto fini;
 432        }
 433
 434        while (1) {
 435                mutex_lock(&dmirror->mutex);
 436                ret = dmirror_do_write(dmirror, start, end, &bounce);
 437                mutex_unlock(&dmirror->mutex);
 438                if (ret != -ENOENT)
 439                        break;
 440
 441                start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
 442                ret = dmirror_fault(dmirror, start, end, true);
 443                if (ret)
 444                        break;
 445                cmd->faults++;
 446        }
 447
 448fini:
 449        cmd->cpages = bounce.cpages;
 450        dmirror_bounce_fini(&bounce);
 451        return ret;
 452}
 453
 454static bool dmirror_allocate_chunk(struct dmirror_device *mdevice,
 455                                   struct page **ppage)
 456{
 457        struct dmirror_chunk *devmem;
 458        struct resource *res;
 459        unsigned long pfn;
 460        unsigned long pfn_first;
 461        unsigned long pfn_last;
 462        void *ptr;
 463
 464        devmem = kzalloc(sizeof(*devmem), GFP_KERNEL);
 465        if (!devmem)
 466                return false;
 467
 468        res = request_free_mem_region(&iomem_resource, DEVMEM_CHUNK_SIZE,
 469                                      "hmm_dmirror");
 470        if (IS_ERR(res))
 471                goto err_devmem;
 472
 473        devmem->pagemap.type = MEMORY_DEVICE_PRIVATE;
 474        devmem->pagemap.range.start = res->start;
 475        devmem->pagemap.range.end = res->end;
 476        devmem->pagemap.nr_range = 1;
 477        devmem->pagemap.ops = &dmirror_devmem_ops;
 478        devmem->pagemap.owner = mdevice;
 479
 480        mutex_lock(&mdevice->devmem_lock);
 481
 482        if (mdevice->devmem_count == mdevice->devmem_capacity) {
 483                struct dmirror_chunk **new_chunks;
 484                unsigned int new_capacity;
 485
 486                new_capacity = mdevice->devmem_capacity +
 487                                DEVMEM_CHUNKS_RESERVE;
 488                new_chunks = krealloc(mdevice->devmem_chunks,
 489                                sizeof(new_chunks[0]) * new_capacity,
 490                                GFP_KERNEL);
 491                if (!new_chunks)
 492                        goto err_release;
 493                mdevice->devmem_capacity = new_capacity;
 494                mdevice->devmem_chunks = new_chunks;
 495        }
 496
 497        ptr = memremap_pages(&devmem->pagemap, numa_node_id());
 498        if (IS_ERR(ptr))
 499                goto err_release;
 500
 501        devmem->mdevice = mdevice;
 502        pfn_first = devmem->pagemap.range.start >> PAGE_SHIFT;
 503        pfn_last = pfn_first + (range_len(&devmem->pagemap.range) >> PAGE_SHIFT);
 504        mdevice->devmem_chunks[mdevice->devmem_count++] = devmem;
 505
 506        mutex_unlock(&mdevice->devmem_lock);
 507
 508        pr_info("added new %u MB chunk (total %u chunks, %u MB) PFNs [0x%lx 0x%lx)\n",
 509                DEVMEM_CHUNK_SIZE / (1024 * 1024),
 510                mdevice->devmem_count,
 511                mdevice->devmem_count * (DEVMEM_CHUNK_SIZE / (1024 * 1024)),
 512                pfn_first, pfn_last);
 513
 514        spin_lock(&mdevice->lock);
 515        for (pfn = pfn_first; pfn < pfn_last; pfn++) {
 516                struct page *page = pfn_to_page(pfn);
 517
 518                page->zone_device_data = mdevice->free_pages;
 519                mdevice->free_pages = page;
 520        }
 521        if (ppage) {
 522                *ppage = mdevice->free_pages;
 523                mdevice->free_pages = (*ppage)->zone_device_data;
 524                mdevice->calloc++;
 525        }
 526        spin_unlock(&mdevice->lock);
 527
 528        return true;
 529
 530err_release:
 531        mutex_unlock(&mdevice->devmem_lock);
 532        release_mem_region(devmem->pagemap.range.start, range_len(&devmem->pagemap.range));
 533err_devmem:
 534        kfree(devmem);
 535
 536        return false;
 537}
 538
 539static struct page *dmirror_devmem_alloc_page(struct dmirror_device *mdevice)
 540{
 541        struct page *dpage = NULL;
 542        struct page *rpage;
 543
 544        /*
 545         * This is a fake device so we alloc real system memory to store
 546         * our device memory.
 547         */
 548        rpage = alloc_page(GFP_HIGHUSER);
 549        if (!rpage)
 550                return NULL;
 551
 552        spin_lock(&mdevice->lock);
 553
 554        if (mdevice->free_pages) {
 555                dpage = mdevice->free_pages;
 556                mdevice->free_pages = dpage->zone_device_data;
 557                mdevice->calloc++;
 558                spin_unlock(&mdevice->lock);
 559        } else {
 560                spin_unlock(&mdevice->lock);
 561                if (!dmirror_allocate_chunk(mdevice, &dpage))
 562                        goto error;
 563        }
 564
 565        dpage->zone_device_data = rpage;
 566        get_page(dpage);
 567        lock_page(dpage);
 568        return dpage;
 569
 570error:
 571        __free_page(rpage);
 572        return NULL;
 573}
 574
 575static void dmirror_migrate_alloc_and_copy(struct migrate_vma *args,
 576                                           struct dmirror *dmirror)
 577{
 578        struct dmirror_device *mdevice = dmirror->mdevice;
 579        const unsigned long *src = args->src;
 580        unsigned long *dst = args->dst;
 581        unsigned long addr;
 582
 583        for (addr = args->start; addr < args->end; addr += PAGE_SIZE,
 584                                                   src++, dst++) {
 585                struct page *spage;
 586                struct page *dpage;
 587                struct page *rpage;
 588
 589                if (!(*src & MIGRATE_PFN_MIGRATE))
 590                        continue;
 591
 592                /*
 593                 * Note that spage might be NULL which is OK since it is an
 594                 * unallocated pte_none() or read-only zero page.
 595                 */
 596                spage = migrate_pfn_to_page(*src);
 597
 598                dpage = dmirror_devmem_alloc_page(mdevice);
 599                if (!dpage)
 600                        continue;
 601
 602                rpage = dpage->zone_device_data;
 603                if (spage)
 604                        copy_highpage(rpage, spage);
 605                else
 606                        clear_highpage(rpage);
 607
 608                /*
 609                 * Normally, a device would use the page->zone_device_data to
 610                 * point to the mirror but here we use it to hold the page for
 611                 * the simulated device memory and that page holds the pointer
 612                 * to the mirror.
 613                 */
 614                rpage->zone_device_data = dmirror;
 615
 616                *dst = migrate_pfn(page_to_pfn(dpage)) |
 617                            MIGRATE_PFN_LOCKED;
 618                if ((*src & MIGRATE_PFN_WRITE) ||
 619                    (!spage && args->vma->vm_flags & VM_WRITE))
 620                        *dst |= MIGRATE_PFN_WRITE;
 621        }
 622}
 623
 624static int dmirror_check_atomic(struct dmirror *dmirror, unsigned long start,
 625                             unsigned long end)
 626{
 627        unsigned long pfn;
 628
 629        for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
 630                void *entry;
 631
 632                entry = xa_load(&dmirror->pt, pfn);
 633                if (xa_pointer_tag(entry) == DPT_XA_TAG_ATOMIC)
 634                        return -EPERM;
 635        }
 636
 637        return 0;
 638}
 639
 640static int dmirror_atomic_map(unsigned long start, unsigned long end,
 641                              struct page **pages, struct dmirror *dmirror)
 642{
 643        unsigned long pfn, mapped = 0;
 644        int i;
 645
 646        /* Map the migrated pages into the device's page tables. */
 647        mutex_lock(&dmirror->mutex);
 648
 649        for (i = 0, pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++, i++) {
 650                void *entry;
 651
 652                if (!pages[i])
 653                        continue;
 654
 655                entry = pages[i];
 656                entry = xa_tag_pointer(entry, DPT_XA_TAG_ATOMIC);
 657                entry = xa_store(&dmirror->pt, pfn, entry, GFP_ATOMIC);
 658                if (xa_is_err(entry)) {
 659                        mutex_unlock(&dmirror->mutex);
 660                        return xa_err(entry);
 661                }
 662
 663                mapped++;
 664        }
 665
 666        mutex_unlock(&dmirror->mutex);
 667        return mapped;
 668}
 669
 670static int dmirror_migrate_finalize_and_map(struct migrate_vma *args,
 671                                            struct dmirror *dmirror)
 672{
 673        unsigned long start = args->start;
 674        unsigned long end = args->end;
 675        const unsigned long *src = args->src;
 676        const unsigned long *dst = args->dst;
 677        unsigned long pfn;
 678
 679        /* Map the migrated pages into the device's page tables. */
 680        mutex_lock(&dmirror->mutex);
 681
 682        for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++,
 683                                                                src++, dst++) {
 684                struct page *dpage;
 685                void *entry;
 686
 687                if (!(*src & MIGRATE_PFN_MIGRATE))
 688                        continue;
 689
 690                dpage = migrate_pfn_to_page(*dst);
 691                if (!dpage)
 692                        continue;
 693
 694                /*
 695                 * Store the page that holds the data so the page table
 696                 * doesn't have to deal with ZONE_DEVICE private pages.
 697                 */
 698                entry = dpage->zone_device_data;
 699                if (*dst & MIGRATE_PFN_WRITE)
 700                        entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
 701                entry = xa_store(&dmirror->pt, pfn, entry, GFP_ATOMIC);
 702                if (xa_is_err(entry)) {
 703                        mutex_unlock(&dmirror->mutex);
 704                        return xa_err(entry);
 705                }
 706        }
 707
 708        mutex_unlock(&dmirror->mutex);
 709        return 0;
 710}
 711
 712static int dmirror_exclusive(struct dmirror *dmirror,
 713                             struct hmm_dmirror_cmd *cmd)
 714{
 715        unsigned long start, end, addr;
 716        unsigned long size = cmd->npages << PAGE_SHIFT;
 717        struct mm_struct *mm = dmirror->notifier.mm;
 718        struct page *pages[64];
 719        struct dmirror_bounce bounce;
 720        unsigned long next;
 721        int ret;
 722
 723        start = cmd->addr;
 724        end = start + size;
 725        if (end < start)
 726                return -EINVAL;
 727
 728        /* Since the mm is for the mirrored process, get a reference first. */
 729        if (!mmget_not_zero(mm))
 730                return -EINVAL;
 731
 732        mmap_read_lock(mm);
 733        for (addr = start; addr < end; addr = next) {
 734                unsigned long mapped;
 735                int i;
 736
 737                if (end < addr + (ARRAY_SIZE(pages) << PAGE_SHIFT))
 738                        next = end;
 739                else
 740                        next = addr + (ARRAY_SIZE(pages) << PAGE_SHIFT);
 741
 742                ret = make_device_exclusive_range(mm, addr, next, pages, NULL);
 743                mapped = dmirror_atomic_map(addr, next, pages, dmirror);
 744                for (i = 0; i < ret; i++) {
 745                        if (pages[i]) {
 746                                unlock_page(pages[i]);
 747                                put_page(pages[i]);
 748                        }
 749                }
 750
 751                if (addr + (mapped << PAGE_SHIFT) < next) {
 752                        mmap_read_unlock(mm);
 753                        mmput(mm);
 754                        return -EBUSY;
 755                }
 756        }
 757        mmap_read_unlock(mm);
 758        mmput(mm);
 759
 760        /* Return the migrated data for verification. */
 761        ret = dmirror_bounce_init(&bounce, start, size);
 762        if (ret)
 763                return ret;
 764        mutex_lock(&dmirror->mutex);
 765        ret = dmirror_do_read(dmirror, start, end, &bounce);
 766        mutex_unlock(&dmirror->mutex);
 767        if (ret == 0) {
 768                if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
 769                                 bounce.size))
 770                        ret = -EFAULT;
 771        }
 772
 773        cmd->cpages = bounce.cpages;
 774        dmirror_bounce_fini(&bounce);
 775        return ret;
 776}
 777
 778static int dmirror_migrate(struct dmirror *dmirror,
 779                           struct hmm_dmirror_cmd *cmd)
 780{
 781        unsigned long start, end, addr;
 782        unsigned long size = cmd->npages << PAGE_SHIFT;
 783        struct mm_struct *mm = dmirror->notifier.mm;
 784        struct vm_area_struct *vma;
 785        unsigned long src_pfns[64];
 786        unsigned long dst_pfns[64];
 787        struct dmirror_bounce bounce;
 788        struct migrate_vma args;
 789        unsigned long next;
 790        int ret;
 791
 792        start = cmd->addr;
 793        end = start + size;
 794        if (end < start)
 795                return -EINVAL;
 796
 797        /* Since the mm is for the mirrored process, get a reference first. */
 798        if (!mmget_not_zero(mm))
 799                return -EINVAL;
 800
 801        mmap_read_lock(mm);
 802        for (addr = start; addr < end; addr = next) {
 803                vma = vma_lookup(mm, addr);
 804                if (!vma || !(vma->vm_flags & VM_READ)) {
 805                        ret = -EINVAL;
 806                        goto out;
 807                }
 808                next = min(end, addr + (ARRAY_SIZE(src_pfns) << PAGE_SHIFT));
 809                if (next > vma->vm_end)
 810                        next = vma->vm_end;
 811
 812                args.vma = vma;
 813                args.src = src_pfns;
 814                args.dst = dst_pfns;
 815                args.start = addr;
 816                args.end = next;
 817                args.pgmap_owner = dmirror->mdevice;
 818                args.flags = MIGRATE_VMA_SELECT_SYSTEM;
 819                ret = migrate_vma_setup(&args);
 820                if (ret)
 821                        goto out;
 822
 823                dmirror_migrate_alloc_and_copy(&args, dmirror);
 824                migrate_vma_pages(&args);
 825                dmirror_migrate_finalize_and_map(&args, dmirror);
 826                migrate_vma_finalize(&args);
 827        }
 828        mmap_read_unlock(mm);
 829        mmput(mm);
 830
 831        /* Return the migrated data for verification. */
 832        ret = dmirror_bounce_init(&bounce, start, size);
 833        if (ret)
 834                return ret;
 835        mutex_lock(&dmirror->mutex);
 836        ret = dmirror_do_read(dmirror, start, end, &bounce);
 837        mutex_unlock(&dmirror->mutex);
 838        if (ret == 0) {
 839                if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
 840                                 bounce.size))
 841                        ret = -EFAULT;
 842        }
 843        cmd->cpages = bounce.cpages;
 844        dmirror_bounce_fini(&bounce);
 845        return ret;
 846
 847out:
 848        mmap_read_unlock(mm);
 849        mmput(mm);
 850        return ret;
 851}
 852
 853static void dmirror_mkentry(struct dmirror *dmirror, struct hmm_range *range,
 854                            unsigned char *perm, unsigned long entry)
 855{
 856        struct page *page;
 857
 858        if (entry & HMM_PFN_ERROR) {
 859                *perm = HMM_DMIRROR_PROT_ERROR;
 860                return;
 861        }
 862        if (!(entry & HMM_PFN_VALID)) {
 863                *perm = HMM_DMIRROR_PROT_NONE;
 864                return;
 865        }
 866
 867        page = hmm_pfn_to_page(entry);
 868        if (is_device_private_page(page)) {
 869                /* Is the page migrated to this device or some other? */
 870                if (dmirror->mdevice == dmirror_page_to_device(page))
 871                        *perm = HMM_DMIRROR_PROT_DEV_PRIVATE_LOCAL;
 872                else
 873                        *perm = HMM_DMIRROR_PROT_DEV_PRIVATE_REMOTE;
 874        } else if (is_zero_pfn(page_to_pfn(page)))
 875                *perm = HMM_DMIRROR_PROT_ZERO;
 876        else
 877                *perm = HMM_DMIRROR_PROT_NONE;
 878        if (entry & HMM_PFN_WRITE)
 879                *perm |= HMM_DMIRROR_PROT_WRITE;
 880        else
 881                *perm |= HMM_DMIRROR_PROT_READ;
 882        if (hmm_pfn_to_map_order(entry) + PAGE_SHIFT == PMD_SHIFT)
 883                *perm |= HMM_DMIRROR_PROT_PMD;
 884        else if (hmm_pfn_to_map_order(entry) + PAGE_SHIFT == PUD_SHIFT)
 885                *perm |= HMM_DMIRROR_PROT_PUD;
 886}
 887
 888static bool dmirror_snapshot_invalidate(struct mmu_interval_notifier *mni,
 889                                const struct mmu_notifier_range *range,
 890                                unsigned long cur_seq)
 891{
 892        struct dmirror_interval *dmi =
 893                container_of(mni, struct dmirror_interval, notifier);
 894        struct dmirror *dmirror = dmi->dmirror;
 895
 896        if (mmu_notifier_range_blockable(range))
 897                mutex_lock(&dmirror->mutex);
 898        else if (!mutex_trylock(&dmirror->mutex))
 899                return false;
 900
 901        /*
 902         * Snapshots only need to set the sequence number since any
 903         * invalidation in the interval invalidates the whole snapshot.
 904         */
 905        mmu_interval_set_seq(mni, cur_seq);
 906
 907        mutex_unlock(&dmirror->mutex);
 908        return true;
 909}
 910
 911static const struct mmu_interval_notifier_ops dmirror_mrn_ops = {
 912        .invalidate = dmirror_snapshot_invalidate,
 913};
 914
 915static int dmirror_range_snapshot(struct dmirror *dmirror,
 916                                  struct hmm_range *range,
 917                                  unsigned char *perm)
 918{
 919        struct mm_struct *mm = dmirror->notifier.mm;
 920        struct dmirror_interval notifier;
 921        unsigned long timeout =
 922                jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
 923        unsigned long i;
 924        unsigned long n;
 925        int ret = 0;
 926
 927        notifier.dmirror = dmirror;
 928        range->notifier = &notifier.notifier;
 929
 930        ret = mmu_interval_notifier_insert(range->notifier, mm,
 931                        range->start, range->end - range->start,
 932                        &dmirror_mrn_ops);
 933        if (ret)
 934                return ret;
 935
 936        while (true) {
 937                if (time_after(jiffies, timeout)) {
 938                        ret = -EBUSY;
 939                        goto out;
 940                }
 941
 942                range->notifier_seq = mmu_interval_read_begin(range->notifier);
 943
 944                mmap_read_lock(mm);
 945                ret = hmm_range_fault(range);
 946                mmap_read_unlock(mm);
 947                if (ret) {
 948                        if (ret == -EBUSY)
 949                                continue;
 950                        goto out;
 951                }
 952
 953                mutex_lock(&dmirror->mutex);
 954                if (mmu_interval_read_retry(range->notifier,
 955                                            range->notifier_seq)) {
 956                        mutex_unlock(&dmirror->mutex);
 957                        continue;
 958                }
 959                break;
 960        }
 961
 962        n = (range->end - range->start) >> PAGE_SHIFT;
 963        for (i = 0; i < n; i++)
 964                dmirror_mkentry(dmirror, range, perm + i, range->hmm_pfns[i]);
 965
 966        mutex_unlock(&dmirror->mutex);
 967out:
 968        mmu_interval_notifier_remove(range->notifier);
 969        return ret;
 970}
 971
 972static int dmirror_snapshot(struct dmirror *dmirror,
 973                            struct hmm_dmirror_cmd *cmd)
 974{
 975        struct mm_struct *mm = dmirror->notifier.mm;
 976        unsigned long start, end;
 977        unsigned long size = cmd->npages << PAGE_SHIFT;
 978        unsigned long addr;
 979        unsigned long next;
 980        unsigned long pfns[64];
 981        unsigned char perm[64];
 982        char __user *uptr;
 983        struct hmm_range range = {
 984                .hmm_pfns = pfns,
 985                .dev_private_owner = dmirror->mdevice,
 986        };
 987        int ret = 0;
 988
 989        start = cmd->addr;
 990        end = start + size;
 991        if (end < start)
 992                return -EINVAL;
 993
 994        /* Since the mm is for the mirrored process, get a reference first. */
 995        if (!mmget_not_zero(mm))
 996                return -EINVAL;
 997
 998        /*
 999         * Register a temporary notifier to detect invalidations even if it
1000         * overlaps with other mmu_interval_notifiers.
1001         */
1002        uptr = u64_to_user_ptr(cmd->ptr);
1003        for (addr = start; addr < end; addr = next) {
1004                unsigned long n;
1005
1006                next = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
1007                range.start = addr;
1008                range.end = next;
1009
1010                ret = dmirror_range_snapshot(dmirror, &range, perm);
1011                if (ret)
1012                        break;
1013
1014                n = (range.end - range.start) >> PAGE_SHIFT;
1015                if (copy_to_user(uptr, perm, n)) {
1016                        ret = -EFAULT;
1017                        break;
1018                }
1019
1020                cmd->cpages += n;
1021                uptr += n;
1022        }
1023        mmput(mm);
1024
1025        return ret;
1026}
1027
1028static long dmirror_fops_unlocked_ioctl(struct file *filp,
1029                                        unsigned int command,
1030                                        unsigned long arg)
1031{
1032        void __user *uarg = (void __user *)arg;
1033        struct hmm_dmirror_cmd cmd;
1034        struct dmirror *dmirror;
1035        int ret;
1036
1037        dmirror = filp->private_data;
1038        if (!dmirror)
1039                return -EINVAL;
1040
1041        if (copy_from_user(&cmd, uarg, sizeof(cmd)))
1042                return -EFAULT;
1043
1044        if (cmd.addr & ~PAGE_MASK)
1045                return -EINVAL;
1046        if (cmd.addr >= (cmd.addr + (cmd.npages << PAGE_SHIFT)))
1047                return -EINVAL;
1048
1049        cmd.cpages = 0;
1050        cmd.faults = 0;
1051
1052        switch (command) {
1053        case HMM_DMIRROR_READ:
1054                ret = dmirror_read(dmirror, &cmd);
1055                break;
1056
1057        case HMM_DMIRROR_WRITE:
1058                ret = dmirror_write(dmirror, &cmd);
1059                break;
1060
1061        case HMM_DMIRROR_MIGRATE:
1062                ret = dmirror_migrate(dmirror, &cmd);
1063                break;
1064
1065        case HMM_DMIRROR_EXCLUSIVE:
1066                ret = dmirror_exclusive(dmirror, &cmd);
1067                break;
1068
1069        case HMM_DMIRROR_CHECK_EXCLUSIVE:
1070                ret = dmirror_check_atomic(dmirror, cmd.addr,
1071                                        cmd.addr + (cmd.npages << PAGE_SHIFT));
1072                break;
1073
1074        case HMM_DMIRROR_SNAPSHOT:
1075                ret = dmirror_snapshot(dmirror, &cmd);
1076                break;
1077
1078        default:
1079                return -EINVAL;
1080        }
1081        if (ret)
1082                return ret;
1083
1084        if (copy_to_user(uarg, &cmd, sizeof(cmd)))
1085                return -EFAULT;
1086
1087        return 0;
1088}
1089
1090static const struct file_operations dmirror_fops = {
1091        .open           = dmirror_fops_open,
1092        .release        = dmirror_fops_release,
1093        .unlocked_ioctl = dmirror_fops_unlocked_ioctl,
1094        .llseek         = default_llseek,
1095        .owner          = THIS_MODULE,
1096};
1097
1098static void dmirror_devmem_free(struct page *page)
1099{
1100        struct page *rpage = page->zone_device_data;
1101        struct dmirror_device *mdevice;
1102
1103        if (rpage)
1104                __free_page(rpage);
1105
1106        mdevice = dmirror_page_to_device(page);
1107
1108        spin_lock(&mdevice->lock);
1109        mdevice->cfree++;
1110        page->zone_device_data = mdevice->free_pages;
1111        mdevice->free_pages = page;
1112        spin_unlock(&mdevice->lock);
1113}
1114
1115static vm_fault_t dmirror_devmem_fault_alloc_and_copy(struct migrate_vma *args,
1116                                                      struct dmirror *dmirror)
1117{
1118        const unsigned long *src = args->src;
1119        unsigned long *dst = args->dst;
1120        unsigned long start = args->start;
1121        unsigned long end = args->end;
1122        unsigned long addr;
1123
1124        for (addr = start; addr < end; addr += PAGE_SIZE,
1125                                       src++, dst++) {
1126                struct page *dpage, *spage;
1127
1128                spage = migrate_pfn_to_page(*src);
1129                if (!spage || !(*src & MIGRATE_PFN_MIGRATE))
1130                        continue;
1131                spage = spage->zone_device_data;
1132
1133                dpage = alloc_page_vma(GFP_HIGHUSER_MOVABLE, args->vma, addr);
1134                if (!dpage)
1135                        continue;
1136
1137                lock_page(dpage);
1138                xa_erase(&dmirror->pt, addr >> PAGE_SHIFT);
1139                copy_highpage(dpage, spage);
1140                *dst = migrate_pfn(page_to_pfn(dpage)) | MIGRATE_PFN_LOCKED;
1141                if (*src & MIGRATE_PFN_WRITE)
1142                        *dst |= MIGRATE_PFN_WRITE;
1143        }
1144        return 0;
1145}
1146
1147static vm_fault_t dmirror_devmem_fault(struct vm_fault *vmf)
1148{
1149        struct migrate_vma args;
1150        unsigned long src_pfns;
1151        unsigned long dst_pfns;
1152        struct page *rpage;
1153        struct dmirror *dmirror;
1154        vm_fault_t ret;
1155
1156        /*
1157         * Normally, a device would use the page->zone_device_data to point to
1158         * the mirror but here we use it to hold the page for the simulated
1159         * device memory and that page holds the pointer to the mirror.
1160         */
1161        rpage = vmf->page->zone_device_data;
1162        dmirror = rpage->zone_device_data;
1163
1164        /* FIXME demonstrate how we can adjust migrate range */
1165        args.vma = vmf->vma;
1166        args.start = vmf->address;
1167        args.end = args.start + PAGE_SIZE;
1168        args.src = &src_pfns;
1169        args.dst = &dst_pfns;
1170        args.pgmap_owner = dmirror->mdevice;
1171        args.flags = MIGRATE_VMA_SELECT_DEVICE_PRIVATE;
1172
1173        if (migrate_vma_setup(&args))
1174                return VM_FAULT_SIGBUS;
1175
1176        ret = dmirror_devmem_fault_alloc_and_copy(&args, dmirror);
1177        if (ret)
1178                return ret;
1179        migrate_vma_pages(&args);
1180        /*
1181         * No device finalize step is needed since
1182         * dmirror_devmem_fault_alloc_and_copy() will have already
1183         * invalidated the device page table.
1184         */
1185        migrate_vma_finalize(&args);
1186        return 0;
1187}
1188
1189static const struct dev_pagemap_ops dmirror_devmem_ops = {
1190        .page_free      = dmirror_devmem_free,
1191        .migrate_to_ram = dmirror_devmem_fault,
1192};
1193
1194static int dmirror_device_init(struct dmirror_device *mdevice, int id)
1195{
1196        dev_t dev;
1197        int ret;
1198
1199        dev = MKDEV(MAJOR(dmirror_dev), id);
1200        mutex_init(&mdevice->devmem_lock);
1201        spin_lock_init(&mdevice->lock);
1202
1203        cdev_init(&mdevice->cdevice, &dmirror_fops);
1204        mdevice->cdevice.owner = THIS_MODULE;
1205        ret = cdev_add(&mdevice->cdevice, dev, 1);
1206        if (ret)
1207                return ret;
1208
1209        /* Build a list of free ZONE_DEVICE private struct pages */
1210        dmirror_allocate_chunk(mdevice, NULL);
1211
1212        return 0;
1213}
1214
1215static void dmirror_device_remove(struct dmirror_device *mdevice)
1216{
1217        unsigned int i;
1218
1219        if (mdevice->devmem_chunks) {
1220                for (i = 0; i < mdevice->devmem_count; i++) {
1221                        struct dmirror_chunk *devmem =
1222                                mdevice->devmem_chunks[i];
1223
1224                        memunmap_pages(&devmem->pagemap);
1225                        release_mem_region(devmem->pagemap.range.start,
1226                                           range_len(&devmem->pagemap.range));
1227                        kfree(devmem);
1228                }
1229                kfree(mdevice->devmem_chunks);
1230        }
1231
1232        cdev_del(&mdevice->cdevice);
1233}
1234
1235static int __init hmm_dmirror_init(void)
1236{
1237        int ret;
1238        int id;
1239
1240        ret = alloc_chrdev_region(&dmirror_dev, 0, DMIRROR_NDEVICES,
1241                                  "HMM_DMIRROR");
1242        if (ret)
1243                goto err_unreg;
1244
1245        for (id = 0; id < DMIRROR_NDEVICES; id++) {
1246                ret = dmirror_device_init(dmirror_devices + id, id);
1247                if (ret)
1248                        goto err_chrdev;
1249        }
1250
1251        pr_info("HMM test module loaded. This is only for testing HMM.\n");
1252        return 0;
1253
1254err_chrdev:
1255        while (--id >= 0)
1256                dmirror_device_remove(dmirror_devices + id);
1257        unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
1258err_unreg:
1259        return ret;
1260}
1261
1262static void __exit hmm_dmirror_exit(void)
1263{
1264        int id;
1265
1266        for (id = 0; id < DMIRROR_NDEVICES; id++)
1267                dmirror_device_remove(dmirror_devices + id);
1268        unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
1269}
1270
1271module_init(hmm_dmirror_init);
1272module_exit(hmm_dmirror_exit);
1273MODULE_LICENSE("GPL");
1274