linux/drivers/infiniband/hw/mlx5/odp.c
<<
>>
Prefs
   1/*
   2 * Copyright (c) 2013-2015, Mellanox Technologies. All rights reserved.
   3 *
   4 * This software is available to you under a choice of one of two
   5 * licenses.  You may choose to be licensed under the terms of the GNU
   6 * General Public License (GPL) Version 2, available from the file
   7 * COPYING in the main directory of this source tree, or the
   8 * OpenIB.org BSD license below:
   9 *
  10 *     Redistribution and use in source and binary forms, with or
  11 *     without modification, are permitted provided that the following
  12 *     conditions are met:
  13 *
  14 *      - Redistributions of source code must retain the above
  15 *        copyright notice, this list of conditions and the following
  16 *        disclaimer.
  17 *
  18 *      - Redistributions in binary form must reproduce the above
  19 *        copyright notice, this list of conditions and the following
  20 *        disclaimer in the documentation and/or other materials
  21 *        provided with the distribution.
  22 *
  23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  30 * SOFTWARE.
  31 */
  32
  33#include <rdma/ib_umem.h>
  34#include <rdma/ib_umem_odp.h>
  35#include <linux/kernel.h>
  36
  37#include "mlx5_ib.h"
  38#include "cmd.h"
  39
  40#include <linux/mlx5/eq.h>
  41
  42/* Contains the details of a pagefault. */
  43struct mlx5_pagefault {
  44        u32                     bytes_committed;
  45        u32                     token;
  46        u8                      event_subtype;
  47        u8                      type;
  48        union {
  49                /* Initiator or send message responder pagefault details. */
  50                struct {
  51                        /* Received packet size, only valid for responders. */
  52                        u32     packet_size;
  53                        /*
  54                         * Number of resource holding WQE, depends on type.
  55                         */
  56                        u32     wq_num;
  57                        /*
  58                         * WQE index. Refers to either the send queue or
  59                         * receive queue, according to event_subtype.
  60                         */
  61                        u16     wqe_index;
  62                } wqe;
  63                /* RDMA responder pagefault details */
  64                struct {
  65                        u32     r_key;
  66                        /*
  67                         * Received packet size, minimal size page fault
  68                         * resolution required for forward progress.
  69                         */
  70                        u32     packet_size;
  71                        u32     rdma_op_len;
  72                        u64     rdma_va;
  73                } rdma;
  74        };
  75
  76        struct mlx5_ib_pf_eq    *eq;
  77        struct work_struct      work;
  78};
  79
  80#define MAX_PREFETCH_LEN (4*1024*1024U)
  81
  82/* Timeout in ms to wait for an active mmu notifier to complete when handling
  83 * a pagefault. */
  84#define MMU_NOTIFIER_TIMEOUT 1000
  85
  86#define MLX5_IMR_MTT_BITS (30 - PAGE_SHIFT)
  87#define MLX5_IMR_MTT_SHIFT (MLX5_IMR_MTT_BITS + PAGE_SHIFT)
  88#define MLX5_IMR_MTT_ENTRIES BIT_ULL(MLX5_IMR_MTT_BITS)
  89#define MLX5_IMR_MTT_SIZE BIT_ULL(MLX5_IMR_MTT_SHIFT)
  90#define MLX5_IMR_MTT_MASK (~(MLX5_IMR_MTT_SIZE - 1))
  91
  92#define MLX5_KSM_PAGE_SHIFT MLX5_IMR_MTT_SHIFT
  93
  94static u64 mlx5_imr_ksm_entries;
  95
  96static int check_parent(struct ib_umem_odp *odp,
  97                               struct mlx5_ib_mr *parent)
  98{
  99        struct mlx5_ib_mr *mr = odp->private;
 100
 101        return mr && mr->parent == parent && !odp->dying;
 102}
 103
 104static struct ib_ucontext_per_mm *mr_to_per_mm(struct mlx5_ib_mr *mr)
 105{
 106        if (WARN_ON(!mr || !is_odp_mr(mr)))
 107                return NULL;
 108
 109        return to_ib_umem_odp(mr->umem)->per_mm;
 110}
 111
 112static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp)
 113{
 114        struct mlx5_ib_mr *mr = odp->private, *parent = mr->parent;
 115        struct ib_ucontext_per_mm *per_mm = odp->per_mm;
 116        struct rb_node *rb;
 117
 118        down_read(&per_mm->umem_rwsem);
 119        while (1) {
 120                rb = rb_next(&odp->interval_tree.rb);
 121                if (!rb)
 122                        goto not_found;
 123                odp = rb_entry(rb, struct ib_umem_odp, interval_tree.rb);
 124                if (check_parent(odp, parent))
 125                        goto end;
 126        }
 127not_found:
 128        odp = NULL;
 129end:
 130        up_read(&per_mm->umem_rwsem);
 131        return odp;
 132}
 133
 134static struct ib_umem_odp *odp_lookup(u64 start, u64 length,
 135                                      struct mlx5_ib_mr *parent)
 136{
 137        struct ib_ucontext_per_mm *per_mm = mr_to_per_mm(parent);
 138        struct ib_umem_odp *odp;
 139        struct rb_node *rb;
 140
 141        down_read(&per_mm->umem_rwsem);
 142        odp = rbt_ib_umem_lookup(&per_mm->umem_tree, start, length);
 143        if (!odp)
 144                goto end;
 145
 146        while (1) {
 147                if (check_parent(odp, parent))
 148                        goto end;
 149                rb = rb_next(&odp->interval_tree.rb);
 150                if (!rb)
 151                        goto not_found;
 152                odp = rb_entry(rb, struct ib_umem_odp, interval_tree.rb);
 153                if (ib_umem_start(odp) > start + length)
 154                        goto not_found;
 155        }
 156not_found:
 157        odp = NULL;
 158end:
 159        up_read(&per_mm->umem_rwsem);
 160        return odp;
 161}
 162
 163void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset,
 164                           size_t nentries, struct mlx5_ib_mr *mr, int flags)
 165{
 166        struct ib_pd *pd = mr->ibmr.pd;
 167        struct mlx5_ib_dev *dev = to_mdev(pd->device);
 168        struct ib_umem_odp *odp;
 169        unsigned long va;
 170        int i;
 171
 172        if (flags & MLX5_IB_UPD_XLT_ZAP) {
 173                for (i = 0; i < nentries; i++, pklm++) {
 174                        pklm->bcount = cpu_to_be32(MLX5_IMR_MTT_SIZE);
 175                        pklm->key = cpu_to_be32(dev->null_mkey);
 176                        pklm->va = 0;
 177                }
 178                return;
 179        }
 180
 181        odp = odp_lookup(offset * MLX5_IMR_MTT_SIZE,
 182                         nentries * MLX5_IMR_MTT_SIZE, mr);
 183
 184        for (i = 0; i < nentries; i++, pklm++) {
 185                pklm->bcount = cpu_to_be32(MLX5_IMR_MTT_SIZE);
 186                va = (offset + i) * MLX5_IMR_MTT_SIZE;
 187                if (odp && odp->umem.address == va) {
 188                        struct mlx5_ib_mr *mtt = odp->private;
 189
 190                        pklm->key = cpu_to_be32(mtt->ibmr.lkey);
 191                        odp = odp_next(odp);
 192                } else {
 193                        pklm->key = cpu_to_be32(dev->null_mkey);
 194                }
 195                mlx5_ib_dbg(dev, "[%d] va %lx key %x\n",
 196                            i, va, be32_to_cpu(pklm->key));
 197        }
 198}
 199
 200static void mr_leaf_free_action(struct work_struct *work)
 201{
 202        struct ib_umem_odp *odp = container_of(work, struct ib_umem_odp, work);
 203        int idx = ib_umem_start(odp) >> MLX5_IMR_MTT_SHIFT;
 204        struct mlx5_ib_mr *mr = odp->private, *imr = mr->parent;
 205
 206        mr->parent = NULL;
 207        synchronize_srcu(&mr->dev->mr_srcu);
 208
 209        ib_umem_release(&odp->umem);
 210        if (imr->live)
 211                mlx5_ib_update_xlt(imr, idx, 1, 0,
 212                                   MLX5_IB_UPD_XLT_INDIRECT |
 213                                   MLX5_IB_UPD_XLT_ATOMIC);
 214        mlx5_mr_cache_free(mr->dev, mr);
 215
 216        if (atomic_dec_and_test(&imr->num_leaf_free))
 217                wake_up(&imr->q_leaf_free);
 218}
 219
 220void mlx5_ib_invalidate_range(struct ib_umem_odp *umem_odp, unsigned long start,
 221                              unsigned long end)
 222{
 223        struct mlx5_ib_mr *mr;
 224        const u64 umr_block_mask = (MLX5_UMR_MTT_ALIGNMENT /
 225                                    sizeof(struct mlx5_mtt)) - 1;
 226        u64 idx = 0, blk_start_idx = 0;
 227        int in_block = 0;
 228        u64 addr;
 229
 230        if (!umem_odp) {
 231                pr_err("invalidation called on NULL umem or non-ODP umem\n");
 232                return;
 233        }
 234
 235        mr = umem_odp->private;
 236
 237        if (!mr || !mr->ibmr.pd)
 238                return;
 239
 240        start = max_t(u64, ib_umem_start(umem_odp), start);
 241        end = min_t(u64, ib_umem_end(umem_odp), end);
 242
 243        /*
 244         * Iteration one - zap the HW's MTTs. The notifiers_count ensures that
 245         * while we are doing the invalidation, no page fault will attempt to
 246         * overwrite the same MTTs.  Concurent invalidations might race us,
 247         * but they will write 0s as well, so no difference in the end result.
 248         */
 249        mutex_lock(&umem_odp->umem_mutex);
 250        for (addr = start; addr < end; addr += BIT(umem_odp->page_shift)) {
 251                idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
 252                /*
 253                 * Strive to write the MTTs in chunks, but avoid overwriting
 254                 * non-existing MTTs. The huristic here can be improved to
 255                 * estimate the cost of another UMR vs. the cost of bigger
 256                 * UMR.
 257                 */
 258                if (umem_odp->dma_list[idx] &
 259                    (ODP_READ_ALLOWED_BIT | ODP_WRITE_ALLOWED_BIT)) {
 260                        if (!in_block) {
 261                                blk_start_idx = idx;
 262                                in_block = 1;
 263                        }
 264                } else {
 265                        u64 umr_offset = idx & umr_block_mask;
 266
 267                        if (in_block && umr_offset == 0) {
 268                                mlx5_ib_update_xlt(mr, blk_start_idx,
 269                                                   idx - blk_start_idx, 0,
 270                                                   MLX5_IB_UPD_XLT_ZAP |
 271                                                   MLX5_IB_UPD_XLT_ATOMIC);
 272                                in_block = 0;
 273                        }
 274                }
 275        }
 276        if (in_block)
 277                mlx5_ib_update_xlt(mr, blk_start_idx,
 278                                   idx - blk_start_idx + 1, 0,
 279                                   MLX5_IB_UPD_XLT_ZAP |
 280                                   MLX5_IB_UPD_XLT_ATOMIC);
 281        mutex_unlock(&umem_odp->umem_mutex);
 282        /*
 283         * We are now sure that the device will not access the
 284         * memory. We can safely unmap it, and mark it as dirty if
 285         * needed.
 286         */
 287
 288        ib_umem_odp_unmap_dma_pages(umem_odp, start, end);
 289
 290        if (unlikely(!umem_odp->npages && mr->parent &&
 291                     !umem_odp->dying)) {
 292                WRITE_ONCE(umem_odp->dying, 1);
 293                atomic_inc(&mr->parent->num_leaf_free);
 294                schedule_work(&umem_odp->work);
 295        }
 296}
 297
 298void mlx5_ib_internal_fill_odp_caps(struct mlx5_ib_dev *dev)
 299{
 300        struct ib_odp_caps *caps = &dev->odp_caps;
 301
 302        memset(caps, 0, sizeof(*caps));
 303
 304        if (!MLX5_CAP_GEN(dev->mdev, pg) ||
 305            !mlx5_ib_can_use_umr(dev, true))
 306                return;
 307
 308        caps->general_caps = IB_ODP_SUPPORT;
 309
 310        if (MLX5_CAP_GEN(dev->mdev, umr_extended_translation_offset))
 311                dev->odp_max_size = U64_MAX;
 312        else
 313                dev->odp_max_size = BIT_ULL(MLX5_MAX_UMR_SHIFT + PAGE_SHIFT);
 314
 315        if (MLX5_CAP_ODP(dev->mdev, ud_odp_caps.send))
 316                caps->per_transport_caps.ud_odp_caps |= IB_ODP_SUPPORT_SEND;
 317
 318        if (MLX5_CAP_ODP(dev->mdev, ud_odp_caps.srq_receive))
 319                caps->per_transport_caps.ud_odp_caps |= IB_ODP_SUPPORT_SRQ_RECV;
 320
 321        if (MLX5_CAP_ODP(dev->mdev, rc_odp_caps.send))
 322                caps->per_transport_caps.rc_odp_caps |= IB_ODP_SUPPORT_SEND;
 323
 324        if (MLX5_CAP_ODP(dev->mdev, rc_odp_caps.receive))
 325                caps->per_transport_caps.rc_odp_caps |= IB_ODP_SUPPORT_RECV;
 326
 327        if (MLX5_CAP_ODP(dev->mdev, rc_odp_caps.write))
 328                caps->per_transport_caps.rc_odp_caps |= IB_ODP_SUPPORT_WRITE;
 329
 330        if (MLX5_CAP_ODP(dev->mdev, rc_odp_caps.read))
 331                caps->per_transport_caps.rc_odp_caps |= IB_ODP_SUPPORT_READ;
 332
 333        if (MLX5_CAP_ODP(dev->mdev, rc_odp_caps.atomic))
 334                caps->per_transport_caps.rc_odp_caps |= IB_ODP_SUPPORT_ATOMIC;
 335
 336        if (MLX5_CAP_ODP(dev->mdev, rc_odp_caps.srq_receive))
 337                caps->per_transport_caps.rc_odp_caps |= IB_ODP_SUPPORT_SRQ_RECV;
 338
 339        if (MLX5_CAP_ODP(dev->mdev, xrc_odp_caps.send))
 340                caps->per_transport_caps.xrc_odp_caps |= IB_ODP_SUPPORT_SEND;
 341
 342        if (MLX5_CAP_ODP(dev->mdev, xrc_odp_caps.receive))
 343                caps->per_transport_caps.xrc_odp_caps |= IB_ODP_SUPPORT_RECV;
 344
 345        if (MLX5_CAP_ODP(dev->mdev, xrc_odp_caps.write))
 346                caps->per_transport_caps.xrc_odp_caps |= IB_ODP_SUPPORT_WRITE;
 347
 348        if (MLX5_CAP_ODP(dev->mdev, xrc_odp_caps.read))
 349                caps->per_transport_caps.xrc_odp_caps |= IB_ODP_SUPPORT_READ;
 350
 351        if (MLX5_CAP_ODP(dev->mdev, xrc_odp_caps.atomic))
 352                caps->per_transport_caps.xrc_odp_caps |= IB_ODP_SUPPORT_ATOMIC;
 353
 354        if (MLX5_CAP_ODP(dev->mdev, xrc_odp_caps.srq_receive))
 355                caps->per_transport_caps.xrc_odp_caps |= IB_ODP_SUPPORT_SRQ_RECV;
 356
 357        if (MLX5_CAP_GEN(dev->mdev, fixed_buffer_size) &&
 358            MLX5_CAP_GEN(dev->mdev, null_mkey) &&
 359            MLX5_CAP_GEN(dev->mdev, umr_extended_translation_offset) &&
 360            !MLX5_CAP_GEN(dev->mdev, umr_indirect_mkey_disabled))
 361                caps->general_caps |= IB_ODP_SUPPORT_IMPLICIT;
 362
 363        return;
 364}
 365
 366static void mlx5_ib_page_fault_resume(struct mlx5_ib_dev *dev,
 367                                      struct mlx5_pagefault *pfault,
 368                                      int error)
 369{
 370        int wq_num = pfault->event_subtype == MLX5_PFAULT_SUBTYPE_WQE ?
 371                     pfault->wqe.wq_num : pfault->token;
 372        u32 out[MLX5_ST_SZ_DW(page_fault_resume_out)] = { };
 373        u32 in[MLX5_ST_SZ_DW(page_fault_resume_in)]   = { };
 374        int err;
 375
 376        MLX5_SET(page_fault_resume_in, in, opcode, MLX5_CMD_OP_PAGE_FAULT_RESUME);
 377        MLX5_SET(page_fault_resume_in, in, page_fault_type, pfault->type);
 378        MLX5_SET(page_fault_resume_in, in, token, pfault->token);
 379        MLX5_SET(page_fault_resume_in, in, wq_number, wq_num);
 380        MLX5_SET(page_fault_resume_in, in, error, !!error);
 381
 382        err = mlx5_cmd_exec(dev->mdev, in, sizeof(in), out, sizeof(out));
 383        if (err)
 384                mlx5_ib_err(dev, "Failed to resolve the page fault on WQ 0x%x err %d\n",
 385                            wq_num, err);
 386}
 387
 388static struct mlx5_ib_mr *implicit_mr_alloc(struct ib_pd *pd,
 389                                            struct ib_umem *umem,
 390                                            bool ksm, int access_flags)
 391{
 392        struct mlx5_ib_dev *dev = to_mdev(pd->device);
 393        struct mlx5_ib_mr *mr;
 394        int err;
 395
 396        mr = mlx5_mr_cache_alloc(dev, ksm ? MLX5_IMR_KSM_CACHE_ENTRY :
 397                                            MLX5_IMR_MTT_CACHE_ENTRY);
 398
 399        if (IS_ERR(mr))
 400                return mr;
 401
 402        mr->ibmr.pd = pd;
 403
 404        mr->dev = dev;
 405        mr->access_flags = access_flags;
 406        mr->mmkey.iova = 0;
 407        mr->umem = umem;
 408
 409        if (ksm) {
 410                err = mlx5_ib_update_xlt(mr, 0,
 411                                         mlx5_imr_ksm_entries,
 412                                         MLX5_KSM_PAGE_SHIFT,
 413                                         MLX5_IB_UPD_XLT_INDIRECT |
 414                                         MLX5_IB_UPD_XLT_ZAP |
 415                                         MLX5_IB_UPD_XLT_ENABLE);
 416
 417        } else {
 418                err = mlx5_ib_update_xlt(mr, 0,
 419                                         MLX5_IMR_MTT_ENTRIES,
 420                                         PAGE_SHIFT,
 421                                         MLX5_IB_UPD_XLT_ZAP |
 422                                         MLX5_IB_UPD_XLT_ENABLE |
 423                                         MLX5_IB_UPD_XLT_ATOMIC);
 424        }
 425
 426        if (err)
 427                goto fail;
 428
 429        mr->ibmr.lkey = mr->mmkey.key;
 430        mr->ibmr.rkey = mr->mmkey.key;
 431
 432        mr->live = 1;
 433
 434        mlx5_ib_dbg(dev, "key %x dev %p mr %p\n",
 435                    mr->mmkey.key, dev->mdev, mr);
 436
 437        return mr;
 438
 439fail:
 440        mlx5_ib_err(dev, "Failed to register MKEY %d\n", err);
 441        mlx5_mr_cache_free(dev, mr);
 442
 443        return ERR_PTR(err);
 444}
 445
 446static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
 447                                                u64 io_virt, size_t bcnt)
 448{
 449        struct mlx5_ib_dev *dev = to_mdev(mr->ibmr.pd->device);
 450        struct ib_umem_odp *odp, *result = NULL;
 451        struct ib_umem_odp *odp_mr = to_ib_umem_odp(mr->umem);
 452        u64 addr = io_virt & MLX5_IMR_MTT_MASK;
 453        int nentries = 0, start_idx = 0, ret;
 454        struct mlx5_ib_mr *mtt;
 455
 456        mutex_lock(&odp_mr->umem_mutex);
 457        odp = odp_lookup(addr, 1, mr);
 458
 459        mlx5_ib_dbg(dev, "io_virt:%llx bcnt:%zx addr:%llx odp:%p\n",
 460                    io_virt, bcnt, addr, odp);
 461
 462next_mr:
 463        if (likely(odp)) {
 464                if (nentries)
 465                        nentries++;
 466        } else {
 467                odp = ib_alloc_odp_umem(odp_mr, addr,
 468                                        MLX5_IMR_MTT_SIZE);
 469                if (IS_ERR(odp)) {
 470                        mutex_unlock(&odp_mr->umem_mutex);
 471                        return ERR_CAST(odp);
 472                }
 473
 474                mtt = implicit_mr_alloc(mr->ibmr.pd, &odp->umem, 0,
 475                                        mr->access_flags);
 476                if (IS_ERR(mtt)) {
 477                        mutex_unlock(&odp_mr->umem_mutex);
 478                        ib_umem_release(&odp->umem);
 479                        return ERR_CAST(mtt);
 480                }
 481
 482                odp->private = mtt;
 483                mtt->umem = &odp->umem;
 484                mtt->mmkey.iova = addr;
 485                mtt->parent = mr;
 486                INIT_WORK(&odp->work, mr_leaf_free_action);
 487
 488                if (!nentries)
 489                        start_idx = addr >> MLX5_IMR_MTT_SHIFT;
 490                nentries++;
 491        }
 492
 493        /* Return first odp if region not covered by single one */
 494        if (likely(!result))
 495                result = odp;
 496
 497        addr += MLX5_IMR_MTT_SIZE;
 498        if (unlikely(addr < io_virt + bcnt)) {
 499                odp = odp_next(odp);
 500                if (odp && odp->umem.address != addr)
 501                        odp = NULL;
 502                goto next_mr;
 503        }
 504
 505        if (unlikely(nentries)) {
 506                ret = mlx5_ib_update_xlt(mr, start_idx, nentries, 0,
 507                                         MLX5_IB_UPD_XLT_INDIRECT |
 508                                         MLX5_IB_UPD_XLT_ATOMIC);
 509                if (ret) {
 510                        mlx5_ib_err(dev, "Failed to update PAS\n");
 511                        result = ERR_PTR(ret);
 512                }
 513        }
 514
 515        mutex_unlock(&odp_mr->umem_mutex);
 516        return result;
 517}
 518
 519struct mlx5_ib_mr *mlx5_ib_alloc_implicit_mr(struct mlx5_ib_pd *pd,
 520                                             struct ib_udata *udata,
 521                                             int access_flags)
 522{
 523        struct mlx5_ib_mr *imr;
 524        struct ib_umem *umem;
 525
 526        umem = ib_umem_get(udata, 0, 0, access_flags, 0);
 527        if (IS_ERR(umem))
 528                return ERR_CAST(umem);
 529
 530        imr = implicit_mr_alloc(&pd->ibpd, umem, 1, access_flags);
 531        if (IS_ERR(imr)) {
 532                ib_umem_release(umem);
 533                return ERR_CAST(imr);
 534        }
 535
 536        imr->umem = umem;
 537        init_waitqueue_head(&imr->q_leaf_free);
 538        atomic_set(&imr->num_leaf_free, 0);
 539        atomic_set(&imr->num_pending_prefetch, 0);
 540
 541        return imr;
 542}
 543
 544static int mr_leaf_free(struct ib_umem_odp *umem_odp, u64 start, u64 end,
 545                        void *cookie)
 546{
 547        struct mlx5_ib_mr *mr = umem_odp->private, *imr = cookie;
 548
 549        if (mr->parent != imr)
 550                return 0;
 551
 552        ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
 553                                    ib_umem_end(umem_odp));
 554
 555        if (umem_odp->dying)
 556                return 0;
 557
 558        WRITE_ONCE(umem_odp->dying, 1);
 559        atomic_inc(&imr->num_leaf_free);
 560        schedule_work(&umem_odp->work);
 561
 562        return 0;
 563}
 564
 565void mlx5_ib_free_implicit_mr(struct mlx5_ib_mr *imr)
 566{
 567        struct ib_ucontext_per_mm *per_mm = mr_to_per_mm(imr);
 568
 569        down_read(&per_mm->umem_rwsem);
 570        rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, 0, ULLONG_MAX,
 571                                      mr_leaf_free, true, imr);
 572        up_read(&per_mm->umem_rwsem);
 573
 574        wait_event(imr->q_leaf_free, !atomic_read(&imr->num_leaf_free));
 575}
 576
 577#define MLX5_PF_FLAGS_PREFETCH  BIT(0)
 578#define MLX5_PF_FLAGS_DOWNGRADE BIT(1)
 579static int pagefault_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr,
 580                        u64 io_virt, size_t bcnt, u32 *bytes_mapped,
 581                        u32 flags)
 582{
 583        int npages = 0, current_seq, page_shift, ret, np;
 584        struct ib_umem_odp *odp_mr = to_ib_umem_odp(mr->umem);
 585        bool downgrade = flags & MLX5_PF_FLAGS_DOWNGRADE;
 586        bool prefetch = flags & MLX5_PF_FLAGS_PREFETCH;
 587        u64 access_mask;
 588        u64 start_idx, page_mask;
 589        struct ib_umem_odp *odp;
 590        size_t size;
 591
 592        if (!odp_mr->page_list) {
 593                odp = implicit_mr_get_data(mr, io_virt, bcnt);
 594
 595                if (IS_ERR(odp))
 596                        return PTR_ERR(odp);
 597                mr = odp->private;
 598        } else {
 599                odp = odp_mr;
 600        }
 601
 602next_mr:
 603        size = min_t(size_t, bcnt, ib_umem_end(odp) - io_virt);
 604
 605        page_shift = odp->page_shift;
 606        page_mask = ~(BIT(page_shift) - 1);
 607        start_idx = (io_virt - (mr->mmkey.iova & page_mask)) >> page_shift;
 608        access_mask = ODP_READ_ALLOWED_BIT;
 609
 610        if (prefetch && !downgrade && !mr->umem->writable) {
 611                /* prefetch with write-access must
 612                 * be supported by the MR
 613                 */
 614                ret = -EINVAL;
 615                goto out;
 616        }
 617
 618        if (mr->umem->writable && !downgrade)
 619                access_mask |= ODP_WRITE_ALLOWED_BIT;
 620
 621        current_seq = READ_ONCE(odp->notifiers_seq);
 622        /*
 623         * Ensure the sequence number is valid for some time before we call
 624         * gup.
 625         */
 626        smp_rmb();
 627
 628        ret = ib_umem_odp_map_dma_pages(to_ib_umem_odp(mr->umem), io_virt, size,
 629                                        access_mask, current_seq);
 630
 631        if (ret < 0)
 632                goto out;
 633
 634        np = ret;
 635
 636        mutex_lock(&odp->umem_mutex);
 637        if (!ib_umem_mmu_notifier_retry(to_ib_umem_odp(mr->umem),
 638                                        current_seq)) {
 639                /*
 640                 * No need to check whether the MTTs really belong to
 641                 * this MR, since ib_umem_odp_map_dma_pages already
 642                 * checks this.
 643                 */
 644                ret = mlx5_ib_update_xlt(mr, start_idx, np,
 645                                         page_shift, MLX5_IB_UPD_XLT_ATOMIC);
 646        } else {
 647                ret = -EAGAIN;
 648        }
 649        mutex_unlock(&odp->umem_mutex);
 650
 651        if (ret < 0) {
 652                if (ret != -EAGAIN)
 653                        mlx5_ib_err(dev, "Failed to update mkey page tables\n");
 654                goto out;
 655        }
 656
 657        if (bytes_mapped) {
 658                u32 new_mappings = (np << page_shift) -
 659                        (io_virt - round_down(io_virt, 1 << page_shift));
 660                *bytes_mapped += min_t(u32, new_mappings, size);
 661        }
 662
 663        npages += np << (page_shift - PAGE_SHIFT);
 664        bcnt -= size;
 665
 666        if (unlikely(bcnt)) {
 667                struct ib_umem_odp *next;
 668
 669                io_virt += size;
 670                next = odp_next(odp);
 671                if (unlikely(!next || next->umem.address != io_virt)) {
 672                        mlx5_ib_dbg(dev, "next implicit leaf removed at 0x%llx. got %p\n",
 673                                    io_virt, next);
 674                        return -EAGAIN;
 675                }
 676                odp = next;
 677                mr = odp->private;
 678                goto next_mr;
 679        }
 680
 681        return npages;
 682
 683out:
 684        if (ret == -EAGAIN) {
 685                unsigned long timeout = msecs_to_jiffies(MMU_NOTIFIER_TIMEOUT);
 686
 687                if (!wait_for_completion_timeout(&odp->notifier_completion,
 688                                                 timeout)) {
 689                        mlx5_ib_warn(
 690                                dev,
 691                                "timeout waiting for mmu notifier. seq %d against %d. notifiers_count=%d\n",
 692                                current_seq, odp->notifiers_seq,
 693                                odp->notifiers_count);
 694                }
 695        }
 696
 697        return ret;
 698}
 699
 700struct pf_frame {
 701        struct pf_frame *next;
 702        u32 key;
 703        u64 io_virt;
 704        size_t bcnt;
 705        int depth;
 706};
 707
 708static bool mkey_is_eq(struct mlx5_core_mkey *mmkey, u32 key)
 709{
 710        if (!mmkey)
 711                return false;
 712        if (mmkey->type == MLX5_MKEY_MW)
 713                return mlx5_base_mkey(mmkey->key) == mlx5_base_mkey(key);
 714        return mmkey->key == key;
 715}
 716
 717static int get_indirect_num_descs(struct mlx5_core_mkey *mmkey)
 718{
 719        struct mlx5_ib_mw *mw;
 720        struct mlx5_ib_devx_mr *devx_mr;
 721
 722        if (mmkey->type == MLX5_MKEY_MW) {
 723                mw = container_of(mmkey, struct mlx5_ib_mw, mmkey);
 724                return mw->ndescs;
 725        }
 726
 727        devx_mr = container_of(mmkey, struct mlx5_ib_devx_mr,
 728                               mmkey);
 729        return devx_mr->ndescs;
 730}
 731
 732/*
 733 * Handle a single data segment in a page-fault WQE or RDMA region.
 734 *
 735 * Returns number of OS pages retrieved on success. The caller may continue to
 736 * the next data segment.
 737 * Can return the following error codes:
 738 * -EAGAIN to designate a temporary error. The caller will abort handling the
 739 *  page fault and resolve it.
 740 * -EFAULT when there's an error mapping the requested pages. The caller will
 741 *  abort the page fault handling.
 742 */
 743static int pagefault_single_data_segment(struct mlx5_ib_dev *dev,
 744                                         struct ib_pd *pd, u32 key,
 745                                         u64 io_virt, size_t bcnt,
 746                                         u32 *bytes_committed,
 747                                         u32 *bytes_mapped, u32 flags)
 748{
 749        int npages = 0, srcu_key, ret, i, outlen, cur_outlen = 0, depth = 0;
 750        bool prefetch = flags & MLX5_PF_FLAGS_PREFETCH;
 751        struct pf_frame *head = NULL, *frame;
 752        struct mlx5_core_mkey *mmkey;
 753        struct mlx5_ib_mr *mr;
 754        struct mlx5_klm *pklm;
 755        u32 *out = NULL;
 756        size_t offset;
 757        int ndescs;
 758
 759        srcu_key = srcu_read_lock(&dev->mr_srcu);
 760
 761        io_virt += *bytes_committed;
 762        bcnt -= *bytes_committed;
 763
 764next_mr:
 765        mmkey = xa_load(&dev->mdev->priv.mkey_table, mlx5_base_mkey(key));
 766        if (!mkey_is_eq(mmkey, key)) {
 767                mlx5_ib_dbg(dev, "failed to find mkey %x\n", key);
 768                ret = -EFAULT;
 769                goto srcu_unlock;
 770        }
 771
 772        if (prefetch && mmkey->type != MLX5_MKEY_MR) {
 773                mlx5_ib_dbg(dev, "prefetch is allowed only for MR\n");
 774                ret = -EINVAL;
 775                goto srcu_unlock;
 776        }
 777
 778        switch (mmkey->type) {
 779        case MLX5_MKEY_MR:
 780                mr = container_of(mmkey, struct mlx5_ib_mr, mmkey);
 781                if (!mr->live || !mr->ibmr.pd) {
 782                        mlx5_ib_dbg(dev, "got dead MR\n");
 783                        ret = -EFAULT;
 784                        goto srcu_unlock;
 785                }
 786
 787                if (prefetch) {
 788                        if (!is_odp_mr(mr) ||
 789                            mr->ibmr.pd != pd) {
 790                                mlx5_ib_dbg(dev, "Invalid prefetch request: %s\n",
 791                                            is_odp_mr(mr) ?  "MR is not ODP" :
 792                                            "PD is not of the MR");
 793                                ret = -EINVAL;
 794                                goto srcu_unlock;
 795                        }
 796                }
 797
 798                if (!is_odp_mr(mr)) {
 799                        mlx5_ib_dbg(dev, "skipping non ODP MR (lkey=0x%06x) in page fault handler.\n",
 800                                    key);
 801                        if (bytes_mapped)
 802                                *bytes_mapped += bcnt;
 803                        ret = 0;
 804                        goto srcu_unlock;
 805                }
 806
 807                ret = pagefault_mr(dev, mr, io_virt, bcnt, bytes_mapped, flags);
 808                if (ret < 0)
 809                        goto srcu_unlock;
 810
 811                npages += ret;
 812                ret = 0;
 813                break;
 814
 815        case MLX5_MKEY_MW:
 816        case MLX5_MKEY_INDIRECT_DEVX:
 817                ndescs = get_indirect_num_descs(mmkey);
 818
 819                if (depth >= MLX5_CAP_GEN(dev->mdev, max_indirection)) {
 820                        mlx5_ib_dbg(dev, "indirection level exceeded\n");
 821                        ret = -EFAULT;
 822                        goto srcu_unlock;
 823                }
 824
 825                outlen = MLX5_ST_SZ_BYTES(query_mkey_out) +
 826                        sizeof(*pklm) * (ndescs - 2);
 827
 828                if (outlen > cur_outlen) {
 829                        kfree(out);
 830                        out = kzalloc(outlen, GFP_KERNEL);
 831                        if (!out) {
 832                                ret = -ENOMEM;
 833                                goto srcu_unlock;
 834                        }
 835                        cur_outlen = outlen;
 836                }
 837
 838                pklm = (struct mlx5_klm *)MLX5_ADDR_OF(query_mkey_out, out,
 839                                                       bsf0_klm0_pas_mtt0_1);
 840
 841                ret = mlx5_core_query_mkey(dev->mdev, mmkey, out, outlen);
 842                if (ret)
 843                        goto srcu_unlock;
 844
 845                offset = io_virt - MLX5_GET64(query_mkey_out, out,
 846                                              memory_key_mkey_entry.start_addr);
 847
 848                for (i = 0; bcnt && i < ndescs; i++, pklm++) {
 849                        if (offset >= be32_to_cpu(pklm->bcount)) {
 850                                offset -= be32_to_cpu(pklm->bcount);
 851                                continue;
 852                        }
 853
 854                        frame = kzalloc(sizeof(*frame), GFP_KERNEL);
 855                        if (!frame) {
 856                                ret = -ENOMEM;
 857                                goto srcu_unlock;
 858                        }
 859
 860                        frame->key = be32_to_cpu(pklm->key);
 861                        frame->io_virt = be64_to_cpu(pklm->va) + offset;
 862                        frame->bcnt = min_t(size_t, bcnt,
 863                                            be32_to_cpu(pklm->bcount) - offset);
 864                        frame->depth = depth + 1;
 865                        frame->next = head;
 866                        head = frame;
 867
 868                        bcnt -= frame->bcnt;
 869                        offset = 0;
 870                }
 871                break;
 872
 873        default:
 874                mlx5_ib_dbg(dev, "wrong mkey type %d\n", mmkey->type);
 875                ret = -EFAULT;
 876                goto srcu_unlock;
 877        }
 878
 879        if (head) {
 880                frame = head;
 881                head = frame->next;
 882
 883                key = frame->key;
 884                io_virt = frame->io_virt;
 885                bcnt = frame->bcnt;
 886                depth = frame->depth;
 887                kfree(frame);
 888
 889                goto next_mr;
 890        }
 891
 892srcu_unlock:
 893        while (head) {
 894                frame = head;
 895                head = frame->next;
 896                kfree(frame);
 897        }
 898        kfree(out);
 899
 900        srcu_read_unlock(&dev->mr_srcu, srcu_key);
 901        *bytes_committed = 0;
 902        return ret ? ret : npages;
 903}
 904
 905/**
 906 * Parse a series of data segments for page fault handling.
 907 *
 908 * @pfault contains page fault information.
 909 * @wqe points at the first data segment in the WQE.
 910 * @wqe_end points after the end of the WQE.
 911 * @bytes_mapped receives the number of bytes that the function was able to
 912 *               map. This allows the caller to decide intelligently whether
 913 *               enough memory was mapped to resolve the page fault
 914 *               successfully (e.g. enough for the next MTU, or the entire
 915 *               WQE).
 916 * @total_wqe_bytes receives the total data size of this WQE in bytes (minus
 917 *                  the committed bytes).
 918 *
 919 * Returns the number of pages loaded if positive, zero for an empty WQE, or a
 920 * negative error code.
 921 */
 922static int pagefault_data_segments(struct mlx5_ib_dev *dev,
 923                                   struct mlx5_pagefault *pfault,
 924                                   void *wqe,
 925                                   void *wqe_end, u32 *bytes_mapped,
 926                                   u32 *total_wqe_bytes, bool receive_queue)
 927{
 928        int ret = 0, npages = 0;
 929        u64 io_virt;
 930        u32 key;
 931        u32 byte_count;
 932        size_t bcnt;
 933        int inline_segment;
 934
 935        if (bytes_mapped)
 936                *bytes_mapped = 0;
 937        if (total_wqe_bytes)
 938                *total_wqe_bytes = 0;
 939
 940        while (wqe < wqe_end) {
 941                struct mlx5_wqe_data_seg *dseg = wqe;
 942
 943                io_virt = be64_to_cpu(dseg->addr);
 944                key = be32_to_cpu(dseg->lkey);
 945                byte_count = be32_to_cpu(dseg->byte_count);
 946                inline_segment = !!(byte_count &  MLX5_INLINE_SEG);
 947                bcnt           = byte_count & ~MLX5_INLINE_SEG;
 948
 949                if (inline_segment) {
 950                        bcnt = bcnt & MLX5_WQE_INLINE_SEG_BYTE_COUNT_MASK;
 951                        wqe += ALIGN(sizeof(struct mlx5_wqe_inline_seg) + bcnt,
 952                                     16);
 953                } else {
 954                        wqe += sizeof(*dseg);
 955                }
 956
 957                /* receive WQE end of sg list. */
 958                if (receive_queue && bcnt == 0 && key == MLX5_INVALID_LKEY &&
 959                    io_virt == 0)
 960                        break;
 961
 962                if (!inline_segment && total_wqe_bytes) {
 963                        *total_wqe_bytes += bcnt - min_t(size_t, bcnt,
 964                                        pfault->bytes_committed);
 965                }
 966
 967                /* A zero length data segment designates a length of 2GB. */
 968                if (bcnt == 0)
 969                        bcnt = 1U << 31;
 970
 971                if (inline_segment || bcnt <= pfault->bytes_committed) {
 972                        pfault->bytes_committed -=
 973                                min_t(size_t, bcnt,
 974                                      pfault->bytes_committed);
 975                        continue;
 976                }
 977
 978                ret = pagefault_single_data_segment(dev, NULL, key,
 979                                                    io_virt, bcnt,
 980                                                    &pfault->bytes_committed,
 981                                                    bytes_mapped, 0);
 982                if (ret < 0)
 983                        break;
 984                npages += ret;
 985        }
 986
 987        return ret < 0 ? ret : npages;
 988}
 989
 990static const u32 mlx5_ib_odp_opcode_cap[] = {
 991        [MLX5_OPCODE_SEND]             = IB_ODP_SUPPORT_SEND,
 992        [MLX5_OPCODE_SEND_IMM]         = IB_ODP_SUPPORT_SEND,
 993        [MLX5_OPCODE_SEND_INVAL]       = IB_ODP_SUPPORT_SEND,
 994        [MLX5_OPCODE_RDMA_WRITE]       = IB_ODP_SUPPORT_WRITE,
 995        [MLX5_OPCODE_RDMA_WRITE_IMM]   = IB_ODP_SUPPORT_WRITE,
 996        [MLX5_OPCODE_RDMA_READ]        = IB_ODP_SUPPORT_READ,
 997        [MLX5_OPCODE_ATOMIC_CS]        = IB_ODP_SUPPORT_ATOMIC,
 998        [MLX5_OPCODE_ATOMIC_FA]        = IB_ODP_SUPPORT_ATOMIC,
 999};
