linux/net/tls/tls_main.c
<<
>>
Prefs
   1/*
   2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
   3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
   4 *
   5 * This software is available to you under a choice of one of two
   6 * licenses.  You may choose to be licensed under the terms of the GNU
   7 * General Public License (GPL) Version 2, available from the file
   8 * COPYING in the main directory of this source tree, or the
   9 * OpenIB.org BSD license below:
  10 *
  11 *     Redistribution and use in source and binary forms, with or
  12 *     without modification, are permitted provided that the following
  13 *     conditions are met:
  14 *
  15 *      - Redistributions of source code must retain the above
  16 *        copyright notice, this list of conditions and the following
  17 *        disclaimer.
  18 *
  19 *      - Redistributions in binary form must reproduce the above
  20 *        copyright notice, this list of conditions and the following
  21 *        disclaimer in the documentation and/or other materials
  22 *        provided with the distribution.
  23 *
  24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  31 * SOFTWARE.
  32 */
  33
  34#include <linux/module.h>
  35
  36#include <net/tcp.h>
  37#include <net/inet_common.h>
  38#include <linux/highmem.h>
  39#include <linux/netdevice.h>
  40#include <linux/sched/signal.h>
  41#include <linux/inetdevice.h>
  42#include <linux/inet_diag.h>
  43
  44#include <net/snmp.h>
  45#include <net/tls.h>
  46#include <net/tls_toe.h>
  47
  48MODULE_AUTHOR("Mellanox Technologies");
  49MODULE_DESCRIPTION("Transport Layer Security Support");
  50MODULE_LICENSE("Dual BSD/GPL");
  51MODULE_ALIAS_TCP_ULP("tls");
  52
  53enum {
  54        TLSV4,
  55        TLSV6,
  56        TLS_NUM_PROTS,
  57};
  58
  59static const struct proto *saved_tcpv6_prot;
  60static DEFINE_MUTEX(tcpv6_prot_mutex);
  61static const struct proto *saved_tcpv4_prot;
  62static DEFINE_MUTEX(tcpv4_prot_mutex);
  63static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
  64static struct proto_ops tls_sw_proto_ops;
  65static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
  66                         const struct proto *base);
  67
  68void update_sk_prot(struct sock *sk, struct tls_context *ctx)
  69{
  70        int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
  71
  72        WRITE_ONCE(sk->sk_prot,
  73                   &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
  74}
  75
  76int wait_on_pending_writer(struct sock *sk, long *timeo)
  77{
  78        int rc = 0;
  79        DEFINE_WAIT_FUNC(wait, woken_wake_function);
  80
  81        add_wait_queue(sk_sleep(sk), &wait);
  82        while (1) {
  83                if (!*timeo) {
  84                        rc = -EAGAIN;
  85                        break;
  86                }
  87
  88                if (signal_pending(current)) {
  89                        rc = sock_intr_errno(*timeo);
  90                        break;
  91                }
  92
  93                if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
  94                        break;
  95        }
  96        remove_wait_queue(sk_sleep(sk), &wait);
  97        return rc;
  98}
  99
 100int tls_push_sg(struct sock *sk,
 101                struct tls_context *ctx,
 102                struct scatterlist *sg,
 103                u16 first_offset,
 104                int flags)
 105{
 106        int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
 107        int ret = 0;
 108        struct page *p;
 109        size_t size;
 110        int offset = first_offset;
 111
 112        size = sg->length - offset;
 113        offset += sg->offset;
 114
 115        ctx->in_tcp_sendpages = true;
 116        while (1) {
 117                if (sg_is_last(sg))
 118                        sendpage_flags = flags;
 119
 120                /* is sending application-limited? */
 121                tcp_rate_check_app_limited(sk);
 122                p = sg_page(sg);
 123retry:
 124                ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
 125
 126                if (ret != size) {
 127                        if (ret > 0) {
 128                                offset += ret;
 129                                size -= ret;
 130                                goto retry;
 131                        }
 132
 133                        offset -= sg->offset;
 134                        ctx->partially_sent_offset = offset;
 135                        ctx->partially_sent_record = (void *)sg;
 136                        ctx->in_tcp_sendpages = false;
 137                        return ret;
 138                }
 139
 140                put_page(p);
 141                sk_mem_uncharge(sk, sg->length);
 142                sg = sg_next(sg);
 143                if (!sg)
 144                        break;
 145
 146                offset = sg->offset;
 147                size = sg->length;
 148        }
 149
 150        ctx->in_tcp_sendpages = false;
 151
 152        return 0;
 153}
 154
 155static int tls_handle_open_record(struct sock *sk, int flags)
 156{
 157        struct tls_context *ctx = tls_get_ctx(sk);
 158
 159        if (tls_is_pending_open_record(ctx))
 160                return ctx->push_pending_record(sk, flags);
 161
 162        return 0;
 163}
 164
 165int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
 166                      unsigned char *record_type)
 167{
 168        struct cmsghdr *cmsg;
 169        int rc = -EINVAL;
 170
 171        for_each_cmsghdr(cmsg, msg) {
 172                if (!CMSG_OK(msg, cmsg))
 173                        return -EINVAL;
 174                if (cmsg->cmsg_level != SOL_TLS)
 175                        continue;
 176
 177                switch (cmsg->cmsg_type) {
 178                case TLS_SET_RECORD_TYPE:
 179                        if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
 180                                return -EINVAL;
 181
 182                        if (msg->msg_flags & MSG_MORE)
 183                                return -EINVAL;
 184
 185                        rc = tls_handle_open_record(sk, msg->msg_flags);
 186                        if (rc)
 187                                return rc;
 188
 189                        *record_type = *(unsigned char *)CMSG_DATA(cmsg);
 190                        rc = 0;
 191                        break;
 192                default:
 193                        return -EINVAL;
 194                }
 195        }
 196
 197        return rc;
 198}
 199
 200int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
 201                            int flags)
 202{
 203        struct scatterlist *sg;
 204        u16 offset;
 205
 206        sg = ctx->partially_sent_record;
 207        offset = ctx->partially_sent_offset;
 208
 209        ctx->partially_sent_record = NULL;
 210        return tls_push_sg(sk, ctx, sg, offset, flags);
 211}
 212
 213void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
 214{
 215        struct scatterlist *sg;
 216
 217        for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
 218                put_page(sg_page(sg));
 219                sk_mem_uncharge(sk, sg->length);
 220        }
 221        ctx->partially_sent_record = NULL;
 222}
 223
 224static void tls_write_space(struct sock *sk)
 225{
 226        struct tls_context *ctx = tls_get_ctx(sk);
 227
 228        /* If in_tcp_sendpages call lower protocol write space handler
 229         * to ensure we wake up any waiting operations there. For example
 230         * if do_tcp_sendpages where to call sk_wait_event.
 231         */
 232        if (ctx->in_tcp_sendpages) {
 233                ctx->sk_write_space(sk);
 234                return;
 235        }
 236
 237#ifdef CONFIG_TLS_DEVICE
 238        if (ctx->tx_conf == TLS_HW)
 239                tls_device_write_space(sk, ctx);
 240        else
 241#endif
 242                tls_sw_write_space(sk, ctx);
 243
 244        ctx->sk_write_space(sk);
 245}
 246
 247/**
 248 * tls_ctx_free() - free TLS ULP context
 249 * @sk:  socket to with @ctx is attached
 250 * @ctx: TLS context structure
 251 *
 252 * Free TLS context. If @sk is %NULL caller guarantees that the socket
 253 * to which @ctx was attached has no outstanding references.
 254 */
 255void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
 256{
 257        if (!ctx)
 258                return;
 259
 260        memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
 261        memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
 262        mutex_destroy(&ctx->tx_lock);
 263
 264        if (sk)
 265                kfree_rcu(ctx, rcu);
 266        else
 267                kfree(ctx);
 268}
 269
 270static void tls_sk_proto_cleanup(struct sock *sk,
 271                                 struct tls_context *ctx, long timeo)
 272{
 273        if (unlikely(sk->sk_write_pending) &&
 274            !wait_on_pending_writer(sk, &timeo))
 275                tls_handle_open_record(sk, 0);
 276
 277        /* We need these for tls_sw_fallback handling of other packets */
 278        if (ctx->tx_conf == TLS_SW) {
 279                kfree(ctx->tx.rec_seq);
 280                kfree(ctx->tx.iv);
 281                tls_sw_release_resources_tx(sk);
 282                TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
 283        } else if (ctx->tx_conf == TLS_HW) {
 284                tls_device_free_resources_tx(sk);
 285                TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
 286        }
 287
 288        if (ctx->rx_conf == TLS_SW) {
 289                tls_sw_release_resources_rx(sk);
 290                TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
 291        } else if (ctx->rx_conf == TLS_HW) {
 292                tls_device_offload_cleanup_rx(sk);
 293                TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
 294        }
 295}
 296
 297static void tls_sk_proto_close(struct sock *sk, long timeout)
 298{
 299        struct inet_connection_sock *icsk = inet_csk(sk);
 300        struct tls_context *ctx = tls_get_ctx(sk);
 301        long timeo = sock_sndtimeo(sk, 0);
 302        bool free_ctx;
 303
 304        if (ctx->tx_conf == TLS_SW)
 305                tls_sw_cancel_work_tx(ctx);
 306
 307        lock_sock(sk);
 308        free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
 309
 310        if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
 311                tls_sk_proto_cleanup(sk, ctx, timeo);
 312
 313        write_lock_bh(&sk->sk_callback_lock);
 314        if (free_ctx)
 315                rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
 316        WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
 317        if (sk->sk_write_space == tls_write_space)
 318                sk->sk_write_space = ctx->sk_write_space;
 319        write_unlock_bh(&sk->sk_callback_lock);
 320        release_sock(sk);
 321        if (ctx->tx_conf == TLS_SW)
 322                tls_sw_free_ctx_tx(ctx);
 323        if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
 324                tls_sw_strparser_done(ctx);
 325        if (ctx->rx_conf == TLS_SW)
 326                tls_sw_free_ctx_rx(ctx);
 327        ctx->sk_proto->close(sk, timeout);
 328
 329        if (free_ctx)
 330                tls_ctx_free(sk, ctx);
 331}
 332
 333static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
 334                                int __user *optlen)
 335{
 336        int rc = 0;
 337        struct tls_context *ctx = tls_get_ctx(sk);
 338        struct tls_crypto_info *crypto_info;
 339        int len;
 340
 341        if (get_user(len, optlen))
 342                return -EFAULT;
 343
 344        if (!optval || (len < sizeof(*crypto_info))) {
 345                rc = -EINVAL;
 346                goto out;
 347        }
 348
 349        if (!ctx) {
 350                rc = -EBUSY;
 351                goto out;
 352        }
 353
 354        /* get user crypto info */
 355        crypto_info = &ctx->crypto_send.info;
 356
 357        if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
 358                rc = -EBUSY;
 359                goto out;
 360        }
 361
 362        if (len == sizeof(*crypto_info)) {
 363                if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
 364                        rc = -EFAULT;
 365                goto out;
 366        }
 367
 368        switch (crypto_info->cipher_type) {
 369        case TLS_CIPHER_AES_GCM_128: {
 370                struct tls12_crypto_info_aes_gcm_128 *
 371                  crypto_info_aes_gcm_128 =
 372                  container_of(crypto_info,
 373                               struct tls12_crypto_info_aes_gcm_128,
 374                               info);
 375
 376                if (len != sizeof(*crypto_info_aes_gcm_128)) {
 377                        rc = -EINVAL;
 378                        goto out;
 379                }
 380                lock_sock(sk);
 381                memcpy(crypto_info_aes_gcm_128->iv,
 382                       ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
 383                       TLS_CIPHER_AES_GCM_128_IV_SIZE);
 384                memcpy(crypto_info_aes_gcm_128->rec_seq, ctx->tx.rec_seq,
 385                       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
 386                release_sock(sk);
 387                if (copy_to_user(optval,
 388                                 crypto_info_aes_gcm_128,
 389                                 sizeof(*crypto_info_aes_gcm_128)))
 390                        rc = -EFAULT;
 391                break;
 392        }
 393        case TLS_CIPHER_AES_GCM_256: {
 394                struct tls12_crypto_info_aes_gcm_256 *
 395                  crypto_info_aes_gcm_256 =
 396                  container_of(crypto_info,
 397                               struct tls12_crypto_info_aes_gcm_256,
 398                               info);
 399
 400                if (len != sizeof(*crypto_info_aes_gcm_256)) {
 401                        rc = -EINVAL;
 402                        goto out;
 403                }
 404                lock_sock(sk);
 405                memcpy(crypto_info_aes_gcm_256->iv,
 406                       ctx->tx.iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
 407                       TLS_CIPHER_AES_GCM_256_IV_SIZE);
 408                memcpy(crypto_info_aes_gcm_256->rec_seq, ctx->tx.rec_seq,
 409                       TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
 410                release_sock(sk);
 411                if (copy_to_user(optval,
 412                                 crypto_info_aes_gcm_256,
 413                                 sizeof(*crypto_info_aes_gcm_256)))
 414                        rc = -EFAULT;
 415                break;
 416        }
 417        default:
 418                rc = -EINVAL;
 419        }
 420
 421out:
 422        return rc;
 423}
 424
 425static int do_tls_getsockopt(struct sock *sk, int optname,
 426                             char __user *optval, int __user *optlen)
 427{
 428        int rc = 0;
 429
 430        switch (optname) {
 431        case TLS_TX:
 432                rc = do_tls_getsockopt_tx(sk, optval, optlen);
 433                break;
 434        default:
 435                rc = -ENOPROTOOPT;
 436                break;
 437        }
 438        return rc;
 439}
 440
 441static int tls_getsockopt(struct sock *sk, int level, int optname,
 442                          char __user *optval, int __user *optlen)
 443{
 444        struct tls_context *ctx = tls_get_ctx(sk);
 445
 446        if (level != SOL_TLS)
 447                return ctx->sk_proto->getsockopt(sk, level,
 448                                                 optname, optval, optlen);
 449
 450        return do_tls_getsockopt(sk, optname, optval, optlen);
 451}
 452
 453static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
 454                                  unsigned int optlen, int tx)
 455{
 456        struct tls_crypto_info *crypto_info;
 457        struct tls_crypto_info *alt_crypto_info;
 458        struct tls_context *ctx = tls_get_ctx(sk);
 459        size_t optsize;
 460        int rc = 0;
 461        int conf;
 462
 463        if (!optval || (optlen < sizeof(*crypto_info))) {
 464                rc = -EINVAL;
 465                goto out;
 466        }
 467
 468        if (tx) {
 469                crypto_info = &ctx->crypto_send.info;
 470                alt_crypto_info = &ctx->crypto_recv.info;
 471        } else {
 472                crypto_info = &ctx->crypto_recv.info;
 473                alt_crypto_info = &ctx->crypto_send.info;
 474        }
 475
 476        /* Currently we don't support set crypto info more than one time */
 477        if (TLS_CRYPTO_INFO_READY(crypto_info)) {
 478                rc = -EBUSY;
 479                goto out;
 480        }
 481
 482        rc = copy_from_user(crypto_info, optval, sizeof(*crypto_info));
 483        if (rc) {
 484                rc = -EFAULT;
 485                goto err_crypto_info;
 486        }
 487
 488        /* check version */
 489        if (crypto_info->version != TLS_1_2_VERSION &&
 490            crypto_info->version != TLS_1_3_VERSION) {
 491                rc = -EINVAL;
 492                goto err_crypto_info;
 493        }
 494
 495        /* Ensure that TLS version and ciphers are same in both directions */
 496        if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
 497                if (alt_crypto_info->version != crypto_info->version ||
 498                    alt_crypto_info->cipher_type != crypto_info->cipher_type) {
 499                        rc = -EINVAL;
 500                        goto err_crypto_info;
 501                }
 502        }
 503
 504        switch (crypto_info->cipher_type) {
 505        case TLS_CIPHER_AES_GCM_128:
 506                optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
 507                break;
 508        case TLS_CIPHER_AES_GCM_256: {
 509                optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
 510                break;
 511        }
 512        case TLS_CIPHER_AES_CCM_128:
 513                optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
 514                break;
 515        default:
 516                rc = -EINVAL;
 517                goto err_crypto_info;
 518        }
 519
 520        if (optlen != optsize) {
 521                rc = -EINVAL;
 522                goto err_crypto_info;
 523        }
 524
 525        rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info),
 526                            optlen - sizeof(*crypto_info));
 527        if (rc) {
 528                rc = -EFAULT;
 529                goto err_crypto_info;
 530        }
 531
 532        if (tx) {
 533                rc = tls_set_device_offload(sk, ctx);
 534                conf = TLS_HW;
 535                if (!rc) {
 536                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
 537                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
 538                } else {
 539                        rc = tls_set_sw_offload(sk, ctx, 1);
 540                        if (rc)
 541                                goto err_crypto_info;
 542                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
 543                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
 544                        conf = TLS_SW;
 545                }
 546        } else {
 547                rc = tls_set_device_offload_rx(sk, ctx);
 548                conf = TLS_HW;
 549                if (!rc) {
 550                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
 551                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
 552                } else {
 553                        rc = tls_set_sw_offload(sk, ctx, 0);
 554                        if (rc)
 555                                goto err_crypto_info;
 556                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
 557                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
 558                        conf = TLS_SW;
 559                }
 560                tls_sw_strparser_arm(sk, ctx);
 561        }
 562
 563        if (tx)
 564                ctx->tx_conf = conf;
 565        else
 566                ctx->rx_conf = conf;
 567        update_sk_prot(sk, ctx);
 568        if (tx) {
 569                ctx->sk_write_space = sk->sk_write_space;
 570                sk->sk_write_space = tls_write_space;
 571        } else {
 572                sk->sk_socket->ops = &tls_sw_proto_ops;
 573        }
 574        goto out;
 575
 576err_crypto_info:
 577        memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
 578out:
 579        return rc;
 580}
 581
 582static int do_tls_setsockopt(struct sock *sk, int optname,
 583                             char __user *optval, unsigned int optlen)
 584{
 585        int rc = 0;
 586
 587        switch (optname) {
 588        case TLS_TX:
 589        case TLS_RX:
 590                lock_sock(sk);
 591                rc = do_tls_setsockopt_conf(sk, optval, optlen,
 592                                            optname == TLS_TX);
 593                release_sock(sk);
 594                break;
 595        default:
 596                rc = -ENOPROTOOPT;
 597                break;
 598        }
 599        return rc;
 600}
 601
 602static int tls_setsockopt(struct sock *sk, int level, int optname,
 603                          char __user *optval, unsigned int optlen)
 604{
 605        struct tls_context *ctx = tls_get_ctx(sk);
 606
 607        if (level != SOL_TLS)
 608                return ctx->sk_proto->setsockopt(sk, level, optname, optval,
 609                                                 optlen);
 610
 611        return do_tls_setsockopt(sk, optname, optval, optlen);
 612}
 613
 614struct tls_context *tls_ctx_create(struct sock *sk)
 615{
 616        struct inet_connection_sock *icsk = inet_csk(sk);
 617        struct tls_context *ctx;
 618
 619        ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
 620        if (!ctx)
 621                return NULL;
 622
 623        mutex_init(&ctx->tx_lock);
 624        rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
 625        ctx->sk_proto = READ_ONCE(sk->sk_prot);
 626        return ctx;
 627}
 628
 629static void tls_build_proto(struct sock *sk)
 630{
 631        int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
 632        const struct proto *prot = READ_ONCE(sk->sk_prot);
 633
 634        /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
 635        if (ip_ver == TLSV6 &&
 636            unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
 637                mutex_lock(&tcpv6_prot_mutex);
 638                if (likely(prot != saved_tcpv6_prot)) {
 639                        build_protos(tls_prots[TLSV6], prot);
 640                        smp_store_release(&saved_tcpv6_prot, prot);
 641                }
 642                mutex_unlock(&tcpv6_prot_mutex);
 643        }
 644
 645        if (ip_ver == TLSV4 &&
 646            unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
 647                mutex_lock(&tcpv4_prot_mutex);
 648                if (likely(prot != saved_tcpv4_prot)) {
 649                        build_protos(tls_prots[TLSV4], prot);
 650                        smp_store_release(&saved_tcpv4_prot, prot);
 651                }
 652                mutex_unlock(&tcpv4_prot_mutex);
 653        }
 654}
 655
 656static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
 657                         const struct proto *base)
 658{
 659        prot[TLS_BASE][TLS_BASE] = *base;
 660        prot[TLS_BASE][TLS_BASE].setsockopt     = tls_setsockopt;
 661        prot[TLS_BASE][TLS_BASE].getsockopt     = tls_getsockopt;
 662        prot[TLS_BASE][TLS_BASE].close          = tls_sk_proto_close;
 663
 664        prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
 665        prot[TLS_SW][TLS_BASE].sendmsg          = tls_sw_sendmsg;
 666        prot[TLS_SW][TLS_BASE].sendpage         = tls_sw_sendpage;
 667
 668        prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
 669        prot[TLS_BASE][TLS_SW].recvmsg            = tls_sw_recvmsg;
 670        prot[TLS_BASE][TLS_SW].stream_memory_read = tls_sw_stream_read;
 671        prot[TLS_BASE][TLS_SW].close              = tls_sk_proto_close;
 672
 673        prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
 674        prot[TLS_SW][TLS_SW].recvmsg            = tls_sw_recvmsg;
 675        prot[TLS_SW][TLS_SW].stream_memory_read = tls_sw_stream_read;
 676        prot[TLS_SW][TLS_SW].close              = tls_sk_proto_close;
 677
 678#ifdef CONFIG_TLS_DEVICE
 679        prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
 680        prot[TLS_HW][TLS_BASE].sendmsg          = tls_device_sendmsg;
 681        prot[TLS_HW][TLS_BASE].sendpage         = tls_device_sendpage;
 682
 683        prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
 684        prot[TLS_HW][TLS_SW].sendmsg            = tls_device_sendmsg;
 685        prot[TLS_HW][TLS_SW].sendpage           = tls_device_sendpage;
 686
 687        prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
 688
 689        prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
 690
 691        prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
 692#endif
 693#ifdef CONFIG_TLS_TOE
 694        prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
 695        prot[TLS_HW_RECORD][TLS_HW_RECORD].hash         = tls_toe_hash;
 696        prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash       = tls_toe_unhash;
 697#endif
 698}
 699
 700static int tls_init(struct sock *sk)
 701{
 702        struct tls_context *ctx;
 703        int rc = 0;
 704
 705        tls_build_proto(sk);
 706
 707#ifdef CONFIG_TLS_TOE
 708        if (tls_toe_bypass(sk))
 709                return 0;
 710#endif
 711
 712        /* The TLS ulp is currently supported only for TCP sockets
 713         * in ESTABLISHED state.
 714         * Supporting sockets in LISTEN state will require us
 715         * to modify the accept implementation to clone rather then
 716         * share the ulp context.
 717         */
 718        if (sk->sk_state != TCP_ESTABLISHED)
 719                return -ENOTCONN;
 720
 721        /* allocate tls context */
 722        write_lock_bh(&sk->sk_callback_lock);
 723        ctx = tls_ctx_create(sk);
 724        if (!ctx) {
 725                rc = -ENOMEM;
 726                goto out;
 727        }
 728
 729        ctx->tx_conf = TLS_BASE;
 730        ctx->rx_conf = TLS_BASE;
 731        update_sk_prot(sk, ctx);
 732out:
 733        write_unlock_bh(&sk->sk_callback_lock);
 734        return rc;
 735}
 736
 737static void tls_update(struct sock *sk, struct proto *p,
 738                       void (*write_space)(struct sock *sk))
 739{
 740        struct tls_context *ctx;
 741
 742        ctx = tls_get_ctx(sk);
 743        if (likely(ctx)) {
 744                ctx->sk_write_space = write_space;
 745                ctx->sk_proto = p;
 746        } else {
 747                /* Pairs with lockless read in sk_clone_lock(). */
 748                WRITE_ONCE(sk->sk_prot, p);
 749                sk->sk_write_space = write_space;
 750        }
 751}
 752
 753static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
 754{
 755        u16 version, cipher_type;
 756        struct tls_context *ctx;
 757        struct nlattr *start;
 758        int err;
 759
 760        start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
 761        if (!start)
 762                return -EMSGSIZE;
 763
 764        rcu_read_lock();
 765        ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
 766        if (!ctx) {
 767                err = 0;
 768                goto nla_failure;
 769        }
 770        version = ctx->prot_info.version;
 771        if (version) {
 772                err = nla_put_u16(skb, TLS_INFO_VERSION, version);
 773                if (err)
 774                        goto nla_failure;
 775        }
 776        cipher_type = ctx->prot_info.cipher_type;
 777        if (cipher_type) {
 778                err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
 779                if (err)
 780                        goto nla_failure;
 781        }
 782        err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
 783        if (err)
 784                goto nla_failure;
 785
 786        err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
 787        if (err)
 788                goto nla_failure;
 789
 790        rcu_read_unlock();
 791        nla_nest_end(skb, start);
 792        return 0;
 793
 794nla_failure:
 795        rcu_read_unlock();
 796        nla_nest_cancel(skb, start);
 797        return err;
 798}
 799
 800static size_t tls_get_info_size(const struct sock *sk)
 801{
 802        size_t size = 0;
 803
 804        size += nla_total_size(0) +             /* INET_ULP_INFO_TLS */
 805                nla_total_size(sizeof(u16)) +   /* TLS_INFO_VERSION */
 806                nla_total_size(sizeof(u16)) +   /* TLS_INFO_CIPHER */
 807                nla_total_size(sizeof(u16)) +   /* TLS_INFO_RXCONF */
 808                nla_total_size(sizeof(u16)) +   /* TLS_INFO_TXCONF */
 809                0;
 810
 811        return size;
 812}
 813
 814static int __net_init tls_init_net(struct net *net)
 815{
 816        int err;
 817
 818        net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
 819        if (!net->mib.tls_statistics)
 820                return -ENOMEM;
 821
 822        err = tls_proc_init(net);
 823        if (err)
 824                goto err_free_stats;
 825
 826        return 0;
 827err_free_stats:
 828        free_percpu(net->mib.tls_statistics);
 829        return err;
 830}
 831
 832static void __net_exit tls_exit_net(struct net *net)
 833{
 834        tls_proc_fini(net);
 835        free_percpu(net->mib.tls_statistics);
 836}
 837
 838static struct pernet_operations tls_proc_ops = {
 839        .init = tls_init_net,
 840        .exit = tls_exit_net,
 841};
 842
 843static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
 844        .name                   = "tls",
 845        .owner                  = THIS_MODULE,
 846        .init                   = tls_init,
 847        .update                 = tls_update,
 848        .get_info               = tls_get_info,
 849        .get_info_size          = tls_get_info_size,
 850};
 851
 852static int __init tls_register(void)
 853{
 854        int err;
 855
 856        err = register_pernet_subsys(&tls_proc_ops);
 857        if (err)
 858                return err;
 859
 860        tls_sw_proto_ops = inet_stream_ops;
 861        tls_sw_proto_ops.splice_read = tls_sw_splice_read;
 862        tls_sw_proto_ops.sendpage_locked   = tls_sw_sendpage_locked,
 863
 864        tls_device_init();
 865        tcp_register_ulp(&tcp_tls_ulp_ops);
 866
 867        return 0;
 868}
 869
 870static void __exit tls_unregister(void)
 871{
 872        tcp_unregister_ulp(&tcp_tls_ulp_ops);
 873        tls_device_cleanup();
 874        unregister_pernet_subsys(&tls_proc_ops);
 875}
 876
 877module_init(tls_register);
 878module_exit(tls_unregister);
 879