linux/drivers/infiniband/core/umem_odp.c
<<
>>
Prefs
   1/*
   2 * Copyright (c) 2014 Mellanox Technologies. All rights reserved.
   3 *
   4 * This software is available to you under a choice of one of two
   5 * licenses.  You may choose to be licensed under the terms of the GNU
   6 * General Public License (GPL) Version 2, available from the file
   7 * COPYING in the main directory of this source tree, or the
   8 * OpenIB.org BSD license below:
   9 *
  10 *     Redistribution and use in source and binary forms, with or
  11 *     without modification, are permitted provided that the following
  12 *     conditions are met:
  13 *
  14 *      - Redistributions of source code must retain the above
  15 *        copyright notice, this list of conditions and the following
  16 *        disclaimer.
  17 *
  18 *      - Redistributions in binary form must reproduce the above
  19 *        copyright notice, this list of conditions and the following
  20 *        disclaimer in the documentation and/or other materials
  21 *        provided with the distribution.
  22 *
  23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  30 * SOFTWARE.
  31 */
  32
  33#include <linux/types.h>
  34#include <linux/sched.h>
  35#include <linux/sched/mm.h>
  36#include <linux/sched/task.h>
  37#include <linux/pid.h>
  38#include <linux/slab.h>
  39#include <linux/export.h>
  40#include <linux/vmalloc.h>
  41#include <linux/hugetlb.h>
  42#include <linux/interval_tree_generic.h>
  43#include <linux/pagemap.h>
  44
  45#include <rdma/ib_verbs.h>
  46#include <rdma/ib_umem.h>
  47#include <rdma/ib_umem_odp.h>
  48
  49/*
  50 * The ib_umem list keeps track of memory regions for which the HW
  51 * device request to receive notification when the related memory
  52 * mapping is changed.
  53 *
  54 * ib_umem_lock protects the list.
  55 */
  56
  57static u64 node_start(struct umem_odp_node *n)
  58{
  59        struct ib_umem_odp *umem_odp =
  60                        container_of(n, struct ib_umem_odp, interval_tree);
  61
  62        return ib_umem_start(umem_odp);
  63}
  64
  65/* Note that the representation of the intervals in the interval tree
  66 * considers the ending point as contained in the interval, while the
  67 * function ib_umem_end returns the first address which is not contained
  68 * in the umem.
  69 */
  70static u64 node_last(struct umem_odp_node *n)
  71{
  72        struct ib_umem_odp *umem_odp =
  73                        container_of(n, struct ib_umem_odp, interval_tree);
  74
  75        return ib_umem_end(umem_odp) - 1;
  76}
  77
  78INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
  79                     node_start, node_last, static, rbt_ib_umem)
  80
  81static void ib_umem_notifier_start_account(struct ib_umem_odp *umem_odp)
  82{
  83        mutex_lock(&umem_odp->umem_mutex);
  84        if (umem_odp->notifiers_count++ == 0)
  85                /*
  86                 * Initialize the completion object for waiting on
  87                 * notifiers. Since notifier_count is zero, no one should be
  88                 * waiting right now.
  89                 */
  90                reinit_completion(&umem_odp->notifier_completion);
  91        mutex_unlock(&umem_odp->umem_mutex);
  92}
  93
  94static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp)
  95{
  96        mutex_lock(&umem_odp->umem_mutex);
  97        /*
  98         * This sequence increase will notify the QP page fault that the page
  99         * that is going to be mapped in the spte could have been freed.
 100         */
 101        ++umem_odp->notifiers_seq;
 102        if (--umem_odp->notifiers_count == 0)
 103                complete_all(&umem_odp->notifier_completion);
 104        mutex_unlock(&umem_odp->umem_mutex);
 105}
 106
 107static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp,
 108                                               u64 start, u64 end, void *cookie)
 109{
 110        /*
 111         * Increase the number of notifiers running, to
 112         * prevent any further fault handling on this MR.
 113         */
 114        ib_umem_notifier_start_account(umem_odp);
 115        complete_all(&umem_odp->notifier_completion);
 116        umem_odp->umem.context->invalidate_range(
 117                umem_odp, ib_umem_start(umem_odp), ib_umem_end(umem_odp));
 118        return 0;
 119}
 120
 121static void ib_umem_notifier_release(struct mmu_notifier *mn,
 122                                     struct mm_struct *mm)
 123{
 124        struct ib_ucontext_per_mm *per_mm =
 125                container_of(mn, struct ib_ucontext_per_mm, mn);
 126
 127        down_read(&per_mm->umem_rwsem);
 128        if (per_mm->active)
 129                rbt_ib_umem_for_each_in_range(
 130                        &per_mm->umem_tree, 0, ULLONG_MAX,
 131                        ib_umem_notifier_release_trampoline, true, NULL);
 132        up_read(&per_mm->umem_rwsem);
 133}
 134
 135static int invalidate_range_start_trampoline(struct ib_umem_odp *item,
 136                                             u64 start, u64 end, void *cookie)
 137{
 138        ib_umem_notifier_start_account(item);
 139        item->umem.context->invalidate_range(item, start, end);
 140        return 0;
 141}
 142
 143static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
 144                                const struct mmu_notifier_range *range)
 145{
 146        struct ib_ucontext_per_mm *per_mm =
 147                container_of(mn, struct ib_ucontext_per_mm, mn);
 148        int rc;
 149
 150        if (mmu_notifier_range_blockable(range))
 151                down_read(&per_mm->umem_rwsem);
 152        else if (!down_read_trylock(&per_mm->umem_rwsem))
 153                return -EAGAIN;
 154
 155        if (!per_mm->active) {
 156                up_read(&per_mm->umem_rwsem);
 157                /*
 158                 * At this point active is permanently set and visible to this
 159                 * CPU without a lock, that fact is relied on to skip the unlock
 160                 * in range_end.
 161                 */
 162                return 0;
 163        }
 164
 165        rc = rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
 166                                           range->end,
 167                                           invalidate_range_start_trampoline,
 168                                           mmu_notifier_range_blockable(range),
 169                                           NULL);
 170        if (rc)
 171                up_read(&per_mm->umem_rwsem);
 172        return rc;
 173}
 174
 175static int invalidate_range_end_trampoline(struct ib_umem_odp *item, u64 start,
 176                                           u64 end, void *cookie)
 177{
 178        ib_umem_notifier_end_account(item);
 179        return 0;
 180}
 181
 182static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
 183                                const struct mmu_notifier_range *range)
 184{
 185        struct ib_ucontext_per_mm *per_mm =
 186                container_of(mn, struct ib_ucontext_per_mm, mn);
 187
 188        if (unlikely(!per_mm->active))
 189                return;
 190
 191        rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
 192                                      range->end,
 193                                      invalidate_range_end_trampoline, true, NULL);
 194        up_read(&per_mm->umem_rwsem);
 195}
 196
 197static const struct mmu_notifier_ops ib_umem_notifiers = {
 198        .release                    = ib_umem_notifier_release,
 199        .invalidate_range_start     = ib_umem_notifier_invalidate_range_start,
 200        .invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
 201};
 202
 203static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
 204{
 205        struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
 206
 207        down_write(&per_mm->umem_rwsem);
 208        if (likely(ib_umem_start(umem_odp) != ib_umem_end(umem_odp)))
 209                rbt_ib_umem_insert(&umem_odp->interval_tree,
 210                                   &per_mm->umem_tree);
 211        up_write(&per_mm->umem_rwsem);
 212}
 213
 214static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
 215{
 216        struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
 217
 218        down_write(&per_mm->umem_rwsem);
 219        if (likely(ib_umem_start(umem_odp) != ib_umem_end(umem_odp)))
 220                rbt_ib_umem_remove(&umem_odp->interval_tree,
 221                                   &per_mm->umem_tree);
 222        complete_all(&umem_odp->notifier_completion);
 223
 224        up_write(&per_mm->umem_rwsem);
 225}
 226
 227static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
 228                                               struct mm_struct *mm)
 229{
 230        struct ib_ucontext_per_mm *per_mm;
 231        int ret;
 232
 233        per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
 234        if (!per_mm)
 235                return ERR_PTR(-ENOMEM);
 236
 237        per_mm->context = ctx;
 238        per_mm->mm = mm;
 239        per_mm->umem_tree = RB_ROOT_CACHED;
 240        init_rwsem(&per_mm->umem_rwsem);
 241        per_mm->active = true;
 242
 243        rcu_read_lock();
 244        per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
 245        rcu_read_unlock();
 246
 247        WARN_ON(mm != current->mm);
 248
 249        per_mm->mn.ops = &ib_umem_notifiers;
 250        ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
 251        if (ret) {
 252                dev_err(&ctx->device->dev,
 253                        "Failed to register mmu_notifier %d\n", ret);
 254                goto out_pid;
 255        }
 256
 257        list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
 258        return per_mm;
 259
 260out_pid:
 261        put_pid(per_mm->tgid);
 262        kfree(per_mm);
 263        return ERR_PTR(ret);
 264}
 265
 266static int get_per_mm(struct ib_umem_odp *umem_odp)
 267{
 268        struct ib_ucontext *ctx = umem_odp->umem.context;
 269        struct ib_ucontext_per_mm *per_mm;
 270
 271        /*
 272         * Generally speaking we expect only one or two per_mm in this list,
 273         * so no reason to optimize this search today.
 274         */
 275        mutex_lock(&ctx->per_mm_list_lock);
 276        list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
 277                if (per_mm->mm == umem_odp->umem.owning_mm)
 278                        goto found;
 279        }
 280
 281        per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
 282        if (IS_ERR(per_mm)) {
 283                mutex_unlock(&ctx->per_mm_list_lock);
 284                return PTR_ERR(per_mm);
 285        }
 286
 287found:
 288        umem_odp->per_mm = per_mm;
 289        per_mm->odp_mrs_count++;
 290        mutex_unlock(&ctx->per_mm_list_lock);
 291
 292        return 0;
 293}
 294
 295static void free_per_mm(struct rcu_head *rcu)
 296{
 297        kfree(container_of(rcu, struct ib_ucontext_per_mm, rcu));
 298}
 299
 300static void put_per_mm(struct ib_umem_odp *umem_odp)
 301{
 302        struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
 303        struct ib_ucontext *ctx = umem_odp->umem.context;
 304        bool need_free;
 305
 306        mutex_lock(&ctx->per_mm_list_lock);
 307        umem_odp->per_mm = NULL;
 308        per_mm->odp_mrs_count--;
 309        need_free = per_mm->odp_mrs_count == 0;
 310        if (need_free)
 311                list_del(&per_mm->ucontext_list);
 312        mutex_unlock(&ctx->per_mm_list_lock);
 313
 314        if (!need_free)
 315                return;
 316
 317        /*
 318         * NOTE! mmu_notifier_unregister() can happen between a start/end
 319         * callback, resulting in an start/end, and thus an unbalanced
 320         * lock. This doesn't really matter to us since we are about to kfree
 321         * the memory that holds the lock, however LOCKDEP doesn't like this.
 322         */
 323        down_write(&per_mm->umem_rwsem);
 324        per_mm->active = false;
 325        up_write(&per_mm->umem_rwsem);
 326
 327        WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root));
 328        mmu_notifier_unregister_no_release(&per_mm->mn, per_mm->mm);
 329        put_pid(per_mm->tgid);
 330        mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
 331}
 332
 333struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root,
 334                                      unsigned long addr, size_t size)
 335{
 336        struct ib_ucontext_per_mm *per_mm = root->per_mm;
 337        struct ib_ucontext *ctx = per_mm->context;
 338        struct ib_umem_odp *odp_data;
 339        struct ib_umem *umem;
 340        int pages = size >> PAGE_SHIFT;
 341        int ret;
 342
 343        odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
 344        if (!odp_data)
 345                return ERR_PTR(-ENOMEM);
 346        umem = &odp_data->umem;
 347        umem->context    = ctx;
 348        umem->length     = size;
 349        umem->address    = addr;
 350        odp_data->page_shift = PAGE_SHIFT;
 351        umem->writable   = root->umem.writable;
 352        umem->is_odp = 1;
 353        odp_data->per_mm = per_mm;
 354        umem->owning_mm  = per_mm->mm;
 355        mmgrab(umem->owning_mm);
 356
 357        mutex_init(&odp_data->umem_mutex);
 358        init_completion(&odp_data->notifier_completion);
 359
 360        odp_data->page_list =
 361                vzalloc(array_size(pages, sizeof(*odp_data->page_list)));
 362        if (!odp_data->page_list) {
 363                ret = -ENOMEM;
 364                goto out_odp_data;
 365        }
 366
 367        odp_data->dma_list =
 368                vzalloc(array_size(pages, sizeof(*odp_data->dma_list)));
 369        if (!odp_data->dma_list) {
 370                ret = -ENOMEM;
 371                goto out_page_list;
 372        }
 373
 374        /*
 375         * Caller must ensure that the umem_odp that the per_mm came from
 376         * cannot be freed during the call to ib_alloc_odp_umem.
 377         */
 378        mutex_lock(&ctx->per_mm_list_lock);
 379        per_mm->odp_mrs_count++;
 380        mutex_unlock(&ctx->per_mm_list_lock);
 381        add_umem_to_per_mm(odp_data);
 382
 383        return odp_data;
 384
 385out_page_list:
 386        vfree(odp_data->page_list);
 387out_odp_data:
 388        mmdrop(umem->owning_mm);
 389        kfree(odp_data);
 390        return ERR_PTR(ret);
 391}
 392EXPORT_SYMBOL(ib_alloc_odp_umem);
 393
 394int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
 395{
 396        struct ib_umem *umem = &umem_odp->umem;
 397        /*
 398         * NOTE: This must called in a process context where umem->owning_mm
 399         * == current->mm
 400         */
 401        struct mm_struct *mm = umem->owning_mm;
 402        int ret_val;
 403
 404        umem_odp->page_shift = PAGE_SHIFT;
 405        if (access & IB_ACCESS_HUGETLB) {
 406                struct vm_area_struct *vma;
 407                struct hstate *h;
 408
 409                down_read(&mm->mmap_sem);
 410                vma = find_vma(mm, ib_umem_start(umem_odp));
 411                if (!vma || !is_vm_hugetlb_page(vma)) {
 412                        up_read(&mm->mmap_sem);
 413                        return -EINVAL;
 414                }
 415                h = hstate_vma(vma);
 416                umem_odp->page_shift = huge_page_shift(h);
 417                up_read(&mm->mmap_sem);
 418        }
 419
 420        mutex_init(&umem_odp->umem_mutex);
 421
 422        init_completion(&umem_odp->notifier_completion);
 423
 424        if (ib_umem_odp_num_pages(umem_odp)) {
 425                umem_odp->page_list =
 426                        vzalloc(array_size(sizeof(*umem_odp->page_list),
 427                                           ib_umem_odp_num_pages(umem_odp)));
 428                if (!umem_odp->page_list)
 429                        return -ENOMEM;
 430
 431                umem_odp->dma_list =
 432                        vzalloc(array_size(sizeof(*umem_odp->dma_list),
 433                                           ib_umem_odp_num_pages(umem_odp)));
 434                if (!umem_odp->dma_list) {
 435                        ret_val = -ENOMEM;
 436                        goto out_page_list;
 437                }
 438        }
 439
 440        ret_val = get_per_mm(umem_odp);
 441        if (ret_val)
 442                goto out_dma_list;
 443        add_umem_to_per_mm(umem_odp);
 444
 445        return 0;
 446
 447out_dma_list:
 448        vfree(umem_odp->dma_list);
 449out_page_list:
 450        vfree(umem_odp->page_list);
 451        return ret_val;
 452}
 453
 454void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
 455{
 456        /*
 457         * Ensure that no more pages are mapped in the umem.
 458         *
 459         * It is the driver's responsibility to ensure, before calling us,
 460         * that the hardware will not attempt to access the MR any more.
 461         */
 462        ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
 463                                    ib_umem_end(umem_odp));
 464
 465        remove_umem_from_per_mm(umem_odp);
 466        put_per_mm(umem_odp);
 467        vfree(umem_odp->dma_list);
 468        vfree(umem_odp->page_list);
 469}
 470
 471/*
 472 * Map for DMA and insert a single page into the on-demand paging page tables.
 473 *
 474 * @umem: the umem to insert the page to.
 475 * @page_index: index in the umem to add the page to.
 476 * @page: the page struct to map and add.
 477 * @access_mask: access permissions needed for this page.
 478 * @current_seq: sequence number for synchronization with invalidations.
 479 *               the sequence number is taken from
 480 *               umem_odp->notifiers_seq.
 481 *
 482 * The function returns -EFAULT if the DMA mapping operation fails. It returns
 483 * -EAGAIN if a concurrent invalidation prevents us from updating the page.
 484 *
 485 * The page is released via put_user_page even if the operation failed. For
 486 * on-demand pinning, the page is released whenever it isn't stored in the
 487 * umem.
 488 */
 489static int ib_umem_odp_map_dma_single_page(
 490                struct ib_umem_odp *umem_odp,
 491                int page_index,
 492                struct page *page,
 493                u64 access_mask,
 494                unsigned long current_seq)
 495{
 496        struct ib_ucontext *context = umem_odp->umem.context;
 497        struct ib_device *dev = context->device;
 498        dma_addr_t dma_addr;
 499        int remove_existing_mapping = 0;
 500        int ret = 0;
 501
 502        /*
 503         * Note: we avoid writing if seq is different from the initial seq, to
 504         * handle case of a racing notifier. This check also allows us to bail
 505         * early if we have a notifier running in parallel with us.
 506         */
 507        if (ib_umem_mmu_notifier_retry(umem_odp, current_seq)) {
 508                ret = -EAGAIN;
 509                goto out;
 510        }
 511        if (!(umem_odp->dma_list[page_index])) {
 512                dma_addr =
 513                        ib_dma_map_page(dev, page, 0, BIT(umem_odp->page_shift),
 514                                        DMA_BIDIRECTIONAL);
 515                if (ib_dma_mapping_error(dev, dma_addr)) {
 516                        ret = -EFAULT;
 517                        goto out;
 518                }
 519                umem_odp->dma_list[page_index] = dma_addr | access_mask;
 520                umem_odp->page_list[page_index] = page;
 521                umem_odp->npages++;
 522        } else if (umem_odp->page_list[page_index] == page) {
 523                umem_odp->dma_list[page_index] |= access_mask;
 524        } else {
 525                pr_err("error: got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
 526                       umem_odp->page_list[page_index], page);
 527                /* Better remove the mapping now, to prevent any further
 528                 * damage. */
 529                remove_existing_mapping = 1;
 530        }
 531
 532out:
 533        put_user_page(page);
 534
 535        if (remove_existing_mapping) {
 536                ib_umem_notifier_start_account(umem_odp);
 537                context->invalidate_range(
 538                        umem_odp,
 539                        ib_umem_start(umem_odp) +
 540                                (page_index << umem_odp->page_shift),
 541                        ib_umem_start(umem_odp) +
 542                                ((page_index + 1) << umem_odp->page_shift));
 543                ib_umem_notifier_end_account(umem_odp);
 544                ret = -EAGAIN;
 545        }
 546
 547        return ret;
 548}
 549
 550/**
 551 * ib_umem_odp_map_dma_pages - Pin and DMA map userspace memory in an ODP MR.
 552 *
 553 * Pins the range of pages passed in the argument, and maps them to
 554 * DMA addresses. The DMA addresses of the mapped pages is updated in
 555 * umem_odp->dma_list.
 556 *
 557 * Returns the number of pages mapped in success, negative error code
 558 * for failure.
 559 * An -EAGAIN error code is returned when a concurrent mmu notifier prevents
 560 * the function from completing its task.
 561 * An -ENOENT error code indicates that userspace process is being terminated
 562 * and mm was already destroyed.
 563 * @umem_odp: the umem to map and pin
 564 * @user_virt: the address from which we need to map.
 565 * @bcnt: the minimal number of bytes to pin and map. The mapping might be
 566 *        bigger due to alignment, and may also be smaller in case of an error
 567 *        pinning or mapping a page. The actual pages mapped is returned in
 568 *        the return value.
 569 * @access_mask: bit mask of the requested access permissions for the given
 570 *               range.
 571 * @current_seq: the MMU notifiers sequance value for synchronization with
 572 *               invalidations. the sequance number is read from
 573 *               umem_odp->notifiers_seq before calling this function
 574 */
 575int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
 576                              u64 bcnt, u64 access_mask,
 577                              unsigned long current_seq)
 578{
 579        struct task_struct *owning_process  = NULL;
 580        struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
 581        struct page       **local_page_list = NULL;
 582        u64 page_mask, off;
 583        int j, k, ret = 0, start_idx, npages = 0;
 584        unsigned int flags = 0, page_shift;
 585        phys_addr_t p = 0;
 586
 587        if (access_mask == 0)
 588                return -EINVAL;
 589
 590        if (user_virt < ib_umem_start(umem_odp) ||
 591            user_virt + bcnt > ib_umem_end(umem_odp))
 592                return -EFAULT;
 593
 594        local_page_list = (struct page **)__get_free_page(GFP_KERNEL);
 595        if (!local_page_list)
 596                return -ENOMEM;
 597
 598        page_shift = umem_odp->page_shift;
 599        page_mask = ~(BIT(page_shift) - 1);
 600        off = user_virt & (~page_mask);
 601        user_virt = user_virt & page_mask;
 602        bcnt += off; /* Charge for the first page offset as well. */
 603
 604        /*
 605         * owning_process is allowed to be NULL, this means somehow the mm is
 606         * existing beyond the lifetime of the originating process.. Presumably
 607         * mmget_not_zero will fail in this case.
 608         */
 609        owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
 610        if (!owning_process || !mmget_not_zero(owning_mm)) {
 611                ret = -EINVAL;
 612                goto out_put_task;
 613        }
 614
 615        if (access_mask & ODP_WRITE_ALLOWED_BIT)
 616                flags |= FOLL_WRITE;
 617
 618        start_idx = (user_virt - ib_umem_start(umem_odp)) >> page_shift;
 619        k = start_idx;
 620
 621        while (bcnt > 0) {
 622                const size_t gup_num_pages = min_t(size_t,
 623                                (bcnt + BIT(page_shift) - 1) >> page_shift,
 624                                PAGE_SIZE / sizeof(struct page *));
 625
 626                down_read(&owning_mm->mmap_sem);
 627                /*
 628                 * Note: this might result in redundent page getting. We can
 629                 * avoid this by checking dma_list to be 0 before calling
 630                 * get_user_pages. However, this make the code much more
 631                 * complex (and doesn't gain us much performance in most use
 632                 * cases).
 633                 */
 634                npages = get_user_pages_remote(owning_process, owning_mm,
 635                                user_virt, gup_num_pages,
 636                                flags, local_page_list, NULL, NULL);
 637                up_read(&owning_mm->mmap_sem);
 638
 639                if (npages < 0) {
 640                        if (npages != -EAGAIN)
 641                                pr_warn("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
 642                        else
 643                                pr_debug("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
 644                        break;
 645                }
 646
 647                bcnt -= min_t(size_t, npages << PAGE_SHIFT, bcnt);
 648                mutex_lock(&umem_odp->umem_mutex);
 649                for (j = 0; j < npages; j++, user_virt += PAGE_SIZE) {
 650                        if (user_virt & ~page_mask) {
 651                                p += PAGE_SIZE;
 652                                if (page_to_phys(local_page_list[j]) != p) {
 653                                        ret = -EFAULT;
 654                                        break;
 655                                }
 656                                put_user_page(local_page_list[j]);
 657                                continue;
 658                        }
 659
 660                        ret = ib_umem_odp_map_dma_single_page(
 661                                        umem_odp, k, local_page_list[j],
 662                                        access_mask, current_seq);
 663                        if (ret < 0) {
 664                                if (ret != -EAGAIN)
 665                                        pr_warn("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
 666                                else
 667                                        pr_debug("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
 668                                break;
 669                        }
 670
 671                        p = page_to_phys(local_page_list[j]);
 672                        k++;
 673                }
 674                mutex_unlock(&umem_odp->umem_mutex);
 675
 676                if (ret < 0) {
 677                        /*
 678                         * Release pages, remembering that the first page
 679                         * to hit an error was already released by
 680                         * ib_umem_odp_map_dma_single_page().
 681                         */
 682                        if (npages - (j + 1) > 0)
 683                                put_user_pages(&local_page_list[j+1],
 684                                               npages - (j + 1));
 685                        break;
 686                }
 687        }
 688
 689        if (ret >= 0) {
 690                if (npages < 0 && k == start_idx)
 691                        ret = npages;
 692                else
 693                        ret = k - start_idx;
 694        }
 695
 696        mmput(owning_mm);
 697out_put_task:
 698        if (owning_process)
 699                put_task_struct(owning_process);
 700        free_page((unsigned long)local_page_list);
 701        return ret;
 702}
 703EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
 704
 705void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
 706                                 u64 bound)
 707{
 708        int idx;
 709        u64 addr;
 710        struct ib_device *dev = umem_odp->umem.context->device;
 711
 712        virt = max_t(u64, virt, ib_umem_start(umem_odp));
 713        bound = min_t(u64, bound, ib_umem_end(umem_odp));
 714        /* Note that during the run of this function, the
 715         * notifiers_count of the MR is > 0, preventing any racing
 716         * faults from completion. We might be racing with other
 717         * invalidations, so we must make sure we free each page only
 718         * once. */
 719        mutex_lock(&umem_odp->umem_mutex);
 720        for (addr = virt; addr < bound; addr += BIT(umem_odp->page_shift)) {
 721                idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
 722                if (umem_odp->page_list[idx]) {
 723                        struct page *page = umem_odp->page_list[idx];
 724                        dma_addr_t dma = umem_odp->dma_list[idx];
 725                        dma_addr_t dma_addr = dma & ODP_DMA_ADDR_MASK;
 726
 727                        WARN_ON(!dma_addr);
 728
 729                        ib_dma_unmap_page(dev, dma_addr,
 730                                          BIT(umem_odp->page_shift),
 731                                          DMA_BIDIRECTIONAL);
 732                        if (dma & ODP_WRITE_ALLOWED_BIT) {
 733                                struct page *head_page = compound_head(page);
 734                                /*
 735                                 * set_page_dirty prefers being called with
 736                                 * the page lock. However, MMU notifiers are
 737                                 * called sometimes with and sometimes without
 738                                 * the lock. We rely on the umem_mutex instead
 739                                 * to prevent other mmu notifiers from
 740                                 * continuing and allowing the page mapping to
 741                                 * be removed.
 742                                 */
 743                                set_page_dirty(head_page);
 744                        }
 745                        umem_odp->page_list[idx] = NULL;
 746                        umem_odp->dma_list[idx] = 0;
 747                        umem_odp->npages--;
 748                }
 749        }
 750        mutex_unlock(&umem_odp->umem_mutex);
 751}
 752EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
 753
 754/* @last is not a part of the interval. See comment for function
 755 * node_last.
 756 */
 757int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
 758                                  u64 start, u64 last,
 759                                  umem_call_back cb,
 760                                  bool blockable,
 761                                  void *cookie)
 762{
 763        int ret_val = 0;
 764        struct umem_odp_node *node, *next;
 765        struct ib_umem_odp *umem;
 766
 767        if (unlikely(start == last))
 768                return ret_val;
 769
 770        for (node = rbt_ib_umem_iter_first(root, start, last - 1);
 771                        node; node = next) {
 772                /* TODO move the blockable decision up to the callback */
 773                if (!blockable)
 774                        return -EAGAIN;
 775                next = rbt_ib_umem_iter_next(node, start, last - 1);
 776                umem = container_of(node, struct ib_umem_odp, interval_tree);
 777                ret_val = cb(umem, start, last, cookie) || ret_val;
 778        }
 779
 780        return ret_val;
 781}
 782EXPORT_SYMBOL(rbt_ib_umem_for_each_in_range);
 783
 784struct ib_umem_odp *rbt_ib_umem_lookup(struct rb_root_cached *root,
 785                                       u64 addr, u64 length)
 786{
 787        struct umem_odp_node *node;
 788
 789        node = rbt_ib_umem_iter_first(root, addr, addr + length - 1);
 790        if (node)
 791                return container_of(node, struct ib_umem_odp, interval_tree);
 792        return NULL;
 793
 794}
 795EXPORT_SYMBOL(rbt_ib_umem_lookup);
 796