1000
1001/*
1002 * Parse initiator WQE. Advances the wqe pointer to point at the
1003 * scatter-gather list, and set wqe_end to the end of the WQE.
1004 */
1005static int mlx5_ib_mr_initiator_pfault_handler(
1006        struct mlx5_ib_dev *dev, struct mlx5_pagefault *pfault,
1007        struct mlx5_ib_qp *qp, void **wqe, void **wqe_end, int wqe_length)
1008{
1009        struct mlx5_wqe_ctrl_seg *ctrl = *wqe;
1010        u16 wqe_index = pfault->wqe.wqe_index;
1011        u32 transport_caps;
1012        struct mlx5_base_av *av;
1013        unsigned ds, opcode;
1014#if defined(DEBUG)
1015        u32 ctrl_wqe_index, ctrl_qpn;
1016#endif
1017        u32 qpn = qp->trans_qp.base.mqp.qpn;
1018
1019        ds = be32_to_cpu(ctrl->qpn_ds) & MLX5_WQE_CTRL_DS_MASK;
1020        if (ds * MLX5_WQE_DS_UNITS > wqe_length) {
1021                mlx5_ib_err(dev, "Unable to read the complete WQE. ds = 0x%x, ret = 0x%x\n",
1022                            ds, wqe_length);
1023                return -EFAULT;
1024        }
1025
1026        if (ds == 0) {
1027                mlx5_ib_err(dev, "Got WQE with zero DS. wqe_index=%x, qpn=%x\n",
1028                            wqe_index, qpn);
1029                return -EFAULT;
1030        }
1031
1032#if defined(DEBUG)
1033        ctrl_wqe_index = (be32_to_cpu(ctrl->opmod_idx_opcode) &
1034                        MLX5_WQE_CTRL_WQE_INDEX_MASK) >>
1035                        MLX5_WQE_CTRL_WQE_INDEX_SHIFT;
1036        if (wqe_index != ctrl_wqe_index) {
1037                mlx5_ib_err(dev, "Got WQE with invalid wqe_index. wqe_index=0x%x, qpn=0x%x ctrl->wqe_index=0x%x\n",
1038                            wqe_index, qpn,
1039                            ctrl_wqe_index);
1040                return -EFAULT;
1041        }
1042
1043        ctrl_qpn = (be32_to_cpu(ctrl->qpn_ds) & MLX5_WQE_CTRL_QPN_MASK) >>
1044                MLX5_WQE_CTRL_QPN_SHIFT;
1045        if (qpn != ctrl_qpn) {
1046                mlx5_ib_err(dev, "Got WQE with incorrect QP number. wqe_index=0x%x, qpn=0x%x ctrl->qpn=0x%x\n",
1047                            wqe_index, qpn,
1048                            ctrl_qpn);
1049                return -EFAULT;
1050        }
1051#endif /* DEBUG */
1052
1053        *wqe_end = *wqe + ds * MLX5_WQE_DS_UNITS;
1054        *wqe += sizeof(*ctrl);
1055
1056        opcode = be32_to_cpu(ctrl->opmod_idx_opcode) &
1057                 MLX5_WQE_CTRL_OPCODE_MASK;
1058
1059        switch (qp->ibqp.qp_type) {
1060        case IB_QPT_XRC_INI:
1061                *wqe += sizeof(struct mlx5_wqe_xrc_seg);
1062                transport_caps = dev->odp_caps.per_transport_caps.xrc_odp_caps;
1063                break;
1064        case IB_QPT_RC:
1065                transport_caps = dev->odp_caps.per_transport_caps.rc_odp_caps;
1066                break;
1067        case IB_QPT_UD:
1068                transport_caps = dev->odp_caps.per_transport_caps.ud_odp_caps;
1069                break;
1070        default:
1071                mlx5_ib_err(dev, "ODP fault on QP of an unsupported transport 0x%x\n",
1072                            qp->ibqp.qp_type);
1073                return -EFAULT;
1074        }
1075
1076        if (unlikely(opcode >= ARRAY_SIZE(mlx5_ib_odp_opcode_cap) ||
1077                     !(transport_caps & mlx5_ib_odp_opcode_cap[opcode]))) {
1078                mlx5_ib_err(dev, "ODP fault on QP of an unsupported opcode 0x%x\n",
1079                            opcode);
1080                return -EFAULT;
1081        }
1082
1083        if (qp->ibqp.qp_type == IB_QPT_UD) {
1084                av = *wqe;
1085                if (av->dqp_dct & cpu_to_be32(MLX5_EXTENDED_UD_AV))
1086                        *wqe += sizeof(struct mlx5_av);
1087                else
1088                        *wqe += sizeof(struct mlx5_base_av);
1089        }
1090
1091        switch (opcode) {
1092        case MLX5_OPCODE_RDMA_WRITE:
1093        case MLX5_OPCODE_RDMA_WRITE_IMM:
1094        case MLX5_OPCODE_RDMA_READ:
1095                *wqe += sizeof(struct mlx5_wqe_raddr_seg);
1096                break;
1097        case MLX5_OPCODE_ATOMIC_CS:
1098        case MLX5_OPCODE_ATOMIC_FA:
1099                *wqe += sizeof(struct mlx5_wqe_raddr_seg);
1100                *wqe += sizeof(struct mlx5_wqe_atomic_seg);
1101                break;
1102        }
1103
1104        return 0;
1105}
1106
1107/*
1108 * Parse responder WQE and set wqe_end to the end of the WQE.
1109 */
1110static int mlx5_ib_mr_responder_pfault_handler_srq(struct mlx5_ib_dev *dev,
1111                                                   struct mlx5_ib_srq *srq,
1112                                                   void **wqe, void **wqe_end,
1113                                                   int wqe_length)
1114{
1115        int wqe_size = 1 << srq->msrq.wqe_shift;
1116
1117        if (wqe_size > wqe_length) {
1118                mlx5_ib_err(dev, "Couldn't read all of the receive WQE's content\n");
1119                return -EFAULT;
1120        }
1121
1122        *wqe_end = *wqe + wqe_size;
1123        *wqe += sizeof(struct mlx5_wqe_srq_next_seg);
1124
1125        return 0;
1126}
1127
1128static int mlx5_ib_mr_responder_pfault_handler_rq(struct mlx5_ib_dev *dev,
1129                                                  struct mlx5_ib_qp *qp,
1130                                                  void *wqe, void **wqe_end,
1131                                                  int wqe_length)
1132{
1133        struct mlx5_ib_wq *wq = &qp->rq;
1134        int wqe_size = 1 << wq->wqe_shift;
1135
1136        if (qp->wq_sig) {
1137                mlx5_ib_err(dev, "ODP fault with WQE signatures is not supported\n");
1138                return -EFAULT;
1139        }
1140
1141        if (wqe_size > wqe_length) {
1142                mlx5_ib_err(dev, "Couldn't read all of the receive WQE's content\n");
1143                return -EFAULT;
1144        }
1145
1146        switch (qp->ibqp.qp_type) {
1147        case IB_QPT_RC:
1148                if (!(dev->odp_caps.per_transport_caps.rc_odp_caps &
1149                      IB_ODP_SUPPORT_RECV))
1150                        goto invalid_transport_or_opcode;
1151                break;
1152        default:
1153invalid_transport_or_opcode:
1154                mlx5_ib_err(dev, "ODP fault on QP of an unsupported transport. transport: 0x%x\n",
1155                            qp->ibqp.qp_type);
1156                return -EFAULT;
1157        }
1158
1159        *wqe_end = wqe + wqe_size;
1160
1161        return 0;
1162}
1163
1164static inline struct mlx5_core_rsc_common *odp_get_rsc(struct mlx5_ib_dev *dev,
1165                                                       u32 wq_num, int pf_type)
1166{
1167        struct mlx5_core_rsc_common *common = NULL;
1168        struct mlx5_core_srq *srq;
1169
1170        switch (pf_type) {
1171        case MLX5_WQE_PF_TYPE_RMP:
1172                srq = mlx5_cmd_get_srq(dev, wq_num);
1173                if (srq)
1174                        common = &srq->common;
1175                break;
1176        case MLX5_WQE_PF_TYPE_REQ_SEND_OR_WRITE:
1177        case MLX5_WQE_PF_TYPE_RESP:
1178        case MLX5_WQE_PF_TYPE_REQ_READ_OR_ATOMIC:
1179                common = mlx5_core_res_hold(dev->mdev, wq_num, MLX5_RES_QP);
1180                break;
1181        default:
1182                break;
1183        }
1184
1185        return common;
1186}
1187
1188static inline struct mlx5_ib_qp *res_to_qp(struct mlx5_core_rsc_common *res)
1189{
1190        struct mlx5_core_qp *mqp = (struct mlx5_core_qp *)res;
1191
1192        return to_mibqp(mqp);
1193}
1194
1195static inline struct mlx5_ib_srq *res_to_srq(struct mlx5_core_rsc_common *res)
1196{
1197        struct mlx5_core_srq *msrq =
1198                container_of(res, struct mlx5_core_srq, common);
1199
1200        return to_mibsrq(msrq);
1201}
1202
1203static void mlx5_ib_mr_wqe_pfault_handler(struct mlx5_ib_dev *dev,
1204                                          struct mlx5_pagefault *pfault)
1205{
1206        bool sq = pfault->type & MLX5_PFAULT_REQUESTOR;
1207        u16 wqe_index = pfault->wqe.wqe_index;
1208        void *wqe = NULL, *wqe_end = NULL;
1209        u32 bytes_mapped, total_wqe_bytes;
1210        struct mlx5_core_rsc_common *res;
1211        int resume_with_error = 1;
1212        struct mlx5_ib_qp *qp;
1213        size_t bytes_copied;
1214        int ret = 0;
1215
1216        res = odp_get_rsc(dev, pfault->wqe.wq_num, pfault->type);
1217        if (!res) {
1218                mlx5_ib_dbg(dev, "wqe page fault for missing resource %d\n", pfault->wqe.wq_num);
1219                return;
1220        }
1221
1222        if (res->res != MLX5_RES_QP && res->res != MLX5_RES_SRQ &&
1223            res->res != MLX5_RES_XSRQ) {
1224                mlx5_ib_err(dev, "wqe page fault for unsupported type %d\n",
1225                            pfault->type);
1226                goto resolve_page_fault;
1227        }
1228
1229        wqe = (void *)__get_free_page(GFP_KERNEL);
1230        if (!wqe) {
1231                mlx5_ib_err(dev, "Error allocating memory for IO page fault handling.\n");
1232                goto resolve_page_fault;
1233        }
1234
1235        qp = (res->res == MLX5_RES_QP) ? res_to_qp(res) : NULL;
1236        if (qp && sq) {
1237                ret = mlx5_ib_read_user_wqe_sq(qp, wqe_index, wqe, PAGE_SIZE,
1238                                               &bytes_copied);
1239                if (ret)
1240                        goto read_user;
1241                ret = mlx5_ib_mr_initiator_pfault_handler(
1242                        dev, pfault, qp, &wqe, &wqe_end, bytes_copied);
1243        } else if (qp && !sq) {
1244                ret = mlx5_ib_read_user_wqe_rq(qp, wqe_index, wqe, PAGE_SIZE,
1245                                               &bytes_copied);
1246                if (ret)
1247                        goto read_user;
1248                ret = mlx5_ib_mr_responder_pfault_handler_rq(
1249                        dev, qp, wqe, &wqe_end, bytes_copied);
1250        } else if (!qp) {
1251                struct mlx5_ib_srq *srq = res_to_srq(res);
1252
1253                ret = mlx5_ib_read_user_wqe_srq(srq, wqe_index, wqe, PAGE_SIZE,
1254                                                &bytes_copied);
1255                if (ret)
1256                        goto read_user;
1257                ret = mlx5_ib_mr_responder_pfault_handler_srq(
1258                        dev, srq, &wqe, &wqe_end, bytes_copied);
1259        }
1260
1261        if (ret < 0 || wqe >= wqe_end)
1262                goto resolve_page_fault;
1263
1264        ret = pagefault_data_segments(dev, pfault, wqe, wqe_end, &bytes_mapped,
1265                                      &total_wqe_bytes, !sq);
1266        if (ret == -EAGAIN)
1267                goto out;
1268
1269        if (ret < 0 || total_wqe_bytes > bytes_mapped)
1270                goto resolve_page_fault;
1271
1272out:
1273        ret = 0;
1274        resume_with_error = 0;
1275
1276read_user:
1277        if (ret)
1278                mlx5_ib_err(
1279                        dev,
1280                        "Failed reading a WQE following page fault, error %d, wqe_index %x, qpn %x\n",
1281                        ret, wqe_index, pfault->token);
1282
1283resolve_page_fault:
1284        mlx5_ib_page_fault_resume(dev, pfault, resume_with_error);
1285        mlx5_ib_dbg(dev, "PAGE FAULT completed. QP 0x%x resume_with_error=%d, type: 0x%x\n",
1286                    pfault->wqe.wq_num, resume_with_error,
1287                    pfault->type);
1288        mlx5_core_res_put(res);
1289        free_page((unsigned long)wqe);
1290}
1291
1292static int pages_in_range(u64 address, u32 length)
1293{
1294        return (ALIGN(address + length, PAGE_SIZE) -
1295                (address & PAGE_MASK)) >> PAGE_SHIFT;
1296}
1297
1298static void mlx5_ib_mr_rdma_pfault_handler(struct mlx5_ib_dev *dev,
1299                                           struct mlx5_pagefault *pfault)
1300{
1301        u64 address;
1302        u32 length;
1303        u32 prefetch_len = pfault->bytes_committed;
1304        int prefetch_activated = 0;
1305        u32 rkey = pfault->rdma.r_key;
1306        int ret;
1307
1308        /* The RDMA responder handler handles the page fault in two parts.
1309         * First it brings the necessary pages for the current packet
1310         * (and uses the pfault context), and then (after resuming the QP)
1311         * prefetches more pages. The second operation cannot use the pfault
1312         * context and therefore uses the dummy_pfault context allocated on
1313         * the stack */
1314        pfault->rdma.rdma_va += pfault->bytes_committed;
1315        pfault->rdma.rdma_op_len -= min(pfault->bytes_committed,
1316                                         pfault->rdma.rdma_op_len);
1317        pfault->bytes_committed = 0;
1318
1319        address = pfault->rdma.rdma_va;
1320        length  = pfault->rdma.rdma_op_len;
1321
1322        /* For some operations, the hardware cannot tell the exact message
1323         * length, and in those cases it reports zero. Use prefetch
1324         * logic. */
1325        if (length == 0) {
1326                prefetch_activated = 1;
1327                length = pfault->rdma.packet_size;
1328                prefetch_len = min(MAX_PREFETCH_LEN, prefetch_len);
1329        }
1330
1331        ret = pagefault_single_data_segment(dev, NULL, rkey, address, length,
1332                                            &pfault->bytes_committed, NULL,
1333                                            0);
1334        if (ret == -EAGAIN) {
1335                /* We're racing with an invalidation, don't prefetch */
1336                prefetch_activated = 0;
1337        } else if (ret < 0 || pages_in_range(address, length) > ret) {
1338                mlx5_ib_page_fault_resume(dev, pfault, 1);
1339                if (ret != -ENOENT)
1340                        mlx5_ib_dbg(dev, "PAGE FAULT error %d. QP 0x%x, type: 0x%x\n",
1341                                    ret, pfault->token, pfault->type);
1342                return;
1343        }
1344
1345        mlx5_ib_page_fault_resume(dev, pfault, 0);
1346        mlx5_ib_dbg(dev, "PAGE FAULT completed. QP 0x%x, type: 0x%x, prefetch_activated: %d\n",
1347                    pfault->token, pfault->type,
1348                    prefetch_activated);
1349
1350        /* At this point, there might be a new pagefault already arriving in
1351         * the eq, switch to the dummy pagefault for the rest of the
1352         * processing. We're still OK with the objects being alive as the
1353         * work-queue is being fenced. */
1354
1355        if (prefetch_activated) {
1356                u32 bytes_committed = 0;
1357
1358                ret = pagefault_single_data_segment(dev, NULL, rkey, address,
1359                                                    prefetch_len,
1360                                                    &bytes_committed, NULL,
1361                                                    0);
1362                if (ret < 0 && ret != -EAGAIN) {
1363                        mlx5_ib_dbg(dev, "Prefetch failed. ret: %d, QP 0x%x, address: 0x%.16llx, length = 0x%.16x\n",
1364                                    ret, pfault->token, address, prefetch_len);
1365                }
1366        }
1367}
1368
1369static void mlx5_ib_pfault(struct mlx5_ib_dev *dev, struct mlx5_pagefault *pfault)
1370{
1371        u8 event_subtype = pfault->event_subtype;
1372
1373        switch (event_subtype) {
1374        case MLX5_PFAULT_SUBTYPE_WQE:
1375                mlx5_ib_mr_wqe_pfault_handler(dev, pfault);
1376                break;
1377        case MLX5_PFAULT_SUBTYPE_RDMA:
1378                mlx5_ib_mr_rdma_pfault_handler(dev, pfault);
1379                break;
1380        default:
1381                mlx5_ib_err(dev, "Invalid page fault event subtype: 0x%x\n",
1382                            event_subtype);
1383                mlx5_ib_page_fault_resume(dev, pfault, 1);
1384        }
1385}
1386
1387static void mlx5_ib_eqe_pf_action(struct work_struct *work)
1388{
1389        struct mlx5_pagefault *pfault = container_of(work,
1390                                                     struct mlx5_pagefault,
1391                                                     work);
1392        struct mlx5_ib_pf_eq *eq = pfault->eq;
1393
1394        mlx5_ib_pfault(eq->dev, pfault);
1395        mempool_free(pfault, eq->pool);
1396}
1397
1398static void mlx5_ib_eq_pf_process(struct mlx5_ib_pf_eq *eq)
1399{
1400        struct mlx5_eqe_page_fault *pf_eqe;
1401        struct mlx5_pagefault *pfault;
1402        struct mlx5_eqe *eqe;
1403        int cc = 0;
1404
1405        while ((eqe = mlx5_eq_get_eqe(eq->core, cc))) {
1406                pfault = mempool_alloc(eq->pool, GFP_ATOMIC);
1407                if (!pfault) {
1408                        schedule_work(&eq->work);
1409                        break;
1410                }
1411
1412                pf_eqe = &eqe->data.page_fault;
1413                pfault->event_subtype = eqe->sub_type;
1414                pfault->bytes_committed = be32_to_cpu(pf_eqe->bytes_committed);
1415
1416                mlx5_ib_dbg(eq->dev,
1417                            "PAGE_FAULT: subtype: 0x%02x, bytes_committed: 0x%06x\n",
1418                            eqe->sub_type, pfault->bytes_committed);
1419
1420                switch (eqe->sub_type) {
1421                case MLX5_PFAULT_SUBTYPE_RDMA:
1422                        /* RDMA based event */
1423                        pfault->type =
1424                                be32_to_cpu(pf_eqe->rdma.pftype_token) >> 24;
1425                        pfault->token =
1426                                be32_to_cpu(pf_eqe->rdma.pftype_token) &
1427                                MLX5_24BIT_MASK;
1428                        pfault->rdma.r_key =
1429                                be32_to_cpu(pf_eqe->rdma.r_key);
1430                        pfault->rdma.packet_size =
1431                                be16_to_cpu(pf_eqe->rdma.packet_length);
1432                        pfault->rdma.rdma_op_len =
1433                                be32_to_cpu(pf_eqe->rdma.rdma_op_len);
1434                        pfault->rdma.rdma_va =
1435                                be64_to_cpu(pf_eqe->rdma.rdma_va);
1436                        mlx5_ib_dbg(eq->dev,
1437                                    "PAGE_FAULT: type:0x%x, token: 0x%06x, r_key: 0x%08x\n",
1438                                    pfault->type, pfault->token,
1439                                    pfault->rdma.r_key);
1440                        mlx5_ib_dbg(eq->dev,
1441                                    "PAGE_FAULT: rdma_op_len: 0x%08x, rdma_va: 0x%016llx\n",
1442                                    pfault->rdma.rdma_op_len,
1443                                    pfault->rdma.rdma_va);
1444                        break;
1445
1446                case MLX5_PFAULT_SUBTYPE_WQE:
1447                        /* WQE based event */
1448                        pfault->type =
1449                                (be32_to_cpu(pf_eqe->wqe.pftype_wq) >> 24) & 0x7;
1450                        pfault->token =
1451                                be32_to_cpu(pf_eqe->wqe.token);
1452                        pfault->wqe.wq_num =
1453                                be32_to_cpu(pf_eqe->wqe.pftype_wq) &
1454                                MLX5_24BIT_MASK;
1455                        pfault->wqe.wqe_index =
1456                                be16_to_cpu(pf_eqe->wqe.wqe_index);
1457                        pfault->wqe.packet_size =
1458                                be16_to_cpu(pf_eqe->wqe.packet_length);
1459                        mlx5_ib_dbg(eq->dev,
1460                                    "PAGE_FAULT: type:0x%x, token: 0x%06x, wq_num: 0x%06x, wqe_index: 0x%04x\n",
1461                                    pfault->type, pfault->token,
1462                                    pfault->wqe.wq_num,
1463                                    pfault->wqe.wqe_index);
1464                        break;
1465
1466                default:
1467                        mlx5_ib_warn(eq->dev,
1468                                     "Unsupported page fault event sub-type: 0x%02hhx\n",
1469                                     eqe->sub_type);
1470                        /* Unsupported page faults should still be
1471                         * resolved by the page fault handler
1472                         */
1473                }
1474
1475                pfault->eq = eq;
1476                INIT_WORK(&pfault->work, mlx5_ib_eqe_pf_action);
1477                queue_work(eq->wq, &pfault->work);
1478
1479                cc = mlx5_eq_update_cc(eq->core, ++cc);
1480        }
1481
1482        mlx5_eq_update_ci(eq->core, cc, 1);
1483}
1484
1485static int mlx5_ib_eq_pf_int(struct notifier_block *nb, unsigned long type,
1486                             void *data)
1487{
1488        struct mlx5_ib_pf_eq *eq =
1489                container_of(nb, struct mlx5_ib_pf_eq, irq_nb);
1490        unsigned long flags;
1491
1492        if (spin_trylock_irqsave(&eq->lock, flags)) {
1493                mlx5_ib_eq_pf_process(eq);
1494                spin_unlock_irqrestore(&eq->lock, flags);
1495        } else {
1496                schedule_work(&eq->work);
1497        }
1498
1499        return IRQ_HANDLED;
1500}
1501
1502/* mempool_refill() was proposed but unfortunately wasn't accepted
1503 * http://lkml.iu.edu/hypermail/linux/kernel/1512.1/05073.html
1504 * Cheap workaround.
1505 */
1506static void mempool_refill(mempool_t *pool)
1507{
1508        while (pool->curr_nr < pool->min_nr)
1509                mempool_free(mempool_alloc(pool, GFP_KERNEL), pool);
1510}
1511
1512static void mlx5_ib_eq_pf_action(struct work_struct *work)
1513{
1514        struct mlx5_ib_pf_eq *eq =
1515                container_of(work, struct mlx5_ib_pf_eq, work);
1516
1517        mempool_refill(eq->pool);
1518
1519        spin_lock_irq(&eq->lock);
1520        mlx5_ib_eq_pf_process(eq);
1521        spin_unlock_irq(&eq->lock);
1522}
1523
1524enum {
1525        MLX5_IB_NUM_PF_EQE      = 0x1000,
1526        MLX5_IB_NUM_PF_DRAIN    = 64,
1527};
1528
1529static int
1530mlx5_ib_create_pf_eq(struct mlx5_ib_dev *dev, struct mlx5_ib_pf_eq *eq)
1531{
1532        struct mlx5_eq_param param = {};
1533        int err;
1534
1535        INIT_WORK(&eq->work, mlx5_ib_eq_pf_action);
1536        spin_lock_init(&eq->lock);
1537        eq->dev = dev;
1538
1539        eq->pool = mempool_create_kmalloc_pool(MLX5_IB_NUM_PF_DRAIN,
1540                                               sizeof(struct mlx5_pagefault));
1541        if (!eq->pool)
1542                return -ENOMEM;
1543
1544        eq->wq = alloc_workqueue("mlx5_ib_page_fault",
1545                                 WQ_HIGHPRI | WQ_UNBOUND | WQ_MEM_RECLAIM,
1546                                 MLX5_NUM_CMD_EQE);
1547        if (!eq->wq) {
1548                err = -ENOMEM;
1549                goto err_mempool;
1550        }
1551
1552        eq->irq_nb.notifier_call = mlx5_ib_eq_pf_int;
1553        param = (struct mlx5_eq_param) {
1554                .irq_index = 0,
1555                .nent = MLX5_IB_NUM_PF_EQE,
1556        };
1557        param.mask[0] = 1ull << MLX5_EVENT_TYPE_PAGE_FAULT;
1558        eq->core = mlx5_eq_create_generic(dev->mdev, &param);
1559        if (IS_ERR(eq->core)) {
1560                err = PTR_ERR(eq->core);
1561                goto err_wq;
1562        }
1563        err = mlx5_eq_enable(dev->mdev, eq->core, &eq->irq_nb);
1564        if (err) {
1565                mlx5_ib_err(dev, "failed to enable odp EQ %d\n", err);
1566                goto err_eq;
1567        }
1568
1569        return 0;
1570err_eq:
1571        mlx5_eq_destroy_generic(dev->mdev, eq->core);
1572err_wq:
1573        destroy_workqueue(eq->wq);
1574err_mempool:
1575        mempool_destroy(eq->pool);
1576        return err;
1577}
1578
1579static int
1580mlx5_ib_destroy_pf_eq(struct mlx5_ib_dev *dev, struct mlx5_ib_pf_eq *eq)
1581{
1582        int err;
1583
1584        mlx5_eq_disable(dev->mdev, eq->core, &eq->irq_nb);
1585        err = mlx5_eq_destroy_generic(dev->mdev, eq->core);
1586        cancel_work_sync(&eq->work);
1587        destroy_workqueue(eq->wq);
1588        mempool_destroy(eq->pool);
1589
1590        return err;
1591}
1592
1593void mlx5_odp_init_mr_cache_entry(struct mlx5_cache_ent *ent)
1594{
1595        if (!(ent->dev->odp_caps.general_caps & IB_ODP_SUPPORT_IMPLICIT))
1596                return;
1597
1598        switch (ent->order - 2) {
1599        case MLX5_IMR_MTT_CACHE_ENTRY:
1600                ent->page = PAGE_SHIFT;
1601                ent->xlt = MLX5_IMR_MTT_ENTRIES *
1602                           sizeof(struct mlx5_mtt) /
1603                           MLX5_IB_UMR_OCTOWORD;
1604                ent->access_mode = MLX5_MKC_ACCESS_MODE_MTT;
1605                ent->limit = 0;
1606                break;
1607
1608        case MLX5_IMR_KSM_CACHE_ENTRY:
1609                ent->page = MLX5_KSM_PAGE_SHIFT;
1610                ent->xlt = mlx5_imr_ksm_entries *
1611                           sizeof(struct mlx5_klm) /
1612                           MLX5_IB_UMR_OCTOWORD;
1613                ent->access_mode = MLX5_MKC_ACCESS_MODE_KSM;
1614                ent->limit = 0;
1615                break;
1616        }
1617}
1618
1619static const struct ib_device_ops mlx5_ib_dev_odp_ops = {
1620        .advise_mr = mlx5_ib_advise_mr,
1621};
1622
1623int mlx5_ib_odp_init_one(struct mlx5_ib_dev *dev)
1624{
1625        int ret = 0;
1626
1627        if (!(dev->odp_caps.general_caps & IB_ODP_SUPPORT))
1628                return ret;
1629
1630        ib_set_device_ops(&dev->ib_dev, &mlx5_ib_dev_odp_ops);
1631
1632        if (dev->odp_caps.general_caps & IB_ODP_SUPPORT_IMPLICIT) {
1633                ret = mlx5_cmd_null_mkey(dev->mdev, &dev->null_mkey);
1634                if (ret) {
1635                        mlx5_ib_err(dev, "Error getting null_mkey %d\n", ret);
1636                        return ret;
1637                }
1638        }
1639
1640        ret = mlx5_ib_create_pf_eq(dev, &dev->odp_pf_eq);
1641
1642        return ret;
1643}
1644
1645void mlx5_ib_odp_cleanup_one(struct mlx5_ib_dev *dev)
1646{
1647        if (!(dev->odp_caps.general_caps & IB_ODP_SUPPORT))
1648                return;
1649
1650        mlx5_ib_destroy_pf_eq(dev, &dev->odp_pf_eq);
1651}
1652
1653int mlx5_ib_odp_init(void)
1654{
1655        mlx5_imr_ksm_entries = BIT_ULL(get_order(TASK_SIZE) -
1656                                       MLX5_IMR_MTT_BITS);
1657
1658        return 0;
1659}
1660
1661struct prefetch_mr_work {
1662        struct work_struct work;
1663        struct ib_pd *pd;
1664        u32 pf_flags;
1665        u32 num_sge;
1666        struct ib_sge sg_list[0];
1667};
1668
1669static void num_pending_prefetch_dec(struct mlx5_ib_dev *dev,
1670                                     struct ib_sge *sg_list, u32 num_sge,
1671                                     u32 from)
1672{
1673        u32 i;
1674        int srcu_key;
1675
1676        srcu_key = srcu_read_lock(&dev->mr_srcu);
1677
1678        for (i = from; i < num_sge; ++i) {
1679                struct mlx5_core_mkey *mmkey;
1680                struct mlx5_ib_mr *mr;
1681
1682                mmkey = xa_load(&dev->mdev->priv.mkey_table,
1683                                mlx5_base_mkey(sg_list[i].lkey));
1684                mr = container_of(mmkey, struct mlx5_ib_mr, mmkey);
1685                atomic_dec(&mr->num_pending_prefetch);
1686        }
1687
1688        srcu_read_unlock(&dev->mr_srcu, srcu_key);
1689}
1690
1691static bool num_pending_prefetch_inc(struct ib_pd *pd,
1692                                     struct ib_sge *sg_list, u32 num_sge)
1693{
1694        struct mlx5_ib_dev *dev = to_mdev(pd->device);
1695        bool ret = true;
1696        u32 i;
1697
1698        for (i = 0; i < num_sge; ++i) {
1699                struct mlx5_core_mkey *mmkey;
1700                struct mlx5_ib_mr *mr;
1701
1702                mmkey = xa_load(&dev->mdev->priv.mkey_table,
1703                                mlx5_base_mkey(sg_list[i].lkey));
1704                if (!mmkey || mmkey->key != sg_list[i].lkey) {
1705                        ret = false;
1706                        break;
1707                }
1708
1709                if (mmkey->type != MLX5_MKEY_MR) {
1710                        ret = false;
1711                        break;
1712                }
1713
1714                mr = container_of(mmkey, struct mlx5_ib_mr, mmkey);
1715
1716                if (mr->ibmr.pd != pd) {
1717                        ret = false;
1718                        break;
1719                }
1720
1721                if (!mr->live) {
1722                        ret = false;
1723                        break;
1724                }
1725
1726                atomic_inc(&mr->num_pending_prefetch);
1727        }
1728
1729        if (!ret)
1730                num_pending_prefetch_dec(dev, sg_list, i, 0);
1731
1732        return ret;
1733}
1734
1735static int mlx5_ib_prefetch_sg_list(struct ib_pd *pd, u32 pf_flags,
1736                                    struct ib_sge *sg_list, u32 num_sge)
1737{
1738        u32 i;
1739        int ret = 0;
1740        struct mlx5_ib_dev *dev = to_mdev(pd->device);
1741
1742        for (i = 0; i < num_sge; ++i) {
1743                struct ib_sge *sg = &sg_list[i];
1744                int bytes_committed = 0;
1745
1746                ret = pagefault_single_data_segment(dev, pd, sg->lkey, sg->addr,
1747                                                    sg->length,
1748                                                    &bytes_committed, NULL,
1749                                                    pf_flags);
1750                if (ret < 0)
1751                        break;
1752        }
1753
1754        return ret < 0 ? ret : 0;
1755}
1756
1757static void mlx5_ib_prefetch_mr_work(struct work_struct *work)
1758{
1759        struct prefetch_mr_work *w =
1760                container_of(work, struct prefetch_mr_work, work);
1761
1762        if (ib_device_try_get(w->pd->device)) {
1763                mlx5_ib_prefetch_sg_list(w->pd, w->pf_flags, w->sg_list,
1764                                         w->num_sge);
1765                ib_device_put(w->pd->device);
1766        }
1767
1768        num_pending_prefetch_dec(to_mdev(w->pd->device), w->sg_list,
1769                                 w->num_sge, 0);
1770        kvfree(w);
1771}
1772
1773int mlx5_ib_advise_mr_prefetch(struct ib_pd *pd,
1774                               enum ib_uverbs_advise_mr_advice advice,
1775                               u32 flags, struct ib_sge *sg_list, u32 num_sge)
1776{
1777        struct mlx5_ib_dev *dev = to_mdev(pd->device);
1778        u32 pf_flags = MLX5_PF_FLAGS_PREFETCH;
1779        struct prefetch_mr_work *work;
1780        bool valid_req;
1781        int srcu_key;
1782
1783        if (advice == IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH)
1784                pf_flags |= MLX5_PF_FLAGS_DOWNGRADE;
1785
1786        if (flags & IB_UVERBS_ADVISE_MR_FLAG_FLUSH)
1787                return mlx5_ib_prefetch_sg_list(pd, pf_flags, sg_list,
1788                                                num_sge);
1789
1790        work = kvzalloc(struct_size(work, sg_list, num_sge), GFP_KERNEL);
1791        if (!work)
1792                return -ENOMEM;
1793
1794        memcpy(work->sg_list, sg_list, num_sge * sizeof(struct ib_sge));
1795
1796        /* It is guaranteed that the pd when work is executed is the pd when
1797         * work was queued since pd can't be destroyed while it holds MRs and
1798         * destroying a MR leads to flushing the workquque
1799         */
1800        work->pd = pd;
1801        work->pf_flags = pf_flags;
1802        work->num_sge = num_sge;
1803
1804        INIT_WORK(&work->work, mlx5_ib_prefetch_mr_work);
1805
1806        srcu_key = srcu_read_lock(&dev->mr_srcu);
1807
1808        valid_req = num_pending_prefetch_inc(pd, sg_list, num_sge);
1809        if (valid_req)
1810                queue_work(system_unbound_wq, &work->work);
1811        else
1812                kvfree(work);
1813
1814        srcu_read_unlock(&dev->mr_srcu, srcu_key);
1815
1816        return valid_req ? 0 : -EINVAL;
1817}
1818