linux/drivers/infiniband/core/umem_odp.c
<<
>>
Prefs
   1/*
   2 * Copyright (c) 2014 Mellanox Technologies. All rights reserved.
   3 *
   4 * This software is available to you under a choice of one of two
   5 * licenses.  You may choose to be licensed under the terms of the GNU
   6 * General Public License (GPL) Version 2, available from the file
   7 * COPYING in the main directory of this source tree, or the
   8 * OpenIB.org BSD license below:
   9 *
  10 *     Redistribution and use in source and binary forms, with or
  11 *     without modification, are permitted provided that the following
  12 *     conditions are met:
  13 *
  14 *      - Redistributions of source code must retain the above
  15 *        copyright notice, this list of conditions and the following
  16 *        disclaimer.
  17 *
  18 *      - Redistributions in binary form must reproduce the above
  19 *        copyright notice, this list of conditions and the following
  20 *        disclaimer in the documentation and/or other materials
  21 *        provided with the distribution.
  22 *
  23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  30 * SOFTWARE.
  31 */
  32
  33#include <linux/types.h>
  34#include <linux/sched.h>
  35#include <linux/sched/mm.h>
  36#include <linux/sched/task.h>
  37#include <linux/pid.h>
  38#include <linux/slab.h>
  39#include <linux/export.h>
  40#include <linux/vmalloc.h>
  41#include <linux/hugetlb.h>
  42#include <linux/interval_tree_generic.h>
  43
  44#include <rdma/ib_verbs.h>
  45#include <rdma/ib_umem.h>
  46#include <rdma/ib_umem_odp.h>
  47
  48/*
  49 * The ib_umem list keeps track of memory regions for which the HW
  50 * device request to receive notification when the related memory
  51 * mapping is changed.
  52 *
  53 * ib_umem_lock protects the list.
  54 */
  55
  56static u64 node_start(struct umem_odp_node *n)
  57{
  58        struct ib_umem_odp *umem_odp =
  59                        container_of(n, struct ib_umem_odp, interval_tree);
  60
  61        return ib_umem_start(umem_odp->umem);
  62}
  63
  64/* Note that the representation of the intervals in the interval tree
  65 * considers the ending point as contained in the interval, while the
  66 * function ib_umem_end returns the first address which is not contained
  67 * in the umem.
  68 */
  69static u64 node_last(struct umem_odp_node *n)
  70{
  71        struct ib_umem_odp *umem_odp =
  72                        container_of(n, struct ib_umem_odp, interval_tree);
  73
  74        return ib_umem_end(umem_odp->umem) - 1;
  75}
  76
  77INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
  78                     node_start, node_last, static, rbt_ib_umem)
  79
  80static void ib_umem_notifier_start_account(struct ib_umem *item)
  81{
  82        mutex_lock(&item->odp_data->umem_mutex);
  83
  84        /* Only update private counters for this umem if it has them.
  85         * Otherwise skip it. All page faults will be delayed for this umem. */
  86        if (item->odp_data->mn_counters_active) {
  87                int notifiers_count = item->odp_data->notifiers_count++;
  88
  89                if (notifiers_count == 0)
  90                        /* Initialize the completion object for waiting on
  91                         * notifiers. Since notifier_count is zero, no one
  92                         * should be waiting right now. */
  93                        reinit_completion(&item->odp_data->notifier_completion);
  94        }
  95        mutex_unlock(&item->odp_data->umem_mutex);
  96}
  97
  98static void ib_umem_notifier_end_account(struct ib_umem *item)
  99{
 100        mutex_lock(&item->odp_data->umem_mutex);
 101
 102        /* Only update private counters for this umem if it has them.
 103         * Otherwise skip it. All page faults will be delayed for this umem. */
 104        if (item->odp_data->mn_counters_active) {
 105                /*
 106                 * This sequence increase will notify the QP page fault that
 107                 * the page that is going to be mapped in the spte could have
 108                 * been freed.
 109                 */
 110                ++item->odp_data->notifiers_seq;
 111                if (--item->odp_data->notifiers_count == 0)
 112                        complete_all(&item->odp_data->notifier_completion);
 113        }
 114        mutex_unlock(&item->odp_data->umem_mutex);
 115}
 116
 117/* Account for a new mmu notifier in an ib_ucontext. */
 118static void ib_ucontext_notifier_start_account(struct ib_ucontext *context)
 119{
 120        atomic_inc(&context->notifier_count);
 121}
 122
 123/* Account for a terminating mmu notifier in an ib_ucontext.
 124 *
 125 * Must be called with the ib_ucontext->umem_rwsem semaphore unlocked, since
 126 * the function takes the semaphore itself. */
 127static void ib_ucontext_notifier_end_account(struct ib_ucontext *context)
 128{
 129        int zero_notifiers = atomic_dec_and_test(&context->notifier_count);
 130
 131        if (zero_notifiers &&
 132            !list_empty(&context->no_private_counters)) {
 133                /* No currently running mmu notifiers. Now is the chance to
 134                 * add private accounting to all previously added umems. */
 135                struct ib_umem_odp *odp_data, *next;
 136
 137                /* Prevent concurrent mmu notifiers from working on the
 138                 * no_private_counters list. */
 139                down_write(&context->umem_rwsem);
 140
 141                /* Read the notifier_count again, with the umem_rwsem
 142                 * semaphore taken for write. */
 143                if (!atomic_read(&context->notifier_count)) {
 144                        list_for_each_entry_safe(odp_data, next,
 145                                                 &context->no_private_counters,
 146                                                 no_private_counters) {
 147                                mutex_lock(&odp_data->umem_mutex);
 148                                odp_data->mn_counters_active = true;
 149                                list_del(&odp_data->no_private_counters);
 150                                complete_all(&odp_data->notifier_completion);
 151                                mutex_unlock(&odp_data->umem_mutex);
 152                        }
 153                }
 154
 155                up_write(&context->umem_rwsem);
 156        }
 157}
 158
 159static int ib_umem_notifier_release_trampoline(struct ib_umem *item, u64 start,
 160                                               u64 end, void *cookie) {
 161        /*
 162         * Increase the number of notifiers running, to
 163         * prevent any further fault handling on this MR.
 164         */
 165        ib_umem_notifier_start_account(item);
 166        item->odp_data->dying = 1;
 167        /* Make sure that the fact the umem is dying is out before we release
 168         * all pending page faults. */
 169        smp_wmb();
 170        complete_all(&item->odp_data->notifier_completion);
 171        item->context->invalidate_range(item, ib_umem_start(item),
 172                                        ib_umem_end(item));
 173        return 0;
 174}
 175
 176static void ib_umem_notifier_release(struct mmu_notifier *mn,
 177                                     struct mm_struct *mm)
 178{
 179        struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
 180
 181        if (!context->invalidate_range)
 182                return;
 183
 184        ib_ucontext_notifier_start_account(context);
 185        down_read(&context->umem_rwsem);
 186        rbt_ib_umem_for_each_in_range(&context->umem_tree, 0,
 187                                      ULLONG_MAX,
 188                                      ib_umem_notifier_release_trampoline,
 189                                      NULL);
 190        up_read(&context->umem_rwsem);
 191}
 192
 193static int invalidate_page_trampoline(struct ib_umem *item, u64 start,
 194                                      u64 end, void *cookie)
 195{
 196        ib_umem_notifier_start_account(item);
 197        item->context->invalidate_range(item, start, start + PAGE_SIZE);
 198        ib_umem_notifier_end_account(item);
 199        return 0;
 200}
 201
 202static int invalidate_range_start_trampoline(struct ib_umem *item, u64 start,
 203                                             u64 end, void *cookie)
 204{
 205        ib_umem_notifier_start_account(item);
 206        item->context->invalidate_range(item, start, end);
 207        return 0;
 208}
 209
 210static void ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
 211                                                    struct mm_struct *mm,
 212                                                    unsigned long start,
 213                                                    unsigned long end)
 214{
 215        struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
 216
 217        if (!context->invalidate_range)
 218                return;
 219
 220        ib_ucontext_notifier_start_account(context);
 221        down_read(&context->umem_rwsem);
 222        rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
 223                                      end,
 224                                      invalidate_range_start_trampoline, NULL);
 225        up_read(&context->umem_rwsem);
 226}
 227
 228static int invalidate_range_end_trampoline(struct ib_umem *item, u64 start,
 229                                           u64 end, void *cookie)
 230{
 231        ib_umem_notifier_end_account(item);
 232        return 0;
 233}
 234
 235static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
 236                                                  struct mm_struct *mm,
 237                                                  unsigned long start,
 238                                                  unsigned long end)
 239{
 240        struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
 241
 242        if (!context->invalidate_range)
 243                return;
 244
 245        down_read(&context->umem_rwsem);
 246        rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
 247                                      end,
 248                                      invalidate_range_end_trampoline, NULL);
 249        up_read(&context->umem_rwsem);
 250        ib_ucontext_notifier_end_account(context);
 251}
 252
 253static const struct mmu_notifier_ops ib_umem_notifiers = {
 254        .release                    = ib_umem_notifier_release,
 255        .invalidate_range_start     = ib_umem_notifier_invalidate_range_start,
 256        .invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
 257};
 258
 259struct ib_umem *ib_alloc_odp_umem(struct ib_ucontext *context,
 260                                  unsigned long addr,
 261                                  size_t size)
 262{
 263        struct ib_umem *umem;
 264        struct ib_umem_odp *odp_data;
 265        int pages = size >> PAGE_SHIFT;
 266        int ret;
 267
 268        umem = kzalloc(sizeof(*umem), GFP_KERNEL);
 269        if (!umem)
 270                return ERR_PTR(-ENOMEM);
 271
 272        umem->context    = context;
 273        umem->length     = size;
 274        umem->address    = addr;
 275        umem->page_shift = PAGE_SHIFT;
 276        umem->writable   = 1;
 277
 278        odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
 279        if (!odp_data) {
 280                ret = -ENOMEM;
 281                goto out_umem;
 282        }
 283        odp_data->umem = umem;
 284
 285        mutex_init(&odp_data->umem_mutex);
 286        init_completion(&odp_data->notifier_completion);
 287
 288        odp_data->page_list = vzalloc(pages * sizeof(*odp_data->page_list));
 289        if (!odp_data->page_list) {
 290                ret = -ENOMEM;
 291                goto out_odp_data;
 292        }
 293
 294        odp_data->dma_list = vzalloc(pages * sizeof(*odp_data->dma_list));
 295        if (!odp_data->dma_list) {
 296                ret = -ENOMEM;
 297                goto out_page_list;
 298        }
 299
 300        down_write(&context->umem_rwsem);
 301        context->odp_mrs_count++;
 302        rbt_ib_umem_insert(&odp_data->interval_tree, &context->umem_tree);
 303        if (likely(!atomic_read(&context->notifier_count)))
 304                odp_data->mn_counters_active = true;
 305        else
 306                list_add(&odp_data->no_private_counters,
 307                         &context->no_private_counters);
 308        up_write(&context->umem_rwsem);
 309
 310        umem->odp_data = odp_data;
 311
 312        return umem;
 313
 314out_page_list:
 315        vfree(odp_data->page_list);
 316out_odp_data:
 317        kfree(odp_data);
 318out_umem:
 319        kfree(umem);
 320        return ERR_PTR(ret);
 321}
 322EXPORT_SYMBOL(ib_alloc_odp_umem);
 323
 324int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
 325                    int access)
 326{
 327        int ret_val;
 328        struct pid *our_pid;
 329        struct mm_struct *mm = get_task_mm(current);
 330
 331        if (!mm)
 332                return -EINVAL;
 333
 334        if (access & IB_ACCESS_HUGETLB) {
 335                struct vm_area_struct *vma;
 336                struct hstate *h;
 337
 338                down_read(&mm->mmap_sem);
 339                vma = find_vma(mm, ib_umem_start(umem));
 340                if (!vma || !is_vm_hugetlb_page(vma)) {
 341                        up_read(&mm->mmap_sem);
 342                        return -EINVAL;
 343                }
 344                h = hstate_vma(vma);
 345                umem->page_shift = huge_page_shift(h);
 346                up_read(&mm->mmap_sem);
 347                umem->hugetlb = 1;
 348        } else {
 349                umem->hugetlb = 0;
 350        }
 351
 352        /* Prevent creating ODP MRs in child processes */
 353        rcu_read_lock();
 354        our_pid = get_task_pid(current->group_leader, PIDTYPE_PID);
 355        rcu_read_unlock();
 356        put_pid(our_pid);
 357        if (context->tgid != our_pid) {
 358                ret_val = -EINVAL;
 359                goto out_mm;
 360        }
 361
 362        umem->odp_data = kzalloc(sizeof(*umem->odp_data), GFP_KERNEL);
 363        if (!umem->odp_data) {
 364                ret_val = -ENOMEM;
 365                goto out_mm;
 366        }
 367        umem->odp_data->umem = umem;
 368
 369        mutex_init(&umem->odp_data->umem_mutex);
 370
 371        init_completion(&umem->odp_data->notifier_completion);
 372
 373        if (ib_umem_num_pages(umem)) {
 374                umem->odp_data->page_list = vzalloc(ib_umem_num_pages(umem) *
 375                                            sizeof(*umem->odp_data->page_list));
 376                if (!umem->odp_data->page_list) {
 377                        ret_val = -ENOMEM;
 378                        goto out_odp_data;
 379                }
 380
 381                umem->odp_data->dma_list = vzalloc(ib_umem_num_pages(umem) *
 382                                          sizeof(*umem->odp_data->dma_list));
 383                if (!umem->odp_data->dma_list) {
 384                        ret_val = -ENOMEM;
 385                        goto out_page_list;
 386                }
 387        }
 388
 389        /*
 390         * When using MMU notifiers, we will get a
 391         * notification before the "current" task (and MM) is
 392         * destroyed. We use the umem_rwsem semaphore to synchronize.
 393         */
 394        down_write(&context->umem_rwsem);
 395        context->odp_mrs_count++;
 396        if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
 397                rbt_ib_umem_insert(&umem->odp_data->interval_tree,
 398                                   &context->umem_tree);
 399        if (likely(!atomic_read(&context->notifier_count)) ||
 400            context->odp_mrs_count == 1)
 401                umem->odp_data->mn_counters_active = true;
 402        else
 403                list_add(&umem->odp_data->no_private_counters,
 404                         &context->no_private_counters);
 405        downgrade_write(&context->umem_rwsem);
 406
 407        if (context->odp_mrs_count == 1) {
 408                /*
 409                 * Note that at this point, no MMU notifier is running
 410                 * for this context!
 411                 */
 412                atomic_set(&context->notifier_count, 0);
 413                INIT_HLIST_NODE(&context->mn.hlist);
 414                context->mn.ops = &ib_umem_notifiers;
 415                /*
 416                 * Lock-dep detects a false positive for mmap_sem vs.
 417                 * umem_rwsem, due to not grasping downgrade_write correctly.
 418                 */
 419                lockdep_off();
 420                ret_val = mmu_notifier_register(&context->mn, mm);
 421                lockdep_on();
 422                if (ret_val) {
 423                        pr_err("Failed to register mmu_notifier %d\n", ret_val);
 424                        ret_val = -EBUSY;
 425                        goto out_mutex;
 426                }
 427        }
 428
 429        up_read(&context->umem_rwsem);
 430
 431        /*
 432         * Note that doing an mmput can cause a notifier for the relevant mm.
 433         * If the notifier is called while we hold the umem_rwsem, this will
 434         * cause a deadlock. Therefore, we release the reference only after we
 435         * released the semaphore.
 436         */
 437        mmput(mm);
 438        return 0;
 439
 440out_mutex:
 441        up_read(&context->umem_rwsem);
 442        vfree(umem->odp_data->dma_list);
 443out_page_list:
 444        vfree(umem->odp_data->page_list);
 445out_odp_data:
 446        kfree(umem->odp_data);
 447out_mm:
 448        mmput(mm);
 449        return ret_val;
 450}
 451
 452void ib_umem_odp_release(struct ib_umem *umem)
 453{
 454        struct ib_ucontext *context = umem->context;
 455
 456        /*
 457         * Ensure that no more pages are mapped in the umem.
 458         *
 459         * It is the driver's responsibility to ensure, before calling us,
 460         * that the hardware will not attempt to access the MR any more.
 461         */
 462        ib_umem_odp_unmap_dma_pages(umem, ib_umem_start(umem),
 463                                    ib_umem_end(umem));
 464
 465        down_write(&context->umem_rwsem);
 466        if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
 467                rbt_ib_umem_remove(&umem->odp_data->interval_tree,
 468                                   &context->umem_tree);
 469        context->odp_mrs_count--;
 470        if (!umem->odp_data->mn_counters_active) {
 471                list_del(&umem->odp_data->no_private_counters);
 472                complete_all(&umem->odp_data->notifier_completion);
 473        }
 474
 475        /*
 476         * Downgrade the lock to a read lock. This ensures that the notifiers
 477         * (who lock the mutex for reading) will be able to finish, and we
 478         * will be able to enventually obtain the mmu notifiers SRCU. Note
 479         * that since we are doing it atomically, no other user could register
 480         * and unregister while we do the check.
 481         */
 482        downgrade_write(&context->umem_rwsem);
 483        if (!context->odp_mrs_count) {
 484                struct task_struct *owning_process = NULL;
 485                struct mm_struct *owning_mm        = NULL;
 486
 487                owning_process = get_pid_task(context->tgid,
 488                                              PIDTYPE_PID);
 489                if (owning_process == NULL)
 490                        /*
 491                         * The process is already dead, notifier were removed
 492                         * already.
 493                         */
 494                        goto out;
 495
 496                owning_mm = get_task_mm(owning_process);
 497                if (owning_mm == NULL)
 498                        /*
 499                         * The process' mm is already dead, notifier were
 500                         * removed already.
 501                         */
 502                        goto out_put_task;
 503                mmu_notifier_unregister(&context->mn, owning_mm);
 504
 505                mmput(owning_mm);
 506
 507out_put_task:
 508                put_task_struct(owning_process);
 509        }
 510out:
 511        up_read(&context->umem_rwsem);
 512
 513        vfree(umem->odp_data->dma_list);
 514        vfree(umem->odp_data->page_list);
 515        kfree(umem->odp_data);
 516        kfree(umem);
 517}
 518
 519/*
 520 * Map for DMA and insert a single page into the on-demand paging page tables.
 521 *
 522 * @umem: the umem to insert the page to.
 523 * @page_index: index in the umem to add the page to.
 524 * @page: the page struct to map and add.
 525 * @access_mask: access permissions needed for this page.
 526 * @current_seq: sequence number for synchronization with invalidations.
 527 *               the sequence number is taken from
 528 *               umem->odp_data->notifiers_seq.
 529 *
 530 * The function returns -EFAULT if the DMA mapping operation fails. It returns
 531 * -EAGAIN if a concurrent invalidation prevents us from updating the page.
 532 *
 533 * The page is released via put_page even if the operation failed. For
 534 * on-demand pinning, the page is released whenever it isn't stored in the
 535 * umem.
 536 */
 537static int ib_umem_odp_map_dma_single_page(
 538                struct ib_umem *umem,
 539                int page_index,
 540                struct page *page,
 541                u64 access_mask,
 542                unsigned long current_seq)
 543{
 544        struct ib_device *dev = umem->context->device;
 545        dma_addr_t dma_addr;
 546        int stored_page = 0;
 547        int remove_existing_mapping = 0;
 548        int ret = 0;
 549
 550        /*
 551         * Note: we avoid writing if seq is different from the initial seq, to
 552         * handle case of a racing notifier. This check also allows us to bail
 553         * early if we have a notifier running in parallel with us.
 554         */
 555        if (ib_umem_mmu_notifier_retry(umem, current_seq)) {
 556                ret = -EAGAIN;
 557                goto out;
 558        }
 559        if (!(umem->odp_data->dma_list[page_index])) {
 560                dma_addr = ib_dma_map_page(dev,
 561                                           page,
 562                                           0, BIT(umem->page_shift),
 563                                           DMA_BIDIRECTIONAL);
 564                if (ib_dma_mapping_error(dev, dma_addr)) {
 565                        ret = -EFAULT;
 566                        goto out;
 567                }
 568                umem->odp_data->dma_list[page_index] = dma_addr | access_mask;
 569                umem->odp_data->page_list[page_index] = page;
 570                umem->npages++;
 571                stored_page = 1;
 572        } else if (umem->odp_data->page_list[page_index] == page) {
 573                umem->odp_data->dma_list[page_index] |= access_mask;
 574        } else {
 575                pr_err("error: got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
 576                       umem->odp_data->page_list[page_index], page);
 577                /* Better remove the mapping now, to prevent any further
 578                 * damage. */
 579                remove_existing_mapping = 1;
 580        }
 581
 582out:
 583        /* On Demand Paging - avoid pinning the page */
 584        if (umem->context->invalidate_range || !stored_page)
 585                put_page(page);
 586
 587        if (remove_existing_mapping && umem->context->invalidate_range) {
 588                invalidate_page_trampoline(
 589                        umem,
 590                        ib_umem_start(umem) + (page_index >> umem->page_shift),
 591                        ib_umem_start(umem) + ((page_index + 1) >>
 592                                               umem->page_shift),
 593                        NULL);
 594                ret = -EAGAIN;
 595        }
 596
 597        return ret;
 598}
 599
 600/**
 601 * ib_umem_odp_map_dma_pages - Pin and DMA map userspace memory in an ODP MR.
 602 *
 603 * Pins the range of pages passed in the argument, and maps them to
 604 * DMA addresses. The DMA addresses of the mapped pages is updated in
 605 * umem->odp_data->dma_list.
 606 *
 607 * Returns the number of pages mapped in success, negative error code
 608 * for failure.
 609 * An -EAGAIN error code is returned when a concurrent mmu notifier prevents
 610 * the function from completing its task.
 611 * An -ENOENT error code indicates that userspace process is being terminated
 612 * and mm was already destroyed.
 613 * @umem: the umem to map and pin
 614 * @user_virt: the address from which we need to map.
 615 * @bcnt: the minimal number of bytes to pin and map. The mapping might be
 616 *        bigger due to alignment, and may also be smaller in case of an error
 617 *        pinning or mapping a page. The actual pages mapped is returned in
 618 *        the return value.
 619 * @access_mask: bit mask of the requested access permissions for the given
 620 *               range.
 621 * @current_seq: the MMU notifiers sequance value for synchronization with
 622 *               invalidations. the sequance number is read from
 623 *               umem->odp_data->notifiers_seq before calling this function
 624 */
 625int ib_umem_odp_map_dma_pages(struct ib_umem *umem, u64 user_virt, u64 bcnt,
 626                              u64 access_mask, unsigned long current_seq)
 627{
 628        struct task_struct *owning_process  = NULL;
 629        struct mm_struct   *owning_mm       = NULL;
 630        struct page       **local_page_list = NULL;
 631        u64 page_mask, off;
 632        int j, k, ret = 0, start_idx, npages = 0, page_shift;
 633        unsigned int flags = 0;
 634        phys_addr_t p = 0;
 635
 636        if (access_mask == 0)
 637                return -EINVAL;
 638
 639        if (user_virt < ib_umem_start(umem) ||
 640            user_virt + bcnt > ib_umem_end(umem))
 641                return -EFAULT;
 642
 643        local_page_list = (struct page **)__get_free_page(GFP_KERNEL);
 644        if (!local_page_list)
 645                return -ENOMEM;
 646
 647        page_shift = umem->page_shift;
 648        page_mask = ~(BIT(page_shift) - 1);
 649        off = user_virt & (~page_mask);
 650        user_virt = user_virt & page_mask;
 651        bcnt += off; /* Charge for the first page offset as well. */
 652
 653        owning_process = get_pid_task(umem->context->tgid, PIDTYPE_PID);
 654        if (owning_process == NULL) {
 655                ret = -EINVAL;
 656                goto out_no_task;
 657        }
 658
 659        owning_mm = get_task_mm(owning_process);
 660        if (owning_mm == NULL) {
 661                ret = -ENOENT;
 662                goto out_put_task;
 663        }
 664
 665        if (access_mask & ODP_WRITE_ALLOWED_BIT)
 666                flags |= FOLL_WRITE;
 667
 668        start_idx = (user_virt - ib_umem_start(umem)) >> page_shift;
 669        k = start_idx;
 670
 671        while (bcnt > 0) {
 672                const size_t gup_num_pages = min_t(size_t,
 673                                (bcnt + BIT(page_shift) - 1) >> page_shift,
 674                                PAGE_SIZE / sizeof(struct page *));
 675
 676                down_read(&owning_mm->mmap_sem);
 677                /*
 678                 * Note: this might result in redundent page getting. We can
 679                 * avoid this by checking dma_list to be 0 before calling
 680                 * get_user_pages. However, this make the code much more
 681                 * complex (and doesn't gain us much performance in most use
 682                 * cases).
 683                 */
 684                npages = get_user_pages_remote(owning_process, owning_mm,
 685                                user_virt, gup_num_pages,
 686                                flags, local_page_list, NULL, NULL);
 687                up_read(&owning_mm->mmap_sem);
 688
 689                if (npages < 0)
 690                        break;
 691
 692                bcnt -= min_t(size_t, npages << PAGE_SHIFT, bcnt);
 693                mutex_lock(&umem->odp_data->umem_mutex);
 694                for (j = 0; j < npages; j++, user_virt += PAGE_SIZE) {
 695                        if (user_virt & ~page_mask) {
 696                                p += PAGE_SIZE;
 697                                if (page_to_phys(local_page_list[j]) != p) {
 698                                        ret = -EFAULT;
 699                                        break;
 700                                }
 701                                put_page(local_page_list[j]);
 702                                continue;
 703                        }
 704
 705                        ret = ib_umem_odp_map_dma_single_page(
 706                                        umem, k, local_page_list[j],
 707                                        access_mask, current_seq);
 708                        if (ret < 0)
 709                                break;
 710
 711                        p = page_to_phys(local_page_list[j]);
 712                        k++;
 713                }
 714                mutex_unlock(&umem->odp_data->umem_mutex);
 715
 716                if (ret < 0) {
 717                        /* Release left over pages when handling errors. */
 718                        for (++j; j < npages; ++j)
 719                                put_page(local_page_list[j]);
 720                        break;
 721                }
 722        }
 723
 724        if (ret >= 0) {
 725                if (npages < 0 && k == start_idx)
 726                        ret = npages;
 727                else
 728                        ret = k - start_idx;
 729        }
 730
 731        mmput(owning_mm);
 732out_put_task:
 733        put_task_struct(owning_process);
 734out_no_task:
 735        free_page((unsigned long)local_page_list);
 736        return ret;
 737}
 738EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
 739
 740void ib_umem_odp_unmap_dma_pages(struct ib_umem *umem, u64 virt,
 741                                 u64 bound)
 742{
 743        int idx;
 744        u64 addr;
 745        struct ib_device *dev = umem->context->device;
 746
 747        virt  = max_t(u64, virt,  ib_umem_start(umem));
 748        bound = min_t(u64, bound, ib_umem_end(umem));
 749        /* Note that during the run of this function, the
 750         * notifiers_count of the MR is > 0, preventing any racing
 751         * faults from completion. We might be racing with other
 752         * invalidations, so we must make sure we free each page only
 753         * once. */
 754        mutex_lock(&umem->odp_data->umem_mutex);
 755        for (addr = virt; addr < bound; addr += BIT(umem->page_shift)) {
 756                idx = (addr - ib_umem_start(umem)) >> umem->page_shift;
 757                if (umem->odp_data->page_list[idx]) {
 758                        struct page *page = umem->odp_data->page_list[idx];
 759                        dma_addr_t dma = umem->odp_data->dma_list[idx];
 760                        dma_addr_t dma_addr = dma & ODP_DMA_ADDR_MASK;
 761
 762                        WARN_ON(!dma_addr);
 763
 764                        ib_dma_unmap_page(dev, dma_addr, PAGE_SIZE,
 765                                          DMA_BIDIRECTIONAL);
 766                        if (dma & ODP_WRITE_ALLOWED_BIT) {
 767                                struct page *head_page = compound_head(page);
 768                                /*
 769                                 * set_page_dirty prefers being called with
 770                                 * the page lock. However, MMU notifiers are
 771                                 * called sometimes with and sometimes without
 772                                 * the lock. We rely on the umem_mutex instead
 773                                 * to prevent other mmu notifiers from
 774                                 * continuing and allowing the page mapping to
 775                                 * be removed.
 776                                 */
 777                                set_page_dirty(head_page);
 778                        }
 779                        /* on demand pinning support */
 780                        if (!umem->context->invalidate_range)
 781                                put_page(page);
 782                        umem->odp_data->page_list[idx] = NULL;
 783                        umem->odp_data->dma_list[idx] = 0;
 784                        umem->npages--;
 785                }
 786        }
 787        mutex_unlock(&umem->odp_data->umem_mutex);
 788}
 789EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
 790
 791/* @last is not a part of the interval. See comment for function
 792 * node_last.
 793 */
 794int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
 795                                  u64 start, u64 last,
 796                                  umem_call_back cb,
 797                                  void *cookie)
 798{
 799        int ret_val = 0;
 800        struct umem_odp_node *node, *next;
 801        struct ib_umem_odp *umem;
 802
 803        if (unlikely(start == last))
 804                return ret_val;
 805
 806        for (node = rbt_ib_umem_iter_first(root, start, last - 1);
 807                        node; node = next) {
 808                next = rbt_ib_umem_iter_next(node, start, last - 1);
 809                umem = container_of(node, struct ib_umem_odp, interval_tree);
 810                ret_val = cb(umem->umem, start, last, cookie) || ret_val;
 811        }
 812
 813        return ret_val;
 814}
 815EXPORT_SYMBOL(rbt_ib_umem_for_each_in_range);
 816
 817struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
 818                                       u64 addr, u64 length)
 819{
 820        struct umem_odp_node *node;
 821
 822        node = rbt_ib_umem_iter_first(root, addr, addr + length - 1);
 823        if (node)
 824                return container_of(node, struct ib_umem_odp, interval_tree);
 825        return NULL;
 826
 827}
 828EXPORT_SYMBOL(rbt_ib_umem_lookup);
 829