linux/net/tls/tls_device_fallback.c
<<
>>
Prefs
   1/* Copyright (c) 2018, Mellanox Technologies All rights reserved.
   2 *
   3 * This software is available to you under a choice of one of two
   4 * licenses.  You may choose to be licensed under the terms of the GNU
   5 * General Public License (GPL) Version 2, available from the file
   6 * COPYING in the main directory of this source tree, or the
   7 * OpenIB.org BSD license below:
   8 *
   9 *     Redistribution and use in source and binary forms, with or
  10 *     without modification, are permitted provided that the following
  11 *     conditions are met:
  12 *
  13 *      - Redistributions of source code must retain the above
  14 *        copyright notice, this list of conditions and the following
  15 *        disclaimer.
  16 *
  17 *      - Redistributions in binary form must reproduce the above
  18 *        copyright notice, this list of conditions and the following
  19 *        disclaimer in the documentation and/or other materials
  20 *        provided with the distribution.
  21 *
  22 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  23 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  24 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  25 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  26 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  27 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  28 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  29 * SOFTWARE.
  30 */
  31
  32#include <net/tls.h>
  33#include <crypto/aead.h>
  34#include <crypto/scatterwalk.h>
  35#include <net/ip6_checksum.h>
  36
  37static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
  38{
  39        struct scatterlist *src = walk->sg;
  40        int diff = walk->offset - src->offset;
  41
  42        sg_set_page(sg, sg_page(src),
  43                    src->length - diff, walk->offset);
  44
  45        scatterwalk_crypto_chain(sg, sg_next(src), 2);
  46}
  47
  48static int tls_enc_record(struct aead_request *aead_req,
  49                          struct crypto_aead *aead, char *aad,
  50                          char *iv, __be64 rcd_sn,
  51                          struct scatter_walk *in,
  52                          struct scatter_walk *out, int *in_len)
  53{
  54        unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE];
  55        struct scatterlist sg_in[3];
  56        struct scatterlist sg_out[3];
  57        u16 len;
  58        int rc;
  59
  60        len = min_t(int, *in_len, ARRAY_SIZE(buf));
  61
  62        scatterwalk_copychunks(buf, in, len, 0);
  63        scatterwalk_copychunks(buf, out, len, 1);
  64
  65        *in_len -= len;
  66        if (!*in_len)
  67                return 0;
  68
  69        scatterwalk_pagedone(in, 0, 1);
  70        scatterwalk_pagedone(out, 1, 1);
  71
  72        len = buf[4] | (buf[3] << 8);
  73        len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
  74
  75        tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
  76                     (char *)&rcd_sn, sizeof(rcd_sn), buf[0]);
  77
  78        memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
  79               TLS_CIPHER_AES_GCM_128_IV_SIZE);
  80
  81        sg_init_table(sg_in, ARRAY_SIZE(sg_in));
  82        sg_init_table(sg_out, ARRAY_SIZE(sg_out));
  83        sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE);
  84        sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE);
  85        chain_to_walk(sg_in + 1, in);
  86        chain_to_walk(sg_out + 1, out);
  87
  88        *in_len -= len;
  89        if (*in_len < 0) {
  90                *in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
  91                /* the input buffer doesn't contain the entire record.
  92                 * trim len accordingly. The resulting authentication tag
  93                 * will contain garbage, but we don't care, so we won't
  94                 * include any of it in the output skb
  95                 * Note that we assume the output buffer length
  96                 * is larger then input buffer length + tag size
  97                 */
  98                if (*in_len < 0)
  99                        len += *in_len;
 100
 101                *in_len = 0;
 102        }
 103
 104        if (*in_len) {
 105                scatterwalk_copychunks(NULL, in, len, 2);
 106                scatterwalk_pagedone(in, 0, 1);
 107                scatterwalk_copychunks(NULL, out, len, 2);
 108                scatterwalk_pagedone(out, 1, 1);
 109        }
 110
 111        len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
 112        aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
 113
 114        rc = crypto_aead_encrypt(aead_req);
 115
 116        return rc;
 117}
 118
 119static void tls_init_aead_request(struct aead_request *aead_req,
 120                                  struct crypto_aead *aead)
 121{
 122        aead_request_set_tfm(aead_req, aead);
 123        aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
 124}
 125
 126static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead,
 127                                                   gfp_t flags)
 128{
 129        unsigned int req_size = sizeof(struct aead_request) +
 130                crypto_aead_reqsize(aead);
 131        struct aead_request *aead_req;
 132
 133        aead_req = kzalloc(req_size, flags);
 134        if (aead_req)
 135                tls_init_aead_request(aead_req, aead);
 136        return aead_req;
 137}
 138
 139static int tls_enc_records(struct aead_request *aead_req,
 140                           struct crypto_aead *aead, struct scatterlist *sg_in,
 141                           struct scatterlist *sg_out, char *aad, char *iv,
 142                           u64 rcd_sn, int len)
 143{
 144        struct scatter_walk out, in;
 145        int rc;
 146
 147        scatterwalk_start(&in, sg_in);
 148        scatterwalk_start(&out, sg_out);
 149
 150        do {
 151                rc = tls_enc_record(aead_req, aead, aad, iv,
 152                                    cpu_to_be64(rcd_sn), &in, &out, &len);
 153                rcd_sn++;
 154
 155        } while (rc == 0 && len);
 156
 157        scatterwalk_done(&in, 0, 0);
 158        scatterwalk_done(&out, 1, 0);
 159
 160        return rc;
 161}
 162
 163/* Can't use icsk->icsk_af_ops->send_check here because the ip addresses
 164 * might have been changed by NAT.
 165 */
 166static void update_chksum(struct sk_buff *skb, int headln)
 167{
 168        struct tcphdr *th = tcp_hdr(skb);
 169        int datalen = skb->len - headln;
 170        const struct ipv6hdr *ipv6h;
 171        const struct iphdr *iph;
 172
 173        /* We only changed the payload so if we are using partial we don't
 174         * need to update anything.
 175         */
 176        if (likely(skb->ip_summed == CHECKSUM_PARTIAL))
 177                return;
 178
 179        skb->ip_summed = CHECKSUM_PARTIAL;
 180        skb->csum_start = skb_transport_header(skb) - skb->head;
 181        skb->csum_offset = offsetof(struct tcphdr, check);
 182
 183        if (skb->sk->sk_family == AF_INET6) {
 184                ipv6h = ipv6_hdr(skb);
 185                th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr,
 186                                             datalen, IPPROTO_TCP, 0);
 187        } else {
 188                iph = ip_hdr(skb);
 189                th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen,
 190                                               IPPROTO_TCP, 0);
 191        }
 192}
 193
 194static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln)
 195{
 196        skb_copy_header(nskb, skb);
 197
 198        skb_put(nskb, skb->len);
 199        memcpy(nskb->data, skb->data, headln);
 200        update_chksum(nskb, headln);
 201
 202        nskb->destructor = skb->destructor;
 203        nskb->sk = skb->sk;
 204        skb->destructor = NULL;
 205        skb->sk = NULL;
 206        refcount_add(nskb->truesize - skb->truesize,
 207                     &nskb->sk->sk_wmem_alloc);
 208}
 209
 210/* This function may be called after the user socket is already
 211 * closed so make sure we don't use anything freed during
 212 * tls_sk_proto_close here
 213 */
 214
 215static int fill_sg_in(struct scatterlist *sg_in,
 216                      struct sk_buff *skb,
 217                      struct tls_offload_context_tx *ctx,
 218                      u64 *rcd_sn,
 219                      s32 *sync_size,
 220                      int *resync_sgs)
 221{
 222        int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
 223        int payload_len = skb->len - tcp_payload_offset;
 224        u32 tcp_seq = ntohl(tcp_hdr(skb)->seq);
 225        struct tls_record_info *record;
 226        unsigned long flags;
 227        int remaining;
 228        int i;
 229
 230        spin_lock_irqsave(&ctx->lock, flags);
 231        record = tls_get_record(ctx, tcp_seq, rcd_sn);
 232        if (!record) {
 233                spin_unlock_irqrestore(&ctx->lock, flags);
 234                WARN(1, "Record not found for seq %u\n", tcp_seq);
 235                return -EINVAL;
 236        }
 237
 238        *sync_size = tcp_seq - tls_record_start_seq(record);
 239        if (*sync_size < 0) {
 240                int is_start_marker = tls_record_is_start_marker(record);
 241
 242                spin_unlock_irqrestore(&ctx->lock, flags);
 243                /* This should only occur if the relevant record was
 244                 * already acked. In that case it should be ok
 245                 * to drop the packet and avoid retransmission.
 246                 *
 247                 * There is a corner case where the packet contains
 248                 * both an acked and a non-acked record.
 249                 * We currently don't handle that case and rely
 250                 * on TCP to retranmit a packet that doesn't contain
 251                 * already acked payload.
 252                 */
 253                if (!is_start_marker)
 254                        *sync_size = 0;
 255                return -EINVAL;
 256        }
 257
 258        remaining = *sync_size;
 259        for (i = 0; remaining > 0; i++) {
 260                skb_frag_t *frag = &record->frags[i];
 261
 262                __skb_frag_ref(frag);
 263                sg_set_page(sg_in + i, skb_frag_page(frag),
 264                            skb_frag_size(frag), frag->page_offset);
 265
 266                remaining -= skb_frag_size(frag);
 267
 268                if (remaining < 0)
 269                        sg_in[i].length += remaining;
 270        }
 271        *resync_sgs = i;
 272
 273        spin_unlock_irqrestore(&ctx->lock, flags);
 274        if (skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset, payload_len) < 0)
 275                return -EINVAL;
 276
 277        return 0;
 278}
 279
 280static void fill_sg_out(struct scatterlist sg_out[3], void *buf,
 281                        struct tls_context *tls_ctx,
 282                        struct sk_buff *nskb,
 283                        int tcp_payload_offset,
 284                        int payload_len,
 285                        int sync_size,
 286                        void *dummy_buf)
 287{
 288        sg_set_buf(&sg_out[0], dummy_buf, sync_size);
 289        sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset, payload_len);
 290        /* Add room for authentication tag produced by crypto */
 291        dummy_buf += sync_size;
 292        sg_set_buf(&sg_out[2], dummy_buf, TLS_CIPHER_AES_GCM_128_TAG_SIZE);
 293}
 294
 295static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
 296                                   struct scatterlist sg_out[3],
 297                                   struct scatterlist *sg_in,
 298                                   struct sk_buff *skb,
 299                                   s32 sync_size, u64 rcd_sn)
 300{
 301        int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
 302        struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
 303        int payload_len = skb->len - tcp_payload_offset;
 304        void *buf, *iv, *aad, *dummy_buf;
 305        struct aead_request *aead_req;
 306        struct sk_buff *nskb = NULL;
 307        int buf_len;
 308
 309        aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC);
 310        if (!aead_req)
 311                return NULL;
 312
 313        buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
 314                  TLS_CIPHER_AES_GCM_128_IV_SIZE +
 315                  TLS_AAD_SPACE_SIZE +
 316                  sync_size +
 317                  TLS_CIPHER_AES_GCM_128_TAG_SIZE;
 318        buf = kmalloc(buf_len, GFP_ATOMIC);
 319        if (!buf)
 320                goto free_req;
 321
 322        iv = buf;
 323        memcpy(iv, tls_ctx->crypto_send.aes_gcm_128.salt,
 324               TLS_CIPHER_AES_GCM_128_SALT_SIZE);
 325        aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
 326              TLS_CIPHER_AES_GCM_128_IV_SIZE;
 327        dummy_buf = aad + TLS_AAD_SPACE_SIZE;
 328
 329        nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC);
 330        if (!nskb)
 331                goto free_buf;
 332
 333        skb_reserve(nskb, skb_headroom(skb));
 334
 335        fill_sg_out(sg_out, buf, tls_ctx, nskb, tcp_payload_offset,
 336                    payload_len, sync_size, dummy_buf);
 337
 338        if (tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv,
 339                            rcd_sn, sync_size + payload_len) < 0)
 340                goto free_nskb;
 341
 342        complete_skb(nskb, skb, tcp_payload_offset);
 343
 344        /* validate_xmit_skb_list assumes that if the skb wasn't segmented
 345         * nskb->prev will point to the skb itself
 346         */
 347        nskb->prev = nskb;
 348
 349free_buf:
 350        kfree(buf);
 351free_req:
 352        kfree(aead_req);
 353        return nskb;
 354free_nskb:
 355        kfree_skb(nskb);
 356        nskb = NULL;
 357        goto free_buf;
 358}
 359
 360static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb)
 361{
 362        int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
 363        struct tls_context *tls_ctx = tls_get_ctx(sk);
 364        struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
 365        int payload_len = skb->len - tcp_payload_offset;
 366        struct scatterlist *sg_in, sg_out[3];
 367        struct sk_buff *nskb = NULL;
 368        int sg_in_max_elements;
 369        int resync_sgs = 0;
 370        s32 sync_size = 0;
 371        u64 rcd_sn;
 372
 373        /* worst case is:
 374         * MAX_SKB_FRAGS in tls_record_info
 375         * MAX_SKB_FRAGS + 1 in SKB head and frags.
 376         */
 377        sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1;
 378
 379        if (!payload_len)
 380                return skb;
 381
 382        sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in), GFP_ATOMIC);
 383        if (!sg_in)
 384                goto free_orig;
 385
 386        sg_init_table(sg_in, sg_in_max_elements);
 387        sg_init_table(sg_out, ARRAY_SIZE(sg_out));
 388
 389        if (fill_sg_in(sg_in, skb, ctx, &rcd_sn, &sync_size, &resync_sgs)) {
 390                /* bypass packets before kernel TLS socket option was set */
 391                if (sync_size < 0 && payload_len <= -sync_size)
 392                        nskb = skb_get(skb);
 393                goto put_sg;
 394        }
 395
 396        nskb = tls_enc_skb(tls_ctx, sg_out, sg_in, skb, sync_size, rcd_sn);
 397
 398put_sg:
 399        while (resync_sgs)
 400                put_page(sg_page(&sg_in[--resync_sgs]));
 401        kfree(sg_in);
 402free_orig:
 403        kfree_skb(skb);
 404        return nskb;
 405}
 406
 407struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
 408                                      struct net_device *dev,
 409                                      struct sk_buff *skb)
 410{
 411        if (dev == tls_get_ctx(sk)->netdev)
 412                return skb;
 413
 414        return tls_sw_fallback(sk, skb);
 415}
 416EXPORT_SYMBOL_GPL(tls_validate_xmit_skb);
 417
 418int tls_sw_fallback_init(struct sock *sk,
 419                         struct tls_offload_context_tx *offload_ctx,
 420                         struct tls_crypto_info *crypto_info)
 421{
 422        const u8 *key;
 423        int rc;
 424
 425        offload_ctx->aead_send =
 426            crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC);
 427        if (IS_ERR(offload_ctx->aead_send)) {
 428                rc = PTR_ERR(offload_ctx->aead_send);
 429                pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n", rc);
 430                offload_ctx->aead_send = NULL;
 431                goto err_out;
 432        }
 433
 434        key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
 435
 436        rc = crypto_aead_setkey(offload_ctx->aead_send, key,
 437                                TLS_CIPHER_AES_GCM_128_KEY_SIZE);
 438        if (rc)
 439                goto free_aead;
 440
 441        rc = crypto_aead_setauthsize(offload_ctx->aead_send,
 442                                     TLS_CIPHER_AES_GCM_128_TAG_SIZE);
 443        if (rc)
 444                goto free_aead;
 445
 446        return 0;
 447free_aead:
 448        crypto_free_aead(offload_ctx->aead_send);
 449err_out:
 450        return rc;
 451}
 452