linux/drivers/infiniband/sw/rdmavt/mr.c
<<
>>
Prefs
   1/*
   2 * Copyright(c) 2016 Intel Corporation.
   3 *
   4 * This file is provided under a dual BSD/GPLv2 license.  When using or
   5 * redistributing this file, you may do so under either license.
   6 *
   7 * GPL LICENSE SUMMARY
   8 *
   9 * This program is free software; you can redistribute it and/or modify
  10 * it under the terms of version 2 of the GNU General Public License as
  11 * published by the Free Software Foundation.
  12 *
  13 * This program is distributed in the hope that it will be useful, but
  14 * WITHOUT ANY WARRANTY; without even the implied warranty of
  15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
  16 * General Public License for more details.
  17 *
  18 * BSD LICENSE
  19 *
  20 * Redistribution and use in source and binary forms, with or without
  21 * modification, are permitted provided that the following conditions
  22 * are met:
  23 *
  24 *  - Redistributions of source code must retain the above copyright
  25 *    notice, this list of conditions and the following disclaimer.
  26 *  - Redistributions in binary form must reproduce the above copyright
  27 *    notice, this list of conditions and the following disclaimer in
  28 *    the documentation and/or other materials provided with the
  29 *    distribution.
  30 *  - Neither the name of Intel Corporation nor the names of its
  31 *    contributors may be used to endorse or promote products derived
  32 *    from this software without specific prior written permission.
  33 *
  34 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  35 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  36 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  37 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  38 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  39 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  40 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  41 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  42 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  43 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  44 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  45 *
  46 */
  47
  48#include <linux/slab.h>
  49#include <linux/vmalloc.h>
  50#include <rdma/ib_umem.h>
  51#include <rdma/rdma_vt.h>
  52#include "vt.h"
  53#include "mr.h"
  54#include "trace.h"
  55
  56/**
  57 * rvt_driver_mr_init - Init MR resources per driver
  58 * @rdi: rvt dev struct
  59 *
  60 * Do any intilization needed when a driver registers with rdmavt.
  61 *
  62 * Return: 0 on success or errno on failure
  63 */
  64int rvt_driver_mr_init(struct rvt_dev_info *rdi)
  65{
  66        unsigned int lkey_table_size = rdi->dparms.lkey_table_size;
  67        unsigned lk_tab_size;
  68        int i;
  69
  70        /*
  71         * The top hfi1_lkey_table_size bits are used to index the
  72         * table.  The lower 8 bits can be owned by the user (copied from
  73         * the LKEY).  The remaining bits act as a generation number or tag.
  74         */
  75        if (!lkey_table_size)
  76                return -EINVAL;
  77
  78        spin_lock_init(&rdi->lkey_table.lock);
  79
  80        /* ensure generation is at least 4 bits */
  81        if (lkey_table_size > RVT_MAX_LKEY_TABLE_BITS) {
  82                rvt_pr_warn(rdi, "lkey bits %u too large, reduced to %u\n",
  83                            lkey_table_size, RVT_MAX_LKEY_TABLE_BITS);
  84                rdi->dparms.lkey_table_size = RVT_MAX_LKEY_TABLE_BITS;
  85                lkey_table_size = rdi->dparms.lkey_table_size;
  86        }
  87        rdi->lkey_table.max = 1 << lkey_table_size;
  88        rdi->lkey_table.shift = 32 - lkey_table_size;
  89        lk_tab_size = rdi->lkey_table.max * sizeof(*rdi->lkey_table.table);
  90        rdi->lkey_table.table = (struct rvt_mregion __rcu **)
  91                               vmalloc_node(lk_tab_size, rdi->dparms.node);
  92        if (!rdi->lkey_table.table)
  93                return -ENOMEM;
  94
  95        RCU_INIT_POINTER(rdi->dma_mr, NULL);
  96        for (i = 0; i < rdi->lkey_table.max; i++)
  97                RCU_INIT_POINTER(rdi->lkey_table.table[i], NULL);
  98
  99        rdi->dparms.props.max_mr = rdi->lkey_table.max;
 100        return 0;
 101}
 102
 103/**
 104 * rvt_mr_exit - clean up MR
 105 * @rdi: rvt dev structure
 106 *
 107 * called when drivers have unregistered or perhaps failed to register with us
 108 */
 109void rvt_mr_exit(struct rvt_dev_info *rdi)
 110{
 111        if (rdi->dma_mr)
 112                rvt_pr_err(rdi, "DMA MR not null!\n");
 113
 114        vfree(rdi->lkey_table.table);
 115}
 116
 117static void rvt_deinit_mregion(struct rvt_mregion *mr)
 118{
 119        int i = mr->mapsz;
 120
 121        mr->mapsz = 0;
 122        while (i)
 123                kfree(mr->map[--i]);
 124        percpu_ref_exit(&mr->refcount);
 125}
 126
 127static void __rvt_mregion_complete(struct percpu_ref *ref)
 128{
 129        struct rvt_mregion *mr = container_of(ref, struct rvt_mregion,
 130                                              refcount);
 131
 132        complete(&mr->comp);
 133}
 134
 135static int rvt_init_mregion(struct rvt_mregion *mr, struct ib_pd *pd,
 136                            int count, unsigned int percpu_flags)
 137{
 138        int m, i = 0;
 139        struct rvt_dev_info *dev = ib_to_rvt(pd->device);
 140
 141        mr->mapsz = 0;
 142        m = (count + RVT_SEGSZ - 1) / RVT_SEGSZ;
 143        for (; i < m; i++) {
 144                mr->map[i] = kzalloc_node(sizeof(*mr->map[0]), GFP_KERNEL,
 145                                          dev->dparms.node);
 146                if (!mr->map[i])
 147                        goto bail;
 148                mr->mapsz++;
 149        }
 150        init_completion(&mr->comp);
 151        /* count returning the ptr to user */
 152        if (percpu_ref_init(&mr->refcount, &__rvt_mregion_complete,
 153                            percpu_flags, GFP_KERNEL))
 154                goto bail;
 155
 156        atomic_set(&mr->lkey_invalid, 0);
 157        mr->pd = pd;
 158        mr->max_segs = count;
 159        return 0;
 160bail:
 161        rvt_deinit_mregion(mr);
 162        return -ENOMEM;
 163}
 164
 165/**
 166 * rvt_alloc_lkey - allocate an lkey
 167 * @mr: memory region that this lkey protects
 168 * @dma_region: 0->normal key, 1->restricted DMA key
 169 *
 170 * Returns 0 if successful, otherwise returns -errno.
 171 *
 172 * Increments mr reference count as required.
 173 *
 174 * Sets the lkey field mr for non-dma regions.
 175 *
 176 */
 177static int rvt_alloc_lkey(struct rvt_mregion *mr, int dma_region)
 178{
 179        unsigned long flags;
 180        u32 r;
 181        u32 n;
 182        int ret = 0;
 183        struct rvt_dev_info *dev = ib_to_rvt(mr->pd->device);
 184        struct rvt_lkey_table *rkt = &dev->lkey_table;
 185
 186        rvt_get_mr(mr);
 187        spin_lock_irqsave(&rkt->lock, flags);
 188
 189        /* special case for dma_mr lkey == 0 */
 190        if (dma_region) {
 191                struct rvt_mregion *tmr;
 192
 193                tmr = rcu_access_pointer(dev->dma_mr);
 194                if (!tmr) {
 195                        mr->lkey_published = 1;
 196                        /* Insure published written first */
 197                        rcu_assign_pointer(dev->dma_mr, mr);
 198                        rvt_get_mr(mr);
 199                }
 200                goto success;
 201        }
 202
 203        /* Find the next available LKEY */
 204        r = rkt->next;
 205        n = r;
 206        for (;;) {
 207                if (!rcu_access_pointer(rkt->table[r]))
 208                        break;
 209                r = (r + 1) & (rkt->max - 1);
 210                if (r == n)
 211                        goto bail;
 212        }
 213        rkt->next = (r + 1) & (rkt->max - 1);
 214        /*
 215         * Make sure lkey is never zero which is reserved to indicate an
 216         * unrestricted LKEY.
 217         */
 218        rkt->gen++;
 219        /*
 220         * bits are capped to ensure enough bits for generation number
 221         */
 222        mr->lkey = (r << (32 - dev->dparms.lkey_table_size)) |
 223                ((((1 << (24 - dev->dparms.lkey_table_size)) - 1) & rkt->gen)
 224                 << 8);
 225        if (mr->lkey == 0) {
 226                mr->lkey |= 1 << 8;
 227                rkt->gen++;
 228        }
 229        mr->lkey_published = 1;
 230        /* Insure published written first */
 231        rcu_assign_pointer(rkt->table[r], mr);
 232success:
 233        spin_unlock_irqrestore(&rkt->lock, flags);
 234out:
 235        return ret;
 236bail:
 237        rvt_put_mr(mr);
 238        spin_unlock_irqrestore(&rkt->lock, flags);
 239        ret = -ENOMEM;
 240        goto out;
 241}
 242
 243/**
 244 * rvt_free_lkey - free an lkey
 245 * @mr: mr to free from tables
 246 */
 247static void rvt_free_lkey(struct rvt_mregion *mr)
 248{
 249        unsigned long flags;
 250        u32 lkey = mr->lkey;
 251        u32 r;
 252        struct rvt_dev_info *dev = ib_to_rvt(mr->pd->device);
 253        struct rvt_lkey_table *rkt = &dev->lkey_table;
 254        int freed = 0;
 255
 256        spin_lock_irqsave(&rkt->lock, flags);
 257        if (!lkey) {
 258                if (mr->lkey_published) {
 259                        mr->lkey_published = 0;
 260                        /* insure published is written before pointer */
 261                        rcu_assign_pointer(dev->dma_mr, NULL);
 262                        rvt_put_mr(mr);
 263                }
 264        } else {
 265                if (!mr->lkey_published)
 266                        goto out;
 267                r = lkey >> (32 - dev->dparms.lkey_table_size);
 268                mr->lkey_published = 0;
 269                /* insure published is written before pointer */
 270                rcu_assign_pointer(rkt->table[r], NULL);
 271        }
 272        freed++;
 273out:
 274        spin_unlock_irqrestore(&rkt->lock, flags);
 275        if (freed)
 276                percpu_ref_kill(&mr->refcount);
 277}
 278
 279static struct rvt_mr *__rvt_alloc_mr(int count, struct ib_pd *pd)
 280{
 281        struct rvt_mr *mr;
 282        int rval = -ENOMEM;
 283        int m;
 284
 285        /* Allocate struct plus pointers to first level page tables. */
 286        m = (count + RVT_SEGSZ - 1) / RVT_SEGSZ;
 287        mr = kzalloc(struct_size(mr, mr.map, m), GFP_KERNEL);
 288        if (!mr)
 289                goto bail;
 290
 291        rval = rvt_init_mregion(&mr->mr, pd, count, 0);
 292        if (rval)
 293                goto bail;
 294        /*
 295         * ib_reg_phys_mr() will initialize mr->ibmr except for
 296         * lkey and rkey.
 297         */
 298        rval = rvt_alloc_lkey(&mr->mr, 0);
 299        if (rval)
 300                goto bail_mregion;
 301        mr->ibmr.lkey = mr->mr.lkey;
 302        mr->ibmr.rkey = mr->mr.lkey;
 303done:
 304        return mr;
 305
 306bail_mregion:
 307        rvt_deinit_mregion(&mr->mr);
 308bail:
 309        kfree(mr);
 310        mr = ERR_PTR(rval);
 311        goto done;
 312}
 313
 314static void __rvt_free_mr(struct rvt_mr *mr)
 315{
 316        rvt_free_lkey(&mr->mr);
 317        rvt_deinit_mregion(&mr->mr);
 318        kfree(mr);
 319}
 320
 321/**
 322 * rvt_get_dma_mr - get a DMA memory region
 323 * @pd: protection domain for this memory region
 324 * @acc: access flags
 325 *
 326 * Return: the memory region on success, otherwise returns an errno.
 327 */
 328struct ib_mr *rvt_get_dma_mr(struct ib_pd *pd, int acc)
 329{
 330        struct rvt_mr *mr;
 331        struct ib_mr *ret;
 332        int rval;
 333
 334        if (ibpd_to_rvtpd(pd)->user)
 335                return ERR_PTR(-EPERM);
 336
 337        mr = kzalloc(sizeof(*mr), GFP_KERNEL);
 338        if (!mr) {
 339                ret = ERR_PTR(-ENOMEM);
 340                goto bail;
 341        }
 342
 343        rval = rvt_init_mregion(&mr->mr, pd, 0, 0);
 344        if (rval) {
 345                ret = ERR_PTR(rval);
 346                goto bail;
 347        }
 348
 349        rval = rvt_alloc_lkey(&mr->mr, 1);
 350        if (rval) {
 351                ret = ERR_PTR(rval);
 352                goto bail_mregion;
 353        }
 354
 355        mr->mr.access_flags = acc;
 356        ret = &mr->ibmr;
 357done:
 358        return ret;
 359
 360bail_mregion:
 361        rvt_deinit_mregion(&mr->mr);
 362bail:
 363        kfree(mr);
 364        goto done;
 365}
 366
 367/**
 368 * rvt_reg_user_mr - register a userspace memory region
 369 * @pd: protection domain for this memory region
 370 * @start: starting userspace address
 371 * @length: length of region to register
 372 * @virt_addr: associated virtual address
 373 * @mr_access_flags: access flags for this memory region
 374 * @udata: unused by the driver
 375 *
 376 * Return: the memory region on success, otherwise returns an errno.
 377 */
 378struct ib_mr *rvt_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
 379                              u64 virt_addr, int mr_access_flags,
 380                              struct ib_udata *udata)
 381{
 382        struct rvt_mr *mr;
 383        struct ib_umem *umem;
 384        struct sg_page_iter sg_iter;
 385        int n, m;
 386        struct ib_mr *ret;
 387
 388        if (length == 0)
 389                return ERR_PTR(-EINVAL);
 390
 391        umem = ib_umem_get(pd->device, start, length, mr_access_flags);
 392        if (IS_ERR(umem))
 393                return (void *)umem;
 394
 395        n = ib_umem_num_pages(umem);
 396
 397        mr = __rvt_alloc_mr(n, pd);
 398        if (IS_ERR(mr)) {
 399                ret = (struct ib_mr *)mr;
 400                goto bail_umem;
 401        }
 402
 403        mr->mr.user_base = start;
 404        mr->mr.iova = virt_addr;
 405        mr->mr.length = length;
 406        mr->mr.offset = ib_umem_offset(umem);
 407        mr->mr.access_flags = mr_access_flags;
 408        mr->umem = umem;
 409
 410        mr->mr.page_shift = PAGE_SHIFT;
 411        m = 0;
 412        n = 0;
 413        for_each_sg_page (umem->sg_head.sgl, &sg_iter, umem->nmap, 0) {
 414                void *vaddr;
 415
 416                vaddr = page_address(sg_page_iter_page(&sg_iter));
 417                if (!vaddr) {
 418                        ret = ERR_PTR(-EINVAL);
 419                        goto bail_inval;
 420                }
 421                mr->mr.map[m]->segs[n].vaddr = vaddr;
 422                mr->mr.map[m]->segs[n].length = PAGE_SIZE;
 423                trace_rvt_mr_user_seg(&mr->mr, m, n, vaddr, PAGE_SIZE);
 424                if (++n == RVT_SEGSZ) {
 425                        m++;
 426                        n = 0;
 427                }
 428        }
 429        return &mr->ibmr;
 430
 431bail_inval:
 432        __rvt_free_mr(mr);
 433
 434bail_umem:
 435        ib_umem_release(umem);
 436
 437        return ret;
 438}
 439
 440/**
 441 * rvt_dereg_clean_qp_cb - callback from iterator
 442 * @qp: the qp
 443 * @v: the mregion (as u64)
 444 *
 445 * This routine fields the callback for all QPs and
 446 * for QPs in the same PD as the MR will call the
 447 * rvt_qp_mr_clean() to potentially cleanup references.
 448 */
 449static void rvt_dereg_clean_qp_cb(struct rvt_qp *qp, u64 v)
 450{
 451        struct rvt_mregion *mr = (struct rvt_mregion *)v;
 452
 453        /* skip PDs that are not ours */
 454        if (mr->pd != qp->ibqp.pd)
 455                return;
 456        rvt_qp_mr_clean(qp, mr->lkey);
 457}
 458
 459/**
 460 * rvt_dereg_clean_qps - find QPs for reference cleanup
 461 * @mr: the MR that is being deregistered
 462 *
 463 * This routine iterates RC QPs looking for references
 464 * to the lkey noted in mr.
 465 */
 466static void rvt_dereg_clean_qps(struct rvt_mregion *mr)
 467{
 468        struct rvt_dev_info *rdi = ib_to_rvt(mr->pd->device);
 469
 470        rvt_qp_iter(rdi, (u64)mr, rvt_dereg_clean_qp_cb);
 471}
 472
 473/**
 474 * rvt_check_refs - check references
 475 * @mr: the megion
 476 * @t: the caller identification
 477 *
 478 * This routine checks MRs holding a reference during
 479 * when being de-registered.
 480 *
 481 * If the count is non-zero, the code calls a clean routine then
 482 * waits for the timeout for the count to zero.
 483 */
 484static int rvt_check_refs(struct rvt_mregion *mr, const char *t)
 485{
 486        unsigned long timeout;
 487        struct rvt_dev_info *rdi = ib_to_rvt(mr->pd->device);
 488
 489        if (mr->lkey) {
 490                /* avoid dma mr */
 491                rvt_dereg_clean_qps(mr);
 492                /* @mr was indexed on rcu protected @lkey_table */
 493                synchronize_rcu();
 494        }
 495
 496        timeout = wait_for_completion_timeout(&mr->comp, 5 * HZ);
 497        if (!timeout) {
 498                rvt_pr_err(rdi,
 499                           "%s timeout mr %p pd %p lkey %x refcount %ld\n",
 500                           t, mr, mr->pd, mr->lkey,
 501                           atomic_long_read(&mr->refcount.data->count));
 502                rvt_get_mr(mr);
 503                return -EBUSY;
 504        }
 505        return 0;
 506}
 507
 508/**
 509 * rvt_mr_has_lkey - is MR
 510 * @mr: the mregion
 511 * @lkey: the lkey
 512 */
 513bool rvt_mr_has_lkey(struct rvt_mregion *mr, u32 lkey)
 514{
 515        return mr && lkey == mr->lkey;
 516}
 517
 518/**
 519 * rvt_ss_has_lkey - is mr in sge tests
 520 * @ss: the sge state
 521 * @lkey: the lkey
 522 *
 523 * This code tests for an MR in the indicated
 524 * sge state.
 525 */
 526bool rvt_ss_has_lkey(struct rvt_sge_state *ss, u32 lkey)
 527{
 528        int i;
 529        bool rval = false;
 530
 531        if (!ss->num_sge)
 532                return rval;
 533        /* first one */
 534        rval = rvt_mr_has_lkey(ss->sge.mr, lkey);
 535        /* any others */
 536        for (i = 0; !rval && i < ss->num_sge - 1; i++)
 537                rval = rvt_mr_has_lkey(ss->sg_list[i].mr, lkey);
 538        return rval;
 539}
 540
 541/**
 542 * rvt_dereg_mr - unregister and free a memory region
 543 * @ibmr: the memory region to free
 544 * @udata: unused by the driver
 545 *
 546 * Note that this is called to free MRs created by rvt_get_dma_mr()
 547 * or rvt_reg_user_mr().
 548 *
 549 * Returns 0 on success.
 550 */
 551int rvt_dereg_mr(struct ib_mr *ibmr, struct ib_udata *udata)
 552{
 553        struct rvt_mr *mr = to_imr(ibmr);
 554        int ret;
 555
 556        rvt_free_lkey(&mr->mr);
 557
 558        rvt_put_mr(&mr->mr); /* will set completion if last */
 559        ret = rvt_check_refs(&mr->mr, __func__);
 560        if (ret)
 561                goto out;
 562        rvt_deinit_mregion(&mr->mr);
 563        ib_umem_release(mr->umem);
 564        kfree(mr);
 565out:
 566        return ret;
 567}
 568
 569/**
 570 * rvt_alloc_mr - Allocate a memory region usable with the
 571 * @pd: protection domain for this memory region
 572 * @mr_type: mem region type
 573 * @max_num_sg: Max number of segments allowed
 574 *
 575 * Return: the memory region on success, otherwise return an errno.
 576 */
 577struct ib_mr *rvt_alloc_mr(struct ib_pd *pd, enum ib_mr_type mr_type,
 578                           u32 max_num_sg)
 579{
 580        struct rvt_mr *mr;
 581
 582        if (mr_type != IB_MR_TYPE_MEM_REG)
 583                return ERR_PTR(-EINVAL);
 584
 585        mr = __rvt_alloc_mr(max_num_sg, pd);
 586        if (IS_ERR(mr))
 587                return (struct ib_mr *)mr;
 588
 589        return &mr->ibmr;
 590}
 591
 592/**
 593 * rvt_set_page - page assignment function called by ib_sg_to_pages
 594 * @ibmr: memory region
 595 * @addr: dma address of mapped page
 596 *
 597 * Return: 0 on success
 598 */
 599static int rvt_set_page(struct ib_mr *ibmr, u64 addr)
 600{
 601        struct rvt_mr *mr = to_imr(ibmr);
 602        u32 ps = 1 << mr->mr.page_shift;
 603        u32 mapped_segs = mr->mr.length >> mr->mr.page_shift;
 604        int m, n;
 605
 606        if (unlikely(mapped_segs == mr->mr.max_segs))
 607                return -ENOMEM;
 608
 609        m = mapped_segs / RVT_SEGSZ;
 610        n = mapped_segs % RVT_SEGSZ;
 611        mr->mr.map[m]->segs[n].vaddr = (void *)addr;
 612        mr->mr.map[m]->segs[n].length = ps;
 613        mr->mr.length += ps;
 614        trace_rvt_mr_page_seg(&mr->mr, m, n, (void *)addr, ps);
 615
 616        return 0;
 617}
 618
 619/**
 620 * rvt_map_mr_sg - map sg list and set it the memory region
 621 * @ibmr: memory region
 622 * @sg: dma mapped scatterlist
 623 * @sg_nents: number of entries in sg
 624 * @sg_offset: offset in bytes into sg
 625 *
 626 * Overwrite rvt_mr length with mr length calculated by ib_sg_to_pages.
 627 *
 628 * Return: number of sg elements mapped to the memory region
 629 */
 630int rvt_map_mr_sg(struct ib_mr *ibmr, struct scatterlist *sg,
 631                  int sg_nents, unsigned int *sg_offset)
 632{
 633        struct rvt_mr *mr = to_imr(ibmr);
 634        int ret;
 635
 636        mr->mr.length = 0;
 637        mr->mr.page_shift = PAGE_SHIFT;
 638        ret = ib_sg_to_pages(ibmr, sg, sg_nents, sg_offset, rvt_set_page);
 639        mr->mr.user_base = ibmr->iova;
 640        mr->mr.iova = ibmr->iova;
 641        mr->mr.offset = ibmr->iova - (u64)mr->mr.map[0]->segs[0].vaddr;
 642        mr->mr.length = (size_t)ibmr->length;
 643        trace_rvt_map_mr_sg(ibmr, sg_nents, sg_offset);
 644        return ret;
 645}
 646
 647/**
 648 * rvt_fast_reg_mr - fast register physical MR
 649 * @qp: the queue pair where the work request comes from
 650 * @ibmr: the memory region to be registered
 651 * @key: updated key for this memory region
 652 * @access: access flags for this memory region
 653 *
 654 * Returns 0 on success.
 655 */
 656int rvt_fast_reg_mr(struct rvt_qp *qp, struct ib_mr *ibmr, u32 key,
 657                    int access)
 658{
 659        struct rvt_mr *mr = to_imr(ibmr);
 660
 661        if (qp->ibqp.pd != mr->mr.pd)
 662                return -EACCES;
 663
 664        /* not applicable to dma MR or user MR */
 665        if (!mr->mr.lkey || mr->umem)
 666                return -EINVAL;
 667
 668        if ((key & 0xFFFFFF00) != (mr->mr.lkey & 0xFFFFFF00))
 669                return -EINVAL;
 670
 671        ibmr->lkey = key;
 672        ibmr->rkey = key;
 673        mr->mr.lkey = key;
 674        mr->mr.access_flags = access;
 675        mr->mr.iova = ibmr->iova;
 676        atomic_set(&mr->mr.lkey_invalid, 0);
 677
 678        return 0;
 679}
 680EXPORT_SYMBOL(rvt_fast_reg_mr);
 681
 682/**
 683 * rvt_invalidate_rkey - invalidate an MR rkey
 684 * @qp: queue pair associated with the invalidate op
 685 * @rkey: rkey to invalidate
 686 *
 687 * Returns 0 on success.
 688 */
 689int rvt_invalidate_rkey(struct rvt_qp *qp, u32 rkey)
 690{
 691        struct rvt_dev_info *dev = ib_to_rvt(qp->ibqp.device);
 692        struct rvt_lkey_table *rkt = &dev->lkey_table;
 693        struct rvt_mregion *mr;
 694
 695        if (rkey == 0)
 696                return -EINVAL;
 697
 698        rcu_read_lock();
 699        mr = rcu_dereference(
 700                rkt->table[(rkey >> (32 - dev->dparms.lkey_table_size))]);
 701        if (unlikely(!mr || mr->lkey != rkey || qp->ibqp.pd != mr->pd))
 702                goto bail;
 703
 704        atomic_set(&mr->lkey_invalid, 1);
 705        rcu_read_unlock();
 706        return 0;
 707
 708bail:
 709        rcu_read_unlock();
 710        return -EINVAL;
 711}
 712EXPORT_SYMBOL(rvt_invalidate_rkey);
 713
 714/**
 715 * rvt_sge_adjacent - is isge compressible
 716 * @last_sge: last outgoing SGE written
 717 * @sge: SGE to check
 718 *
 719 * If adjacent will update last_sge to add length.
 720 *
 721 * Return: true if isge is adjacent to last sge
 722 */
 723static inline bool rvt_sge_adjacent(struct rvt_sge *last_sge,
 724                                    struct ib_sge *sge)
 725{
 726        if (last_sge && sge->lkey == last_sge->mr->lkey &&
 727            ((uint64_t)(last_sge->vaddr + last_sge->length) == sge->addr)) {
 728                if (sge->lkey) {
 729                        if (unlikely((sge->addr - last_sge->mr->user_base +
 730                              sge->length > last_sge->mr->length)))
 731                                return false; /* overrun, caller will catch */
 732                } else {
 733                        last_sge->length += sge->length;
 734                }
 735                last_sge->sge_length += sge->length;
 736                trace_rvt_sge_adjacent(last_sge, sge);
 737                return true;
 738        }
 739        return false;
 740}
 741
 742/**
 743 * rvt_lkey_ok - check IB SGE for validity and initialize
 744 * @rkt: table containing lkey to check SGE against
 745 * @pd: protection domain
 746 * @isge: outgoing internal SGE
 747 * @last_sge: last outgoing SGE written
 748 * @sge: SGE to check
 749 * @acc: access flags
 750 *
 751 * Check the IB SGE for validity and initialize our internal version
 752 * of it.
 753 *
 754 * Increments the reference count when a new sge is stored.
 755 *
 756 * Return: 0 if compressed, 1 if added , otherwise returns -errno.
 757 */
 758int rvt_lkey_ok(struct rvt_lkey_table *rkt, struct rvt_pd *pd,
 759                struct rvt_sge *isge, struct rvt_sge *last_sge,
 760                struct ib_sge *sge, int acc)
 761{
 762        struct rvt_mregion *mr;
 763        unsigned n, m;
 764        size_t off;
 765
 766        /*
 767         * We use LKEY == zero for kernel virtual addresses
 768         * (see rvt_get_dma_mr()).
 769         */
 770        if (sge->lkey == 0) {
 771                struct rvt_dev_info *dev = ib_to_rvt(pd->ibpd.device);
 772
 773                if (pd->user)
 774                        return -EINVAL;
 775                if (rvt_sge_adjacent(last_sge, sge))
 776                        return 0;
 777                rcu_read_lock();
 778                mr = rcu_dereference(dev->dma_mr);
 779                if (!mr)
 780                        goto bail;
 781                rvt_get_mr(mr);
 782                rcu_read_unlock();
 783
 784                isge->mr = mr;
 785                isge->vaddr = (void *)sge->addr;
 786                isge->length = sge->length;
 787                isge->sge_length = sge->length;
 788                isge->m = 0;
 789                isge->n = 0;
 790                goto ok;
 791        }
 792        if (rvt_sge_adjacent(last_sge, sge))
 793                return 0;
 794        rcu_read_lock();
 795        mr = rcu_dereference(rkt->table[sge->lkey >> rkt->shift]);
 796        if (!mr)
 797                goto bail;
 798        rvt_get_mr(mr);
 799        if (!READ_ONCE(mr->lkey_published))
 800                goto bail_unref;
 801
 802        if (unlikely(atomic_read(&mr->lkey_invalid) ||
 803                     mr->lkey != sge->lkey || mr->pd != &pd->ibpd))
 804                goto bail_unref;
 805
 806        off = sge->addr - mr->user_base;
 807        if (unlikely(sge->addr < mr->user_base ||
 808                     off + sge->length > mr->length ||
 809                     (mr->access_flags & acc) != acc))
 810                goto bail_unref;
 811        rcu_read_unlock();
 812
 813        off += mr->offset;
 814        if (mr->page_shift) {
 815                /*
 816                 * page sizes are uniform power of 2 so no loop is necessary
 817                 * entries_spanned_by_off is the number of times the loop below
 818                 * would have executed.
 819                */
 820                size_t entries_spanned_by_off;
 821
 822                entries_spanned_by_off = off >> mr->page_shift;
 823                off -= (entries_spanned_by_off << mr->page_shift);
 824                m = entries_spanned_by_off / RVT_SEGSZ;
 825                n = entries_spanned_by_off % RVT_SEGSZ;
 826        } else {
 827                m = 0;
 828                n = 0;
 829                while (off >= mr->map[m]->segs[n].length) {
 830                        off -= mr->map[m]->segs[n].length;
 831                        n++;
 832                        if (n >= RVT_SEGSZ) {
 833                                m++;
 834                                n = 0;
 835                        }
 836                }
 837        }
 838        isge->mr = mr;
 839        isge->vaddr = mr->map[m]->segs[n].vaddr + off;
 840        isge->length = mr->map[m]->segs[n].length - off;
 841        isge->sge_length = sge->length;
 842        isge->m = m;
 843        isge->n = n;
 844ok:
 845        trace_rvt_sge_new(isge, sge);
 846        return 1;
 847bail_unref:
 848        rvt_put_mr(mr);
 849bail:
 850        rcu_read_unlock();
 851        return -EINVAL;
 852}
 853EXPORT_SYMBOL(rvt_lkey_ok);
 854
 855/**
 856 * rvt_rkey_ok - check the IB virtual address, length, and RKEY
 857 * @qp: qp for validation
 858 * @sge: SGE state
 859 * @len: length of data
 860 * @vaddr: virtual address to place data
 861 * @rkey: rkey to check
 862 * @acc: access flags
 863 *
 864 * Return: 1 if successful, otherwise 0.
 865 *
 866 * increments the reference count upon success
 867 */
 868int rvt_rkey_ok(struct rvt_qp *qp, struct rvt_sge *sge,
 869                u32 len, u64 vaddr, u32 rkey, int acc)
 870{
 871        struct rvt_dev_info *dev = ib_to_rvt(qp->ibqp.device);
 872        struct rvt_lkey_table *rkt = &dev->lkey_table;
 873        struct rvt_mregion *mr;
 874        unsigned n, m;
 875        size_t off;
 876
 877        /*
 878         * We use RKEY == zero for kernel virtual addresses
 879         * (see rvt_get_dma_mr()).
 880         */
 881        rcu_read_lock();
 882        if (rkey == 0) {
 883                struct rvt_pd *pd = ibpd_to_rvtpd(qp->ibqp.pd);
 884                struct rvt_dev_info *rdi = ib_to_rvt(pd->ibpd.device);
 885
 886                if (pd->user)
 887                        goto bail;
 888                mr = rcu_dereference(rdi->dma_mr);
 889                if (!mr)
 890                        goto bail;
 891                rvt_get_mr(mr);
 892                rcu_read_unlock();
 893
 894                sge->mr = mr;
 895                sge->vaddr = (void *)vaddr;
 896                sge->length = len;
 897                sge->sge_length = len;
 898                sge->m = 0;
 899                sge->n = 0;
 900                goto ok;
 901        }
 902
 903        mr = rcu_dereference(rkt->table[rkey >> rkt->shift]);
 904        if (!mr)
 905                goto bail;
 906        rvt_get_mr(mr);
 907        /* insure mr read is before test */
 908        if (!READ_ONCE(mr->lkey_published))
 909                goto bail_unref;
 910        if (unlikely(atomic_read(&mr->lkey_invalid) ||
 911                     mr->lkey != rkey || qp->ibqp.pd != mr->pd))
 912                goto bail_unref;
 913
 914        off = vaddr - mr->iova;
 915        if (unlikely(vaddr < mr->iova || off + len > mr->length ||
 916                     (mr->access_flags & acc) == 0))
 917                goto bail_unref;
 918        rcu_read_unlock();
 919
 920        off += mr->offset;
 921        if (mr->page_shift) {
 922                /*
 923                 * page sizes are uniform power of 2 so no loop is necessary
 924                 * entries_spanned_by_off is the number of times the loop below
 925                 * would have executed.
 926                */
 927                size_t entries_spanned_by_off;
 928
 929                entries_spanned_by_off = off >> mr->page_shift;
 930                off -= (entries_spanned_by_off << mr->page_shift);
 931                m = entries_spanned_by_off / RVT_SEGSZ;
 932                n = entries_spanned_by_off % RVT_SEGSZ;
 933        } else {
 934                m = 0;
 935                n = 0;
 936                while (off >= mr->map[m]->segs[n].length) {
 937                        off -= mr->map[m]->segs[n].length;
 938                        n++;
 939                        if (n >= RVT_SEGSZ) {
 940                                m++;
 941                                n = 0;
 942                        }
 943                }
 944        }
 945        sge->mr = mr;
 946        sge->vaddr = mr->map[m]->segs[n].vaddr + off;
 947        sge->length = mr->map[m]->segs[n].length - off;
 948        sge->sge_length = len;
 949        sge->m = m;
 950        sge->n = n;
 951ok:
 952        return 1;
 953bail_unref:
 954        rvt_put_mr(mr);
 955bail:
 956        rcu_read_unlock();
 957        return 0;
 958}
 959EXPORT_SYMBOL(rvt_rkey_ok);
 960