linux/drivers/infiniband/core/umem_odp.c
<<
>>
Prefs
   1/*
   2 * Copyright (c) 2014 Mellanox Technologies. All rights reserved.
   3 *
   4 * This software is available to you under a choice of one of two
   5 * licenses.  You may choose to be licensed under the terms of the GNU
   6 * General Public License (GPL) Version 2, available from the file
   7 * COPYING in the main directory of this source tree, or the
   8 * OpenIB.org BSD license below:
   9 *
  10 *     Redistribution and use in source and binary forms, with or
  11 *     without modification, are permitted provided that the following
  12 *     conditions are met:
  13 *
  14 *      - Redistributions of source code must retain the above
  15 *        copyright notice, this list of conditions and the following
  16 *        disclaimer.
  17 *
  18 *      - Redistributions in binary form must reproduce the above
  19 *        copyright notice, this list of conditions and the following
  20 *        disclaimer in the documentation and/or other materials
  21 *        provided with the distribution.
  22 *
  23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  30 * SOFTWARE.
  31 */
  32
  33#include <linux/types.h>
  34#include <linux/sched.h>
  35#include <linux/sched/mm.h>
  36#include <linux/sched/task.h>
  37#include <linux/pid.h>
  38#include <linux/slab.h>
  39#include <linux/export.h>
  40#include <linux/vmalloc.h>
  41#include <linux/hugetlb.h>
  42#include <linux/interval_tree_generic.h>
  43
  44#include <rdma/ib_verbs.h>
  45#include <rdma/ib_umem.h>
  46#include <rdma/ib_umem_odp.h>
  47
  48/*
  49 * The ib_umem list keeps track of memory regions for which the HW
  50 * device request to receive notification when the related memory
  51 * mapping is changed.
  52 *
  53 * ib_umem_lock protects the list.
  54 */
  55
  56static u64 node_start(struct umem_odp_node *n)
  57{
  58        struct ib_umem_odp *umem_odp =
  59                        container_of(n, struct ib_umem_odp, interval_tree);
  60
  61        return ib_umem_start(&umem_odp->umem);
  62}
  63
  64/* Note that the representation of the intervals in the interval tree
  65 * considers the ending point as contained in the interval, while the
  66 * function ib_umem_end returns the first address which is not contained
  67 * in the umem.
  68 */
  69static u64 node_last(struct umem_odp_node *n)
  70{
  71        struct ib_umem_odp *umem_odp =
  72                        container_of(n, struct ib_umem_odp, interval_tree);
  73
  74        return ib_umem_end(&umem_odp->umem) - 1;
  75}
  76
  77INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
  78                     node_start, node_last, static, rbt_ib_umem)
  79
  80static void ib_umem_notifier_start_account(struct ib_umem_odp *umem_odp)
  81{
  82        mutex_lock(&umem_odp->umem_mutex);
  83        if (umem_odp->notifiers_count++ == 0)
  84                /*
  85                 * Initialize the completion object for waiting on
  86                 * notifiers. Since notifier_count is zero, no one should be
  87                 * waiting right now.
  88                 */
  89                reinit_completion(&umem_odp->notifier_completion);
  90        mutex_unlock(&umem_odp->umem_mutex);
  91}
  92
  93static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp)
  94{
  95        mutex_lock(&umem_odp->umem_mutex);
  96        /*
  97         * This sequence increase will notify the QP page fault that the page
  98         * that is going to be mapped in the spte could have been freed.
  99         */
 100        ++umem_odp->notifiers_seq;
 101        if (--umem_odp->notifiers_count == 0)
 102                complete_all(&umem_odp->notifier_completion);
 103        mutex_unlock(&umem_odp->umem_mutex);
 104}
 105
 106static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp,
 107                                               u64 start, u64 end, void *cookie)
 108{
 109        struct ib_umem *umem = &umem_odp->umem;
 110
 111        /*
 112         * Increase the number of notifiers running, to
 113         * prevent any further fault handling on this MR.
 114         */
 115        ib_umem_notifier_start_account(umem_odp);
 116        umem_odp->dying = 1;
 117        /* Make sure that the fact the umem is dying is out before we release
 118         * all pending page faults. */
 119        smp_wmb();
 120        complete_all(&umem_odp->notifier_completion);
 121        umem->context->invalidate_range(umem_odp, ib_umem_start(umem),
 122                                        ib_umem_end(umem));
 123        return 0;
 124}
 125
 126static void ib_umem_notifier_release(struct mmu_notifier *mn,
 127                                     struct mm_struct *mm)
 128{
 129        struct ib_ucontext_per_mm *per_mm =
 130                container_of(mn, struct ib_ucontext_per_mm, mn);
 131
 132        down_read(&per_mm->umem_rwsem);
 133        if (per_mm->active)
 134                rbt_ib_umem_for_each_in_range(
 135                        &per_mm->umem_tree, 0, ULLONG_MAX,
 136                        ib_umem_notifier_release_trampoline, NULL);
 137        up_read(&per_mm->umem_rwsem);
 138}
 139
 140static int invalidate_range_start_trampoline(struct ib_umem_odp *item,
 141                                             u64 start, u64 end, void *cookie)
 142{
 143        ib_umem_notifier_start_account(item);
 144        item->umem.context->invalidate_range(item, start, end);
 145        return 0;
 146}
 147
 148static void ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
 149                                                    struct mm_struct *mm,
 150                                                    unsigned long start,
 151                                                    unsigned long end)
 152{
 153        struct ib_ucontext_per_mm *per_mm =
 154                container_of(mn, struct ib_ucontext_per_mm, mn);
 155
 156        down_read(&per_mm->umem_rwsem);
 157
 158        if (!per_mm->active) {
 159                up_read(&per_mm->umem_rwsem);
 160                /*
 161                 * At this point active is permanently set and visible to this
 162                 * CPU without a lock, that fact is relied on to skip the unlock
 163                 * in range_end.
 164                 */
 165                return;
 166        }
 167
 168        rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, start,
 169                                      end,
 170                                      invalidate_range_start_trampoline, NULL);
 171}
 172
 173static int invalidate_range_end_trampoline(struct ib_umem_odp *item, u64 start,
 174                                           u64 end, void *cookie)
 175{
 176        ib_umem_notifier_end_account(item);
 177        return 0;
 178}
 179
 180static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
 181                                                  struct mm_struct *mm,
 182                                                  unsigned long start,
 183                                                  unsigned long end)
 184{
 185        struct ib_ucontext_per_mm *per_mm =
 186                container_of(mn, struct ib_ucontext_per_mm, mn);
 187
 188        if (unlikely(!per_mm->active))
 189                return;
 190
 191        rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, start,
 192                                      end,
 193                                      invalidate_range_end_trampoline, NULL);
 194        up_read(&per_mm->umem_rwsem);
 195}
 196
 197static const struct mmu_notifier_ops ib_umem_notifiers = {
 198        .release                    = ib_umem_notifier_release,
 199        .invalidate_range_start     = ib_umem_notifier_invalidate_range_start,
 200        .invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
 201};
 202
 203static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
 204{
 205        struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
 206        struct ib_umem *umem = &umem_odp->umem;
 207
 208        down_write(&per_mm->umem_rwsem);
 209        if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
 210                rbt_ib_umem_insert(&umem_odp->interval_tree,
 211                                   &per_mm->umem_tree);
 212        up_write(&per_mm->umem_rwsem);
 213}
 214
 215static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
 216{
 217        struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
 218        struct ib_umem *umem = &umem_odp->umem;
 219
 220        down_write(&per_mm->umem_rwsem);
 221        if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
 222                rbt_ib_umem_remove(&umem_odp->interval_tree,
 223                                   &per_mm->umem_tree);
 224        complete_all(&umem_odp->notifier_completion);
 225
 226        up_write(&per_mm->umem_rwsem);
 227}
 228
 229static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
 230                                               struct mm_struct *mm)
 231{
 232        struct ib_ucontext_per_mm *per_mm;
 233        int ret;
 234
 235        per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
 236        if (!per_mm)
 237                return ERR_PTR(-ENOMEM);
 238
 239        per_mm->context = ctx;
 240        per_mm->mm = mm;
 241        per_mm->umem_tree = RB_ROOT_CACHED;
 242        init_rwsem(&per_mm->umem_rwsem);
 243        per_mm->active = ctx->invalidate_range;
 244
 245        rcu_read_lock();
 246        per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
 247        rcu_read_unlock();
 248
 249        WARN_ON(mm != current->mm);
 250
 251        per_mm->mn.ops = &ib_umem_notifiers;
 252        ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
 253        if (ret) {
 254                dev_err(&ctx->device->dev,
 255                        "Failed to register mmu_notifier %d\n", ret);
 256                goto out_pid;
 257        }
 258
 259        list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
 260        return per_mm;
 261
 262out_pid:
 263        put_pid(per_mm->tgid);
 264        kfree(per_mm);
 265        return ERR_PTR(ret);
 266}
 267
 268static int get_per_mm(struct ib_umem_odp *umem_odp)
 269{
 270        struct ib_ucontext *ctx = umem_odp->umem.context;
 271        struct ib_ucontext_per_mm *per_mm;
 272
 273        /*
 274         * Generally speaking we expect only one or two per_mm in this list,
 275         * so no reason to optimize this search today.
 276         */
 277        mutex_lock(&ctx->per_mm_list_lock);
 278        list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
 279                if (per_mm->mm == umem_odp->umem.owning_mm)
 280                        goto found;
 281        }
 282
 283        per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
 284        if (IS_ERR(per_mm)) {
 285                mutex_unlock(&ctx->per_mm_list_lock);
 286                return PTR_ERR(per_mm);
 287        }
 288
 289found:
 290        umem_odp->per_mm = per_mm;
 291        per_mm->odp_mrs_count++;
 292        mutex_unlock(&ctx->per_mm_list_lock);
 293
 294        return 0;
 295}
 296
 297static void free_per_mm(struct rcu_head *rcu)
 298{
 299        kfree(container_of(rcu, struct ib_ucontext_per_mm, rcu));
 300}
 301
 302static void put_per_mm(struct ib_umem_odp *umem_odp)
 303{
 304        struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
 305        struct ib_ucontext *ctx = umem_odp->umem.context;
 306        bool need_free;
 307
 308        mutex_lock(&ctx->per_mm_list_lock);
 309        umem_odp->per_mm = NULL;
 310        per_mm->odp_mrs_count--;
 311        need_free = per_mm->odp_mrs_count == 0;
 312        if (need_free)
 313                list_del(&per_mm->ucontext_list);
 314        mutex_unlock(&ctx->per_mm_list_lock);
 315
 316        if (!need_free)
 317                return;
 318
 319        /*
 320         * NOTE! mmu_notifier_unregister() can happen between a start/end
 321         * callback, resulting in an start/end, and thus an unbalanced
 322         * lock. This doesn't really matter to us since we are about to kfree
 323         * the memory that holds the lock, however LOCKDEP doesn't like this.
 324         */
 325        down_write(&per_mm->umem_rwsem);
 326        per_mm->active = false;
 327        up_write(&per_mm->umem_rwsem);
 328
 329        WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root));
 330        mmu_notifier_unregister_no_release(&per_mm->mn, per_mm->mm);
 331        put_pid(per_mm->tgid);
 332        mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
 333}
 334
 335struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root,
 336                                      unsigned long addr, size_t size)
 337{
 338        struct ib_ucontext_per_mm *per_mm = root->per_mm;
 339        struct ib_ucontext *ctx = per_mm->context;
 340        struct ib_umem_odp *odp_data;
 341        struct ib_umem *umem;
 342        int pages = size >> PAGE_SHIFT;
 343        int ret;
 344
 345        odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
 346        if (!odp_data)
 347                return ERR_PTR(-ENOMEM);
 348        umem = &odp_data->umem;
 349        umem->context    = ctx;
 350        umem->length     = size;
 351        umem->address    = addr;
 352        umem->page_shift = PAGE_SHIFT;
 353        umem->writable   = root->umem.writable;
 354        umem->is_odp = 1;
 355        odp_data->per_mm = per_mm;
 356        umem->owning_mm  = per_mm->mm;
 357        mmgrab(umem->owning_mm);
 358
 359        mutex_init(&odp_data->umem_mutex);
 360        init_completion(&odp_data->notifier_completion);
 361
 362        odp_data->page_list =
 363                vzalloc(array_size(pages, sizeof(*odp_data->page_list)));
 364        if (!odp_data->page_list) {
 365                ret = -ENOMEM;
 366                goto out_odp_data;
 367        }
 368
 369        odp_data->dma_list =
 370                vzalloc(array_size(pages, sizeof(*odp_data->dma_list)));
 371        if (!odp_data->dma_list) {
 372                ret = -ENOMEM;
 373                goto out_page_list;
 374        }
 375
 376        /*
 377         * Caller must ensure that the umem_odp that the per_mm came from
 378         * cannot be freed during the call to ib_alloc_odp_umem.
 379         */
 380        mutex_lock(&ctx->per_mm_list_lock);
 381        per_mm->odp_mrs_count++;
 382        mutex_unlock(&ctx->per_mm_list_lock);
 383        add_umem_to_per_mm(odp_data);
 384
 385        return odp_data;
 386
 387out_page_list:
 388        vfree(odp_data->page_list);
 389out_odp_data:
 390        mmdrop(umem->owning_mm);
 391        kfree(odp_data);
 392        return ERR_PTR(ret);
 393}
 394EXPORT_SYMBOL(ib_alloc_odp_umem);
 395
 396int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
 397{
 398        struct ib_umem *umem = &umem_odp->umem;
 399        /*
 400         * NOTE: This must called in a process context where umem->owning_mm
 401         * == current->mm
 402         */
 403        struct mm_struct *mm = umem->owning_mm;
 404        int ret_val;
 405
 406        if (access & IB_ACCESS_HUGETLB) {
 407                struct vm_area_struct *vma;
 408                struct hstate *h;
 409
 410                down_read(&mm->mmap_sem);
 411                vma = find_vma(mm, ib_umem_start(umem));
 412                if (!vma || !is_vm_hugetlb_page(vma)) {
 413                        up_read(&mm->mmap_sem);
 414                        return -EINVAL;
 415                }
 416                h = hstate_vma(vma);
 417                umem->page_shift = huge_page_shift(h);
 418                up_read(&mm->mmap_sem);
 419                umem->hugetlb = 1;
 420        } else {
 421                umem->hugetlb = 0;
 422        }
 423
 424        mutex_init(&umem_odp->umem_mutex);
 425
 426        init_completion(&umem_odp->notifier_completion);
 427
 428        if (ib_umem_num_pages(umem)) {
 429                umem_odp->page_list =
 430                        vzalloc(array_size(sizeof(*umem_odp->page_list),
 431                                           ib_umem_num_pages(umem)));
 432                if (!umem_odp->page_list)
 433                        return -ENOMEM;
 434
 435                umem_odp->dma_list =
 436                        vzalloc(array_size(sizeof(*umem_odp->dma_list),
 437                                           ib_umem_num_pages(umem)));
 438                if (!umem_odp->dma_list) {
 439                        ret_val = -ENOMEM;
 440                        goto out_page_list;
 441                }
 442        }
 443
 444        ret_val = get_per_mm(umem_odp);
 445        if (ret_val)
 446                goto out_dma_list;
 447        add_umem_to_per_mm(umem_odp);
 448
 449        return 0;
 450
 451out_dma_list:
 452        vfree(umem_odp->dma_list);
 453out_page_list:
 454        vfree(umem_odp->page_list);
 455        return ret_val;
 456}
 457
 458void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
 459{
 460        struct ib_umem *umem = &umem_odp->umem;
 461
 462        /*
 463         * Ensure that no more pages are mapped in the umem.
 464         *
 465         * It is the driver's responsibility to ensure, before calling us,
 466         * that the hardware will not attempt to access the MR any more.
 467         */
 468        ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem),
 469                                    ib_umem_end(umem));
 470
 471        remove_umem_from_per_mm(umem_odp);
 472        put_per_mm(umem_odp);
 473        vfree(umem_odp->dma_list);
 474        vfree(umem_odp->page_list);
 475}
 476
 477/*
 478 * Map for DMA and insert a single page into the on-demand paging page tables.
 479 *
 480 * @umem: the umem to insert the page to.
 481 * @page_index: index in the umem to add the page to.
 482 * @page: the page struct to map and add.
 483 * @access_mask: access permissions needed for this page.
 484 * @current_seq: sequence number for synchronization with invalidations.
 485 *               the sequence number is taken from
 486 *               umem_odp->notifiers_seq.
 487 *
 488 * The function returns -EFAULT if the DMA mapping operation fails. It returns
 489 * -EAGAIN if a concurrent invalidation prevents us from updating the page.
 490 *
 491 * The page is released via put_page even if the operation failed. For
 492 * on-demand pinning, the page is released whenever it isn't stored in the
 493 * umem.
 494 */
 495static int ib_umem_odp_map_dma_single_page(
 496                struct ib_umem_odp *umem_odp,
 497                int page_index,
 498                struct page *page,
 499                u64 access_mask,
 500                unsigned long current_seq)
 501{
 502        struct ib_umem *umem = &umem_odp->umem;
 503        struct ib_device *dev = umem->context->device;
 504        dma_addr_t dma_addr;
 505        int stored_page = 0;
 506        int remove_existing_mapping = 0;
 507        int ret = 0;
 508
 509        /*
 510         * Note: we avoid writing if seq is different from the initial seq, to
 511         * handle case of a racing notifier. This check also allows us to bail
 512         * early if we have a notifier running in parallel with us.
 513         */
 514        if (ib_umem_mmu_notifier_retry(umem_odp, current_seq)) {
 515                ret = -EAGAIN;
 516                goto out;
 517        }
 518        if (!(umem_odp->dma_list[page_index])) {
 519                dma_addr = ib_dma_map_page(dev,
 520                                           page,
 521                                           0, BIT(umem->page_shift),
 522                                           DMA_BIDIRECTIONAL);
 523                if (ib_dma_mapping_error(dev, dma_addr)) {
 524                        ret = -EFAULT;
 525                        goto out;
 526                }
 527                umem_odp->dma_list[page_index] = dma_addr | access_mask;
 528                umem_odp->page_list[page_index] = page;
 529                umem->npages++;
 530                stored_page = 1;
 531        } else if (umem_odp->page_list[page_index] == page) {
 532                umem_odp->dma_list[page_index] |= access_mask;
 533        } else {
 534                pr_err("error: got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
 535                       umem_odp->page_list[page_index], page);
 536                /* Better remove the mapping now, to prevent any further
 537                 * damage. */
 538                remove_existing_mapping = 1;
 539        }
 540
 541out:
 542        /* On Demand Paging - avoid pinning the page */
 543        if (umem->context->invalidate_range || !stored_page)
 544                put_page(page);
 545
 546        if (remove_existing_mapping && umem->context->invalidate_range) {
 547                ib_umem_notifier_start_account(umem_odp);
 548                umem->context->invalidate_range(
 549                        umem_odp,
 550                        ib_umem_start(umem) + (page_index << umem->page_shift),
 551                        ib_umem_start(umem) +
 552                                ((page_index + 1) << umem->page_shift));
 553                ib_umem_notifier_end_account(umem_odp);
 554                ret = -EAGAIN;
 555        }
 556
 557        return ret;
 558}
 559
 560/**
 561 * ib_umem_odp_map_dma_pages - Pin and DMA map userspace memory in an ODP MR.
 562 *
 563 * Pins the range of pages passed in the argument, and maps them to
 564 * DMA addresses. The DMA addresses of the mapped pages is updated in
 565 * umem_odp->dma_list.
 566 *
 567 * Returns the number of pages mapped in success, negative error code
 568 * for failure.
 569 * An -EAGAIN error code is returned when a concurrent mmu notifier prevents
 570 * the function from completing its task.
 571 * An -ENOENT error code indicates that userspace process is being terminated
 572 * and mm was already destroyed.
 573 * @umem_odp: the umem to map and pin
 574 * @user_virt: the address from which we need to map.
 575 * @bcnt: the minimal number of bytes to pin and map. The mapping might be
 576 *        bigger due to alignment, and may also be smaller in case of an error
 577 *        pinning or mapping a page. The actual pages mapped is returned in
 578 *        the return value.
 579 * @access_mask: bit mask of the requested access permissions for the given
 580 *               range.
 581 * @current_seq: the MMU notifiers sequance value for synchronization with
 582 *               invalidations. the sequance number is read from
 583 *               umem_odp->notifiers_seq before calling this function
 584 */
 585int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
 586                              u64 bcnt, u64 access_mask,
 587                              unsigned long current_seq)
 588{
 589        struct ib_umem *umem = &umem_odp->umem;
 590        struct task_struct *owning_process  = NULL;
 591        struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
 592        struct page       **local_page_list = NULL;
 593        u64 page_mask, off;
 594        int j, k, ret = 0, start_idx, npages = 0, page_shift;
 595        unsigned int flags = 0;
 596        phys_addr_t p = 0;
 597
 598        if (access_mask == 0)
 599                return -EINVAL;
 600
 601        if (user_virt < ib_umem_start(umem) ||
 602            user_virt + bcnt > ib_umem_end(umem))
 603                return -EFAULT;
 604
 605        local_page_list = (struct page **)__get_free_page(GFP_KERNEL);
 606        if (!local_page_list)
 607                return -ENOMEM;
 608
 609        page_shift = umem->page_shift;
 610        page_mask = ~(BIT(page_shift) - 1);
 611        off = user_virt & (~page_mask);
 612        user_virt = user_virt & page_mask;
 613        bcnt += off; /* Charge for the first page offset as well. */
 614
 615        /*
 616         * owning_process is allowed to be NULL, this means somehow the mm is
 617         * existing beyond the lifetime of the originating process.. Presumably
 618         * mmget_not_zero will fail in this case.
 619         */
 620        owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
 621        if (!owning_process || !mmget_not_zero(owning_mm)) {
 622                ret = -EINVAL;
 623                goto out_put_task;
 624        }
 625
 626        if (access_mask & ODP_WRITE_ALLOWED_BIT)
 627                flags |= FOLL_WRITE;
 628
 629        start_idx = (user_virt - ib_umem_start(umem)) >> page_shift;
 630        k = start_idx;
 631
 632        while (bcnt > 0) {
 633                const size_t gup_num_pages = min_t(size_t,
 634                                (bcnt + BIT(page_shift) - 1) >> page_shift,
 635                                PAGE_SIZE / sizeof(struct page *));
 636
 637                down_read(&owning_mm->mmap_sem);
 638                /*
 639                 * Note: this might result in redundent page getting. We can
 640                 * avoid this by checking dma_list to be 0 before calling
 641                 * get_user_pages. However, this make the code much more
 642                 * complex (and doesn't gain us much performance in most use
 643                 * cases).
 644                 */
 645                npages = get_user_pages_remote(owning_process, owning_mm,
 646                                user_virt, gup_num_pages,
 647                                flags, local_page_list, NULL, NULL);
 648                up_read(&owning_mm->mmap_sem);
 649
 650                if (npages < 0) {
 651                        if (npages != -EAGAIN)
 652                                pr_warn("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
 653                        else
 654                                pr_debug("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
 655                        break;
 656                }
 657
 658                bcnt -= min_t(size_t, npages << PAGE_SHIFT, bcnt);
 659                mutex_lock(&umem_odp->umem_mutex);
 660                for (j = 0; j < npages; j++, user_virt += PAGE_SIZE) {
 661                        if (user_virt & ~page_mask) {
 662                                p += PAGE_SIZE;
 663                                if (page_to_phys(local_page_list[j]) != p) {
 664                                        ret = -EFAULT;
 665                                        break;
 666                                }
 667                                put_page(local_page_list[j]);
 668                                continue;
 669                        }
 670
 671                        ret = ib_umem_odp_map_dma_single_page(
 672                                        umem_odp, k, local_page_list[j],
 673                                        access_mask, current_seq);
 674                        if (ret < 0) {
 675                                if (ret != -EAGAIN)
 676                                        pr_warn("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
 677                                else
 678                                        pr_debug("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
 679                                break;
 680                        }
 681
 682                        p = page_to_phys(local_page_list[j]);
 683                        k++;
 684                }
 685                mutex_unlock(&umem_odp->umem_mutex);
 686
 687                if (ret < 0) {
 688                        /* Release left over pages when handling errors. */
 689                        for (++j; j < npages; ++j)
 690                                put_page(local_page_list[j]);
 691                        break;
 692                }
 693        }
 694
 695        if (ret >= 0) {
 696                if (npages < 0 && k == start_idx)
 697                        ret = npages;
 698                else
 699                        ret = k - start_idx;
 700        }
 701
 702        mmput(owning_mm);
 703out_put_task:
 704        if (owning_process)
 705                put_task_struct(owning_process);
 706        free_page((unsigned long)local_page_list);
 707        return ret;
 708}
 709EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
 710
 711void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
 712                                 u64 bound)
 713{
 714        struct ib_umem *umem = &umem_odp->umem;
 715        int idx;
 716        u64 addr;
 717        struct ib_device *dev = umem->context->device;
 718
 719        virt  = max_t(u64, virt,  ib_umem_start(umem));
 720        bound = min_t(u64, bound, ib_umem_end(umem));
 721        /* Note that during the run of this function, the
 722         * notifiers_count of the MR is > 0, preventing any racing
 723         * faults from completion. We might be racing with other
 724         * invalidations, so we must make sure we free each page only
 725         * once. */
 726        mutex_lock(&umem_odp->umem_mutex);
 727        for (addr = virt; addr < bound; addr += BIT(umem->page_shift)) {
 728                idx = (addr - ib_umem_start(umem)) >> umem->page_shift;
 729                if (umem_odp->page_list[idx]) {
 730                        struct page *page = umem_odp->page_list[idx];
 731                        dma_addr_t dma = umem_odp->dma_list[idx];
 732                        dma_addr_t dma_addr = dma & ODP_DMA_ADDR_MASK;
 733
 734                        WARN_ON(!dma_addr);
 735
 736                        ib_dma_unmap_page(dev, dma_addr, PAGE_SIZE,
 737                                          DMA_BIDIRECTIONAL);
 738                        if (dma & ODP_WRITE_ALLOWED_BIT) {
 739                                struct page *head_page = compound_head(page);
 740                                /*
 741                                 * set_page_dirty prefers being called with
 742                                 * the page lock. However, MMU notifiers are
 743                                 * called sometimes with and sometimes without
 744                                 * the lock. We rely on the umem_mutex instead
 745                                 * to prevent other mmu notifiers from
 746                                 * continuing and allowing the page mapping to
 747                                 * be removed.
 748                                 */
 749                                set_page_dirty(head_page);
 750                        }
 751                        /* on demand pinning support */
 752                        if (!umem->context->invalidate_range)
 753                                put_page(page);
 754                        umem_odp->page_list[idx] = NULL;
 755                        umem_odp->dma_list[idx] = 0;
 756                        umem->npages--;
 757                }
 758        }
 759        mutex_unlock(&umem_odp->umem_mutex);
 760}
 761EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
 762
 763/* @last is not a part of the interval. See comment for function
 764 * node_last.
 765 */
 766int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
 767                                  u64 start, u64 last,
 768                                  umem_call_back cb,
 769                                  void *cookie)
 770{
 771        int ret_val = 0;
 772        struct umem_odp_node *node, *next;
 773        struct ib_umem_odp *umem;
 774
 775        if (unlikely(start == last))
 776                return ret_val;
 777
 778        for (node = rbt_ib_umem_iter_first(root, start, last - 1);
 779                        node; node = next) {
 780                next = rbt_ib_umem_iter_next(node, start, last - 1);
 781                umem = container_of(node, struct ib_umem_odp, interval_tree);
 782                ret_val = cb(umem, start, last, cookie) || ret_val;
 783        }
 784
 785        return ret_val;
 786}
 787EXPORT_SYMBOL(rbt_ib_umem_for_each_in_range);
 788
 789struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
 790                                       u64 addr, u64 length)
 791{
 792        struct umem_odp_node *node;
 793
 794        node = rbt_ib_umem_iter_first(root, addr, addr + length - 1);
 795        if (node)
 796                return container_of(node, struct ib_umem_odp, interval_tree);
 797        return NULL;
 798
 799}
 800EXPORT_SYMBOL(rbt_ib_umem_lookup);
 801