linux/net/tls/tls_device.c
<<
>>
Prefs
   1/* Copyright (c) 2018, Mellanox Technologies All rights reserved.
   2 *
   3 * This software is available to you under a choice of one of two
   4 * licenses.  You may choose to be licensed under the terms of the GNU
   5 * General Public License (GPL) Version 2, available from the file
   6 * COPYING in the main directory of this source tree, or the
   7 * OpenIB.org BSD license below:
   8 *
   9 *     Redistribution and use in source and binary forms, with or
  10 *     without modification, are permitted provided that the following
  11 *     conditions are met:
  12 *
  13 *      - Redistributions of source code must retain the above
  14 *        copyright notice, this list of conditions and the following
  15 *        disclaimer.
  16 *
  17 *      - Redistributions in binary form must reproduce the above
  18 *        copyright notice, this list of conditions and the following
  19 *        disclaimer in the documentation and/or other materials
  20 *        provided with the distribution.
  21 *
  22 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  23 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  24 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  25 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  26 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  27 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  28 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  29 * SOFTWARE.
  30 */
  31
  32#include <crypto/aead.h>
  33#include <linux/highmem.h>
  34#include <linux/module.h>
  35#include <linux/netdevice.h>
  36#include <net/dst.h>
  37#include <net/inet_connection_sock.h>
  38#include <net/tcp.h>
  39#include <net/tls.h>
  40
  41/* device_offload_lock is used to synchronize tls_dev_add
  42 * against NETDEV_DOWN notifications.
  43 */
  44static DECLARE_RWSEM(device_offload_lock);
  45
  46static void tls_device_gc_task(struct work_struct *work);
  47
  48static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
  49static LIST_HEAD(tls_device_gc_list);
  50static LIST_HEAD(tls_device_list);
  51static DEFINE_SPINLOCK(tls_device_lock);
  52
  53static void tls_device_free_ctx(struct tls_context *ctx)
  54{
  55        struct tls_offload_context *offload_ctx = tls_offload_ctx(ctx);
  56
  57        kfree(offload_ctx);
  58        kfree(ctx);
  59}
  60
  61static void tls_device_gc_task(struct work_struct *work)
  62{
  63        struct tls_context *ctx, *tmp;
  64        unsigned long flags;
  65        LIST_HEAD(gc_list);
  66
  67        spin_lock_irqsave(&tls_device_lock, flags);
  68        list_splice_init(&tls_device_gc_list, &gc_list);
  69        spin_unlock_irqrestore(&tls_device_lock, flags);
  70
  71        list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
  72                struct net_device *netdev = ctx->netdev;
  73
  74                if (netdev) {
  75                        netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
  76                                                        TLS_OFFLOAD_CTX_DIR_TX);
  77                        dev_put(netdev);
  78                }
  79
  80                list_del(&ctx->list);
  81                tls_device_free_ctx(ctx);
  82        }
  83}
  84
  85static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
  86{
  87        unsigned long flags;
  88
  89        spin_lock_irqsave(&tls_device_lock, flags);
  90        list_move_tail(&ctx->list, &tls_device_gc_list);
  91
  92        /* schedule_work inside the spinlock
  93         * to make sure tls_device_down waits for that work.
  94         */
  95        schedule_work(&tls_device_gc_work);
  96
  97        spin_unlock_irqrestore(&tls_device_lock, flags);
  98}
  99
 100/* We assume that the socket is already connected */
 101static struct net_device *get_netdev_for_sock(struct sock *sk)
 102{
 103        struct dst_entry *dst = sk_dst_get(sk);
 104        struct net_device *netdev = NULL;
 105
 106        if (likely(dst)) {
 107                netdev = dst->dev;
 108                dev_hold(netdev);
 109        }
 110
 111        dst_release(dst);
 112
 113        return netdev;
 114}
 115
 116static void destroy_record(struct tls_record_info *record)
 117{
 118        int nr_frags = record->num_frags;
 119        skb_frag_t *frag;
 120
 121        while (nr_frags-- > 0) {
 122                frag = &record->frags[nr_frags];
 123                __skb_frag_unref(frag);
 124        }
 125        kfree(record);
 126}
 127
 128static void delete_all_records(struct tls_offload_context *offload_ctx)
 129{
 130        struct tls_record_info *info, *temp;
 131
 132        list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
 133                list_del(&info->list);
 134                destroy_record(info);
 135        }
 136
 137        offload_ctx->retransmit_hint = NULL;
 138}
 139
 140static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
 141{
 142        struct tls_context *tls_ctx = tls_get_ctx(sk);
 143        struct tls_record_info *info, *temp;
 144        struct tls_offload_context *ctx;
 145        u64 deleted_records = 0;
 146        unsigned long flags;
 147
 148        if (!tls_ctx)
 149                return;
 150
 151        ctx = tls_offload_ctx(tls_ctx);
 152
 153        spin_lock_irqsave(&ctx->lock, flags);
 154        info = ctx->retransmit_hint;
 155        if (info && !before(acked_seq, info->end_seq)) {
 156                ctx->retransmit_hint = NULL;
 157                list_del(&info->list);
 158                destroy_record(info);
 159                deleted_records++;
 160        }
 161
 162        list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
 163                if (before(acked_seq, info->end_seq))
 164                        break;
 165                list_del(&info->list);
 166
 167                destroy_record(info);
 168                deleted_records++;
 169        }
 170
 171        ctx->unacked_record_sn += deleted_records;
 172        spin_unlock_irqrestore(&ctx->lock, flags);
 173}
 174
 175/* At this point, there should be no references on this
 176 * socket and no in-flight SKBs associated with this
 177 * socket, so it is safe to free all the resources.
 178 */
 179void tls_device_sk_destruct(struct sock *sk)
 180{
 181        struct tls_context *tls_ctx = tls_get_ctx(sk);
 182        struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
 183
 184        if (ctx->open_record)
 185                destroy_record(ctx->open_record);
 186
 187        delete_all_records(ctx);
 188        crypto_free_aead(ctx->aead_send);
 189        ctx->sk_destruct(sk);
 190        clean_acked_data_disable(inet_csk(sk));
 191
 192        if (refcount_dec_and_test(&tls_ctx->refcount))
 193                tls_device_queue_ctx_destruction(tls_ctx);
 194}
 195EXPORT_SYMBOL(tls_device_sk_destruct);
 196
 197static void tls_append_frag(struct tls_record_info *record,
 198                            struct page_frag *pfrag,
 199                            int size)
 200{
 201        skb_frag_t *frag;
 202
 203        frag = &record->frags[record->num_frags - 1];
 204        if (frag->page.p == pfrag->page &&
 205            frag->page_offset + frag->size == pfrag->offset) {
 206                frag->size += size;
 207        } else {
 208                ++frag;
 209                frag->page.p = pfrag->page;
 210                frag->page_offset = pfrag->offset;
 211                frag->size = size;
 212                ++record->num_frags;
 213                get_page(pfrag->page);
 214        }
 215
 216        pfrag->offset += size;
 217        record->len += size;
 218}
 219
 220static int tls_push_record(struct sock *sk,
 221                           struct tls_context *ctx,
 222                           struct tls_offload_context *offload_ctx,
 223                           struct tls_record_info *record,
 224                           struct page_frag *pfrag,
 225                           int flags,
 226                           unsigned char record_type)
 227{
 228        struct tcp_sock *tp = tcp_sk(sk);
 229        struct page_frag dummy_tag_frag;
 230        skb_frag_t *frag;
 231        int i;
 232
 233        /* fill prepend */
 234        frag = &record->frags[0];
 235        tls_fill_prepend(ctx,
 236                         skb_frag_address(frag),
 237                         record->len - ctx->tx.prepend_size,
 238                         record_type);
 239
 240        /* HW doesn't care about the data in the tag, because it fills it. */
 241        dummy_tag_frag.page = skb_frag_page(frag);
 242        dummy_tag_frag.offset = 0;
 243
 244        tls_append_frag(record, &dummy_tag_frag, ctx->tx.tag_size);
 245        record->end_seq = tp->write_seq + record->len;
 246        spin_lock_irq(&offload_ctx->lock);
 247        list_add_tail(&record->list, &offload_ctx->records_list);
 248        spin_unlock_irq(&offload_ctx->lock);
 249        offload_ctx->open_record = NULL;
 250        set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
 251        tls_advance_record_sn(sk, &ctx->tx);
 252
 253        for (i = 0; i < record->num_frags; i++) {
 254                frag = &record->frags[i];
 255                sg_unmark_end(&offload_ctx->sg_tx_data[i]);
 256                sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
 257                            frag->size, frag->page_offset);
 258                sk_mem_charge(sk, frag->size);
 259                get_page(skb_frag_page(frag));
 260        }
 261        sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
 262
 263        /* all ready, send */
 264        return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
 265}
 266
 267static int tls_create_new_record(struct tls_offload_context *offload_ctx,
 268                                 struct page_frag *pfrag,
 269                                 size_t prepend_size)
 270{
 271        struct tls_record_info *record;
 272        skb_frag_t *frag;
 273
 274        record = kmalloc(sizeof(*record), GFP_KERNEL);
 275        if (!record)
 276                return -ENOMEM;
 277
 278        frag = &record->frags[0];
 279        __skb_frag_set_page(frag, pfrag->page);
 280        frag->page_offset = pfrag->offset;
 281        skb_frag_size_set(frag, prepend_size);
 282
 283        get_page(pfrag->page);
 284        pfrag->offset += prepend_size;
 285
 286        record->num_frags = 1;
 287        record->len = prepend_size;
 288        offload_ctx->open_record = record;
 289        return 0;
 290}
 291
 292static int tls_do_allocation(struct sock *sk,
 293                             struct tls_offload_context *offload_ctx,
 294                             struct page_frag *pfrag,
 295                             size_t prepend_size)
 296{
 297        int ret;
 298
 299        if (!offload_ctx->open_record) {
 300                if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
 301                                                   sk->sk_allocation))) {
 302                        sk->sk_prot->enter_memory_pressure(sk);
 303                        sk_stream_moderate_sndbuf(sk);
 304                        return -ENOMEM;
 305                }
 306
 307                ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
 308                if (ret)
 309                        return ret;
 310
 311                if (pfrag->size > pfrag->offset)
 312                        return 0;
 313        }
 314
 315        if (!sk_page_frag_refill(sk, pfrag))
 316                return -ENOMEM;
 317
 318        return 0;
 319}
 320
 321static int tls_push_data(struct sock *sk,
 322                         struct iov_iter *msg_iter,
 323                         size_t size, int flags,
 324                         unsigned char record_type)
 325{
 326        struct tls_context *tls_ctx = tls_get_ctx(sk);
 327        struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx);
 328        int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
 329        int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
 330        struct tls_record_info *record = ctx->open_record;
 331        struct page_frag *pfrag;
 332        size_t orig_size = size;
 333        u32 max_open_record_len;
 334        int copy, rc = 0;
 335        bool done = false;
 336        long timeo;
 337
 338        if (flags &
 339            ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
 340                return -ENOTSUPP;
 341
 342        if (sk->sk_err)
 343                return -sk->sk_err;
 344
 345        timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
 346        rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
 347        if (rc < 0)
 348                return rc;
 349
 350        pfrag = sk_page_frag(sk);
 351
 352        /* TLS_HEADER_SIZE is not counted as part of the TLS record, and
 353         * we need to leave room for an authentication tag.
 354         */
 355        max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
 356                              tls_ctx->tx.prepend_size;
 357        do {
 358                rc = tls_do_allocation(sk, ctx, pfrag,
 359                                       tls_ctx->tx.prepend_size);
 360                if (rc) {
 361                        rc = sk_stream_wait_memory(sk, &timeo);
 362                        if (!rc)
 363                                continue;
 364
 365                        record = ctx->open_record;
 366                        if (!record)
 367                                break;
 368handle_error:
 369                        if (record_type != TLS_RECORD_TYPE_DATA) {
 370                                /* avoid sending partial
 371                                 * record with type !=
 372                                 * application_data
 373                                 */
 374                                size = orig_size;
 375                                destroy_record(record);
 376                                ctx->open_record = NULL;
 377                        } else if (record->len > tls_ctx->tx.prepend_size) {
 378                                goto last_record;
 379                        }
 380
 381                        break;
 382                }
 383
 384                record = ctx->open_record;
 385                copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
 386                copy = min_t(size_t, copy, (max_open_record_len - record->len));
 387
 388                if (copy_from_iter_nocache(page_address(pfrag->page) +
 389                                               pfrag->offset,
 390                                           copy, msg_iter) != copy) {
 391                        rc = -EFAULT;
 392                        goto handle_error;
 393                }
 394                tls_append_frag(record, pfrag, copy);
 395
 396                size -= copy;
 397                if (!size) {
 398last_record:
 399                        tls_push_record_flags = flags;
 400                        if (more) {
 401                                tls_ctx->pending_open_record_frags =
 402                                                record->num_frags;
 403                                break;
 404                        }
 405
 406                        done = true;
 407                }
 408
 409                if (done || record->len >= max_open_record_len ||
 410                    (record->num_frags >= MAX_SKB_FRAGS - 1)) {
 411                        rc = tls_push_record(sk,
 412                                             tls_ctx,
 413                                             ctx,
 414                                             record,
 415                                             pfrag,
 416                                             tls_push_record_flags,
 417                                             record_type);
 418                        if (rc < 0)
 419                                break;
 420                }
 421        } while (!done);
 422
 423        if (orig_size - size > 0)
 424                rc = orig_size - size;
 425
 426        return rc;
 427}
 428
 429int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 430{
 431        unsigned char record_type = TLS_RECORD_TYPE_DATA;
 432        int rc;
 433
 434        lock_sock(sk);
 435
 436        if (unlikely(msg->msg_controllen)) {
 437                rc = tls_proccess_cmsg(sk, msg, &record_type);
 438                if (rc)
 439                        goto out;
 440        }
 441
 442        rc = tls_push_data(sk, &msg->msg_iter, size,
 443                           msg->msg_flags, record_type);
 444
 445out:
 446        release_sock(sk);
 447        return rc;
 448}
 449
 450int tls_device_sendpage(struct sock *sk, struct page *page,
 451                        int offset, size_t size, int flags)
 452{
 453        struct iov_iter msg_iter;
 454        char *kaddr = kmap(page);
 455        struct kvec iov;
 456        int rc;
 457
 458        if (flags & MSG_SENDPAGE_NOTLAST)
 459                flags |= MSG_MORE;
 460
 461        lock_sock(sk);
 462
 463        if (flags & MSG_OOB) {
 464                rc = -ENOTSUPP;
 465                goto out;
 466        }
 467
 468        iov.iov_base = kaddr + offset;
 469        iov.iov_len = size;
 470        iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size);
 471        rc = tls_push_data(sk, &msg_iter, size,
 472                           flags, TLS_RECORD_TYPE_DATA);
 473        kunmap(page);
 474
 475out:
 476        release_sock(sk);
 477        return rc;
 478}
 479
 480struct tls_record_info *tls_get_record(struct tls_offload_context *context,
 481                                       u32 seq, u64 *p_record_sn)
 482{
 483        u64 record_sn = context->hint_record_sn;
 484        struct tls_record_info *info;
 485
 486        info = context->retransmit_hint;
 487        if (!info ||
 488            before(seq, info->end_seq - info->len)) {
 489                /* if retransmit_hint is irrelevant start
 490                 * from the beggining of the list
 491                 */
 492                info = list_first_entry(&context->records_list,
 493                                        struct tls_record_info, list);
 494                record_sn = context->unacked_record_sn;
 495        }
 496
 497        list_for_each_entry_from(info, &context->records_list, list) {
 498                if (before(seq, info->end_seq)) {
 499                        if (!context->retransmit_hint ||
 500                            after(info->end_seq,
 501                                  context->retransmit_hint->end_seq)) {
 502                                context->hint_record_sn = record_sn;
 503                                context->retransmit_hint = info;
 504                        }
 505                        *p_record_sn = record_sn;
 506                        return info;
 507                }
 508                record_sn++;
 509        }
 510
 511        return NULL;
 512}
 513EXPORT_SYMBOL(tls_get_record);
 514
 515static int tls_device_push_pending_record(struct sock *sk, int flags)
 516{
 517        struct iov_iter msg_iter;
 518
 519        iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0);
 520        return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
 521}
 522
 523int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 524{
 525        u16 nonce_size, tag_size, iv_size, rec_seq_size;
 526        struct tls_record_info *start_marker_record;
 527        struct tls_offload_context *offload_ctx;
 528        struct tls_crypto_info *crypto_info;
 529        struct net_device *netdev;
 530        char *iv, *rec_seq;
 531        struct sk_buff *skb;
 532        int rc = -EINVAL;
 533        __be64 rcd_sn;
 534
 535        if (!ctx)
 536                goto out;
 537
 538        if (ctx->priv_ctx_tx) {
 539                rc = -EEXIST;
 540                goto out;
 541        }
 542
 543        start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
 544        if (!start_marker_record) {
 545                rc = -ENOMEM;
 546                goto out;
 547        }
 548
 549        offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE, GFP_KERNEL);
 550        if (!offload_ctx) {
 551                rc = -ENOMEM;
 552                goto free_marker_record;
 553        }
 554
 555        crypto_info = &ctx->crypto_send;
 556        switch (crypto_info->cipher_type) {
 557        case TLS_CIPHER_AES_GCM_128:
 558                nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
 559                tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
 560                iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
 561                iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
 562                rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
 563                rec_seq =
 564                 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
 565                break;
 566        default:
 567                rc = -EINVAL;
 568                goto free_offload_ctx;
 569        }
 570
 571        ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size;
 572        ctx->tx.tag_size = tag_size;
 573        ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size;
 574        ctx->tx.iv_size = iv_size;
 575        ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
 576                             GFP_KERNEL);
 577        if (!ctx->tx.iv) {
 578                rc = -ENOMEM;
 579                goto free_offload_ctx;
 580        }
 581
 582        memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
 583
 584        ctx->tx.rec_seq_size = rec_seq_size;
 585        ctx->tx.rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
 586        if (!ctx->tx.rec_seq) {
 587                rc = -ENOMEM;
 588                goto free_iv;
 589        }
 590        memcpy(ctx->tx.rec_seq, rec_seq, rec_seq_size);
 591
 592        rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
 593        if (rc)
 594                goto free_rec_seq;
 595
 596        /* start at rec_seq - 1 to account for the start marker record */
 597        memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
 598        offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
 599
 600        start_marker_record->end_seq = tcp_sk(sk)->write_seq;
 601        start_marker_record->len = 0;
 602        start_marker_record->num_frags = 0;
 603
 604        INIT_LIST_HEAD(&offload_ctx->records_list);
 605        list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
 606        spin_lock_init(&offload_ctx->lock);
 607        sg_init_table(offload_ctx->sg_tx_data,
 608                      ARRAY_SIZE(offload_ctx->sg_tx_data));
 609
 610        clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
 611        ctx->push_pending_record = tls_device_push_pending_record;
 612        offload_ctx->sk_destruct = sk->sk_destruct;
 613
 614        /* TLS offload is greatly simplified if we don't send
 615         * SKBs where only part of the payload needs to be encrypted.
 616         * So mark the last skb in the write queue as end of record.
 617         */
 618        skb = tcp_write_queue_tail(sk);
 619        if (skb)
 620                TCP_SKB_CB(skb)->eor = 1;
 621
 622        refcount_set(&ctx->refcount, 1);
 623
 624        /* We support starting offload on multiple sockets
 625         * concurrently, so we only need a read lock here.
 626         * This lock must precede get_netdev_for_sock to prevent races between
 627         * NETDEV_DOWN and setsockopt.
 628         */
 629        down_read(&device_offload_lock);
 630        netdev = get_netdev_for_sock(sk);
 631        if (!netdev) {
 632                pr_err_ratelimited("%s: netdev not found\n", __func__);
 633                rc = -EINVAL;
 634                goto release_lock;
 635        }
 636
 637        if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
 638                rc = -ENOTSUPP;
 639                goto release_netdev;
 640        }
 641
 642        /* Avoid offloading if the device is down
 643         * We don't want to offload new flows after
 644         * the NETDEV_DOWN event
 645         */
 646        if (!(netdev->flags & IFF_UP)) {
 647                rc = -EINVAL;
 648                goto release_netdev;
 649        }
 650
 651        ctx->priv_ctx_tx = offload_ctx;
 652        rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
 653                                             &ctx->crypto_send,
 654                                             tcp_sk(sk)->write_seq);
 655        if (rc)
 656                goto release_netdev;
 657
 658        ctx->netdev = netdev;
 659
 660        spin_lock_irq(&tls_device_lock);
 661        list_add_tail(&ctx->list, &tls_device_list);
 662        spin_unlock_irq(&tls_device_lock);
 663
 664        sk->sk_validate_xmit_skb = tls_validate_xmit_skb;
 665        /* following this assignment tls_is_sk_tx_device_offloaded
 666         * will return true and the context might be accessed
 667         * by the netdev's xmit function.
 668         */
 669        smp_store_release(&sk->sk_destruct,
 670                          &tls_device_sk_destruct);
 671        up_read(&device_offload_lock);
 672        goto out;
 673
 674release_netdev:
 675        dev_put(netdev);
 676release_lock:
 677        up_read(&device_offload_lock);
 678        clean_acked_data_disable(inet_csk(sk));
 679        crypto_free_aead(offload_ctx->aead_send);
 680free_rec_seq:
 681        kfree(ctx->tx.rec_seq);
 682free_iv:
 683        kfree(ctx->tx.iv);
 684free_offload_ctx:
 685        kfree(offload_ctx);
 686        ctx->priv_ctx_tx = NULL;
 687free_marker_record:
 688        kfree(start_marker_record);
 689out:
 690        return rc;
 691}
 692
 693static int tls_device_down(struct net_device *netdev)
 694{
 695        struct tls_context *ctx, *tmp;
 696        unsigned long flags;
 697        LIST_HEAD(list);
 698
 699        /* Request a write lock to block new offload attempts */
 700        down_write(&device_offload_lock);
 701
 702        spin_lock_irqsave(&tls_device_lock, flags);
 703        list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
 704                if (ctx->netdev != netdev ||
 705                    !refcount_inc_not_zero(&ctx->refcount))
 706                        continue;
 707
 708                list_move(&ctx->list, &list);
 709        }
 710        spin_unlock_irqrestore(&tls_device_lock, flags);
 711
 712        list_for_each_entry_safe(ctx, tmp, &list, list) {
 713                netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
 714                                                TLS_OFFLOAD_CTX_DIR_TX);
 715                ctx->netdev = NULL;
 716                dev_put(netdev);
 717                list_del_init(&ctx->list);
 718
 719                if (refcount_dec_and_test(&ctx->refcount))
 720                        tls_device_free_ctx(ctx);
 721        }
 722
 723        up_write(&device_offload_lock);
 724
 725        flush_work(&tls_device_gc_work);
 726
 727        return NOTIFY_DONE;
 728}
 729
 730static int tls_dev_event(struct notifier_block *this, unsigned long event,
 731                         void *ptr)
 732{
 733        struct net_device *dev = netdev_notifier_info_to_dev(ptr);
 734
 735        if (!(dev->features & NETIF_F_HW_TLS_TX))
 736                return NOTIFY_DONE;
 737
 738        switch (event) {
 739        case NETDEV_REGISTER:
 740        case NETDEV_FEAT_CHANGE:
 741                if  (dev->tlsdev_ops &&
 742                     dev->tlsdev_ops->tls_dev_add &&
 743                     dev->tlsdev_ops->tls_dev_del)
 744                        return NOTIFY_DONE;
 745                else
 746                        return NOTIFY_BAD;
 747        case NETDEV_DOWN:
 748                return tls_device_down(dev);
 749        }
 750        return NOTIFY_DONE;
 751}
 752
 753static struct notifier_block tls_dev_notifier = {
 754        .notifier_call  = tls_dev_event,
 755};
 756
 757void __init tls_device_init(void)
 758{
 759        register_netdevice_notifier(&tls_dev_notifier);
 760}
 761
 762void __exit tls_device_cleanup(void)
 763{
 764        unregister_netdevice_notifier(&tls_dev_notifier);
 765        flush_work(&tls_device_gc_work);
 766}
 767