linux/net/tls/tls_main.c
<<
>>
Prefs
   1/*
   2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
   3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
   4 *
   5 * This software is available to you under a choice of one of two
   6 * licenses.  You may choose to be licensed under the terms of the GNU
   7 * General Public License (GPL) Version 2, available from the file
   8 * COPYING in the main directory of this source tree, or the
   9 * OpenIB.org BSD license below:
  10 *
  11 *     Redistribution and use in source and binary forms, with or
  12 *     without modification, are permitted provided that the following
  13 *     conditions are met:
  14 *
  15 *      - Redistributions of source code must retain the above
  16 *        copyright notice, this list of conditions and the following
  17 *        disclaimer.
  18 *
  19 *      - Redistributions in binary form must reproduce the above
  20 *        copyright notice, this list of conditions and the following
  21 *        disclaimer in the documentation and/or other materials
  22 *        provided with the distribution.
  23 *
  24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  31 * SOFTWARE.
  32 */
  33
  34#include <linux/module.h>
  35
  36#include <net/tcp.h>
  37#include <net/inet_common.h>
  38#include <linux/highmem.h>
  39#include <linux/netdevice.h>
  40#include <linux/sched/signal.h>
  41#include <linux/inetdevice.h>
  42#include <linux/inet_diag.h>
  43
  44#include <net/snmp.h>
  45#include <net/tls.h>
  46#include <net/tls_toe.h>
  47
  48MODULE_AUTHOR("Mellanox Technologies");
  49MODULE_DESCRIPTION("Transport Layer Security Support");
  50MODULE_LICENSE("Dual BSD/GPL");
  51MODULE_ALIAS_TCP_ULP("tls");
  52
  53enum {
  54        TLSV4,
  55        TLSV6,
  56        TLS_NUM_PROTS,
  57};
  58
  59static const struct proto *saved_tcpv6_prot;
  60static DEFINE_MUTEX(tcpv6_prot_mutex);
  61static const struct proto *saved_tcpv4_prot;
  62static DEFINE_MUTEX(tcpv4_prot_mutex);
  63static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
  64static struct proto_ops tls_sw_proto_ops;
  65static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
  66                         const struct proto *base);
  67
  68void update_sk_prot(struct sock *sk, struct tls_context *ctx)
  69{
  70        int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
  71
  72        WRITE_ONCE(sk->sk_prot,
  73                   &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
  74}
  75
  76int wait_on_pending_writer(struct sock *sk, long *timeo)
  77{
  78        int rc = 0;
  79        DEFINE_WAIT_FUNC(wait, woken_wake_function);
  80
  81        add_wait_queue(sk_sleep(sk), &wait);
  82        while (1) {
  83                if (!*timeo) {
  84                        rc = -EAGAIN;
  85                        break;
  86                }
  87
  88                if (signal_pending(current)) {
  89                        rc = sock_intr_errno(*timeo);
  90                        break;
  91                }
  92
  93                if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
  94                        break;
  95        }
  96        remove_wait_queue(sk_sleep(sk), &wait);
  97        return rc;
  98}
  99
 100int tls_push_sg(struct sock *sk,
 101                struct tls_context *ctx,
 102                struct scatterlist *sg,
 103                u16 first_offset,
 104                int flags)
 105{
 106        int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
 107        int ret = 0;
 108        struct page *p;
 109        size_t size;
 110        int offset = first_offset;
 111
 112        size = sg->length - offset;
 113        offset += sg->offset;
 114
 115        ctx->in_tcp_sendpages = true;
 116        while (1) {
 117                if (sg_is_last(sg))
 118                        sendpage_flags = flags;
 119
 120                /* is sending application-limited? */
 121                tcp_rate_check_app_limited(sk);
 122                p = sg_page(sg);
 123retry:
 124                ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
 125
 126                if (ret != size) {
 127                        if (ret > 0) {
 128                                offset += ret;
 129                                size -= ret;
 130                                goto retry;
 131                        }
 132
 133                        offset -= sg->offset;
 134                        ctx->partially_sent_offset = offset;
 135                        ctx->partially_sent_record = (void *)sg;
 136                        ctx->in_tcp_sendpages = false;
 137                        return ret;
 138                }
 139
 140                put_page(p);
 141                sk_mem_uncharge(sk, sg->length);
 142                sg = sg_next(sg);
 143                if (!sg)
 144                        break;
 145
 146                offset = sg->offset;
 147                size = sg->length;
 148        }
 149
 150        ctx->in_tcp_sendpages = false;
 151
 152        return 0;
 153}
 154
 155static int tls_handle_open_record(struct sock *sk, int flags)
 156{
 157        struct tls_context *ctx = tls_get_ctx(sk);
 158
 159        if (tls_is_pending_open_record(ctx))
 160                return ctx->push_pending_record(sk, flags);
 161
 162        return 0;
 163}
 164
 165int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
 166                      unsigned char *record_type)
 167{
 168        struct cmsghdr *cmsg;
 169        int rc = -EINVAL;
 170
 171        for_each_cmsghdr(cmsg, msg) {
 172                if (!CMSG_OK(msg, cmsg))
 173                        return -EINVAL;
 174                if (cmsg->cmsg_level != SOL_TLS)
 175                        continue;
 176
 177                switch (cmsg->cmsg_type) {
 178                case TLS_SET_RECORD_TYPE:
 179                        if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
 180                                return -EINVAL;
 181
 182                        if (msg->msg_flags & MSG_MORE)
 183                                return -EINVAL;
 184
 185                        rc = tls_handle_open_record(sk, msg->msg_flags);
 186                        if (rc)
 187                                return rc;
 188
 189                        *record_type = *(unsigned char *)CMSG_DATA(cmsg);
 190                        rc = 0;
 191                        break;
 192                default:
 193                        return -EINVAL;
 194                }
 195        }
 196
 197        return rc;
 198}
 199
 200int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
 201                            int flags)
 202{
 203        struct scatterlist *sg;
 204        u16 offset;
 205
 206        sg = ctx->partially_sent_record;
 207        offset = ctx->partially_sent_offset;
 208
 209        ctx->partially_sent_record = NULL;
 210        return tls_push_sg(sk, ctx, sg, offset, flags);
 211}
 212
 213void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
 214{
 215        struct scatterlist *sg;
 216
 217        for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
 218                put_page(sg_page(sg));
 219                sk_mem_uncharge(sk, sg->length);
 220        }
 221        ctx->partially_sent_record = NULL;
 222}
 223
 224static void tls_write_space(struct sock *sk)
 225{
 226        struct tls_context *ctx = tls_get_ctx(sk);
 227
 228        /* If in_tcp_sendpages call lower protocol write space handler
 229         * to ensure we wake up any waiting operations there. For example
 230         * if do_tcp_sendpages where to call sk_wait_event.
 231         */
 232        if (ctx->in_tcp_sendpages) {
 233                ctx->sk_write_space(sk);
 234                return;
 235        }
 236
 237#ifdef CONFIG_TLS_DEVICE
 238        if (ctx->tx_conf == TLS_HW)
 239                tls_device_write_space(sk, ctx);
 240        else
 241#endif
 242                tls_sw_write_space(sk, ctx);
 243
 244        ctx->sk_write_space(sk);
 245}
 246
 247/**
 248 * tls_ctx_free() - free TLS ULP context
 249 * @sk:  socket to with @ctx is attached
 250 * @ctx: TLS context structure
 251 *
 252 * Free TLS context. If @sk is %NULL caller guarantees that the socket
 253 * to which @ctx was attached has no outstanding references.
 254 */
 255void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
 256{
 257        if (!ctx)
 258                return;
 259
 260        memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
 261        memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
 262        mutex_destroy(&ctx->tx_lock);
 263
 264        if (sk)
 265                kfree_rcu(ctx, rcu);
 266        else
 267                kfree(ctx);
 268}
 269
 270static void tls_sk_proto_cleanup(struct sock *sk,
 271                                 struct tls_context *ctx, long timeo)
 272{
 273        if (unlikely(sk->sk_write_pending) &&
 274            !wait_on_pending_writer(sk, &timeo))
 275                tls_handle_open_record(sk, 0);
 276
 277        /* We need these for tls_sw_fallback handling of other packets */
 278        if (ctx->tx_conf == TLS_SW) {
 279                kfree(ctx->tx.rec_seq);
 280                kfree(ctx->tx.iv);
 281                tls_sw_release_resources_tx(sk);
 282                TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
 283        } else if (ctx->tx_conf == TLS_HW) {
 284                tls_device_free_resources_tx(sk);
 285                TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
 286        }
 287
 288        if (ctx->rx_conf == TLS_SW) {
 289                tls_sw_release_resources_rx(sk);
 290                TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
 291        } else if (ctx->rx_conf == TLS_HW) {
 292                tls_device_offload_cleanup_rx(sk);
 293                TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
 294        }
 295}
 296
 297static void tls_sk_proto_close(struct sock *sk, long timeout)
 298{
 299        struct inet_connection_sock *icsk = inet_csk(sk);
 300        struct tls_context *ctx = tls_get_ctx(sk);
 301        long timeo = sock_sndtimeo(sk, 0);
 302        bool free_ctx;
 303
 304        if (ctx->tx_conf == TLS_SW)
 305                tls_sw_cancel_work_tx(ctx);
 306
 307        lock_sock(sk);
 308        free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
 309
 310        if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
 311                tls_sk_proto_cleanup(sk, ctx, timeo);
 312
 313        write_lock_bh(&sk->sk_callback_lock);
 314        if (free_ctx)
 315                rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
 316        WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
 317        if (sk->sk_write_space == tls_write_space)
 318                sk->sk_write_space = ctx->sk_write_space;
 319        write_unlock_bh(&sk->sk_callback_lock);
 320        release_sock(sk);
 321        if (ctx->tx_conf == TLS_SW)
 322                tls_sw_free_ctx_tx(ctx);
 323        if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
 324                tls_sw_strparser_done(ctx);
 325        if (ctx->rx_conf == TLS_SW)
 326                tls_sw_free_ctx_rx(ctx);
 327        ctx->sk_proto->close(sk, timeout);
 328
 329        if (free_ctx)
 330                tls_ctx_free(sk, ctx);
 331}
 332
 333static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
 334                                  int __user *optlen, int tx)
 335{
 336        int rc = 0;
 337        struct tls_context *ctx = tls_get_ctx(sk);
 338        struct tls_crypto_info *crypto_info;
 339        struct cipher_context *cctx;
 340        int len;
 341
 342        if (get_user(len, optlen))
 343                return -EFAULT;
 344
 345        if (!optval || (len < sizeof(*crypto_info))) {
 346                rc = -EINVAL;
 347                goto out;
 348        }
 349
 350        if (!ctx) {
 351                rc = -EBUSY;
 352                goto out;
 353        }
 354
 355        /* get user crypto info */
 356        if (tx) {
 357                crypto_info = &ctx->crypto_send.info;
 358                cctx = &ctx->tx;
 359        } else {
 360                crypto_info = &ctx->crypto_recv.info;
 361                cctx = &ctx->rx;
 362        }
 363
 364        if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
 365                rc = -EBUSY;
 366                goto out;
 367        }
 368
 369        if (len == sizeof(*crypto_info)) {
 370                if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
 371                        rc = -EFAULT;
 372                goto out;
 373        }
 374
 375        switch (crypto_info->cipher_type) {
 376        case TLS_CIPHER_AES_GCM_128: {
 377                struct tls12_crypto_info_aes_gcm_128 *
 378                  crypto_info_aes_gcm_128 =
 379                  container_of(crypto_info,
 380                               struct tls12_crypto_info_aes_gcm_128,
 381                               info);
 382
 383                if (len != sizeof(*crypto_info_aes_gcm_128)) {
 384                        rc = -EINVAL;
 385                        goto out;
 386                }
 387                lock_sock(sk);
 388                memcpy(crypto_info_aes_gcm_128->iv,
 389                       cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
 390                       TLS_CIPHER_AES_GCM_128_IV_SIZE);
 391                memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq,
 392                       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
 393                release_sock(sk);
 394                if (copy_to_user(optval,
 395                                 crypto_info_aes_gcm_128,
 396                                 sizeof(*crypto_info_aes_gcm_128)))
 397                        rc = -EFAULT;
 398                break;
 399        }
 400        case TLS_CIPHER_AES_GCM_256: {
 401                struct tls12_crypto_info_aes_gcm_256 *
 402                  crypto_info_aes_gcm_256 =
 403                  container_of(crypto_info,
 404                               struct tls12_crypto_info_aes_gcm_256,
 405                               info);
 406
 407                if (len != sizeof(*crypto_info_aes_gcm_256)) {
 408                        rc = -EINVAL;
 409                        goto out;
 410                }
 411                lock_sock(sk);
 412                memcpy(crypto_info_aes_gcm_256->iv,
 413                       cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
 414                       TLS_CIPHER_AES_GCM_256_IV_SIZE);
 415                memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq,
 416                       TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
 417                release_sock(sk);
 418                if (copy_to_user(optval,
 419                                 crypto_info_aes_gcm_256,
 420                                 sizeof(*crypto_info_aes_gcm_256)))
 421                        rc = -EFAULT;
 422                break;
 423        }
 424        default:
 425                rc = -EINVAL;
 426        }
 427
 428out:
 429        return rc;
 430}
 431
 432static int do_tls_getsockopt(struct sock *sk, int optname,
 433                             char __user *optval, int __user *optlen)
 434{
 435        int rc = 0;
 436
 437        switch (optname) {
 438        case TLS_TX:
 439        case TLS_RX:
 440                rc = do_tls_getsockopt_conf(sk, optval, optlen,
 441                                            optname == TLS_TX);
 442                break;
 443        default:
 444                rc = -ENOPROTOOPT;
 445                break;
 446        }
 447        return rc;
 448}
 449
 450static int tls_getsockopt(struct sock *sk, int level, int optname,
 451                          char __user *optval, int __user *optlen)
 452{
 453        struct tls_context *ctx = tls_get_ctx(sk);
 454
 455        if (level != SOL_TLS)
 456                return ctx->sk_proto->getsockopt(sk, level,
 457                                                 optname, optval, optlen);
 458
 459        return do_tls_getsockopt(sk, optname, optval, optlen);
 460}
 461
 462static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
 463                                  unsigned int optlen, int tx)
 464{
 465        struct tls_crypto_info *crypto_info;
 466        struct tls_crypto_info *alt_crypto_info;
 467        struct tls_context *ctx = tls_get_ctx(sk);
 468        size_t optsize;
 469        int rc = 0;
 470        int conf;
 471
 472        if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) {
 473                rc = -EINVAL;
 474                goto out;
 475        }
 476
 477        if (tx) {
 478                crypto_info = &ctx->crypto_send.info;
 479                alt_crypto_info = &ctx->crypto_recv.info;
 480        } else {
 481                crypto_info = &ctx->crypto_recv.info;
 482                alt_crypto_info = &ctx->crypto_send.info;
 483        }
 484
 485        /* Currently we don't support set crypto info more than one time */
 486        if (TLS_CRYPTO_INFO_READY(crypto_info)) {
 487                rc = -EBUSY;
 488                goto out;
 489        }
 490
 491        rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
 492        if (rc) {
 493                rc = -EFAULT;
 494                goto err_crypto_info;
 495        }
 496
 497        /* check version */
 498        if (crypto_info->version != TLS_1_2_VERSION &&
 499            crypto_info->version != TLS_1_3_VERSION) {
 500                rc = -EINVAL;
 501                goto err_crypto_info;
 502        }
 503
 504        /* Ensure that TLS version and ciphers are same in both directions */
 505        if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
 506                if (alt_crypto_info->version != crypto_info->version ||
 507                    alt_crypto_info->cipher_type != crypto_info->cipher_type) {
 508                        rc = -EINVAL;
 509                        goto err_crypto_info;
 510                }
 511        }
 512
 513        switch (crypto_info->cipher_type) {
 514        case TLS_CIPHER_AES_GCM_128:
 515                optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
 516                break;
 517        case TLS_CIPHER_AES_GCM_256: {
 518                optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
 519                break;
 520        }
 521        case TLS_CIPHER_AES_CCM_128:
 522                optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
 523                break;
 524        case TLS_CIPHER_CHACHA20_POLY1305:
 525                optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305);
 526                break;
 527        default:
 528                rc = -EINVAL;
 529                goto err_crypto_info;
 530        }
 531
 532        if (optlen != optsize) {
 533                rc = -EINVAL;
 534                goto err_crypto_info;
 535        }
 536
 537        rc = copy_from_sockptr_offset(crypto_info + 1, optval,
 538                                      sizeof(*crypto_info),
 539                                      optlen - sizeof(*crypto_info));
 540        if (rc) {
 541                rc = -EFAULT;
 542                goto err_crypto_info;
 543        }
 544
 545        if (tx) {
 546                rc = tls_set_device_offload(sk, ctx);
 547                conf = TLS_HW;
 548                if (!rc) {
 549                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
 550                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
 551                } else {
 552                        rc = tls_set_sw_offload(sk, ctx, 1);
 553                        if (rc)
 554                                goto err_crypto_info;
 555                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
 556                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
 557                        conf = TLS_SW;
 558                }
 559        } else {
 560                rc = tls_set_device_offload_rx(sk, ctx);
 561                conf = TLS_HW;
 562                if (!rc) {
 563                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
 564                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
 565                } else {
 566                        rc = tls_set_sw_offload(sk, ctx, 0);
 567                        if (rc)
 568                                goto err_crypto_info;
 569                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
 570                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
 571                        conf = TLS_SW;
 572                }
 573                tls_sw_strparser_arm(sk, ctx);
 574        }
 575
 576        if (tx)
 577                ctx->tx_conf = conf;
 578        else
 579                ctx->rx_conf = conf;
 580        update_sk_prot(sk, ctx);
 581        if (tx) {
 582                ctx->sk_write_space = sk->sk_write_space;
 583                sk->sk_write_space = tls_write_space;
 584        } else {
 585                sk->sk_socket->ops = &tls_sw_proto_ops;
 586        }
 587        goto out;
 588
 589err_crypto_info:
 590        memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
 591out:
 592        return rc;
 593}
 594
 595static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
 596                             unsigned int optlen)
 597{
 598        int rc = 0;
 599
 600        switch (optname) {
 601        case TLS_TX:
 602        case TLS_RX:
 603                lock_sock(sk);
 604                rc = do_tls_setsockopt_conf(sk, optval, optlen,
 605                                            optname == TLS_TX);
 606                release_sock(sk);
 607                break;
 608        default:
 609                rc = -ENOPROTOOPT;
 610                break;
 611        }
 612        return rc;
 613}
 614
 615static int tls_setsockopt(struct sock *sk, int level, int optname,
 616                          sockptr_t optval, unsigned int optlen)
 617{
 618        struct tls_context *ctx = tls_get_ctx(sk);
 619
 620        if (level != SOL_TLS)
 621                return ctx->sk_proto->setsockopt(sk, level, optname, optval,
 622                                                 optlen);
 623
 624        return do_tls_setsockopt(sk, optname, optval, optlen);
 625}
 626
 627struct tls_context *tls_ctx_create(struct sock *sk)
 628{
 629        struct inet_connection_sock *icsk = inet_csk(sk);
 630        struct tls_context *ctx;
 631
 632        ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
 633        if (!ctx)
 634                return NULL;
 635
 636        mutex_init(&ctx->tx_lock);
 637        rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
 638        ctx->sk_proto = READ_ONCE(sk->sk_prot);
 639        ctx->sk = sk;
 640        return ctx;
 641}
 642
 643static void tls_build_proto(struct sock *sk)
 644{
 645        int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
 646        struct proto *prot = READ_ONCE(sk->sk_prot);
 647
 648        /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
 649        if (ip_ver == TLSV6 &&
 650            unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
 651                mutex_lock(&tcpv6_prot_mutex);
 652                if (likely(prot != saved_tcpv6_prot)) {
 653                        build_protos(tls_prots[TLSV6], prot);
 654                        smp_store_release(&saved_tcpv6_prot, prot);
 655                }
 656                mutex_unlock(&tcpv6_prot_mutex);
 657        }
 658
 659        if (ip_ver == TLSV4 &&
 660            unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
 661                mutex_lock(&tcpv4_prot_mutex);
 662                if (likely(prot != saved_tcpv4_prot)) {
 663                        build_protos(tls_prots[TLSV4], prot);
 664                        smp_store_release(&saved_tcpv4_prot, prot);
 665                }
 666                mutex_unlock(&tcpv4_prot_mutex);
 667        }
 668}
 669
 670static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
 671                         const struct proto *base)
 672{
 673        prot[TLS_BASE][TLS_BASE] = *base;
 674        prot[TLS_BASE][TLS_BASE].setsockopt     = tls_setsockopt;
 675        prot[TLS_BASE][TLS_BASE].getsockopt     = tls_getsockopt;
 676        prot[TLS_BASE][TLS_BASE].close          = tls_sk_proto_close;
 677
 678        prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
 679        prot[TLS_SW][TLS_BASE].sendmsg          = tls_sw_sendmsg;
 680        prot[TLS_SW][TLS_BASE].sendpage         = tls_sw_sendpage;
 681
 682        prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
 683        prot[TLS_BASE][TLS_SW].recvmsg            = tls_sw_recvmsg;
 684        prot[TLS_BASE][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
 685        prot[TLS_BASE][TLS_SW].close              = tls_sk_proto_close;
 686
 687        prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
 688        prot[TLS_SW][TLS_SW].recvmsg            = tls_sw_recvmsg;
 689        prot[TLS_SW][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
 690        prot[TLS_SW][TLS_SW].close              = tls_sk_proto_close;
 691
 692#ifdef CONFIG_TLS_DEVICE
 693        prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
 694        prot[TLS_HW][TLS_BASE].sendmsg          = tls_device_sendmsg;
 695        prot[TLS_HW][TLS_BASE].sendpage         = tls_device_sendpage;
 696
 697        prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
 698        prot[TLS_HW][TLS_SW].sendmsg            = tls_device_sendmsg;
 699        prot[TLS_HW][TLS_SW].sendpage           = tls_device_sendpage;
 700
 701        prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
 702
 703        prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
 704
 705        prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
 706#endif
 707#ifdef CONFIG_TLS_TOE
 708        prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
 709        prot[TLS_HW_RECORD][TLS_HW_RECORD].hash         = tls_toe_hash;
 710        prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash       = tls_toe_unhash;
 711#endif
 712}
 713
 714static int tls_init(struct sock *sk)
 715{
 716        struct tls_context *ctx;
 717        int rc = 0;
 718
 719        tls_build_proto(sk);
 720
 721#ifdef CONFIG_TLS_TOE
 722        if (tls_toe_bypass(sk))
 723                return 0;
 724#endif
 725
 726        /* The TLS ulp is currently supported only for TCP sockets
 727         * in ESTABLISHED state.
 728         * Supporting sockets in LISTEN state will require us
 729         * to modify the accept implementation to clone rather then
 730         * share the ulp context.
 731         */
 732        if (sk->sk_state != TCP_ESTABLISHED)
 733                return -ENOTCONN;
 734
 735        /* allocate tls context */
 736        write_lock_bh(&sk->sk_callback_lock);
 737        ctx = tls_ctx_create(sk);
 738        if (!ctx) {
 739                rc = -ENOMEM;
 740                goto out;
 741        }
 742
 743        ctx->tx_conf = TLS_BASE;
 744        ctx->rx_conf = TLS_BASE;
 745        update_sk_prot(sk, ctx);
 746out:
 747        write_unlock_bh(&sk->sk_callback_lock);
 748        return rc;
 749}
 750
 751static void tls_update(struct sock *sk, struct proto *p,
 752                       void (*write_space)(struct sock *sk))
 753{
 754        struct tls_context *ctx;
 755
 756        ctx = tls_get_ctx(sk);
 757        if (likely(ctx)) {
 758                ctx->sk_write_space = write_space;
 759                ctx->sk_proto = p;
 760        } else {
 761                /* Pairs with lockless read in sk_clone_lock(). */
 762                WRITE_ONCE(sk->sk_prot, p);
 763                sk->sk_write_space = write_space;
 764        }
 765}
 766
 767static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
 768{
 769        u16 version, cipher_type;
 770        struct tls_context *ctx;
 771        struct nlattr *start;
 772        int err;
 773
 774        start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
 775        if (!start)
 776                return -EMSGSIZE;
 777
 778        rcu_read_lock();
 779        ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
 780        if (!ctx) {
 781                err = 0;
 782                goto nla_failure;
 783        }
 784        version = ctx->prot_info.version;
 785        if (version) {
 786                err = nla_put_u16(skb, TLS_INFO_VERSION, version);
 787                if (err)
 788                        goto nla_failure;
 789        }
 790        cipher_type = ctx->prot_info.cipher_type;
 791        if (cipher_type) {
 792                err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
 793                if (err)
 794                        goto nla_failure;
 795        }
 796        err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
 797        if (err)
 798                goto nla_failure;
 799
 800        err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
 801        if (err)
 802                goto nla_failure;
 803
 804        rcu_read_unlock();
 805        nla_nest_end(skb, start);
 806        return 0;
 807
 808nla_failure:
 809        rcu_read_unlock();
 810        nla_nest_cancel(skb, start);
 811        return err;
 812}
 813
 814static size_t tls_get_info_size(const struct sock *sk)
 815{
 816        size_t size = 0;
 817
 818        size += nla_total_size(0) +             /* INET_ULP_INFO_TLS */
 819                nla_total_size(sizeof(u16)) +   /* TLS_INFO_VERSION */
 820                nla_total_size(sizeof(u16)) +   /* TLS_INFO_CIPHER */
 821                nla_total_size(sizeof(u16)) +   /* TLS_INFO_RXCONF */
 822                nla_total_size(sizeof(u16)) +   /* TLS_INFO_TXCONF */
 823                0;
 824
 825        return size;
 826}
 827
 828static int __net_init tls_init_net(struct net *net)
 829{
 830        int err;
 831
 832        net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
 833        if (!net->mib.tls_statistics)
 834                return -ENOMEM;
 835
 836        err = tls_proc_init(net);
 837        if (err)
 838                goto err_free_stats;
 839
 840        return 0;
 841err_free_stats:
 842        free_percpu(net->mib.tls_statistics);
 843        return err;
 844}
 845
 846static void __net_exit tls_exit_net(struct net *net)
 847{
 848        tls_proc_fini(net);
 849        free_percpu(net->mib.tls_statistics);
 850}
 851
 852static struct pernet_operations tls_proc_ops = {
 853        .init = tls_init_net,
 854        .exit = tls_exit_net,
 855};
 856
 857static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
 858        .name                   = "tls",
 859        .owner                  = THIS_MODULE,
 860        .init                   = tls_init,
 861        .update                 = tls_update,
 862        .get_info               = tls_get_info,
 863        .get_info_size          = tls_get_info_size,
 864};
 865
 866static int __init tls_register(void)
 867{
 868        int err;
 869
 870        err = register_pernet_subsys(&tls_proc_ops);
 871        if (err)
 872                return err;
 873
 874        tls_sw_proto_ops = inet_stream_ops;
 875        tls_sw_proto_ops.splice_read = tls_sw_splice_read;
 876        tls_sw_proto_ops.sendpage_locked   = tls_sw_sendpage_locked;
 877
 878        tls_device_init();
 879        tcp_register_ulp(&tcp_tls_ulp_ops);
 880
 881        return 0;
 882}
 883
 884static void __exit tls_unregister(void)
 885{
 886        tcp_unregister_ulp(&tcp_tls_ulp_ops);
 887        tls_device_cleanup();
 888        unregister_pernet_subsys(&tls_proc_ops);
 889}
 890
 891module_init(tls_register);
 892module_exit(tls_unregister);
 893