linux/net/tls/tls_sw.c
<<
>>
Prefs
   1/*
   2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
   3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
   4 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
   5 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
   6 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
   7 *
   8 * This software is available to you under a choice of one of two
   9 * licenses.  You may choose to be licensed under the terms of the GNU
  10 * General Public License (GPL) Version 2, available from the file
  11 * COPYING in the main directory of this source tree, or the
  12 * OpenIB.org BSD license below:
  13 *
  14 *     Redistribution and use in source and binary forms, with or
  15 *     without modification, are permitted provided that the following
  16 *     conditions are met:
  17 *
  18 *      - Redistributions of source code must retain the above
  19 *        copyright notice, this list of conditions and the following
  20 *        disclaimer.
  21 *
  22 *      - Redistributions in binary form must reproduce the above
  23 *        copyright notice, this list of conditions and the following
  24 *        disclaimer in the documentation and/or other materials
  25 *        provided with the distribution.
  26 *
  27 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  28 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  29 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  30 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  31 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  32 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  33 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  34 * SOFTWARE.
  35 */
  36
  37#include <linux/sched/signal.h>
  38#include <linux/module.h>
  39#include <crypto/aead.h>
  40
  41#include <net/strparser.h>
  42#include <net/tls.h>
  43
  44#define MAX_IV_SIZE     TLS_CIPHER_AES_GCM_128_IV_SIZE
  45
  46static int tls_do_decryption(struct sock *sk,
  47                             struct scatterlist *sgin,
  48                             struct scatterlist *sgout,
  49                             char *iv_recv,
  50                             size_t data_len,
  51                             struct sk_buff *skb,
  52                             gfp_t flags)
  53{
  54        struct tls_context *tls_ctx = tls_get_ctx(sk);
  55        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
  56        struct strp_msg *rxm = strp_msg(skb);
  57        struct aead_request *aead_req;
  58
  59        int ret;
  60        unsigned int req_size = sizeof(struct aead_request) +
  61                crypto_aead_reqsize(ctx->aead_recv);
  62
  63        aead_req = kzalloc(req_size, flags);
  64        if (!aead_req)
  65                return -ENOMEM;
  66
  67        aead_request_set_tfm(aead_req, ctx->aead_recv);
  68        aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
  69        aead_request_set_crypt(aead_req, sgin, sgout,
  70                               data_len + tls_ctx->rx.tag_size,
  71                               (u8 *)iv_recv);
  72        aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
  73                                  crypto_req_done, &ctx->async_wait);
  74
  75        ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
  76
  77        if (ret < 0)
  78                goto out;
  79
  80        rxm->offset += tls_ctx->rx.prepend_size;
  81        rxm->full_len -= tls_ctx->rx.overhead_size;
  82        tls_advance_record_sn(sk, &tls_ctx->rx);
  83
  84        ctx->decrypted = true;
  85
  86        ctx->saved_data_ready(sk);
  87
  88out:
  89        kfree(aead_req);
  90        return ret;
  91}
  92
  93static void trim_sg(struct sock *sk, struct scatterlist *sg,
  94                    int *sg_num_elem, unsigned int *sg_size, int target_size)
  95{
  96        int i = *sg_num_elem - 1;
  97        int trim = *sg_size - target_size;
  98
  99        if (trim <= 0) {
 100                WARN_ON(trim < 0);
 101                return;
 102        }
 103
 104        *sg_size = target_size;
 105        while (trim >= sg[i].length) {
 106                trim -= sg[i].length;
 107                sk_mem_uncharge(sk, sg[i].length);
 108                put_page(sg_page(&sg[i]));
 109                i--;
 110
 111                if (i < 0)
 112                        goto out;
 113        }
 114
 115        sg[i].length -= trim;
 116        sk_mem_uncharge(sk, trim);
 117
 118out:
 119        *sg_num_elem = i + 1;
 120}
 121
 122static void trim_both_sgl(struct sock *sk, int target_size)
 123{
 124        struct tls_context *tls_ctx = tls_get_ctx(sk);
 125        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 126
 127        trim_sg(sk, ctx->sg_plaintext_data,
 128                &ctx->sg_plaintext_num_elem,
 129                &ctx->sg_plaintext_size,
 130                target_size);
 131
 132        if (target_size > 0)
 133                target_size += tls_ctx->tx.overhead_size;
 134
 135        trim_sg(sk, ctx->sg_encrypted_data,
 136                &ctx->sg_encrypted_num_elem,
 137                &ctx->sg_encrypted_size,
 138                target_size);
 139}
 140
 141static int alloc_encrypted_sg(struct sock *sk, int len)
 142{
 143        struct tls_context *tls_ctx = tls_get_ctx(sk);
 144        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 145        int rc = 0;
 146
 147        rc = sk_alloc_sg(sk, len,
 148                         ctx->sg_encrypted_data, 0,
 149                         &ctx->sg_encrypted_num_elem,
 150                         &ctx->sg_encrypted_size, 0);
 151
 152        return rc;
 153}
 154
 155static int alloc_plaintext_sg(struct sock *sk, int len)
 156{
 157        struct tls_context *tls_ctx = tls_get_ctx(sk);
 158        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 159        int rc = 0;
 160
 161        rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
 162                         &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
 163                         tls_ctx->pending_open_record_frags);
 164
 165        return rc;
 166}
 167
 168static void free_sg(struct sock *sk, struct scatterlist *sg,
 169                    int *sg_num_elem, unsigned int *sg_size)
 170{
 171        int i, n = *sg_num_elem;
 172
 173        for (i = 0; i < n; ++i) {
 174                sk_mem_uncharge(sk, sg[i].length);
 175                put_page(sg_page(&sg[i]));
 176        }
 177        *sg_num_elem = 0;
 178        *sg_size = 0;
 179}
 180
 181static void tls_free_both_sg(struct sock *sk)
 182{
 183        struct tls_context *tls_ctx = tls_get_ctx(sk);
 184        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 185
 186        free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
 187                &ctx->sg_encrypted_size);
 188
 189        free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
 190                &ctx->sg_plaintext_size);
 191}
 192
 193static int tls_do_encryption(struct tls_context *tls_ctx,
 194                             struct tls_sw_context *ctx, size_t data_len,
 195                             gfp_t flags)
 196{
 197        unsigned int req_size = sizeof(struct aead_request) +
 198                crypto_aead_reqsize(ctx->aead_send);
 199        struct aead_request *aead_req;
 200        int rc;
 201
 202        aead_req = kzalloc(req_size, flags);
 203        if (!aead_req)
 204                return -ENOMEM;
 205
 206        ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
 207        ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
 208
 209        aead_request_set_tfm(aead_req, ctx->aead_send);
 210        aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
 211        aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
 212                               data_len, tls_ctx->tx.iv);
 213
 214        aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
 215                                  crypto_req_done, &ctx->async_wait);
 216
 217        rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
 218
 219        ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
 220        ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
 221
 222        kfree(aead_req);
 223        return rc;
 224}
 225
 226static int tls_push_record(struct sock *sk, int flags,
 227                           unsigned char record_type)
 228{
 229        struct tls_context *tls_ctx = tls_get_ctx(sk);
 230        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 231        int rc;
 232
 233        sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
 234        sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
 235
 236        tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
 237                     tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
 238                     record_type);
 239
 240        tls_fill_prepend(tls_ctx,
 241                         page_address(sg_page(&ctx->sg_encrypted_data[0])) +
 242                         ctx->sg_encrypted_data[0].offset,
 243                         ctx->sg_plaintext_size, record_type);
 244
 245        tls_ctx->pending_open_record_frags = 0;
 246        set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
 247
 248        rc = tls_do_encryption(tls_ctx, ctx, ctx->sg_plaintext_size,
 249                               sk->sk_allocation);
 250        if (rc < 0) {
 251                /* If we are called from write_space and
 252                 * we fail, we need to set this SOCK_NOSPACE
 253                 * to trigger another write_space in the future.
 254                 */
 255                set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 256                return rc;
 257        }
 258
 259        free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
 260                &ctx->sg_plaintext_size);
 261
 262        ctx->sg_encrypted_num_elem = 0;
 263        ctx->sg_encrypted_size = 0;
 264
 265        /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
 266        rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
 267        if (rc < 0 && rc != -EAGAIN)
 268                tls_err_abort(sk, EBADMSG);
 269
 270        tls_advance_record_sn(sk, &tls_ctx->tx);
 271        return rc;
 272}
 273
 274static int tls_sw_push_pending_record(struct sock *sk, int flags)
 275{
 276        return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
 277}
 278
 279static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
 280                              int length, int *pages_used,
 281                              unsigned int *size_used,
 282                              struct scatterlist *to, int to_max_pages,
 283                              bool charge)
 284{
 285        struct page *pages[MAX_SKB_FRAGS];
 286
 287        size_t offset;
 288        ssize_t copied, use;
 289        int i = 0;
 290        unsigned int size = *size_used;
 291        int num_elem = *pages_used;
 292        int rc = 0;
 293        int maxpages;
 294
 295        while (length > 0) {
 296                i = 0;
 297                maxpages = to_max_pages - num_elem;
 298                if (maxpages == 0) {
 299                        rc = -EFAULT;
 300                        goto out;
 301                }
 302                copied = iov_iter_get_pages(from, pages,
 303                                            length,
 304                                            maxpages, &offset);
 305                if (copied <= 0) {
 306                        rc = -EFAULT;
 307                        goto out;
 308                }
 309
 310                iov_iter_advance(from, copied);
 311
 312                length -= copied;
 313                size += copied;
 314                while (copied) {
 315                        use = min_t(int, copied, PAGE_SIZE - offset);
 316
 317                        sg_set_page(&to[num_elem],
 318                                    pages[i], use, offset);
 319                        sg_unmark_end(&to[num_elem]);
 320                        if (charge)
 321                                sk_mem_charge(sk, use);
 322
 323                        offset = 0;
 324                        copied -= use;
 325
 326                        ++i;
 327                        ++num_elem;
 328                }
 329        }
 330
 331out:
 332        *size_used = size;
 333        *pages_used = num_elem;
 334
 335        return rc;
 336}
 337
 338static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
 339                             int bytes)
 340{
 341        struct tls_context *tls_ctx = tls_get_ctx(sk);
 342        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 343        struct scatterlist *sg = ctx->sg_plaintext_data;
 344        int copy, i, rc = 0;
 345
 346        for (i = tls_ctx->pending_open_record_frags;
 347             i < ctx->sg_plaintext_num_elem; ++i) {
 348                copy = sg[i].length;
 349                if (copy_from_iter(
 350                                page_address(sg_page(&sg[i])) + sg[i].offset,
 351                                copy, from) != copy) {
 352                        rc = -EFAULT;
 353                        goto out;
 354                }
 355                bytes -= copy;
 356
 357                ++tls_ctx->pending_open_record_frags;
 358
 359                if (!bytes)
 360                        break;
 361        }
 362
 363out:
 364        return rc;
 365}
 366
 367int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 368{
 369        struct tls_context *tls_ctx = tls_get_ctx(sk);
 370        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 371        int ret = 0;
 372        int required_size;
 373        long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 374        bool eor = !(msg->msg_flags & MSG_MORE);
 375        size_t try_to_copy, copied = 0;
 376        unsigned char record_type = TLS_RECORD_TYPE_DATA;
 377        int record_room;
 378        bool full_record;
 379        int orig_size;
 380
 381        if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
 382                return -ENOTSUPP;
 383
 384        lock_sock(sk);
 385
 386        if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo))
 387                goto send_end;
 388
 389        if (unlikely(msg->msg_controllen)) {
 390                ret = tls_proccess_cmsg(sk, msg, &record_type);
 391                if (ret)
 392                        goto send_end;
 393        }
 394
 395        while (msg_data_left(msg)) {
 396                if (sk->sk_err) {
 397                        ret = -sk->sk_err;
 398                        goto send_end;
 399                }
 400
 401                orig_size = ctx->sg_plaintext_size;
 402                full_record = false;
 403                try_to_copy = msg_data_left(msg);
 404                record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
 405                if (try_to_copy >= record_room) {
 406                        try_to_copy = record_room;
 407                        full_record = true;
 408                }
 409
 410                required_size = ctx->sg_plaintext_size + try_to_copy +
 411                                tls_ctx->tx.overhead_size;
 412
 413                if (!sk_stream_memory_free(sk))
 414                        goto wait_for_sndbuf;
 415alloc_encrypted:
 416                ret = alloc_encrypted_sg(sk, required_size);
 417                if (ret) {
 418                        if (ret != -ENOSPC)
 419                                goto wait_for_memory;
 420
 421                        /* Adjust try_to_copy according to the amount that was
 422                         * actually allocated. The difference is due
 423                         * to max sg elements limit
 424                         */
 425                        try_to_copy -= required_size - ctx->sg_encrypted_size;
 426                        full_record = true;
 427                }
 428
 429                if (full_record || eor) {
 430                        ret = zerocopy_from_iter(sk, &msg->msg_iter,
 431                                try_to_copy, &ctx->sg_plaintext_num_elem,
 432                                &ctx->sg_plaintext_size,
 433                                ctx->sg_plaintext_data,
 434                                ARRAY_SIZE(ctx->sg_plaintext_data),
 435                                true);
 436                        if (ret)
 437                                goto fallback_to_reg_send;
 438
 439                        copied += try_to_copy;
 440                        ret = tls_push_record(sk, msg->msg_flags, record_type);
 441                        if (!ret)
 442                                continue;
 443                        if (ret == -EAGAIN)
 444                                goto send_end;
 445
 446                        copied -= try_to_copy;
 447fallback_to_reg_send:
 448                        iov_iter_revert(&msg->msg_iter,
 449                                        ctx->sg_plaintext_size - orig_size);
 450                        trim_sg(sk, ctx->sg_plaintext_data,
 451                                &ctx->sg_plaintext_num_elem,
 452                                &ctx->sg_plaintext_size,
 453                                orig_size);
 454                }
 455
 456                required_size = ctx->sg_plaintext_size + try_to_copy;
 457alloc_plaintext:
 458                ret = alloc_plaintext_sg(sk, required_size);
 459                if (ret) {
 460                        if (ret != -ENOSPC)
 461                                goto wait_for_memory;
 462
 463                        /* Adjust try_to_copy according to the amount that was
 464                         * actually allocated. The difference is due
 465                         * to max sg elements limit
 466                         */
 467                        try_to_copy -= required_size - ctx->sg_plaintext_size;
 468                        full_record = true;
 469
 470                        trim_sg(sk, ctx->sg_encrypted_data,
 471                                &ctx->sg_encrypted_num_elem,
 472                                &ctx->sg_encrypted_size,
 473                                ctx->sg_plaintext_size +
 474                                tls_ctx->tx.overhead_size);
 475                }
 476
 477                ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
 478                if (ret)
 479                        goto trim_sgl;
 480
 481                copied += try_to_copy;
 482                if (full_record || eor) {
 483push_record:
 484                        ret = tls_push_record(sk, msg->msg_flags, record_type);
 485                        if (ret) {
 486                                if (ret == -ENOMEM)
 487                                        goto wait_for_memory;
 488
 489                                goto send_end;
 490                        }
 491                }
 492
 493                continue;
 494
 495wait_for_sndbuf:
 496                set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 497wait_for_memory:
 498                ret = sk_stream_wait_memory(sk, &timeo);
 499                if (ret) {
 500trim_sgl:
 501                        trim_both_sgl(sk, orig_size);
 502                        goto send_end;
 503                }
 504
 505                if (tls_is_pending_closed_record(tls_ctx))
 506                        goto push_record;
 507
 508                if (ctx->sg_encrypted_size < required_size)
 509                        goto alloc_encrypted;
 510
 511                goto alloc_plaintext;
 512        }
 513
 514send_end:
 515        ret = sk_stream_error(sk, msg->msg_flags, ret);
 516
 517        release_sock(sk);
 518        return copied ? copied : ret;
 519}
 520
 521int tls_sw_sendpage(struct sock *sk, struct page *page,
 522                    int offset, size_t size, int flags)
 523{
 524        struct tls_context *tls_ctx = tls_get_ctx(sk);
 525        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 526        int ret = 0;
 527        long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
 528        bool eor;
 529        size_t orig_size = size;
 530        unsigned char record_type = TLS_RECORD_TYPE_DATA;
 531        struct scatterlist *sg;
 532        bool full_record;
 533        int record_room;
 534
 535        if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
 536                      MSG_SENDPAGE_NOTLAST))
 537                return -ENOTSUPP;
 538
 539        /* No MSG_EOR from splice, only look at MSG_MORE */
 540        eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
 541
 542        lock_sock(sk);
 543
 544        sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
 545
 546        if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo))
 547                goto sendpage_end;
 548
 549        /* Call the sk_stream functions to manage the sndbuf mem. */
 550        while (size > 0) {
 551                size_t copy, required_size;
 552
 553                if (sk->sk_err) {
 554                        ret = -sk->sk_err;
 555                        goto sendpage_end;
 556                }
 557
 558                full_record = false;
 559                record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
 560                copy = size;
 561                if (copy >= record_room) {
 562                        copy = record_room;
 563                        full_record = true;
 564                }
 565                required_size = ctx->sg_plaintext_size + copy +
 566                              tls_ctx->tx.overhead_size;
 567
 568                if (!sk_stream_memory_free(sk))
 569                        goto wait_for_sndbuf;
 570alloc_payload:
 571                ret = alloc_encrypted_sg(sk, required_size);
 572                if (ret) {
 573                        if (ret != -ENOSPC)
 574                                goto wait_for_memory;
 575
 576                        /* Adjust copy according to the amount that was
 577                         * actually allocated. The difference is due
 578                         * to max sg elements limit
 579                         */
 580                        copy -= required_size - ctx->sg_plaintext_size;
 581                        full_record = true;
 582                }
 583
 584                get_page(page);
 585                sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
 586                sg_set_page(sg, page, copy, offset);
 587                sg_unmark_end(sg);
 588
 589                ctx->sg_plaintext_num_elem++;
 590
 591                sk_mem_charge(sk, copy);
 592                offset += copy;
 593                size -= copy;
 594                ctx->sg_plaintext_size += copy;
 595                tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
 596
 597                if (full_record || eor ||
 598                    ctx->sg_plaintext_num_elem ==
 599                    ARRAY_SIZE(ctx->sg_plaintext_data)) {
 600push_record:
 601                        ret = tls_push_record(sk, flags, record_type);
 602                        if (ret) {
 603                                if (ret == -ENOMEM)
 604                                        goto wait_for_memory;
 605
 606                                goto sendpage_end;
 607                        }
 608                }
 609                continue;
 610wait_for_sndbuf:
 611                set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
 612wait_for_memory:
 613                ret = sk_stream_wait_memory(sk, &timeo);
 614                if (ret) {
 615                        trim_both_sgl(sk, ctx->sg_plaintext_size);
 616                        goto sendpage_end;
 617                }
 618
 619                if (tls_is_pending_closed_record(tls_ctx))
 620                        goto push_record;
 621
 622                goto alloc_payload;
 623        }
 624
 625sendpage_end:
 626        if (orig_size > size)
 627                ret = orig_size - size;
 628        else
 629                ret = sk_stream_error(sk, flags, ret);
 630
 631        release_sock(sk);
 632        return ret;
 633}
 634
 635static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
 636                                     long timeo, int *err)
 637{
 638        struct tls_context *tls_ctx = tls_get_ctx(sk);
 639        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 640        struct sk_buff *skb;
 641        DEFINE_WAIT_FUNC(wait, woken_wake_function);
 642
 643        while (!(skb = ctx->recv_pkt)) {
 644                if (sk->sk_err) {
 645                        *err = sock_error(sk);
 646                        return NULL;
 647                }
 648
 649                if (sock_flag(sk, SOCK_DONE))
 650                        return NULL;
 651
 652                if ((flags & MSG_DONTWAIT) || !timeo) {
 653                        *err = -EAGAIN;
 654                        return NULL;
 655                }
 656
 657                add_wait_queue(sk_sleep(sk), &wait);
 658                sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 659                sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
 660                sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
 661                remove_wait_queue(sk_sleep(sk), &wait);
 662
 663                /* Handle signals */
 664                if (signal_pending(current)) {
 665                        *err = sock_intr_errno(timeo);
 666                        return NULL;
 667                }
 668        }
 669
 670        return skb;
 671}
 672
 673static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
 674                       struct scatterlist *sgout)
 675{
 676        struct tls_context *tls_ctx = tls_get_ctx(sk);
 677        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 678        char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE];
 679        struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];
 680        struct scatterlist *sgin = &sgin_arr[0];
 681        struct strp_msg *rxm = strp_msg(skb);
 682        int ret, nsg = ARRAY_SIZE(sgin_arr);
 683        struct sk_buff *unused;
 684
 685        ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
 686                            iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
 687                            tls_ctx->rx.iv_size);
 688        if (ret < 0)
 689                return ret;
 690
 691        memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
 692        if (!sgout) {
 693                nsg = skb_cow_data(skb, 0, &unused) + 1;
 694                sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation);
 695                if (!sgout)
 696                        sgout = sgin;
 697        }
 698
 699        sg_init_table(sgin, nsg);
 700        sg_set_buf(&sgin[0], ctx->rx_aad_ciphertext, TLS_AAD_SPACE_SIZE);
 701
 702        nsg = skb_to_sgvec(skb, &sgin[1],
 703                           rxm->offset + tls_ctx->rx.prepend_size,
 704                           rxm->full_len - tls_ctx->rx.prepend_size);
 705
 706        tls_make_aad(ctx->rx_aad_ciphertext,
 707                     rxm->full_len - tls_ctx->rx.overhead_size,
 708                     tls_ctx->rx.rec_seq,
 709                     tls_ctx->rx.rec_seq_size,
 710                     ctx->control);
 711
 712        ret = tls_do_decryption(sk, sgin, sgout, iv,
 713                                rxm->full_len - tls_ctx->rx.overhead_size,
 714                                skb, sk->sk_allocation);
 715
 716        if (sgin != &sgin_arr[0])
 717                kfree(sgin);
 718
 719        return ret;
 720}
 721
 722static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
 723                               unsigned int len)
 724{
 725        struct tls_context *tls_ctx = tls_get_ctx(sk);
 726        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 727        struct strp_msg *rxm = strp_msg(skb);
 728
 729        if (len < rxm->full_len) {
 730                rxm->offset += len;
 731                rxm->full_len -= len;
 732
 733                return false;
 734        }
 735
 736        /* Finished with message */
 737        ctx->recv_pkt = NULL;
 738        kfree_skb(skb);
 739        strp_unpause(&ctx->strp);
 740
 741        return true;
 742}
 743
 744int tls_sw_recvmsg(struct sock *sk,
 745                   struct msghdr *msg,
 746                   size_t len,
 747                   int nonblock,
 748                   int flags,
 749                   int *addr_len)
 750{
 751        struct tls_context *tls_ctx = tls_get_ctx(sk);
 752        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 753        unsigned char control;
 754        struct strp_msg *rxm;
 755        struct sk_buff *skb;
 756        ssize_t copied = 0;
 757        bool cmsg = false;
 758        int err = 0;
 759        long timeo;
 760
 761        flags |= nonblock;
 762
 763        if (unlikely(flags & MSG_ERRQUEUE))
 764                return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
 765
 766        lock_sock(sk);
 767
 768        timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 769        do {
 770                bool zc = false;
 771                int chunk = 0;
 772
 773                skb = tls_wait_data(sk, flags, timeo, &err);
 774                if (!skb)
 775                        goto recv_end;
 776
 777                rxm = strp_msg(skb);
 778                if (!cmsg) {
 779                        int cerr;
 780
 781                        cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
 782                                        sizeof(ctx->control), &ctx->control);
 783                        cmsg = true;
 784                        control = ctx->control;
 785                        if (ctx->control != TLS_RECORD_TYPE_DATA) {
 786                                if (cerr || msg->msg_flags & MSG_CTRUNC) {
 787                                        err = -EIO;
 788                                        goto recv_end;
 789                                }
 790                        }
 791                } else if (control != ctx->control) {
 792                        goto recv_end;
 793                }
 794
 795                if (!ctx->decrypted) {
 796                        int page_count;
 797                        int to_copy;
 798
 799                        page_count = iov_iter_npages(&msg->msg_iter,
 800                                                     MAX_SKB_FRAGS);
 801                        to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
 802                        if (to_copy <= len && page_count < MAX_SKB_FRAGS &&
 803                            likely(!(flags & MSG_PEEK)))  {
 804                                struct scatterlist sgin[MAX_SKB_FRAGS + 1];
 805                                int pages = 0;
 806
 807                                zc = true;
 808                                sg_init_table(sgin, MAX_SKB_FRAGS + 1);
 809                                sg_set_buf(&sgin[0], ctx->rx_aad_plaintext,
 810                                           TLS_AAD_SPACE_SIZE);
 811
 812                                err = zerocopy_from_iter(sk, &msg->msg_iter,
 813                                                         to_copy, &pages,
 814                                                         &chunk, &sgin[1],
 815                                                         MAX_SKB_FRAGS, false);
 816                                if (err < 0)
 817                                        goto fallback_to_reg_recv;
 818
 819                                err = decrypt_skb(sk, skb, sgin);
 820                                for (; pages > 0; pages--)
 821                                        put_page(sg_page(&sgin[pages]));
 822                                if (err < 0) {
 823                                        tls_err_abort(sk, EBADMSG);
 824                                        goto recv_end;
 825                                }
 826                        } else {
 827fallback_to_reg_recv:
 828                                err = decrypt_skb(sk, skb, NULL);
 829                                if (err < 0) {
 830                                        tls_err_abort(sk, EBADMSG);
 831                                        goto recv_end;
 832                                }
 833                        }
 834                        ctx->decrypted = true;
 835                }
 836
 837                if (!zc) {
 838                        chunk = min_t(unsigned int, rxm->full_len, len);
 839                        err = skb_copy_datagram_msg(skb, rxm->offset, msg,
 840                                                    chunk);
 841                        if (err < 0)
 842                                goto recv_end;
 843                }
 844
 845                copied += chunk;
 846                len -= chunk;
 847                if (likely(!(flags & MSG_PEEK))) {
 848                        u8 control = ctx->control;
 849
 850                        if (tls_sw_advance_skb(sk, skb, chunk)) {
 851                                /* Return full control message to
 852                                 * userspace before trying to parse
 853                                 * another message type
 854                                 */
 855                                msg->msg_flags |= MSG_EOR;
 856                                if (control != TLS_RECORD_TYPE_DATA)
 857                                        goto recv_end;
 858                        }
 859                }
 860        } while (len);
 861
 862recv_end:
 863        release_sock(sk);
 864        return copied ? : err;
 865}
 866
 867ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
 868                           struct pipe_inode_info *pipe,
 869                           size_t len, unsigned int flags)
 870{
 871        struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
 872        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 873        struct strp_msg *rxm = NULL;
 874        struct sock *sk = sock->sk;
 875        struct sk_buff *skb;
 876        ssize_t copied = 0;
 877        int err = 0;
 878        long timeo;
 879        int chunk;
 880
 881        lock_sock(sk);
 882
 883        timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 884
 885        skb = tls_wait_data(sk, flags, timeo, &err);
 886        if (!skb)
 887                goto splice_read_end;
 888
 889        /* splice does not support reading control messages */
 890        if (ctx->control != TLS_RECORD_TYPE_DATA) {
 891                err = -ENOTSUPP;
 892                goto splice_read_end;
 893        }
 894
 895        if (!ctx->decrypted) {
 896                err = decrypt_skb(sk, skb, NULL);
 897
 898                if (err < 0) {
 899                        tls_err_abort(sk, EBADMSG);
 900                        goto splice_read_end;
 901                }
 902                ctx->decrypted = true;
 903        }
 904        rxm = strp_msg(skb);
 905
 906        chunk = min_t(unsigned int, rxm->full_len, len);
 907        copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
 908        if (copied < 0)
 909                goto splice_read_end;
 910
 911        if (likely(!(flags & MSG_PEEK)))
 912                tls_sw_advance_skb(sk, skb, copied);
 913
 914splice_read_end:
 915        release_sock(sk);
 916        return copied ? : err;
 917}
 918
 919unsigned int tls_sw_poll(struct file *file, struct socket *sock,
 920                         struct poll_table_struct *wait)
 921{
 922        unsigned int ret;
 923        struct sock *sk = sock->sk;
 924        struct tls_context *tls_ctx = tls_get_ctx(sk);
 925        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 926
 927        /* Grab POLLOUT and POLLHUP from the underlying socket */
 928        ret = ctx->sk_poll(file, sock, wait);
 929
 930        /* Clear POLLIN bits, and set based on recv_pkt */
 931        ret &= ~(POLLIN | POLLRDNORM);
 932        if (ctx->recv_pkt)
 933                ret |= POLLIN | POLLRDNORM;
 934
 935        return ret;
 936}
 937
 938static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
 939{
 940        struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
 941        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 942        char header[tls_ctx->rx.prepend_size];
 943        struct strp_msg *rxm = strp_msg(skb);
 944        size_t cipher_overhead;
 945        size_t data_len = 0;
 946        int ret;
 947
 948        /* Verify that we have a full TLS header, or wait for more data */
 949        if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
 950                return 0;
 951
 952        /* Linearize header to local buffer */
 953        ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
 954
 955        if (ret < 0)
 956                goto read_failure;
 957
 958        ctx->control = header[0];
 959
 960        data_len = ((header[4] & 0xFF) | (header[3] << 8));
 961
 962        cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
 963
 964        if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
 965                ret = -EMSGSIZE;
 966                goto read_failure;
 967        }
 968        if (data_len < cipher_overhead) {
 969                ret = -EBADMSG;
 970                goto read_failure;
 971        }
 972
 973        if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) ||
 974            header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) {
 975                ret = -EINVAL;
 976                goto read_failure;
 977        }
 978
 979        return data_len + TLS_HEADER_SIZE;
 980
 981read_failure:
 982        tls_err_abort(strp->sk, ret);
 983
 984        return ret;
 985}
 986
 987static void tls_queue(struct strparser *strp, struct sk_buff *skb)
 988{
 989        struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
 990        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
 991        struct strp_msg *rxm;
 992
 993        rxm = strp_msg(skb);
 994
 995        ctx->decrypted = false;
 996
 997        ctx->recv_pkt = skb;
 998        strp_pause(strp);
 999
1000        strp->sk->sk_state_change(strp->sk);
1001}
1002
1003static void tls_data_ready(struct sock *sk)
1004{
1005        struct tls_context *tls_ctx = tls_get_ctx(sk);
1006        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
1007
1008        strp_data_ready(&ctx->strp);
1009}
1010
1011void tls_sw_free_resources(struct sock *sk)
1012{
1013        struct tls_context *tls_ctx = tls_get_ctx(sk);
1014        struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
1015
1016        if (ctx->aead_send)
1017                crypto_free_aead(ctx->aead_send);
1018        if (ctx->aead_recv) {
1019                if (ctx->recv_pkt) {
1020                        kfree_skb(ctx->recv_pkt);
1021                        ctx->recv_pkt = NULL;
1022                }
1023                crypto_free_aead(ctx->aead_recv);
1024                strp_stop(&ctx->strp);
1025                write_lock_bh(&sk->sk_callback_lock);
1026                sk->sk_data_ready = ctx->saved_data_ready;
1027                write_unlock_bh(&sk->sk_callback_lock);
1028                release_sock(sk);
1029                strp_done(&ctx->strp);
1030                lock_sock(sk);
1031        }
1032
1033        tls_free_both_sg(sk);
1034
1035        kfree(ctx);
1036        kfree(tls_ctx);
1037}
1038
1039int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
1040{
1041        char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
1042        struct tls_crypto_info *crypto_info;
1043        struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
1044        struct tls_sw_context *sw_ctx;
1045        struct cipher_context *cctx;
1046        struct crypto_aead **aead;
1047        struct strp_callbacks cb;
1048        u16 nonce_size, tag_size, iv_size, rec_seq_size;
1049        char *iv, *rec_seq;
1050        int rc = 0;
1051
1052        if (!ctx) {
1053                rc = -EINVAL;
1054                goto out;
1055        }
1056
1057        if (!ctx->priv_ctx) {
1058                sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
1059                if (!sw_ctx) {
1060                        rc = -ENOMEM;
1061                        goto out;
1062                }
1063                crypto_init_wait(&sw_ctx->async_wait);
1064        } else {
1065                sw_ctx = ctx->priv_ctx;
1066        }
1067
1068        ctx->priv_ctx = (struct tls_offload_context *)sw_ctx;
1069
1070        if (tx) {
1071                crypto_info = &ctx->crypto_send;
1072                cctx = &ctx->tx;
1073                aead = &sw_ctx->aead_send;
1074        } else {
1075                crypto_info = &ctx->crypto_recv;
1076                cctx = &ctx->rx;
1077                aead = &sw_ctx->aead_recv;
1078        }
1079
1080        switch (crypto_info->cipher_type) {
1081        case TLS_CIPHER_AES_GCM_128: {
1082                nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1083                tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1084                iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1085                iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1086                rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1087                rec_seq =
1088                 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1089                gcm_128_info =
1090                        (struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
1091                break;
1092        }
1093        default:
1094                rc = -EINVAL;
1095                goto free_priv;
1096        }
1097
1098        /* Sanity-check the IV size for stack allocations. */
1099        if (iv_size > MAX_IV_SIZE) {
1100                rc = -EINVAL;
1101                goto free_priv;
1102        }
1103
1104        cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
1105        cctx->tag_size = tag_size;
1106        cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
1107        cctx->iv_size = iv_size;
1108        cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1109                           GFP_KERNEL);
1110        if (!cctx->iv) {
1111                rc = -ENOMEM;
1112                goto free_priv;
1113        }
1114        memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1115        memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1116        cctx->rec_seq_size = rec_seq_size;
1117        cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
1118        if (!cctx->rec_seq) {
1119                rc = -ENOMEM;
1120                goto free_iv;
1121        }
1122        memcpy(cctx->rec_seq, rec_seq, rec_seq_size);
1123
1124        if (tx) {
1125                sg_init_table(sw_ctx->sg_encrypted_data,
1126                              ARRAY_SIZE(sw_ctx->sg_encrypted_data));
1127                sg_init_table(sw_ctx->sg_plaintext_data,
1128                              ARRAY_SIZE(sw_ctx->sg_plaintext_data));
1129
1130                sg_init_table(sw_ctx->sg_aead_in, 2);
1131                sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
1132                           sizeof(sw_ctx->aad_space));
1133                sg_unmark_end(&sw_ctx->sg_aead_in[1]);
1134                sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
1135                sg_init_table(sw_ctx->sg_aead_out, 2);
1136                sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
1137                           sizeof(sw_ctx->aad_space));
1138                sg_unmark_end(&sw_ctx->sg_aead_out[1]);
1139                sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
1140        }
1141
1142        if (!*aead) {
1143                *aead = crypto_alloc_aead("gcm(aes)", 0, 0);
1144                if (IS_ERR(*aead)) {
1145                        rc = PTR_ERR(*aead);
1146                        *aead = NULL;
1147                        goto free_rec_seq;
1148                }
1149        }
1150
1151        ctx->push_pending_record = tls_sw_push_pending_record;
1152
1153        memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1154
1155        rc = crypto_aead_setkey(*aead, keyval,
1156                                TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1157        if (rc)
1158                goto free_aead;
1159
1160        rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
1161        if (rc)
1162                goto free_aead;
1163
1164        if (!tx) {
1165                /* Set up strparser */
1166                memset(&cb, 0, sizeof(cb));
1167                cb.rcv_msg = tls_queue;
1168                cb.parse_msg = tls_read_size;
1169
1170                strp_init(&sw_ctx->strp, sk, &cb);
1171
1172                write_lock_bh(&sk->sk_callback_lock);
1173                sw_ctx->saved_data_ready = sk->sk_data_ready;
1174                sk->sk_data_ready = tls_data_ready;
1175                write_unlock_bh(&sk->sk_callback_lock);
1176
1177                sw_ctx->sk_poll = sk->sk_socket->ops->poll;
1178
1179                strp_check_rcv(&sw_ctx->strp);
1180        }
1181
1182        goto out;
1183
1184free_aead:
1185        crypto_free_aead(*aead);
1186        *aead = NULL;
1187free_rec_seq:
1188        kfree(cctx->rec_seq);
1189        cctx->rec_seq = NULL;
1190free_iv:
1191        kfree(ctx->tx.iv);
1192        ctx->tx.iv = NULL;
1193free_priv:
1194        kfree(ctx->priv_ctx);
1195        ctx->priv_ctx = NULL;
1196out:
1197        return rc;
1198}
1199