linux/net/tls/tls_main.c
<<
>>
Prefs
   1/*
   2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
   3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
   4 *
   5 * This software is available to you under a choice of one of two
   6 * licenses.  You may choose to be licensed under the terms of the GNU
   7 * General Public License (GPL) Version 2, available from the file
   8 * COPYING in the main directory of this source tree, or the
   9 * OpenIB.org BSD license below:
  10 *
  11 *     Redistribution and use in source and binary forms, with or
  12 *     without modification, are permitted provided that the following
  13 *     conditions are met:
  14 *
  15 *      - Redistributions of source code must retain the above
  16 *        copyright notice, this list of conditions and the following
  17 *        disclaimer.
  18 *
  19 *      - Redistributions in binary form must reproduce the above
  20 *        copyright notice, this list of conditions and the following
  21 *        disclaimer in the documentation and/or other materials
  22 *        provided with the distribution.
  23 *
  24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  31 * SOFTWARE.
  32 */
  33
  34#include <linux/module.h>
  35
  36#include <net/tcp.h>
  37#include <net/inet_common.h>
  38#include <linux/highmem.h>
  39#include <linux/netdevice.h>
  40#include <linux/sched/signal.h>
  41#include <linux/inetdevice.h>
  42
  43#include <net/tls.h>
  44
  45MODULE_AUTHOR("Mellanox Technologies");
  46MODULE_DESCRIPTION("Transport Layer Security Support");
  47MODULE_LICENSE("Dual BSD/GPL");
  48MODULE_ALIAS_TCP_ULP("tls");
  49
  50enum {
  51        TLSV4,
  52        TLSV6,
  53        TLS_NUM_PROTS,
  54};
  55
  56static struct proto *saved_tcpv6_prot;
  57static DEFINE_MUTEX(tcpv6_prot_mutex);
  58static struct proto *saved_tcpv4_prot;
  59static DEFINE_MUTEX(tcpv4_prot_mutex);
  60static LIST_HEAD(device_list);
  61static DEFINE_SPINLOCK(device_spinlock);
  62static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
  63static struct proto_ops tls_sw_proto_ops;
  64static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
  65                         struct proto *base);
  66
  67static void update_sk_prot(struct sock *sk, struct tls_context *ctx)
  68{
  69        int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
  70
  71        sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf];
  72}
  73
  74int wait_on_pending_writer(struct sock *sk, long *timeo)
  75{
  76        int rc = 0;
  77        DEFINE_WAIT_FUNC(wait, woken_wake_function);
  78
  79        add_wait_queue(sk_sleep(sk), &wait);
  80        while (1) {
  81                if (!*timeo) {
  82                        rc = -EAGAIN;
  83                        break;
  84                }
  85
  86                if (signal_pending(current)) {
  87                        rc = sock_intr_errno(*timeo);
  88                        break;
  89                }
  90
  91                if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait))
  92                        break;
  93        }
  94        remove_wait_queue(sk_sleep(sk), &wait);
  95        return rc;
  96}
  97
  98int tls_push_sg(struct sock *sk,
  99                struct tls_context *ctx,
 100                struct scatterlist *sg,
 101                u16 first_offset,
 102                int flags)
 103{
 104        int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
 105        int ret = 0;
 106        struct page *p;
 107        size_t size;
 108        int offset = first_offset;
 109
 110        size = sg->length - offset;
 111        offset += sg->offset;
 112
 113        ctx->in_tcp_sendpages = true;
 114        while (1) {
 115                if (sg_is_last(sg))
 116                        sendpage_flags = flags;
 117
 118                /* is sending application-limited? */
 119                tcp_rate_check_app_limited(sk);
 120                p = sg_page(sg);
 121retry:
 122                ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
 123
 124                if (ret != size) {
 125                        if (ret > 0) {
 126                                offset += ret;
 127                                size -= ret;
 128                                goto retry;
 129                        }
 130
 131                        offset -= sg->offset;
 132                        ctx->partially_sent_offset = offset;
 133                        ctx->partially_sent_record = (void *)sg;
 134                        ctx->in_tcp_sendpages = false;
 135                        return ret;
 136                }
 137
 138                put_page(p);
 139                sk_mem_uncharge(sk, sg->length);
 140                sg = sg_next(sg);
 141                if (!sg)
 142                        break;
 143
 144                offset = sg->offset;
 145                size = sg->length;
 146        }
 147
 148        ctx->in_tcp_sendpages = false;
 149
 150        return 0;
 151}
 152
 153static int tls_handle_open_record(struct sock *sk, int flags)
 154{
 155        struct tls_context *ctx = tls_get_ctx(sk);
 156
 157        if (tls_is_pending_open_record(ctx))
 158                return ctx->push_pending_record(sk, flags);
 159
 160        return 0;
 161}
 162
 163int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
 164                      unsigned char *record_type)
 165{
 166        struct cmsghdr *cmsg;
 167        int rc = -EINVAL;
 168
 169        for_each_cmsghdr(cmsg, msg) {
 170                if (!CMSG_OK(msg, cmsg))
 171                        return -EINVAL;
 172                if (cmsg->cmsg_level != SOL_TLS)
 173                        continue;
 174
 175                switch (cmsg->cmsg_type) {
 176                case TLS_SET_RECORD_TYPE:
 177                        if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
 178                                return -EINVAL;
 179
 180                        if (msg->msg_flags & MSG_MORE)
 181                                return -EINVAL;
 182
 183                        rc = tls_handle_open_record(sk, msg->msg_flags);
 184                        if (rc)
 185                                return rc;
 186
 187                        *record_type = *(unsigned char *)CMSG_DATA(cmsg);
 188                        rc = 0;
 189                        break;
 190                default:
 191                        return -EINVAL;
 192                }
 193        }
 194
 195        return rc;
 196}
 197
 198int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
 199                            int flags)
 200{
 201        struct scatterlist *sg;
 202        u16 offset;
 203
 204        sg = ctx->partially_sent_record;
 205        offset = ctx->partially_sent_offset;
 206
 207        ctx->partially_sent_record = NULL;
 208        return tls_push_sg(sk, ctx, sg, offset, flags);
 209}
 210
 211bool tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
 212{
 213        struct scatterlist *sg;
 214
 215        sg = ctx->partially_sent_record;
 216        if (!sg)
 217                return false;
 218
 219        while (1) {
 220                put_page(sg_page(sg));
 221                sk_mem_uncharge(sk, sg->length);
 222
 223                if (sg_is_last(sg))
 224                        break;
 225                sg++;
 226        }
 227        ctx->partially_sent_record = NULL;
 228        return true;
 229}
 230
 231static void tls_write_space(struct sock *sk)
 232{
 233        struct tls_context *ctx = tls_get_ctx(sk);
 234
 235        /* If in_tcp_sendpages call lower protocol write space handler
 236         * to ensure we wake up any waiting operations there. For example
 237         * if do_tcp_sendpages where to call sk_wait_event.
 238         */
 239        if (ctx->in_tcp_sendpages) {
 240                ctx->sk_write_space(sk);
 241                return;
 242        }
 243
 244#ifdef CONFIG_TLS_DEVICE
 245        if (ctx->tx_conf == TLS_HW)
 246                tls_device_write_space(sk, ctx);
 247        else
 248#endif
 249                tls_sw_write_space(sk, ctx);
 250
 251        ctx->sk_write_space(sk);
 252}
 253
 254void tls_ctx_free(struct tls_context *ctx)
 255{
 256        if (!ctx)
 257                return;
 258
 259        memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
 260        memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
 261        kfree(ctx);
 262}
 263
 264static void tls_sk_proto_cleanup(struct sock *sk,
 265                                 struct tls_context *ctx, long timeo)
 266{
 267        if (unlikely(sk->sk_write_pending) &&
 268            !wait_on_pending_writer(sk, &timeo))
 269                tls_handle_open_record(sk, 0);
 270
 271        /* We need these for tls_sw_fallback handling of other packets */
 272        if (ctx->tx_conf == TLS_SW) {
 273                kfree(ctx->tx.rec_seq);
 274                kfree(ctx->tx.iv);
 275                tls_sw_release_resources_tx(sk);
 276#ifdef CONFIG_TLS_DEVICE
 277        } else if (ctx->tx_conf == TLS_HW) {
 278                tls_device_free_resources_tx(sk);
 279#endif
 280        }
 281
 282        if (ctx->rx_conf == TLS_SW)
 283                tls_sw_release_resources_rx(sk);
 284
 285#ifdef CONFIG_TLS_DEVICE
 286        if (ctx->rx_conf == TLS_HW)
 287                tls_device_offload_cleanup_rx(sk);
 288#endif
 289}
 290
 291static void tls_sk_proto_close(struct sock *sk, long timeout)
 292{
 293        struct inet_connection_sock *icsk = inet_csk(sk);
 294        struct tls_context *ctx = tls_get_ctx(sk);
 295        long timeo = sock_sndtimeo(sk, 0);
 296        bool free_ctx;
 297
 298        if (ctx->tx_conf == TLS_SW)
 299                tls_sw_cancel_work_tx(ctx);
 300
 301        lock_sock(sk);
 302        free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
 303
 304        if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
 305                tls_sk_proto_cleanup(sk, ctx, timeo);
 306
 307        write_lock_bh(&sk->sk_callback_lock);
 308        if (free_ctx)
 309                icsk->icsk_ulp_data = NULL;
 310        sk->sk_prot = ctx->sk_proto;
 311        if (sk->sk_write_space == tls_write_space)
 312                sk->sk_write_space = ctx->sk_write_space;
 313        write_unlock_bh(&sk->sk_callback_lock);
 314        release_sock(sk);
 315        if (ctx->tx_conf == TLS_SW)
 316                tls_sw_free_ctx_tx(ctx);
 317        if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
 318                tls_sw_strparser_done(ctx);
 319        if (ctx->rx_conf == TLS_SW)
 320                tls_sw_free_ctx_rx(ctx);
 321        ctx->sk_proto_close(sk, timeout);
 322
 323        if (free_ctx)
 324                tls_ctx_free(ctx);
 325}
 326
 327static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
 328                                int __user *optlen)
 329{
 330        int rc = 0;
 331        struct tls_context *ctx = tls_get_ctx(sk);
 332        struct tls_crypto_info *crypto_info;
 333        int len;
 334
 335        if (get_user(len, optlen))
 336                return -EFAULT;
 337
 338        if (!optval || (len < sizeof(*crypto_info))) {
 339                rc = -EINVAL;
 340                goto out;
 341        }
 342
 343        if (!ctx) {
 344                rc = -EBUSY;
 345                goto out;
 346        }
 347
 348        /* get user crypto info */
 349        crypto_info = &ctx->crypto_send.info;
 350
 351        if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
 352                rc = -EBUSY;
 353                goto out;
 354        }
 355
 356        if (len == sizeof(*crypto_info)) {
 357                if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
 358                        rc = -EFAULT;
 359                goto out;
 360        }
 361
 362        switch (crypto_info->cipher_type) {
 363        case TLS_CIPHER_AES_GCM_128: {
 364                struct tls12_crypto_info_aes_gcm_128 *
 365                  crypto_info_aes_gcm_128 =
 366                  container_of(crypto_info,
 367                               struct tls12_crypto_info_aes_gcm_128,
 368                               info);
 369
 370                if (len != sizeof(*crypto_info_aes_gcm_128)) {
 371                        rc = -EINVAL;
 372                        goto out;
 373                }
 374                lock_sock(sk);
 375                memcpy(crypto_info_aes_gcm_128->iv,
 376                       ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
 377                       TLS_CIPHER_AES_GCM_128_IV_SIZE);
 378                memcpy(crypto_info_aes_gcm_128->rec_seq, ctx->tx.rec_seq,
 379                       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
 380                release_sock(sk);
 381                if (copy_to_user(optval,
 382                                 crypto_info_aes_gcm_128,
 383                                 sizeof(*crypto_info_aes_gcm_128)))
 384                        rc = -EFAULT;
 385                break;
 386        }
 387        case TLS_CIPHER_AES_GCM_256: {
 388                struct tls12_crypto_info_aes_gcm_256 *
 389                  crypto_info_aes_gcm_256 =
 390                  container_of(crypto_info,
 391                               struct tls12_crypto_info_aes_gcm_256,
 392                               info);
 393
 394                if (len != sizeof(*crypto_info_aes_gcm_256)) {
 395                        rc = -EINVAL;
 396                        goto out;
 397                }
 398                lock_sock(sk);
 399                memcpy(crypto_info_aes_gcm_256->iv,
 400                       ctx->tx.iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
 401                       TLS_CIPHER_AES_GCM_256_IV_SIZE);
 402                memcpy(crypto_info_aes_gcm_256->rec_seq, ctx->tx.rec_seq,
 403                       TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
 404                release_sock(sk);
 405                if (copy_to_user(optval,
 406                                 crypto_info_aes_gcm_256,
 407                                 sizeof(*crypto_info_aes_gcm_256)))
 408                        rc = -EFAULT;
 409                break;
 410        }
 411        default:
 412                rc = -EINVAL;
 413        }
 414
 415out:
 416        return rc;
 417}
 418
 419static int do_tls_getsockopt(struct sock *sk, int optname,
 420                             char __user *optval, int __user *optlen)
 421{
 422        int rc = 0;
 423
 424        switch (optname) {
 425        case TLS_TX:
 426                rc = do_tls_getsockopt_tx(sk, optval, optlen);
 427                break;
 428        default:
 429                rc = -ENOPROTOOPT;
 430                break;
 431        }
 432        return rc;
 433}
 434
 435static int tls_getsockopt(struct sock *sk, int level, int optname,
 436                          char __user *optval, int __user *optlen)
 437{
 438        struct tls_context *ctx = tls_get_ctx(sk);
 439
 440        if (level != SOL_TLS)
 441                return ctx->getsockopt(sk, level, optname, optval, optlen);
 442
 443        return do_tls_getsockopt(sk, optname, optval, optlen);
 444}
 445
 446static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
 447                                  unsigned int optlen, int tx)
 448{
 449        struct tls_crypto_info *crypto_info;
 450        struct tls_crypto_info *alt_crypto_info;
 451        struct tls_context *ctx = tls_get_ctx(sk);
 452        size_t optsize;
 453        int rc = 0;
 454        int conf;
 455
 456        if (!optval || (optlen < sizeof(*crypto_info))) {
 457                rc = -EINVAL;
 458                goto out;
 459        }
 460
 461        if (tx) {
 462                crypto_info = &ctx->crypto_send.info;
 463                alt_crypto_info = &ctx->crypto_recv.info;
 464        } else {
 465                crypto_info = &ctx->crypto_recv.info;
 466                alt_crypto_info = &ctx->crypto_send.info;
 467        }
 468
 469        /* Currently we don't support set crypto info more than one time */
 470        if (TLS_CRYPTO_INFO_READY(crypto_info)) {
 471                rc = -EBUSY;
 472                goto out;
 473        }
 474
 475        rc = copy_from_user(crypto_info, optval, sizeof(*crypto_info));
 476        if (rc) {
 477                rc = -EFAULT;
 478                goto err_crypto_info;
 479        }
 480
 481        /* check version */
 482        if (crypto_info->version != TLS_1_2_VERSION &&
 483            crypto_info->version != TLS_1_3_VERSION) {
 484                rc = -ENOTSUPP;
 485                goto err_crypto_info;
 486        }
 487
 488        /* Ensure that TLS version and ciphers are same in both directions */
 489        if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
 490                if (alt_crypto_info->version != crypto_info->version ||
 491                    alt_crypto_info->cipher_type != crypto_info->cipher_type) {
 492                        rc = -EINVAL;
 493                        goto err_crypto_info;
 494                }
 495        }
 496
 497        switch (crypto_info->cipher_type) {
 498        case TLS_CIPHER_AES_GCM_128:
 499                optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
 500                break;
 501        case TLS_CIPHER_AES_GCM_256: {
 502                optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
 503                break;
 504        }
 505        case TLS_CIPHER_AES_CCM_128:
 506                optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
 507                break;
 508        default:
 509                rc = -EINVAL;
 510                goto err_crypto_info;
 511        }
 512
 513        if (optlen != optsize) {
 514                rc = -EINVAL;
 515                goto err_crypto_info;
 516        }
 517
 518        rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info),
 519                            optlen - sizeof(*crypto_info));
 520        if (rc) {
 521                rc = -EFAULT;
 522                goto err_crypto_info;
 523        }
 524
 525        if (tx) {
 526#ifdef CONFIG_TLS_DEVICE
 527                rc = tls_set_device_offload(sk, ctx);
 528                conf = TLS_HW;
 529                if (rc) {
 530#else
 531                {
 532#endif
 533                        rc = tls_set_sw_offload(sk, ctx, 1);
 534                        if (rc)
 535                                goto err_crypto_info;
 536                        conf = TLS_SW;
 537                }
 538        } else {
 539#ifdef CONFIG_TLS_DEVICE
 540                rc = tls_set_device_offload_rx(sk, ctx);
 541                conf = TLS_HW;
 542                if (rc) {
 543#else
 544                {
 545#endif
 546                        rc = tls_set_sw_offload(sk, ctx, 0);
 547                        if (rc)
 548                                goto err_crypto_info;
 549                        conf = TLS_SW;
 550                }
 551                tls_sw_strparser_arm(sk, ctx);
 552        }
 553
 554        if (tx)
 555                ctx->tx_conf = conf;
 556        else
 557                ctx->rx_conf = conf;
 558        update_sk_prot(sk, ctx);
 559        if (tx) {
 560                ctx->sk_write_space = sk->sk_write_space;
 561                sk->sk_write_space = tls_write_space;
 562        } else {
 563                sk->sk_socket->ops = &tls_sw_proto_ops;
 564        }
 565        goto out;
 566
 567err_crypto_info:
 568        memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
 569out:
 570        return rc;
 571}
 572
 573static int do_tls_setsockopt(struct sock *sk, int optname,
 574                             char __user *optval, unsigned int optlen)
 575{
 576        int rc = 0;
 577
 578        switch (optname) {
 579        case TLS_TX:
 580        case TLS_RX:
 581                lock_sock(sk);
 582                rc = do_tls_setsockopt_conf(sk, optval, optlen,
 583                                            optname == TLS_TX);
 584                release_sock(sk);
 585                break;
 586        default:
 587                rc = -ENOPROTOOPT;
 588                break;
 589        }
 590        return rc;
 591}
 592
 593static int tls_setsockopt(struct sock *sk, int level, int optname,
 594                          char __user *optval, unsigned int optlen)
 595{
 596        struct tls_context *ctx = tls_get_ctx(sk);
 597
 598        if (level != SOL_TLS)
 599                return ctx->setsockopt(sk, level, optname, optval, optlen);
 600
 601        return do_tls_setsockopt(sk, optname, optval, optlen);
 602}
 603
 604static struct tls_context *create_ctx(struct sock *sk)
 605{
 606        struct inet_connection_sock *icsk = inet_csk(sk);
 607        struct tls_context *ctx;
 608
 609        ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
 610        if (!ctx)
 611                return NULL;
 612
 613        icsk->icsk_ulp_data = ctx;
 614        ctx->setsockopt = sk->sk_prot->setsockopt;
 615        ctx->getsockopt = sk->sk_prot->getsockopt;
 616        ctx->sk_proto_close = sk->sk_prot->close;
 617        ctx->unhash = sk->sk_prot->unhash;
 618        return ctx;
 619}
 620
 621static void tls_build_proto(struct sock *sk)
 622{
 623        int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
 624
 625        /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
 626        if (ip_ver == TLSV6 &&
 627            unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
 628                mutex_lock(&tcpv6_prot_mutex);
 629                if (likely(sk->sk_prot != saved_tcpv6_prot)) {
 630                        build_protos(tls_prots[TLSV6], sk->sk_prot);
 631                        smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
 632                }
 633                mutex_unlock(&tcpv6_prot_mutex);
 634        }
 635
 636        if (ip_ver == TLSV4 &&
 637            unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) {
 638                mutex_lock(&tcpv4_prot_mutex);
 639                if (likely(sk->sk_prot != saved_tcpv4_prot)) {
 640                        build_protos(tls_prots[TLSV4], sk->sk_prot);
 641                        smp_store_release(&saved_tcpv4_prot, sk->sk_prot);
 642                }
 643                mutex_unlock(&tcpv4_prot_mutex);
 644        }
 645}
 646
 647static void tls_hw_sk_destruct(struct sock *sk)
 648{
 649        struct tls_context *ctx = tls_get_ctx(sk);
 650        struct inet_connection_sock *icsk = inet_csk(sk);
 651
 652        ctx->sk_destruct(sk);
 653        /* Free ctx */
 654        tls_ctx_free(ctx);
 655        icsk->icsk_ulp_data = NULL;
 656}
 657
 658static int tls_hw_prot(struct sock *sk)
 659{
 660        struct tls_context *ctx;
 661        struct tls_device *dev;
 662        int rc = 0;
 663
 664        spin_lock_bh(&device_spinlock);
 665        list_for_each_entry(dev, &device_list, dev_list) {
 666                if (dev->feature && dev->feature(dev)) {
 667                        ctx = create_ctx(sk);
 668                        if (!ctx)
 669                                goto out;
 670
 671                        spin_unlock_bh(&device_spinlock);
 672                        tls_build_proto(sk);
 673                        ctx->hash = sk->sk_prot->hash;
 674                        ctx->unhash = sk->sk_prot->unhash;
 675                        ctx->sk_proto_close = sk->sk_prot->close;
 676                        ctx->sk_destruct = sk->sk_destruct;
 677                        sk->sk_destruct = tls_hw_sk_destruct;
 678                        ctx->rx_conf = TLS_HW_RECORD;
 679                        ctx->tx_conf = TLS_HW_RECORD;
 680                        update_sk_prot(sk, ctx);
 681                        spin_lock_bh(&device_spinlock);
 682                        rc = 1;
 683                        break;
 684                }
 685        }
 686out:
 687        spin_unlock_bh(&device_spinlock);
 688        return rc;
 689}
 690
 691static void tls_hw_unhash(struct sock *sk)
 692{
 693        struct tls_context *ctx = tls_get_ctx(sk);
 694        struct tls_device *dev;
 695
 696        spin_lock_bh(&device_spinlock);
 697        list_for_each_entry(dev, &device_list, dev_list) {
 698                if (dev->unhash) {
 699                        kref_get(&dev->kref);
 700                        spin_unlock_bh(&device_spinlock);
 701                        dev->unhash(dev, sk);
 702                        kref_put(&dev->kref, dev->release);
 703                        spin_lock_bh(&device_spinlock);
 704                }
 705        }
 706        spin_unlock_bh(&device_spinlock);
 707        ctx->unhash(sk);
 708}
 709
 710static int tls_hw_hash(struct sock *sk)
 711{
 712        struct tls_context *ctx = tls_get_ctx(sk);
 713        struct tls_device *dev;
 714        int err;
 715
 716        err = ctx->hash(sk);
 717        spin_lock_bh(&device_spinlock);
 718        list_for_each_entry(dev, &device_list, dev_list) {
 719                if (dev->hash) {
 720                        kref_get(&dev->kref);
 721                        spin_unlock_bh(&device_spinlock);
 722                        err |= dev->hash(dev, sk);
 723                        kref_put(&dev->kref, dev->release);
 724                        spin_lock_bh(&device_spinlock);
 725                }
 726        }
 727        spin_unlock_bh(&device_spinlock);
 728
 729        if (err)
 730                tls_hw_unhash(sk);
 731        return err;
 732}
 733
 734static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
 735                         struct proto *base)
 736{
 737        prot[TLS_BASE][TLS_BASE] = *base;
 738        prot[TLS_BASE][TLS_BASE].setsockopt     = tls_setsockopt;
 739        prot[TLS_BASE][TLS_BASE].getsockopt     = tls_getsockopt;
 740        prot[TLS_BASE][TLS_BASE].close          = tls_sk_proto_close;
 741
 742        prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
 743        prot[TLS_SW][TLS_BASE].sendmsg          = tls_sw_sendmsg;
 744        prot[TLS_SW][TLS_BASE].sendpage         = tls_sw_sendpage;
 745
 746        prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
 747        prot[TLS_BASE][TLS_SW].recvmsg            = tls_sw_recvmsg;
 748        prot[TLS_BASE][TLS_SW].stream_memory_read = tls_sw_stream_read;
 749        prot[TLS_BASE][TLS_SW].close              = tls_sk_proto_close;
 750
 751        prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
 752        prot[TLS_SW][TLS_SW].recvmsg            = tls_sw_recvmsg;
 753        prot[TLS_SW][TLS_SW].stream_memory_read = tls_sw_stream_read;
 754        prot[TLS_SW][TLS_SW].close              = tls_sk_proto_close;
 755
 756#ifdef CONFIG_TLS_DEVICE
 757        prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
 758        prot[TLS_HW][TLS_BASE].sendmsg          = tls_device_sendmsg;
 759        prot[TLS_HW][TLS_BASE].sendpage         = tls_device_sendpage;
 760
 761        prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
 762        prot[TLS_HW][TLS_SW].sendmsg            = tls_device_sendmsg;
 763        prot[TLS_HW][TLS_SW].sendpage           = tls_device_sendpage;
 764
 765        prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
 766
 767        prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
 768
 769        prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
 770#endif
 771
 772        prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
 773        prot[TLS_HW_RECORD][TLS_HW_RECORD].hash         = tls_hw_hash;
 774        prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash       = tls_hw_unhash;
 775}
 776
 777static int tls_init(struct sock *sk)
 778{
 779        struct tls_context *ctx;
 780        int rc = 0;
 781
 782        if (tls_hw_prot(sk))
 783                return 0;
 784
 785        /* The TLS ulp is currently supported only for TCP sockets
 786         * in ESTABLISHED state.
 787         * Supporting sockets in LISTEN state will require us
 788         * to modify the accept implementation to clone rather then
 789         * share the ulp context.
 790         */
 791        if (sk->sk_state != TCP_ESTABLISHED)
 792                return -ENOTSUPP;
 793
 794        tls_build_proto(sk);
 795
 796        /* allocate tls context */
 797        write_lock_bh(&sk->sk_callback_lock);
 798        ctx = create_ctx(sk);
 799        if (!ctx) {
 800                rc = -ENOMEM;
 801                goto out;
 802        }
 803
 804        ctx->tx_conf = TLS_BASE;
 805        ctx->rx_conf = TLS_BASE;
 806        ctx->sk_proto = sk->sk_prot;
 807        update_sk_prot(sk, ctx);
 808out:
 809        write_unlock_bh(&sk->sk_callback_lock);
 810        return rc;
 811}
 812
 813static void tls_update(struct sock *sk, struct proto *p)
 814{
 815        struct tls_context *ctx;
 816
 817        ctx = tls_get_ctx(sk);
 818        if (likely(ctx)) {
 819                ctx->sk_proto_close = p->close;
 820                ctx->sk_proto = p;
 821        } else {
 822                sk->sk_prot = p;
 823        }
 824}
 825
 826void tls_register_device(struct tls_device *device)
 827{
 828        spin_lock_bh(&device_spinlock);
 829        list_add_tail(&device->dev_list, &device_list);
 830        spin_unlock_bh(&device_spinlock);
 831}
 832EXPORT_SYMBOL(tls_register_device);
 833
 834void tls_unregister_device(struct tls_device *device)
 835{
 836        spin_lock_bh(&device_spinlock);
 837        list_del(&device->dev_list);
 838        spin_unlock_bh(&device_spinlock);
 839}
 840EXPORT_SYMBOL(tls_unregister_device);
 841
 842static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
 843        .name                   = "tls",
 844        .owner                  = THIS_MODULE,
 845        .init                   = tls_init,
 846        .update                 = tls_update,
 847};
 848
 849static int __init tls_register(void)
 850{
 851        tls_sw_proto_ops = inet_stream_ops;
 852        tls_sw_proto_ops.splice_read = tls_sw_splice_read;
 853
 854#ifdef CONFIG_TLS_DEVICE
 855        tls_device_init();
 856#endif
 857        tcp_register_ulp(&tcp_tls_ulp_ops);
 858
 859        return 0;
 860}
 861
 862static void __exit tls_unregister(void)
 863{
 864        tcp_unregister_ulp(&tcp_tls_ulp_ops);
 865#ifdef CONFIG_TLS_DEVICE
 866        tls_device_cleanup();
 867#endif
 868}
 869
 870module_init(tls_register);
 871module_exit(tls_unregister);
 872