linux/drivers/infiniband/ulp/rtrs/rtrs-srv.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-or-later
   2/*
   3 * RDMA Transport Layer
   4 *
   5 * Copyright (c) 2014 - 2018 ProfitBricks GmbH. All rights reserved.
   6 * Copyright (c) 2018 - 2019 1&1 IONOS Cloud GmbH. All rights reserved.
   7 * Copyright (c) 2019 - 2020 1&1 IONOS SE. All rights reserved.
   8 */
   9
  10#undef pr_fmt
  11#define pr_fmt(fmt) KBUILD_MODNAME " L" __stringify(__LINE__) ": " fmt
  12
  13#include <linux/module.h>
  14#include <linux/mempool.h>
  15
  16#include "rtrs-srv.h"
  17#include "rtrs-log.h"
  18#include <rdma/ib_cm.h>
  19
  20MODULE_DESCRIPTION("RDMA Transport Server");
  21MODULE_LICENSE("GPL");
  22
  23/* Must be power of 2, see mask from mr->page_size in ib_sg_to_pages() */
  24#define DEFAULT_MAX_CHUNK_SIZE (128 << 10)
  25#define DEFAULT_SESS_QUEUE_DEPTH 512
  26#define MAX_HDR_SIZE PAGE_SIZE
  27
  28/* We guarantee to serve 10 paths at least */
  29#define CHUNK_POOL_SZ 10
  30
  31static struct rtrs_rdma_dev_pd dev_pd;
  32static mempool_t *chunk_pool;
  33struct class *rtrs_dev_class;
  34
  35static int __read_mostly max_chunk_size = DEFAULT_MAX_CHUNK_SIZE;
  36static int __read_mostly sess_queue_depth = DEFAULT_SESS_QUEUE_DEPTH;
  37
  38static bool always_invalidate = true;
  39module_param(always_invalidate, bool, 0444);
  40MODULE_PARM_DESC(always_invalidate,
  41                 "Invalidate memory registration for contiguous memory regions before accessing.");
  42
  43module_param_named(max_chunk_size, max_chunk_size, int, 0444);
  44MODULE_PARM_DESC(max_chunk_size,
  45                 "Max size for each IO request, when change the unit is in byte (default: "
  46                 __stringify(DEFAULT_MAX_CHUNK_SIZE) "KB)");
  47
  48module_param_named(sess_queue_depth, sess_queue_depth, int, 0444);
  49MODULE_PARM_DESC(sess_queue_depth,
  50                 "Number of buffers for pending I/O requests to allocate per session. Maximum: "
  51                 __stringify(MAX_SESS_QUEUE_DEPTH) " (default: "
  52                 __stringify(DEFAULT_SESS_QUEUE_DEPTH) ")");
  53
  54static cpumask_t cq_affinity_mask = { CPU_BITS_ALL };
  55
  56static struct workqueue_struct *rtrs_wq;
  57
  58static inline struct rtrs_srv_con *to_srv_con(struct rtrs_con *c)
  59{
  60        return container_of(c, struct rtrs_srv_con, c);
  61}
  62
  63static inline struct rtrs_srv_sess *to_srv_sess(struct rtrs_sess *s)
  64{
  65        return container_of(s, struct rtrs_srv_sess, s);
  66}
  67
  68static bool __rtrs_srv_change_state(struct rtrs_srv_sess *sess,
  69                                     enum rtrs_srv_state new_state)
  70{
  71        enum rtrs_srv_state old_state;
  72        bool changed = false;
  73
  74        lockdep_assert_held(&sess->state_lock);
  75        old_state = sess->state;
  76        switch (new_state) {
  77        case RTRS_SRV_CONNECTED:
  78                switch (old_state) {
  79                case RTRS_SRV_CONNECTING:
  80                        changed = true;
  81                        fallthrough;
  82                default:
  83                        break;
  84                }
  85                break;
  86        case RTRS_SRV_CLOSING:
  87                switch (old_state) {
  88                case RTRS_SRV_CONNECTING:
  89                case RTRS_SRV_CONNECTED:
  90                        changed = true;
  91                        fallthrough;
  92                default:
  93                        break;
  94                }
  95                break;
  96        case RTRS_SRV_CLOSED:
  97                switch (old_state) {
  98                case RTRS_SRV_CLOSING:
  99                        changed = true;
 100                        fallthrough;
 101                default:
 102                        break;
 103                }
 104                break;
 105        default:
 106                break;
 107        }
 108        if (changed)
 109                sess->state = new_state;
 110
 111        return changed;
 112}
 113
 114static bool rtrs_srv_change_state_get_old(struct rtrs_srv_sess *sess,
 115                                           enum rtrs_srv_state new_state,
 116                                           enum rtrs_srv_state *old_state)
 117{
 118        bool changed;
 119
 120        spin_lock_irq(&sess->state_lock);
 121        *old_state = sess->state;
 122        changed = __rtrs_srv_change_state(sess, new_state);
 123        spin_unlock_irq(&sess->state_lock);
 124
 125        return changed;
 126}
 127
 128static bool rtrs_srv_change_state(struct rtrs_srv_sess *sess,
 129                                   enum rtrs_srv_state new_state)
 130{
 131        enum rtrs_srv_state old_state;
 132
 133        return rtrs_srv_change_state_get_old(sess, new_state, &old_state);
 134}
 135
 136static void free_id(struct rtrs_srv_op *id)
 137{
 138        if (!id)
 139                return;
 140        kfree(id);
 141}
 142
 143static void rtrs_srv_free_ops_ids(struct rtrs_srv_sess *sess)
 144{
 145        struct rtrs_srv *srv = sess->srv;
 146        int i;
 147
 148        WARN_ON(atomic_read(&sess->ids_inflight));
 149        if (sess->ops_ids) {
 150                for (i = 0; i < srv->queue_depth; i++)
 151                        free_id(sess->ops_ids[i]);
 152                kfree(sess->ops_ids);
 153                sess->ops_ids = NULL;
 154        }
 155}
 156
 157static void rtrs_srv_rdma_done(struct ib_cq *cq, struct ib_wc *wc);
 158
 159static struct ib_cqe io_comp_cqe = {
 160        .done = rtrs_srv_rdma_done
 161};
 162
 163static int rtrs_srv_alloc_ops_ids(struct rtrs_srv_sess *sess)
 164{
 165        struct rtrs_srv *srv = sess->srv;
 166        struct rtrs_srv_op *id;
 167        int i;
 168
 169        sess->ops_ids = kcalloc(srv->queue_depth, sizeof(*sess->ops_ids),
 170                                GFP_KERNEL);
 171        if (!sess->ops_ids)
 172                goto err;
 173
 174        for (i = 0; i < srv->queue_depth; ++i) {
 175                id = kzalloc(sizeof(*id), GFP_KERNEL);
 176                if (!id)
 177                        goto err;
 178
 179                sess->ops_ids[i] = id;
 180        }
 181        init_waitqueue_head(&sess->ids_waitq);
 182        atomic_set(&sess->ids_inflight, 0);
 183
 184        return 0;
 185
 186err:
 187        rtrs_srv_free_ops_ids(sess);
 188        return -ENOMEM;
 189}
 190
 191static inline void rtrs_srv_get_ops_ids(struct rtrs_srv_sess *sess)
 192{
 193        atomic_inc(&sess->ids_inflight);
 194}
 195
 196static inline void rtrs_srv_put_ops_ids(struct rtrs_srv_sess *sess)
 197{
 198        if (atomic_dec_and_test(&sess->ids_inflight))
 199                wake_up(&sess->ids_waitq);
 200}
 201
 202static void rtrs_srv_wait_ops_ids(struct rtrs_srv_sess *sess)
 203{
 204        wait_event(sess->ids_waitq, !atomic_read(&sess->ids_inflight));
 205}
 206
 207
 208static void rtrs_srv_reg_mr_done(struct ib_cq *cq, struct ib_wc *wc)
 209{
 210        struct rtrs_srv_con *con = cq->cq_context;
 211        struct rtrs_sess *s = con->c.sess;
 212        struct rtrs_srv_sess *sess = to_srv_sess(s);
 213
 214        if (unlikely(wc->status != IB_WC_SUCCESS)) {
 215                rtrs_err(s, "REG MR failed: %s\n",
 216                          ib_wc_status_msg(wc->status));
 217                close_sess(sess);
 218                return;
 219        }
 220}
 221
 222static struct ib_cqe local_reg_cqe = {
 223        .done = rtrs_srv_reg_mr_done
 224};
 225
 226static int rdma_write_sg(struct rtrs_srv_op *id)
 227{
 228        struct rtrs_sess *s = id->con->c.sess;
 229        struct rtrs_srv_sess *sess = to_srv_sess(s);
 230        dma_addr_t dma_addr = sess->dma_addr[id->msg_id];
 231        struct rtrs_srv_mr *srv_mr;
 232        struct rtrs_srv *srv = sess->srv;
 233        struct ib_send_wr inv_wr, imm_wr;
 234        struct ib_rdma_wr *wr = NULL;
 235        enum ib_send_flags flags;
 236        size_t sg_cnt;
 237        int err, offset;
 238        bool need_inval;
 239        u32 rkey = 0;
 240        struct ib_reg_wr rwr;
 241        struct ib_sge *plist;
 242        struct ib_sge list;
 243
 244        sg_cnt = le16_to_cpu(id->rd_msg->sg_cnt);
 245        need_inval = le16_to_cpu(id->rd_msg->flags) & RTRS_MSG_NEED_INVAL_F;
 246        if (unlikely(sg_cnt != 1))
 247                return -EINVAL;
 248
 249        offset = 0;
 250
 251        wr              = &id->tx_wr;
 252        plist           = &id->tx_sg;
 253        plist->addr     = dma_addr + offset;
 254        plist->length   = le32_to_cpu(id->rd_msg->desc[0].len);
 255
 256        /* WR will fail with length error
 257         * if this is 0
 258         */
 259        if (unlikely(plist->length == 0)) {
 260                rtrs_err(s, "Invalid RDMA-Write sg list length 0\n");
 261                return -EINVAL;
 262        }
 263
 264        plist->lkey = sess->s.dev->ib_pd->local_dma_lkey;
 265        offset += plist->length;
 266
 267        wr->wr.sg_list  = plist;
 268        wr->wr.num_sge  = 1;
 269        wr->remote_addr = le64_to_cpu(id->rd_msg->desc[0].addr);
 270        wr->rkey        = le32_to_cpu(id->rd_msg->desc[0].key);
 271        if (rkey == 0)
 272                rkey = wr->rkey;
 273        else
 274                /* Only one key is actually used */
 275                WARN_ON_ONCE(rkey != wr->rkey);
 276
 277        wr->wr.opcode = IB_WR_RDMA_WRITE;
 278        wr->wr.ex.imm_data = 0;
 279        wr->wr.send_flags  = 0;
 280
 281        if (need_inval && always_invalidate) {
 282                wr->wr.next = &rwr.wr;
 283                rwr.wr.next = &inv_wr;
 284                inv_wr.next = &imm_wr;
 285        } else if (always_invalidate) {
 286                wr->wr.next = &rwr.wr;
 287                rwr.wr.next = &imm_wr;
 288        } else if (need_inval) {
 289                wr->wr.next = &inv_wr;
 290                inv_wr.next = &imm_wr;
 291        } else {
 292                wr->wr.next = &imm_wr;
 293        }
 294        /*
 295         * From time to time we have to post signaled sends,
 296         * or send queue will fill up and only QP reset can help.
 297         */
 298        flags = (atomic_inc_return(&id->con->wr_cnt) % srv->queue_depth) ?
 299                0 : IB_SEND_SIGNALED;
 300
 301        if (need_inval) {
 302                inv_wr.sg_list = NULL;
 303                inv_wr.num_sge = 0;
 304                inv_wr.opcode = IB_WR_SEND_WITH_INV;
 305                inv_wr.send_flags = 0;
 306                inv_wr.ex.invalidate_rkey = rkey;
 307        }
 308
 309        imm_wr.next = NULL;
 310        if (always_invalidate) {
 311                struct rtrs_msg_rkey_rsp *msg;
 312
 313                srv_mr = &sess->mrs[id->msg_id];
 314                rwr.wr.opcode = IB_WR_REG_MR;
 315                rwr.wr.num_sge = 0;
 316                rwr.mr = srv_mr->mr;
 317                rwr.wr.send_flags = 0;
 318                rwr.key = srv_mr->mr->rkey;
 319                rwr.access = (IB_ACCESS_LOCAL_WRITE |
 320                              IB_ACCESS_REMOTE_WRITE);
 321                msg = srv_mr->iu->buf;
 322                msg->buf_id = cpu_to_le16(id->msg_id);
 323                msg->type = cpu_to_le16(RTRS_MSG_RKEY_RSP);
 324                msg->rkey = cpu_to_le32(srv_mr->mr->rkey);
 325
 326                list.addr   = srv_mr->iu->dma_addr;
 327                list.length = sizeof(*msg);
 328                list.lkey   = sess->s.dev->ib_pd->local_dma_lkey;
 329                imm_wr.sg_list = &list;
 330                imm_wr.num_sge = 1;
 331                imm_wr.opcode = IB_WR_SEND_WITH_IMM;
 332                ib_dma_sync_single_for_device(sess->s.dev->ib_dev,
 333                                              srv_mr->iu->dma_addr,
 334                                              srv_mr->iu->size, DMA_TO_DEVICE);
 335        } else {
 336                imm_wr.sg_list = NULL;
 337                imm_wr.num_sge = 0;
 338                imm_wr.opcode = IB_WR_RDMA_WRITE_WITH_IMM;
 339        }
 340        imm_wr.send_flags = flags;
 341        imm_wr.ex.imm_data = cpu_to_be32(rtrs_to_io_rsp_imm(id->msg_id,
 342                                                             0, need_inval));
 343
 344        imm_wr.wr_cqe   = &io_comp_cqe;
 345        ib_dma_sync_single_for_device(sess->s.dev->ib_dev, dma_addr,
 346                                      offset, DMA_BIDIRECTIONAL);
 347
 348        err = ib_post_send(id->con->c.qp, &id->tx_wr.wr, NULL);
 349        if (unlikely(err))
 350                rtrs_err(s,
 351                          "Posting RDMA-Write-Request to QP failed, err: %d\n",
 352                          err);
 353
 354        return err;
 355}
 356
 357/**
 358 * send_io_resp_imm() - respond to client with empty IMM on failed READ/WRITE
 359 *                      requests or on successful WRITE request.
 360 * @con:        the connection to send back result
 361 * @id:         the id associated with the IO
 362 * @errno:      the error number of the IO.
 363 *
 364 * Return 0 on success, errno otherwise.
 365 */
 366static int send_io_resp_imm(struct rtrs_srv_con *con, struct rtrs_srv_op *id,
 367                            int errno)
 368{
 369        struct rtrs_sess *s = con->c.sess;
 370        struct rtrs_srv_sess *sess = to_srv_sess(s);
 371        struct ib_send_wr inv_wr, imm_wr, *wr = NULL;
 372        struct ib_reg_wr rwr;
 373        struct rtrs_srv *srv = sess->srv;
 374        struct rtrs_srv_mr *srv_mr;
 375        bool need_inval = false;
 376        enum ib_send_flags flags;
 377        u32 imm;
 378        int err;
 379
 380        if (id->dir == READ) {
 381                struct rtrs_msg_rdma_read *rd_msg = id->rd_msg;
 382                size_t sg_cnt;
 383
 384                need_inval = le16_to_cpu(rd_msg->flags) &
 385                                RTRS_MSG_NEED_INVAL_F;
 386                sg_cnt = le16_to_cpu(rd_msg->sg_cnt);
 387
 388                if (need_inval) {
 389                        if (likely(sg_cnt)) {
 390                                inv_wr.sg_list = NULL;
 391                                inv_wr.num_sge = 0;
 392                                inv_wr.opcode = IB_WR_SEND_WITH_INV;
 393                                inv_wr.send_flags = 0;
 394                                /* Only one key is actually used */
 395                                inv_wr.ex.invalidate_rkey =
 396                                        le32_to_cpu(rd_msg->desc[0].key);
 397                        } else {
 398                                WARN_ON_ONCE(1);
 399                                need_inval = false;
 400                        }
 401                }
 402        }
 403
 404        if (need_inval && always_invalidate) {
 405                wr = &inv_wr;
 406                inv_wr.next = &rwr.wr;
 407                rwr.wr.next = &imm_wr;
 408        } else if (always_invalidate) {
 409                wr = &rwr.wr;
 410                rwr.wr.next = &imm_wr;
 411        } else if (need_inval) {
 412                wr = &inv_wr;
 413                inv_wr.next = &imm_wr;
 414        } else {
 415                wr = &imm_wr;
 416        }
 417        /*
 418         * From time to time we have to post signalled sends,
 419         * or send queue will fill up and only QP reset can help.
 420         */
 421        flags = (atomic_inc_return(&con->wr_cnt) % srv->queue_depth) ?
 422                0 : IB_SEND_SIGNALED;
 423        imm = rtrs_to_io_rsp_imm(id->msg_id, errno, need_inval);
 424        imm_wr.next = NULL;
 425        if (always_invalidate) {
 426                struct ib_sge list;
 427                struct rtrs_msg_rkey_rsp *msg;
 428
 429                srv_mr = &sess->mrs[id->msg_id];
 430                rwr.wr.next = &imm_wr;
 431                rwr.wr.opcode = IB_WR_REG_MR;
 432                rwr.wr.num_sge = 0;
 433                rwr.wr.send_flags = 0;
 434                rwr.mr = srv_mr->mr;
 435                rwr.key = srv_mr->mr->rkey;
 436                rwr.access = (IB_ACCESS_LOCAL_WRITE |
 437                              IB_ACCESS_REMOTE_WRITE);
 438                msg = srv_mr->iu->buf;
 439                msg->buf_id = cpu_to_le16(id->msg_id);
 440                msg->type = cpu_to_le16(RTRS_MSG_RKEY_RSP);
 441                msg->rkey = cpu_to_le32(srv_mr->mr->rkey);
 442
 443                list.addr   = srv_mr->iu->dma_addr;
 444                list.length = sizeof(*msg);
 445                list.lkey   = sess->s.dev->ib_pd->local_dma_lkey;
 446                imm_wr.sg_list = &list;
 447                imm_wr.num_sge = 1;
 448                imm_wr.opcode = IB_WR_SEND_WITH_IMM;
 449                ib_dma_sync_single_for_device(sess->s.dev->ib_dev,
 450                                              srv_mr->iu->dma_addr,
 451                                              srv_mr->iu->size, DMA_TO_DEVICE);
 452        } else {
 453                imm_wr.sg_list = NULL;
 454                imm_wr.num_sge = 0;
 455                imm_wr.opcode = IB_WR_RDMA_WRITE_WITH_IMM;
 456        }
 457        imm_wr.send_flags = flags;
 458        imm_wr.wr_cqe   = &io_comp_cqe;
 459
 460        imm_wr.ex.imm_data = cpu_to_be32(imm);
 461
 462        err = ib_post_send(id->con->c.qp, wr, NULL);
 463        if (unlikely(err))
 464                rtrs_err_rl(s, "Posting RDMA-Reply to QP failed, err: %d\n",
 465                             err);
 466
 467        return err;
 468}
 469
 470void close_sess(struct rtrs_srv_sess *sess)
 471{
 472        enum rtrs_srv_state old_state;
 473
 474        if (rtrs_srv_change_state_get_old(sess, RTRS_SRV_CLOSING,
 475                                           &old_state))
 476                queue_work(rtrs_wq, &sess->close_work);
 477        WARN_ON(sess->state != RTRS_SRV_CLOSING);
 478}
 479
 480static inline const char *rtrs_srv_state_str(enum rtrs_srv_state state)
 481{
 482        switch (state) {
 483        case RTRS_SRV_CONNECTING:
 484                return "RTRS_SRV_CONNECTING";
 485        case RTRS_SRV_CONNECTED:
 486                return "RTRS_SRV_CONNECTED";
 487        case RTRS_SRV_CLOSING:
 488                return "RTRS_SRV_CLOSING";
 489        case RTRS_SRV_CLOSED:
 490                return "RTRS_SRV_CLOSED";
 491        default:
 492                return "UNKNOWN";
 493        }
 494}
 495
 496/**
 497 * rtrs_srv_resp_rdma() - Finish an RDMA request
 498 *
 499 * @id:         Internal RTRS operation identifier
 500 * @status:     Response Code sent to the other side for this operation.
 501 *              0 = success, <=0 error
 502 * Context: any
 503 *
 504 * Finish a RDMA operation. A message is sent to the client and the
 505 * corresponding memory areas will be released.
 506 */
 507bool rtrs_srv_resp_rdma(struct rtrs_srv_op *id, int status)
 508{
 509        struct rtrs_srv_sess *sess;
 510        struct rtrs_srv_con *con;
 511        struct rtrs_sess *s;
 512        int err;
 513
 514        if (WARN_ON(!id))
 515                return true;
 516
 517        con = id->con;
 518        s = con->c.sess;
 519        sess = to_srv_sess(s);
 520
 521        id->status = status;
 522
 523        if (unlikely(sess->state != RTRS_SRV_CONNECTED)) {
 524                rtrs_err_rl(s,
 525                             "Sending I/O response failed,  session is disconnected, sess state %s\n",
 526                             rtrs_srv_state_str(sess->state));
 527                goto out;
 528        }
 529        if (always_invalidate) {
 530                struct rtrs_srv_mr *mr = &sess->mrs[id->msg_id];
 531
 532                ib_update_fast_reg_key(mr->mr, ib_inc_rkey(mr->mr->rkey));
 533        }
 534        if (unlikely(atomic_sub_return(1,
 535                                       &con->sq_wr_avail) < 0)) {
 536                pr_err("IB send queue full\n");
 537                atomic_add(1, &con->sq_wr_avail);
 538                spin_lock(&con->rsp_wr_wait_lock);
 539                list_add_tail(&id->wait_list, &con->rsp_wr_wait_list);
 540                spin_unlock(&con->rsp_wr_wait_lock);
 541                return false;
 542        }
 543
 544        if (status || id->dir == WRITE || !id->rd_msg->sg_cnt)
 545                err = send_io_resp_imm(con, id, status);
 546        else
 547                err = rdma_write_sg(id);
 548
 549        if (unlikely(err)) {
 550                rtrs_err_rl(s, "IO response failed: %d\n", err);
 551                close_sess(sess);
 552        }
 553out:
 554        rtrs_srv_put_ops_ids(sess);
 555        return true;
 556}
 557EXPORT_SYMBOL(rtrs_srv_resp_rdma);
 558
 559/**
 560 * rtrs_srv_set_sess_priv() - Set private pointer in rtrs_srv.
 561 * @srv:        Session pointer
 562 * @priv:       The private pointer that is associated with the session.
 563 */
 564void rtrs_srv_set_sess_priv(struct rtrs_srv *srv, void *priv)
 565{
 566        srv->priv = priv;
 567}
 568EXPORT_SYMBOL(rtrs_srv_set_sess_priv);
 569
 570static void unmap_cont_bufs(struct rtrs_srv_sess *sess)
 571{
 572        int i;
 573
 574        for (i = 0; i < sess->mrs_num; i++) {
 575                struct rtrs_srv_mr *srv_mr;
 576
 577                srv_mr = &sess->mrs[i];
 578                rtrs_iu_free(srv_mr->iu, DMA_TO_DEVICE,
 579                              sess->s.dev->ib_dev, 1);
 580                ib_dereg_mr(srv_mr->mr);
 581                ib_dma_unmap_sg(sess->s.dev->ib_dev, srv_mr->sgt.sgl,
 582                                srv_mr->sgt.nents, DMA_BIDIRECTIONAL);
 583                sg_free_table(&srv_mr->sgt);
 584        }
 585        kfree(sess->mrs);
 586}
 587
 588static int map_cont_bufs(struct rtrs_srv_sess *sess)
 589{
 590        struct rtrs_srv *srv = sess->srv;
 591        struct rtrs_sess *ss = &sess->s;
 592        int i, mri, err, mrs_num;
 593        unsigned int chunk_bits;
 594        int chunks_per_mr = 1;
 595
 596        /*
 597         * Here we map queue_depth chunks to MR.  Firstly we have to
 598         * figure out how many chunks can we map per MR.
 599         */
 600        if (always_invalidate) {
 601                /*
 602                 * in order to do invalidate for each chunks of memory, we needs
 603                 * more memory regions.
 604                 */
 605                mrs_num = srv->queue_depth;
 606        } else {
 607                chunks_per_mr =
 608                        sess->s.dev->ib_dev->attrs.max_fast_reg_page_list_len;
 609                mrs_num = DIV_ROUND_UP(srv->queue_depth, chunks_per_mr);
 610                chunks_per_mr = DIV_ROUND_UP(srv->queue_depth, mrs_num);
 611        }
 612
 613        sess->mrs = kcalloc(mrs_num, sizeof(*sess->mrs), GFP_KERNEL);
 614        if (!sess->mrs)
 615                return -ENOMEM;
 616
 617        sess->mrs_num = mrs_num;
 618
 619        for (mri = 0; mri < mrs_num; mri++) {
 620                struct rtrs_srv_mr *srv_mr = &sess->mrs[mri];
 621                struct sg_table *sgt = &srv_mr->sgt;
 622                struct scatterlist *s;
 623                struct ib_mr *mr;
 624                int nr, chunks;
 625
 626                chunks = chunks_per_mr * mri;
 627                if (!always_invalidate)
 628                        chunks_per_mr = min_t(int, chunks_per_mr,
 629                                              srv->queue_depth - chunks);
 630
 631                err = sg_alloc_table(sgt, chunks_per_mr, GFP_KERNEL);
 632                if (err)
 633                        goto err;
 634
 635                for_each_sg(sgt->sgl, s, chunks_per_mr, i)
 636                        sg_set_page(s, srv->chunks[chunks + i],
 637                                    max_chunk_size, 0);
 638
 639                nr = ib_dma_map_sg(sess->s.dev->ib_dev, sgt->sgl,
 640                                   sgt->nents, DMA_BIDIRECTIONAL);
 641                if (nr < sgt->nents) {
 642                        err = nr < 0 ? nr : -EINVAL;
 643                        goto free_sg;
 644                }
 645                mr = ib_alloc_mr(sess->s.dev->ib_pd, IB_MR_TYPE_MEM_REG,
 646                                 sgt->nents);
 647                if (IS_ERR(mr)) {
 648                        err = PTR_ERR(mr);
 649                        goto unmap_sg;
 650                }
 651                nr = ib_map_mr_sg(mr, sgt->sgl, sgt->nents,
 652                                  NULL, max_chunk_size);
 653                if (nr < 0 || nr < sgt->nents) {
 654                        err = nr < 0 ? nr : -EINVAL;
 655                        goto dereg_mr;
 656                }
 657
 658                if (always_invalidate) {
 659                        srv_mr->iu = rtrs_iu_alloc(1,
 660                                        sizeof(struct rtrs_msg_rkey_rsp),
 661                                        GFP_KERNEL, sess->s.dev->ib_dev,
 662                                        DMA_TO_DEVICE, rtrs_srv_rdma_done);
 663                        if (!srv_mr->iu) {
 664                                err = -ENOMEM;
 665                                rtrs_err(ss, "rtrs_iu_alloc(), err: %d\n", err);
 666                                goto free_iu;
 667                        }
 668                }
 669                /* Eventually dma addr for each chunk can be cached */
 670                for_each_sg(sgt->sgl, s, sgt->orig_nents, i)
 671                        sess->dma_addr[chunks + i] = sg_dma_address(s);
 672
 673                ib_update_fast_reg_key(mr, ib_inc_rkey(mr->rkey));
 674                srv_mr->mr = mr;
 675
 676                continue;
 677err:
 678                while (mri--) {
 679                        srv_mr = &sess->mrs[mri];
 680                        sgt = &srv_mr->sgt;
 681                        mr = srv_mr->mr;
 682free_iu:
 683                        rtrs_iu_free(srv_mr->iu, DMA_TO_DEVICE,
 684                                      sess->s.dev->ib_dev, 1);
 685dereg_mr:
 686                        ib_dereg_mr(mr);
 687unmap_sg:
 688                        ib_dma_unmap_sg(sess->s.dev->ib_dev, sgt->sgl,
 689                                        sgt->nents, DMA_BIDIRECTIONAL);
 690free_sg:
 691                        sg_free_table(sgt);
 692                }
 693                kfree(sess->mrs);
 694
 695                return err;
 696        }
 697
 698        chunk_bits = ilog2(srv->queue_depth - 1) + 1;
 699        sess->mem_bits = (MAX_IMM_PAYL_BITS - chunk_bits);
 700
 701        return 0;
 702}
 703
 704static void rtrs_srv_hb_err_handler(struct rtrs_con *c)
 705{
 706        close_sess(to_srv_sess(c->sess));
 707}
 708
 709static void rtrs_srv_init_hb(struct rtrs_srv_sess *sess)
 710{
 711        rtrs_init_hb(&sess->s, &io_comp_cqe,
 712                      RTRS_HB_INTERVAL_MS,
 713                      RTRS_HB_MISSED_MAX,
 714                      rtrs_srv_hb_err_handler,
 715                      rtrs_wq);
 716}
 717
 718static void rtrs_srv_start_hb(struct rtrs_srv_sess *sess)
 719{
 720        rtrs_start_hb(&sess->s);
 721}
 722
 723static void rtrs_srv_stop_hb(struct rtrs_srv_sess *sess)
 724{
 725        rtrs_stop_hb(&sess->s);
 726}
 727
 728static void rtrs_srv_info_rsp_done(struct ib_cq *cq, struct ib_wc *wc)
 729{
 730        struct rtrs_srv_con *con = cq->cq_context;
 731        struct rtrs_sess *s = con->c.sess;
 732        struct rtrs_srv_sess *sess = to_srv_sess(s);
 733        struct rtrs_iu *iu;
 734
 735        iu = container_of(wc->wr_cqe, struct rtrs_iu, cqe);
 736        rtrs_iu_free(iu, DMA_TO_DEVICE, sess->s.dev->ib_dev, 1);
 737
 738        if (unlikely(wc->status != IB_WC_SUCCESS)) {
 739                rtrs_err(s, "Sess info response send failed: %s\n",
 740                          ib_wc_status_msg(wc->status));
 741                close_sess(sess);
 742                return;
 743        }
 744        WARN_ON(wc->opcode != IB_WC_SEND);
 745}
 746
 747static void rtrs_srv_sess_up(struct rtrs_srv_sess *sess)
 748{
 749        struct rtrs_srv *srv = sess->srv;
 750        struct rtrs_srv_ctx *ctx = srv->ctx;
 751        int up;
 752
 753        mutex_lock(&srv->paths_ev_mutex);
 754        up = ++srv->paths_up;
 755        if (up == 1)
 756                ctx->ops.link_ev(srv, RTRS_SRV_LINK_EV_CONNECTED, NULL);
 757        mutex_unlock(&srv->paths_ev_mutex);
 758
 759        /* Mark session as established */
 760        sess->established = true;
 761}
 762
 763static void rtrs_srv_sess_down(struct rtrs_srv_sess *sess)
 764{
 765        struct rtrs_srv *srv = sess->srv;
 766        struct rtrs_srv_ctx *ctx = srv->ctx;
 767
 768        if (!sess->established)
 769                return;
 770
 771        sess->established = false;
 772        mutex_lock(&srv->paths_ev_mutex);
 773        WARN_ON(!srv->paths_up);
 774        if (--srv->paths_up == 0)
 775                ctx->ops.link_ev(srv, RTRS_SRV_LINK_EV_DISCONNECTED, srv->priv);
 776        mutex_unlock(&srv->paths_ev_mutex);
 777}
 778
 779static int post_recv_sess(struct rtrs_srv_sess *sess);
 780
 781static int process_info_req(struct rtrs_srv_con *con,
 782                            struct rtrs_msg_info_req *msg)
 783{
 784        struct rtrs_sess *s = con->c.sess;
 785        struct rtrs_srv_sess *sess = to_srv_sess(s);
 786        struct ib_send_wr *reg_wr = NULL;
 787        struct rtrs_msg_info_rsp *rsp;
 788        struct rtrs_iu *tx_iu;
 789        struct ib_reg_wr *rwr;
 790        int mri, err;
 791        size_t tx_sz;
 792
 793        err = post_recv_sess(sess);
 794        if (unlikely(err)) {
 795                rtrs_err(s, "post_recv_sess(), err: %d\n", err);
 796                return err;
 797        }
 798        rwr = kcalloc(sess->mrs_num, sizeof(*rwr), GFP_KERNEL);
 799        if (unlikely(!rwr))
 800                return -ENOMEM;
 801        strlcpy(sess->s.sessname, msg->sessname, sizeof(sess->s.sessname));
 802
 803        tx_sz  = sizeof(*rsp);
 804        tx_sz += sizeof(rsp->desc[0]) * sess->mrs_num;
 805        tx_iu = rtrs_iu_alloc(1, tx_sz, GFP_KERNEL, sess->s.dev->ib_dev,
 806                               DMA_TO_DEVICE, rtrs_srv_info_rsp_done);
 807        if (unlikely(!tx_iu)) {
 808                err = -ENOMEM;
 809                goto rwr_free;
 810        }
 811
 812        rsp = tx_iu->buf;
 813        rsp->type = cpu_to_le16(RTRS_MSG_INFO_RSP);
 814        rsp->sg_cnt = cpu_to_le16(sess->mrs_num);
 815
 816        for (mri = 0; mri < sess->mrs_num; mri++) {
 817                struct ib_mr *mr = sess->mrs[mri].mr;
 818
 819                rsp->desc[mri].addr = cpu_to_le64(mr->iova);
 820                rsp->desc[mri].key  = cpu_to_le32(mr->rkey);
 821                rsp->desc[mri].len  = cpu_to_le32(mr->length);
 822
 823                /*
 824                 * Fill in reg MR request and chain them *backwards*
 825                 */
 826                rwr[mri].wr.next = mri ? &rwr[mri - 1].wr : NULL;
 827                rwr[mri].wr.opcode = IB_WR_REG_MR;
 828                rwr[mri].wr.wr_cqe = &local_reg_cqe;
 829                rwr[mri].wr.num_sge = 0;
 830                rwr[mri].wr.send_flags = mri ? 0 : IB_SEND_SIGNALED;
 831                rwr[mri].mr = mr;
 832                rwr[mri].key = mr->rkey;
 833                rwr[mri].access = (IB_ACCESS_LOCAL_WRITE |
 834                                   IB_ACCESS_REMOTE_WRITE);
 835                reg_wr = &rwr[mri].wr;
 836        }
 837
 838        err = rtrs_srv_create_sess_files(sess);
 839        if (unlikely(err))
 840                goto iu_free;
 841        kobject_get(&sess->kobj);
 842        get_device(&sess->srv->dev);
 843        rtrs_srv_change_state(sess, RTRS_SRV_CONNECTED);
 844        rtrs_srv_start_hb(sess);
 845
 846        /*
 847         * We do not account number of established connections at the current
 848         * moment, we rely on the client, which should send info request when
 849         * all connections are successfully established.  Thus, simply notify
 850         * listener with a proper event if we are the first path.
 851         */
 852        rtrs_srv_sess_up(sess);
 853
 854        ib_dma_sync_single_for_device(sess->s.dev->ib_dev, tx_iu->dma_addr,
 855                                      tx_iu->size, DMA_TO_DEVICE);
 856
 857        /* Send info response */
 858        err = rtrs_iu_post_send(&con->c, tx_iu, tx_sz, reg_wr);
 859        if (unlikely(err)) {
 860                rtrs_err(s, "rtrs_iu_post_send(), err: %d\n", err);
 861iu_free:
 862                rtrs_iu_free(tx_iu, DMA_TO_DEVICE, sess->s.dev->ib_dev, 1);
 863        }
 864rwr_free:
 865        kfree(rwr);
 866
 867        return err;
 868}
 869
 870static void rtrs_srv_info_req_done(struct ib_cq *cq, struct ib_wc *wc)
 871{
 872        struct rtrs_srv_con *con = cq->cq_context;
 873        struct rtrs_sess *s = con->c.sess;
 874        struct rtrs_srv_sess *sess = to_srv_sess(s);
 875        struct rtrs_msg_info_req *msg;
 876        struct rtrs_iu *iu;
 877        int err;
 878
 879        WARN_ON(con->c.cid);
 880
 881        iu = container_of(wc->wr_cqe, struct rtrs_iu, cqe);
 882        if (unlikely(wc->status != IB_WC_SUCCESS)) {
 883                rtrs_err(s, "Sess info request receive failed: %s\n",
 884                          ib_wc_status_msg(wc->status));
 885                goto close;
 886        }
 887        WARN_ON(wc->opcode != IB_WC_RECV);
 888
 889        if (unlikely(wc->byte_len < sizeof(*msg))) {
 890                rtrs_err(s, "Sess info request is malformed: size %d\n",
 891                          wc->byte_len);
 892                goto close;
 893        }
 894        ib_dma_sync_single_for_cpu(sess->s.dev->ib_dev, iu->dma_addr,
 895                                   iu->size, DMA_FROM_DEVICE);
 896        msg = iu->buf;
 897        if (unlikely(le16_to_cpu(msg->type) != RTRS_MSG_INFO_REQ)) {
 898                rtrs_err(s, "Sess info request is malformed: type %d\n",
 899                          le16_to_cpu(msg->type));
 900                goto close;
 901        }
 902        err = process_info_req(con, msg);
 903        if (unlikely(err))
 904                goto close;
 905
 906out:
 907        rtrs_iu_free(iu, DMA_FROM_DEVICE, sess->s.dev->ib_dev, 1);
 908        return;
 909close:
 910        close_sess(sess);
 911        goto out;
 912}
 913
 914static int post_recv_info_req(struct rtrs_srv_con *con)
 915{
 916        struct rtrs_sess *s = con->c.sess;
 917        struct rtrs_srv_sess *sess = to_srv_sess(s);
 918        struct rtrs_iu *rx_iu;
 919        int err;
 920
 921        rx_iu = rtrs_iu_alloc(1, sizeof(struct rtrs_msg_info_req),
 922                               GFP_KERNEL, sess->s.dev->ib_dev,
 923                               DMA_FROM_DEVICE, rtrs_srv_info_req_done);
 924        if (unlikely(!rx_iu))
 925                return -ENOMEM;
 926        /* Prepare for getting info response */
 927        err = rtrs_iu_post_recv(&con->c, rx_iu);
 928        if (unlikely(err)) {
 929                rtrs_err(s, "rtrs_iu_post_recv(), err: %d\n", err);
 930                rtrs_iu_free(rx_iu, DMA_FROM_DEVICE, sess->s.dev->ib_dev, 1);
 931                return err;
 932        }
 933
 934        return 0;
 935}
 936
 937static int post_recv_io(struct rtrs_srv_con *con, size_t q_size)
 938{
 939        int i, err;
 940
 941        for (i = 0; i < q_size; i++) {
 942                err = rtrs_post_recv_empty(&con->c, &io_comp_cqe);
 943                if (unlikely(err))
 944                        return err;
 945        }
 946
 947        return 0;
 948}
 949
 950static int post_recv_sess(struct rtrs_srv_sess *sess)
 951{
 952        struct rtrs_srv *srv = sess->srv;
 953        struct rtrs_sess *s = &sess->s;
 954        size_t q_size;
 955        int err, cid;
 956
 957        for (cid = 0; cid < sess->s.con_num; cid++) {
 958                if (cid == 0)
 959                        q_size = SERVICE_CON_QUEUE_DEPTH;
 960                else
 961                        q_size = srv->queue_depth;
 962
 963                err = post_recv_io(to_srv_con(sess->s.con[cid]), q_size);
 964                if (unlikely(err)) {
 965                        rtrs_err(s, "post_recv_io(), err: %d\n", err);
 966                        return err;
 967                }
 968        }
 969
 970        return 0;
 971}
 972
 973static void process_read(struct rtrs_srv_con *con,
 974                         struct rtrs_msg_rdma_read *msg,
 975                         u32 buf_id, u32 off)
 976{
 977        struct rtrs_sess *s = con->c.sess;
 978        struct rtrs_srv_sess *sess = to_srv_sess(s);
 979        struct rtrs_srv *srv = sess->srv;
 980        struct rtrs_srv_ctx *ctx = srv->ctx;
 981        struct rtrs_srv_op *id;
 982
 983        size_t usr_len, data_len;
 984        void *data;
 985        int ret;
 986
 987        if (unlikely(sess->state != RTRS_SRV_CONNECTED)) {
 988                rtrs_err_rl(s,
 989                             "Processing read request failed,  session is disconnected, sess state %s\n",
 990                             rtrs_srv_state_str(sess->state));
 991                return;
 992        }
 993        if (unlikely(msg->sg_cnt != 1 && msg->sg_cnt != 0)) {
 994                rtrs_err_rl(s,
 995                            "Processing read request failed, invalid message\n");
 996                return;
 997        }
 998        rtrs_srv_get_ops_ids(sess);
 999        rtrs_srv_update_rdma_stats(sess->stats, off, READ);
1000        id = sess->ops_ids[buf_id];
1001        id->con         = con;
1002        id->dir         = READ;
1003        id->msg_id      = buf_id;
1004        id->rd_msg      = msg;
1005        usr_len = le16_to_cpu(msg->usr_len);
1006        data_len = off - usr_len;
1007        data = page_address(srv->chunks[buf_id]);
1008        ret = ctx->ops.rdma_ev(srv, srv->priv, id, READ, data, data_len,
1009                           data + data_len, usr_len);
1010
1011        if (unlikely(ret)) {
1012                rtrs_err_rl(s,
1013                             "Processing read request failed, user module cb reported for msg_id %d, err: %d\n",
1014                             buf_id, ret);
1015                goto send_err_msg;
1016        }
1017
1018        return;
1019
1020send_err_msg:
1021        ret = send_io_resp_imm(con, id, ret);
1022        if (ret < 0) {
1023                rtrs_err_rl(s,
1024                             "Sending err msg for failed RDMA-Write-Req failed, msg_id %d, err: %d\n",
1025                             buf_id, ret);
1026                close_sess(sess);
1027        }
1028        rtrs_srv_put_ops_ids(sess);
1029}
1030
1031static void process_write(struct rtrs_srv_con *con,
1032                          struct rtrs_msg_rdma_write *req,
1033                          u32 buf_id, u32 off)
1034{
1035        struct rtrs_sess *s = con->c.sess;
1036        struct rtrs_srv_sess *sess = to_srv_sess(s);
1037        struct rtrs_srv *srv = sess->srv;
1038        struct rtrs_srv_ctx *ctx = srv->ctx;
1039        struct rtrs_srv_op *id;
1040
1041        size_t data_len, usr_len;
1042        void *data;
1043        int ret;
1044
1045        if (unlikely(sess->state != RTRS_SRV_CONNECTED)) {
1046                rtrs_err_rl(s,
1047                             "Processing write request failed,  session is disconnected, sess state %s\n",
1048                             rtrs_srv_state_str(sess->state));
1049                return;
1050        }
1051        rtrs_srv_get_ops_ids(sess);
1052        rtrs_srv_update_rdma_stats(sess->stats, off, WRITE);
1053        id = sess->ops_ids[buf_id];
1054        id->con    = con;
1055        id->dir    = WRITE;
1056        id->msg_id = buf_id;
1057
1058        usr_len = le16_to_cpu(req->usr_len);
1059        data_len = off - usr_len;
1060        data = page_address(srv->chunks[buf_id]);
1061        ret = ctx->ops.rdma_ev(srv, srv->priv, id, WRITE, data, data_len,
1062                           data + data_len, usr_len);
1063        if (unlikely(ret)) {
1064                rtrs_err_rl(s,
1065                             "Processing write request failed, user module callback reports err: %d\n",
1066                             ret);
1067                goto send_err_msg;
1068        }
1069
1070        return;
1071
1072send_err_msg:
1073        ret = send_io_resp_imm(con, id, ret);
1074        if (ret < 0) {
1075                rtrs_err_rl(s,
1076                             "Processing write request failed, sending I/O response failed, msg_id %d, err: %d\n",
1077                             buf_id, ret);
1078                close_sess(sess);
1079        }
1080        rtrs_srv_put_ops_ids(sess);
1081}
1082
1083static void process_io_req(struct rtrs_srv_con *con, void *msg,
1084                           u32 id, u32 off)
1085{
1086        struct rtrs_sess *s = con->c.sess;
1087        struct rtrs_srv_sess *sess = to_srv_sess(s);
1088        struct rtrs_msg_rdma_hdr *hdr;
1089        unsigned int type;
1090
1091        ib_dma_sync_single_for_cpu(sess->s.dev->ib_dev, sess->dma_addr[id],
1092                                   max_chunk_size, DMA_BIDIRECTIONAL);
1093        hdr = msg;
1094        type = le16_to_cpu(hdr->type);
1095
1096        switch (type) {
1097        case RTRS_MSG_WRITE:
1098                process_write(con, msg, id, off);
1099                break;
1100        case RTRS_MSG_READ:
1101                process_read(con, msg, id, off);
1102                break;
1103        default:
1104                rtrs_err(s,
1105                          "Processing I/O request failed, unknown message type received: 0x%02x\n",
1106                          type);
1107                goto err;
1108        }
1109
1110        return;
1111
1112err:
1113        close_sess(sess);
1114}
1115
1116static void rtrs_srv_inv_rkey_done(struct ib_cq *cq, struct ib_wc *wc)
1117{
1118        struct rtrs_srv_mr *mr =
1119                container_of(wc->wr_cqe, typeof(*mr), inv_cqe);
1120        struct rtrs_srv_con *con = cq->cq_context;
1121        struct rtrs_sess *s = con->c.sess;
1122        struct rtrs_srv_sess *sess = to_srv_sess(s);
1123        struct rtrs_srv *srv = sess->srv;
1124        u32 msg_id, off;
1125        void *data;
1126
1127        if (unlikely(wc->status != IB_WC_SUCCESS)) {
1128                rtrs_err(s, "Failed IB_WR_LOCAL_INV: %s\n",
1129                          ib_wc_status_msg(wc->status));
1130                close_sess(sess);
1131        }
1132        msg_id = mr->msg_id;
1133        off = mr->msg_off;
1134        data = page_address(srv->chunks[msg_id]) + off;
1135        process_io_req(con, data, msg_id, off);
1136}
1137
1138static int rtrs_srv_inv_rkey(struct rtrs_srv_con *con,
1139                              struct rtrs_srv_mr *mr)
1140{
1141        struct ib_send_wr wr = {
1142                .opcode             = IB_WR_LOCAL_INV,
1143                .wr_cqe             = &mr->inv_cqe,
1144                .send_flags         = IB_SEND_SIGNALED,
1145                .ex.invalidate_rkey = mr->mr->rkey,
1146        };
1147        mr->inv_cqe.done = rtrs_srv_inv_rkey_done;
1148
1149        return ib_post_send(con->c.qp, &wr, NULL);
1150}
1151
1152static void rtrs_rdma_process_wr_wait_list(struct rtrs_srv_con *con)
1153{
1154        spin_lock(&con->rsp_wr_wait_lock);
1155        while (!list_empty(&con->rsp_wr_wait_list)) {
1156                struct rtrs_srv_op *id;
1157                int ret;
1158
1159                id = list_entry(con->rsp_wr_wait_list.next,
1160                                struct rtrs_srv_op, wait_list);
1161                list_del(&id->wait_list);
1162
1163                spin_unlock(&con->rsp_wr_wait_lock);
1164                ret = rtrs_srv_resp_rdma(id, id->status);
1165                spin_lock(&con->rsp_wr_wait_lock);
1166
1167                if (!ret) {
1168                        list_add(&id->wait_list, &con->rsp_wr_wait_list);
1169                        break;
1170                }
1171        }
1172        spin_unlock(&con->rsp_wr_wait_lock);
1173}
1174
1175static void rtrs_srv_rdma_done(struct ib_cq *cq, struct ib_wc *wc)
1176{
1177        struct rtrs_srv_con *con = cq->cq_context;
1178        struct rtrs_sess *s = con->c.sess;
1179        struct rtrs_srv_sess *sess = to_srv_sess(s);
1180        struct rtrs_srv *srv = sess->srv;
1181        u32 imm_type, imm_payload;
1182        int err;
1183
1184        if (unlikely(wc->status != IB_WC_SUCCESS)) {
1185                if (wc->status != IB_WC_WR_FLUSH_ERR) {
1186                        rtrs_err(s,
1187                                  "%s (wr_cqe: %p, type: %d, vendor_err: 0x%x, len: %u)\n",
1188                                  ib_wc_status_msg(wc->status), wc->wr_cqe,
1189                                  wc->opcode, wc->vendor_err, wc->byte_len);
1190                        close_sess(sess);
1191                }
1192                return;
1193        }
1194
1195        switch (wc->opcode) {
1196        case IB_WC_RECV_RDMA_WITH_IMM:
1197                /*
1198                 * post_recv() RDMA write completions of IO reqs (read/write)
1199                 * and hb
1200                 */
1201                if (WARN_ON(wc->wr_cqe != &io_comp_cqe))
1202                        return;
1203                err = rtrs_post_recv_empty(&con->c, &io_comp_cqe);
1204                if (unlikely(err)) {
1205                        rtrs_err(s, "rtrs_post_recv(), err: %d\n", err);
1206                        close_sess(sess);
1207                        break;
1208                }
1209                rtrs_from_imm(be32_to_cpu(wc->ex.imm_data),
1210                               &imm_type, &imm_payload);
1211                if (likely(imm_type == RTRS_IO_REQ_IMM)) {
1212                        u32 msg_id, off;
1213                        void *data;
1214
1215                        msg_id = imm_payload >> sess->mem_bits;
1216                        off = imm_payload & ((1 << sess->mem_bits) - 1);
1217                        if (unlikely(msg_id >= srv->queue_depth ||
1218                                     off >= max_chunk_size)) {
1219                                rtrs_err(s, "Wrong msg_id %u, off %u\n",
1220                                          msg_id, off);
1221                                close_sess(sess);
1222                                return;
1223                        }
1224                        if (always_invalidate) {
1225                                struct rtrs_srv_mr *mr = &sess->mrs[msg_id];
1226
1227                                mr->msg_off = off;
1228                                mr->msg_id = msg_id;
1229                                err = rtrs_srv_inv_rkey(con, mr);
1230                                if (unlikely(err)) {
1231                                        rtrs_err(s, "rtrs_post_recv(), err: %d\n",
1232                                                  err);
1233                                        close_sess(sess);
1234                                        break;
1235                                }
1236                        } else {
1237                                data = page_address(srv->chunks[msg_id]) + off;
1238                                process_io_req(con, data, msg_id, off);
1239                        }
1240                } else if (imm_type == RTRS_HB_MSG_IMM) {
1241                        WARN_ON(con->c.cid);
1242                        rtrs_send_hb_ack(&sess->s);
1243                } else if (imm_type == RTRS_HB_ACK_IMM) {
1244                        WARN_ON(con->c.cid);
1245                        sess->s.hb_missed_cnt = 0;
1246                } else {
1247                        rtrs_wrn(s, "Unknown IMM type %u\n", imm_type);
1248                }
1249                break;
1250        case IB_WC_RDMA_WRITE:
1251        case IB_WC_SEND:
1252                /*
1253                 * post_send() RDMA write completions of IO reqs (read/write)
1254                 * and hb
1255                 */
1256                atomic_add(srv->queue_depth, &con->sq_wr_avail);
1257
1258                if (unlikely(!list_empty_careful(&con->rsp_wr_wait_list)))
1259                        rtrs_rdma_process_wr_wait_list(con);
1260
1261                break;
1262        default:
1263                rtrs_wrn(s, "Unexpected WC type: %d\n", wc->opcode);
1264                return;
1265        }
1266}
1267
1268/**
1269 * rtrs_srv_get_sess_name() - Get rtrs_srv peer hostname.
1270 * @srv:        Session
1271 * @sessname:   Sessname buffer
1272 * @len:        Length of sessname buffer
1273 */
1274int rtrs_srv_get_sess_name(struct rtrs_srv *srv, char *sessname, size_t len)
1275{
1276        struct rtrs_srv_sess *sess;
1277        int err = -ENOTCONN;
1278
1279        mutex_lock(&srv->paths_mutex);
1280        list_for_each_entry(sess, &srv->paths_list, s.entry) {
1281                if (sess->state != RTRS_SRV_CONNECTED)
1282                        continue;
1283                strlcpy(sessname, sess->s.sessname,
1284                       min_t(size_t, sizeof(sess->s.sessname), len));
1285                err = 0;
1286                break;
1287        }
1288        mutex_unlock(&srv->paths_mutex);
1289
1290        return err;
1291}
1292EXPORT_SYMBOL(rtrs_srv_get_sess_name);
1293
1294/**
1295 * rtrs_srv_get_sess_qdepth() - Get rtrs_srv qdepth.
1296 * @srv:        Session
1297 */
1298int rtrs_srv_get_queue_depth(struct rtrs_srv *srv)
1299{
1300        return srv->queue_depth;
1301}
1302EXPORT_SYMBOL(rtrs_srv_get_queue_depth);
1303
1304static int find_next_bit_ring(struct rtrs_srv_sess *sess)
1305{
1306        struct ib_device *ib_dev = sess->s.dev->ib_dev;
1307        int v;
1308
1309        v = cpumask_next(sess->cur_cq_vector, &cq_affinity_mask);
1310        if (v >= nr_cpu_ids || v >= ib_dev->num_comp_vectors)
1311                v = cpumask_first(&cq_affinity_mask);
1312        return v;
1313}
1314
1315static int rtrs_srv_get_next_cq_vector(struct rtrs_srv_sess *sess)
1316{
1317        sess->cur_cq_vector = find_next_bit_ring(sess);
1318
1319        return sess->cur_cq_vector;
1320}
1321
1322static void rtrs_srv_dev_release(struct device *dev)
1323{
1324        struct rtrs_srv *srv = container_of(dev, struct rtrs_srv, dev);
1325
1326        kfree(srv);
1327}
1328
1329static struct rtrs_srv *__alloc_srv(struct rtrs_srv_ctx *ctx,
1330                                     const uuid_t *paths_uuid)
1331{
1332        struct rtrs_srv *srv;
1333        int i;
1334
1335        srv = kzalloc(sizeof(*srv), GFP_KERNEL);
1336        if  (!srv)
1337                return NULL;
1338
1339        refcount_set(&srv->refcount, 1);
1340        INIT_LIST_HEAD(&srv->paths_list);
1341        mutex_init(&srv->paths_mutex);
1342        mutex_init(&srv->paths_ev_mutex);
1343        uuid_copy(&srv->paths_uuid, paths_uuid);
1344        srv->queue_depth = sess_queue_depth;
1345        srv->ctx = ctx;
1346        device_initialize(&srv->dev);
1347        srv->dev.release = rtrs_srv_dev_release;
1348
1349        srv->chunks = kcalloc(srv->queue_depth, sizeof(*srv->chunks),
1350                              GFP_KERNEL);
1351        if (!srv->chunks)
1352                goto err_free_srv;
1353
1354        for (i = 0; i < srv->queue_depth; i++) {
1355                srv->chunks[i] = mempool_alloc(chunk_pool, GFP_KERNEL);
1356                if (!srv->chunks[i])
1357                        goto err_free_chunks;
1358        }
1359        list_add(&srv->ctx_list, &ctx->srv_list);
1360
1361        return srv;
1362
1363err_free_chunks:
1364        while (i--)
1365                mempool_free(srv->chunks[i], chunk_pool);
1366        kfree(srv->chunks);
1367
1368err_free_srv:
1369        kfree(srv);
1370
1371        return NULL;
1372}
1373
1374static void free_srv(struct rtrs_srv *srv)
1375{
1376        int i;
1377
1378        WARN_ON(refcount_read(&srv->refcount));
1379        for (i = 0; i < srv->queue_depth; i++)
1380                mempool_free(srv->chunks[i], chunk_pool);
1381        kfree(srv->chunks);
1382        mutex_destroy(&srv->paths_mutex);
1383        mutex_destroy(&srv->paths_ev_mutex);
1384        /* last put to release the srv structure */
1385        put_device(&srv->dev);
1386}
1387
1388static inline struct rtrs_srv *__find_srv_and_get(struct rtrs_srv_ctx *ctx,
1389                                                   const uuid_t *paths_uuid)
1390{
1391        struct rtrs_srv *srv;
1392
1393        list_for_each_entry(srv, &ctx->srv_list, ctx_list) {
1394                if (uuid_equal(&srv->paths_uuid, paths_uuid) &&
1395                    refcount_inc_not_zero(&srv->refcount))
1396                        return srv;
1397        }
1398
1399        return NULL;
1400}
1401
1402static struct rtrs_srv *get_or_create_srv(struct rtrs_srv_ctx *ctx,
1403                                           const uuid_t *paths_uuid)
1404{
1405        struct rtrs_srv *srv;
1406
1407        mutex_lock(&ctx->srv_mutex);
1408        srv = __find_srv_and_get(ctx, paths_uuid);
1409        if (!srv)
1410                srv = __alloc_srv(ctx, paths_uuid);
1411        mutex_unlock(&ctx->srv_mutex);
1412
1413        return srv;
1414}
1415
1416static void put_srv(struct rtrs_srv *srv)
1417{
1418        if (refcount_dec_and_test(&srv->refcount)) {
1419                struct rtrs_srv_ctx *ctx = srv->ctx;
1420
1421                WARN_ON(srv->dev.kobj.state_in_sysfs);
1422
1423                mutex_lock(&ctx->srv_mutex);
1424                list_del(&srv->ctx_list);
1425                mutex_unlock(&ctx->srv_mutex);
1426                free_srv(srv);
1427        }
1428}
1429
1430static void __add_path_to_srv(struct rtrs_srv *srv,
1431                              struct rtrs_srv_sess *sess)
1432{
1433        list_add_tail(&sess->s.entry, &srv->paths_list);
1434        srv->paths_num++;
1435        WARN_ON(srv->paths_num >= MAX_PATHS_NUM);
1436}
1437
1438static void del_path_from_srv(struct rtrs_srv_sess *sess)
1439{
1440        struct rtrs_srv *srv = sess->srv;
1441
1442        if (WARN_ON(!srv))
1443                return;
1444
1445        mutex_lock(&srv->paths_mutex);
1446        list_del(&sess->s.entry);
1447        WARN_ON(!srv->paths_num);
1448        srv->paths_num--;
1449        mutex_unlock(&srv->paths_mutex);
1450}
1451
1452/* return true if addresses are the same, error other wise */
1453static int sockaddr_cmp(const struct sockaddr *a, const struct sockaddr *b)
1454{
1455        switch (a->sa_family) {
1456        case AF_IB:
1457                return memcmp(&((struct sockaddr_ib *)a)->sib_addr,
1458                              &((struct sockaddr_ib *)b)->sib_addr,
1459                              sizeof(struct ib_addr)) &&
1460                        (b->sa_family == AF_IB);
1461        case AF_INET:
1462                return memcmp(&((struct sockaddr_in *)a)->sin_addr,
1463                              &((struct sockaddr_in *)b)->sin_addr,
1464                              sizeof(struct in_addr)) &&
1465                        (b->sa_family == AF_INET);
1466        case AF_INET6:
1467                return memcmp(&((struct sockaddr_in6 *)a)->sin6_addr,
1468                              &((struct sockaddr_in6 *)b)->sin6_addr,
1469                              sizeof(struct in6_addr)) &&
1470                        (b->sa_family == AF_INET6);
1471        default:
1472                return -ENOENT;
1473        }
1474}
1475
1476static bool __is_path_w_addr_exists(struct rtrs_srv *srv,
1477                                    struct rdma_addr *addr)
1478{
1479        struct rtrs_srv_sess *sess;
1480
1481        list_for_each_entry(sess, &srv->paths_list, s.entry)
1482                if (!sockaddr_cmp((struct sockaddr *)&sess->s.dst_addr,
1483                                  (struct sockaddr *)&addr->dst_addr) &&
1484                    !sockaddr_cmp((struct sockaddr *)&sess->s.src_addr,
1485                                  (struct sockaddr *)&addr->src_addr))
1486                        return true;
1487
1488        return false;
1489}
1490
1491static void free_sess(struct rtrs_srv_sess *sess)
1492{
1493        if (sess->kobj.state_in_sysfs)
1494                kobject_put(&sess->kobj);
1495        else
1496                kfree(sess);
1497}
1498
1499static void rtrs_srv_close_work(struct work_struct *work)
1500{
1501        struct rtrs_srv_sess *sess;
1502        struct rtrs_srv_con *con;
1503        int i;
1504
1505        sess = container_of(work, typeof(*sess), close_work);
1506
1507        rtrs_srv_destroy_sess_files(sess);
1508        rtrs_srv_stop_hb(sess);
1509
1510        for (i = 0; i < sess->s.con_num; i++) {
1511                if (!sess->s.con[i])
1512                        continue;
1513                con = to_srv_con(sess->s.con[i]);
1514                rdma_disconnect(con->c.cm_id);
1515                ib_drain_qp(con->c.qp);
1516        }
1517        /* Wait for all inflights */
1518        rtrs_srv_wait_ops_ids(sess);
1519
1520        /* Notify upper layer if we are the last path */
1521        rtrs_srv_sess_down(sess);
1522
1523        unmap_cont_bufs(sess);
1524        rtrs_srv_free_ops_ids(sess);
1525
1526        for (i = 0; i < sess->s.con_num; i++) {
1527                if (!sess->s.con[i])
1528                        continue;
1529                con = to_srv_con(sess->s.con[i]);
1530                rtrs_cq_qp_destroy(&con->c);
1531                rdma_destroy_id(con->c.cm_id);
1532                kfree(con);
1533        }
1534        rtrs_ib_dev_put(sess->s.dev);
1535
1536        del_path_from_srv(sess);
1537        put_srv(sess->srv);
1538        sess->srv = NULL;
1539        rtrs_srv_change_state(sess, RTRS_SRV_CLOSED);
1540
1541        kfree(sess->dma_addr);
1542        kfree(sess->s.con);
1543        free_sess(sess);
1544}
1545
1546static int rtrs_rdma_do_accept(struct rtrs_srv_sess *sess,
1547                               struct rdma_cm_id *cm_id)
1548{
1549        struct rtrs_srv *srv = sess->srv;
1550        struct rtrs_msg_conn_rsp msg;
1551        struct rdma_conn_param param;
1552        int err;
1553
1554        param = (struct rdma_conn_param) {
1555                .rnr_retry_count = 7,
1556                .private_data = &msg,
1557                .private_data_len = sizeof(msg),
1558        };
1559
1560        msg = (struct rtrs_msg_conn_rsp) {
1561                .magic = cpu_to_le16(RTRS_MAGIC),
1562                .version = cpu_to_le16(RTRS_PROTO_VER),
1563                .queue_depth = cpu_to_le16(srv->queue_depth),
1564                .max_io_size = cpu_to_le32(max_chunk_size - MAX_HDR_SIZE),
1565                .max_hdr_size = cpu_to_le32(MAX_HDR_SIZE),
1566        };
1567
1568        if (always_invalidate)
1569                msg.flags = cpu_to_le32(RTRS_MSG_NEW_RKEY_F);
1570
1571        err = rdma_accept(cm_id, &param);
1572        if (err)
1573                pr_err("rdma_accept(), err: %d\n", err);
1574
1575        return err;
1576}
1577
1578static int rtrs_rdma_do_reject(struct rdma_cm_id *cm_id, int errno)
1579{
1580        struct rtrs_msg_conn_rsp msg;
1581        int err;
1582
1583        msg = (struct rtrs_msg_conn_rsp) {
1584                .magic = cpu_to_le16(RTRS_MAGIC),
1585                .version = cpu_to_le16(RTRS_PROTO_VER),
1586                .errno = cpu_to_le16(errno),
1587        };
1588
1589        err = rdma_reject(cm_id, &msg, sizeof(msg), IB_CM_REJ_CONSUMER_DEFINED);
1590        if (err)
1591                pr_err("rdma_reject(), err: %d\n", err);
1592
1593        /* Bounce errno back */
1594        return errno;
1595}
1596
1597static struct rtrs_srv_sess *
1598__find_sess(struct rtrs_srv *srv, const uuid_t *sess_uuid)
1599{
1600        struct rtrs_srv_sess *sess;
1601
1602        list_for_each_entry(sess, &srv->paths_list, s.entry) {
1603                if (uuid_equal(&sess->s.uuid, sess_uuid))
1604                        return sess;
1605        }
1606
1607        return NULL;
1608}
1609
1610static int create_con(struct rtrs_srv_sess *sess,
1611                      struct rdma_cm_id *cm_id,
1612                      unsigned int cid)
1613{
1614        struct rtrs_srv *srv = sess->srv;
1615        struct rtrs_sess *s = &sess->s;
1616        struct rtrs_srv_con *con;
1617
1618        u16 cq_size, wr_queue_size;
1619        int err, cq_vector;
1620
1621        con = kzalloc(sizeof(*con), GFP_KERNEL);
1622        if (!con) {
1623                err = -ENOMEM;
1624                goto err;
1625        }
1626
1627        spin_lock_init(&con->rsp_wr_wait_lock);
1628        INIT_LIST_HEAD(&con->rsp_wr_wait_list);
1629        con->c.cm_id = cm_id;
1630        con->c.sess = &sess->s;
1631        con->c.cid = cid;
1632        atomic_set(&con->wr_cnt, 0);
1633
1634        if (con->c.cid == 0) {
1635                /*
1636                 * All receive and all send (each requiring invalidate)
1637                 * + 2 for drain and heartbeat
1638                 */
1639                wr_queue_size = SERVICE_CON_QUEUE_DEPTH * 3 + 2;
1640                cq_size = wr_queue_size;
1641        } else {
1642                /*
1643                 * If we have all receive requests posted and
1644                 * all write requests posted and each read request
1645                 * requires an invalidate request + drain
1646                 * and qp gets into error state.
1647                 */
1648                cq_size = srv->queue_depth * 3 + 1;
1649                /*
1650                 * In theory we might have queue_depth * 32
1651                 * outstanding requests if an unsafe global key is used
1652                 * and we have queue_depth read requests each consisting
1653                 * of 32 different addresses. div 3 for mlx5.
1654                 */
1655                wr_queue_size = sess->s.dev->ib_dev->attrs.max_qp_wr / 3;
1656        }
1657        atomic_set(&con->sq_wr_avail, wr_queue_size);
1658        cq_vector = rtrs_srv_get_next_cq_vector(sess);
1659
1660        /* TODO: SOFTIRQ can be faster, but be careful with softirq context */
1661        err = rtrs_cq_qp_create(&sess->s, &con->c, 1, cq_vector, cq_size,
1662                                 wr_queue_size, IB_POLL_WORKQUEUE);
1663        if (err) {
1664                rtrs_err(s, "rtrs_cq_qp_create(), err: %d\n", err);
1665                goto free_con;
1666        }
1667        if (con->c.cid == 0) {
1668                err = post_recv_info_req(con);
1669                if (err)
1670                        goto free_cqqp;
1671        }
1672        WARN_ON(sess->s.con[cid]);
1673        sess->s.con[cid] = &con->c;
1674
1675        /*
1676         * Change context from server to current connection.  The other
1677         * way is to use cm_id->qp->qp_context, which does not work on OFED.
1678         */
1679        cm_id->context = &con->c;
1680
1681        return 0;
1682
1683free_cqqp:
1684        rtrs_cq_qp_destroy(&con->c);
1685free_con:
1686        kfree(con);
1687
1688err:
1689        return err;
1690}
1691
1692static struct rtrs_srv_sess *__alloc_sess(struct rtrs_srv *srv,
1693                                           struct rdma_cm_id *cm_id,
1694                                           unsigned int con_num,
1695                                           unsigned int recon_cnt,
1696                                           const uuid_t *uuid)
1697{
1698        struct rtrs_srv_sess *sess;
1699        int err = -ENOMEM;
1700
1701        if (srv->paths_num >= MAX_PATHS_NUM) {
1702                err = -ECONNRESET;
1703                goto err;
1704        }
1705        if (__is_path_w_addr_exists(srv, &cm_id->route.addr)) {
1706                err = -EEXIST;
1707                pr_err("Path with same addr exists\n");
1708                goto err;
1709        }
1710        sess = kzalloc(sizeof(*sess), GFP_KERNEL);
1711        if (!sess)
1712                goto err;
1713
1714        sess->stats = kzalloc(sizeof(*sess->stats), GFP_KERNEL);
1715        if (!sess->stats)
1716                goto err_free_sess;
1717
1718        sess->stats->sess = sess;
1719
1720        sess->dma_addr = kcalloc(srv->queue_depth, sizeof(*sess->dma_addr),
1721                                 GFP_KERNEL);
1722        if (!sess->dma_addr)
1723                goto err_free_stats;
1724
1725        sess->s.con = kcalloc(con_num, sizeof(*sess->s.con), GFP_KERNEL);
1726        if (!sess->s.con)
1727                goto err_free_dma_addr;
1728
1729        sess->state = RTRS_SRV_CONNECTING;
1730        sess->srv = srv;
1731        sess->cur_cq_vector = -1;
1732        sess->s.dst_addr = cm_id->route.addr.dst_addr;
1733        sess->s.src_addr = cm_id->route.addr.src_addr;
1734        sess->s.con_num = con_num;
1735        sess->s.recon_cnt = recon_cnt;
1736        uuid_copy(&sess->s.uuid, uuid);
1737        spin_lock_init(&sess->state_lock);
1738        INIT_WORK(&sess->close_work, rtrs_srv_close_work);
1739        rtrs_srv_init_hb(sess);
1740
1741        sess->s.dev = rtrs_ib_dev_find_or_add(cm_id->device, &dev_pd);
1742        if (!sess->s.dev) {
1743                err = -ENOMEM;
1744                goto err_free_con;
1745        }
1746        err = map_cont_bufs(sess);
1747        if (err)
1748                goto err_put_dev;
1749
1750        err = rtrs_srv_alloc_ops_ids(sess);
1751        if (err)
1752                goto err_unmap_bufs;
1753
1754        __add_path_to_srv(srv, sess);
1755
1756        return sess;
1757
1758err_unmap_bufs:
1759        unmap_cont_bufs(sess);
1760err_put_dev:
1761        rtrs_ib_dev_put(sess->s.dev);
1762err_free_con:
1763        kfree(sess->s.con);
1764err_free_dma_addr:
1765        kfree(sess->dma_addr);
1766err_free_stats:
1767        kfree(sess->stats);
1768err_free_sess:
1769        kfree(sess);
1770err:
1771        return ERR_PTR(err);
1772}
1773
1774static int rtrs_rdma_connect(struct rdma_cm_id *cm_id,
1775                              const struct rtrs_msg_conn_req *msg,
1776                              size_t len)
1777{
1778        struct rtrs_srv_ctx *ctx = cm_id->context;
1779        struct rtrs_srv_sess *sess;
1780        struct rtrs_srv *srv;
1781
1782        u16 version, con_num, cid;
1783        u16 recon_cnt;
1784        int err;
1785
1786        if (len < sizeof(*msg)) {
1787                pr_err("Invalid RTRS connection request\n");
1788                goto reject_w_econnreset;
1789        }
1790        if (le16_to_cpu(msg->magic) != RTRS_MAGIC) {
1791                pr_err("Invalid RTRS magic\n");
1792                goto reject_w_econnreset;
1793        }
1794        version = le16_to_cpu(msg->version);
1795        if (version >> 8 != RTRS_PROTO_VER_MAJOR) {
1796                pr_err("Unsupported major RTRS version: %d, expected %d\n",
1797                       version >> 8, RTRS_PROTO_VER_MAJOR);
1798                goto reject_w_econnreset;
1799        }
1800        con_num = le16_to_cpu(msg->cid_num);
1801        if (con_num > 4096) {
1802                /* Sanity check */
1803                pr_err("Too many connections requested: %d\n", con_num);
1804                goto reject_w_econnreset;
1805        }
1806        cid = le16_to_cpu(msg->cid);
1807        if (cid >= con_num) {
1808                /* Sanity check */
1809                pr_err("Incorrect cid: %d >= %d\n", cid, con_num);
1810                goto reject_w_econnreset;
1811        }
1812        recon_cnt = le16_to_cpu(msg->recon_cnt);
1813        srv = get_or_create_srv(ctx, &msg->paths_uuid);
1814        if (!srv) {
1815                err = -ENOMEM;
1816                goto reject_w_err;
1817        }
1818        mutex_lock(&srv->paths_mutex);
1819        sess = __find_sess(srv, &msg->sess_uuid);
1820        if (sess) {
1821                struct rtrs_sess *s = &sess->s;
1822
1823                /* Session already holds a reference */
1824                put_srv(srv);
1825
1826                if (sess->state != RTRS_SRV_CONNECTING) {
1827                        rtrs_err(s, "Session in wrong state: %s\n",
1828                                  rtrs_srv_state_str(sess->state));
1829                        mutex_unlock(&srv->paths_mutex);
1830                        goto reject_w_econnreset;
1831                }
1832                /*
1833                 * Sanity checks
1834                 */
1835                if (con_num != s->con_num || cid >= s->con_num) {
1836                        rtrs_err(s, "Incorrect request: %d, %d\n",
1837                                  cid, con_num);
1838                        mutex_unlock(&srv->paths_mutex);
1839                        goto reject_w_econnreset;
1840                }
1841                if (s->con[cid]) {
1842                        rtrs_err(s, "Connection already exists: %d\n",
1843                                  cid);
1844                        mutex_unlock(&srv->paths_mutex);
1845                        goto reject_w_econnreset;
1846                }
1847        } else {
1848                sess = __alloc_sess(srv, cm_id, con_num, recon_cnt,
1849                                    &msg->sess_uuid);
1850                if (IS_ERR(sess)) {
1851                        mutex_unlock(&srv->paths_mutex);
1852                        put_srv(srv);
1853                        err = PTR_ERR(sess);
1854                        goto reject_w_err;
1855                }
1856        }
1857        err = create_con(sess, cm_id, cid);
1858        if (err) {
1859                (void)rtrs_rdma_do_reject(cm_id, err);
1860                /*
1861                 * Since session has other connections we follow normal way
1862                 * through workqueue, but still return an error to tell cma.c
1863                 * to call rdma_destroy_id() for current connection.
1864                 */
1865                goto close_and_return_err;
1866        }
1867        err = rtrs_rdma_do_accept(sess, cm_id);
1868        if (err) {
1869                (void)rtrs_rdma_do_reject(cm_id, err);
1870                /*
1871                 * Since current connection was successfully added to the
1872                 * session we follow normal way through workqueue to close the
1873                 * session, thus return 0 to tell cma.c we call
1874                 * rdma_destroy_id() ourselves.
1875                 */
1876                err = 0;
1877                goto close_and_return_err;
1878        }
1879        mutex_unlock(&srv->paths_mutex);
1880
1881        return 0;
1882
1883reject_w_err:
1884        return rtrs_rdma_do_reject(cm_id, err);
1885
1886reject_w_econnreset:
1887        return rtrs_rdma_do_reject(cm_id, -ECONNRESET);
1888
1889close_and_return_err:
1890        close_sess(sess);
1891        mutex_unlock(&srv->paths_mutex);
1892
1893        return err;
1894}
1895
1896static int rtrs_srv_rdma_cm_handler(struct rdma_cm_id *cm_id,
1897                                     struct rdma_cm_event *ev)
1898{
1899        struct rtrs_srv_sess *sess = NULL;
1900        struct rtrs_sess *s = NULL;
1901
1902        if (ev->event != RDMA_CM_EVENT_CONNECT_REQUEST) {
1903                struct rtrs_con *c = cm_id->context;
1904
1905                s = c->sess;
1906                sess = to_srv_sess(s);
1907        }
1908
1909        switch (ev->event) {
1910        case RDMA_CM_EVENT_CONNECT_REQUEST:
1911                /*
1912                 * In case of error cma.c will destroy cm_id,
1913                 * see cma_process_remove()
1914                 */
1915                return rtrs_rdma_connect(cm_id, ev->param.conn.private_data,
1916                                          ev->param.conn.private_data_len);
1917        case RDMA_CM_EVENT_ESTABLISHED:
1918                /* Nothing here */
1919                break;
1920        case RDMA_CM_EVENT_REJECTED:
1921        case RDMA_CM_EVENT_CONNECT_ERROR:
1922        case RDMA_CM_EVENT_UNREACHABLE:
1923                rtrs_err(s, "CM error (CM event: %s, err: %d)\n",
1924                          rdma_event_msg(ev->event), ev->status);
1925                close_sess(sess);
1926                break;
1927        case RDMA_CM_EVENT_DISCONNECTED:
1928        case RDMA_CM_EVENT_ADDR_CHANGE:
1929        case RDMA_CM_EVENT_TIMEWAIT_EXIT:
1930                close_sess(sess);
1931                break;
1932        case RDMA_CM_EVENT_DEVICE_REMOVAL:
1933                close_sess(sess);
1934                break;
1935        default:
1936                pr_err("Ignoring unexpected CM event %s, err %d\n",
1937                       rdma_event_msg(ev->event), ev->status);
1938                break;
1939        }
1940
1941        return 0;
1942}
1943
1944static struct rdma_cm_id *rtrs_srv_cm_init(struct rtrs_srv_ctx *ctx,
1945                                            struct sockaddr *addr,
1946                                            enum rdma_ucm_port_space ps)
1947{
1948        struct rdma_cm_id *cm_id;
1949        int ret;
1950
1951        cm_id = rdma_create_id(&init_net, rtrs_srv_rdma_cm_handler,
1952                               ctx, ps, IB_QPT_RC);
1953        if (IS_ERR(cm_id)) {
1954                ret = PTR_ERR(cm_id);
1955                pr_err("Creating id for RDMA connection failed, err: %d\n",
1956                       ret);
1957                goto err_out;
1958        }
1959        ret = rdma_bind_addr(cm_id, addr);
1960        if (ret) {
1961                pr_err("Binding RDMA address failed, err: %d\n", ret);
1962                goto err_cm;
1963        }
1964        ret = rdma_listen(cm_id, 64);
1965        if (ret) {
1966                pr_err("Listening on RDMA connection failed, err: %d\n",
1967                       ret);
1968                goto err_cm;
1969        }
1970
1971        return cm_id;
1972
1973err_cm:
1974        rdma_destroy_id(cm_id);
1975err_out:
1976
1977        return ERR_PTR(ret);
1978}
1979
1980static int rtrs_srv_rdma_init(struct rtrs_srv_ctx *ctx, u16 port)
1981{
1982        struct sockaddr_in6 sin = {
1983                .sin6_family    = AF_INET6,
1984                .sin6_addr      = IN6ADDR_ANY_INIT,
1985                .sin6_port      = htons(port),
1986        };
1987        struct sockaddr_ib sib = {
1988                .sib_family                     = AF_IB,
1989                .sib_sid        = cpu_to_be64(RDMA_IB_IP_PS_IB | port),
1990                .sib_sid_mask   = cpu_to_be64(0xffffffffffffffffULL),
1991                .sib_pkey       = cpu_to_be16(0xffff),
1992        };
1993        struct rdma_cm_id *cm_ip, *cm_ib;
1994        int ret;
1995
1996        /*
1997         * We accept both IPoIB and IB connections, so we need to keep
1998         * two cm id's, one for each socket type and port space.
1999         * If the cm initialization of one of the id's fails, we abort
2000         * everything.
2001         */
2002        cm_ip = rtrs_srv_cm_init(ctx, (struct sockaddr *)&sin, RDMA_PS_TCP);
2003        if (IS_ERR(cm_ip))
2004                return PTR_ERR(cm_ip);
2005
2006        cm_ib = rtrs_srv_cm_init(ctx, (struct sockaddr *)&sib, RDMA_PS_IB);
2007        if (IS_ERR(cm_ib)) {
2008                ret = PTR_ERR(cm_ib);
2009                goto free_cm_ip;
2010        }
2011
2012        ctx->cm_id_ip = cm_ip;
2013        ctx->cm_id_ib = cm_ib;
2014
2015        return 0;
2016
2017free_cm_ip:
2018        rdma_destroy_id(cm_ip);
2019
2020        return ret;
2021}
2022
2023static struct rtrs_srv_ctx *alloc_srv_ctx(struct rtrs_srv_ops *ops)
2024{
2025        struct rtrs_srv_ctx *ctx;
2026
2027        ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
2028        if (!ctx)
2029                return NULL;
2030
2031        ctx->ops = *ops;
2032        mutex_init(&ctx->srv_mutex);
2033        INIT_LIST_HEAD(&ctx->srv_list);
2034
2035        return ctx;
2036}
2037
2038static void free_srv_ctx(struct rtrs_srv_ctx *ctx)
2039{
2040        WARN_ON(!list_empty(&ctx->srv_list));
2041        mutex_destroy(&ctx->srv_mutex);
2042        kfree(ctx);
2043}
2044
2045/**
2046 * rtrs_srv_open() - open RTRS server context
2047 * @ops:                callback functions
2048 * @port:               port to listen on
2049 *
2050 * Creates server context with specified callbacks.
2051 *
2052 * Return a valid pointer on success otherwise PTR_ERR.
2053 */
2054struct rtrs_srv_ctx *rtrs_srv_open(struct rtrs_srv_ops *ops, u16 port)
2055{
2056        struct rtrs_srv_ctx *ctx;
2057        int err;
2058
2059        ctx = alloc_srv_ctx(ops);
2060        if (!ctx)
2061                return ERR_PTR(-ENOMEM);
2062
2063        err = rtrs_srv_rdma_init(ctx, port);
2064        if (err) {
2065                free_srv_ctx(ctx);
2066                return ERR_PTR(err);
2067        }
2068
2069        return ctx;
2070}
2071EXPORT_SYMBOL(rtrs_srv_open);
2072
2073static void close_sessions(struct rtrs_srv *srv)
2074{
2075        struct rtrs_srv_sess *sess;
2076
2077        mutex_lock(&srv->paths_mutex);
2078        list_for_each_entry(sess, &srv->paths_list, s.entry)
2079                close_sess(sess);
2080        mutex_unlock(&srv->paths_mutex);
2081}
2082
2083static void close_ctx(struct rtrs_srv_ctx *ctx)
2084{
2085        struct rtrs_srv *srv;
2086
2087        mutex_lock(&ctx->srv_mutex);
2088        list_for_each_entry(srv, &ctx->srv_list, ctx_list)
2089                close_sessions(srv);
2090        mutex_unlock(&ctx->srv_mutex);
2091        flush_workqueue(rtrs_wq);
2092}
2093
2094/**
2095 * rtrs_srv_close() - close RTRS server context
2096 * @ctx: pointer to server context
2097 *
2098 * Closes RTRS server context with all client sessions.
2099 */
2100void rtrs_srv_close(struct rtrs_srv_ctx *ctx)
2101{
2102        rdma_destroy_id(ctx->cm_id_ip);
2103        rdma_destroy_id(ctx->cm_id_ib);
2104        close_ctx(ctx);
2105        free_srv_ctx(ctx);
2106}
2107EXPORT_SYMBOL(rtrs_srv_close);
2108
2109static int check_module_params(void)
2110{
2111        if (sess_queue_depth < 1 || sess_queue_depth > MAX_SESS_QUEUE_DEPTH) {
2112                pr_err("Invalid sess_queue_depth value %d, has to be >= %d, <= %d.\n",
2113                       sess_queue_depth, 1, MAX_SESS_QUEUE_DEPTH);
2114                return -EINVAL;
2115        }
2116        if (max_chunk_size < 4096 || !is_power_of_2(max_chunk_size)) {
2117                pr_err("Invalid max_chunk_size value %d, has to be >= %d and should be power of two.\n",
2118                       max_chunk_size, 4096);
2119                return -EINVAL;
2120        }
2121
2122        /*
2123         * Check if IB immediate data size is enough to hold the mem_id and the
2124         * offset inside the memory chunk
2125         */
2126        if ((ilog2(sess_queue_depth - 1) + 1) +
2127            (ilog2(max_chunk_size - 1) + 1) > MAX_IMM_PAYL_BITS) {
2128                pr_err("RDMA immediate size (%db) not enough to encode %d buffers of size %dB. Reduce 'sess_queue_depth' or 'max_chunk_size' parameters.\n",
2129                       MAX_IMM_PAYL_BITS, sess_queue_depth, max_chunk_size);
2130                return -EINVAL;
2131        }
2132
2133        return 0;
2134}
2135
2136static int __init rtrs_server_init(void)
2137{
2138        int err;
2139
2140        pr_info("Loading module %s, proto %s: (max_chunk_size: %d (pure IO %ld, headers %ld) , sess_queue_depth: %d, always_invalidate: %d)\n",
2141                KBUILD_MODNAME, RTRS_PROTO_VER_STRING,
2142                max_chunk_size, max_chunk_size - MAX_HDR_SIZE, MAX_HDR_SIZE,
2143                sess_queue_depth, always_invalidate);
2144
2145        rtrs_rdma_dev_pd_init(0, &dev_pd);
2146
2147        err = check_module_params();
2148        if (err) {
2149                pr_err("Failed to load module, invalid module parameters, err: %d\n",
2150                       err);
2151                return err;
2152        }
2153        chunk_pool = mempool_create_page_pool(sess_queue_depth * CHUNK_POOL_SZ,
2154                                              get_order(max_chunk_size));
2155        if (!chunk_pool)
2156                return -ENOMEM;
2157        rtrs_dev_class = class_create(THIS_MODULE, "rtrs-server");
2158        if (IS_ERR(rtrs_dev_class)) {
2159                err = PTR_ERR(rtrs_dev_class);
2160                goto out_chunk_pool;
2161        }
2162        rtrs_wq = alloc_workqueue("rtrs_server_wq", 0, 0);
2163        if (!rtrs_wq) {
2164                err = -ENOMEM;
2165                goto out_dev_class;
2166        }
2167
2168        return 0;
2169
2170out_dev_class:
2171        class_destroy(rtrs_dev_class);
2172out_chunk_pool:
2173        mempool_destroy(chunk_pool);
2174
2175        return err;
2176}
2177
2178static void __exit rtrs_server_exit(void)
2179{
2180        destroy_workqueue(rtrs_wq);
2181        class_destroy(rtrs_dev_class);
2182        mempool_destroy(chunk_pool);
2183        rtrs_rdma_dev_pd_deinit(&dev_pd);
2184}
2185
2186module_init(rtrs_server_init);
2187module_exit(rtrs_server_exit);
2188