linux/net/xfrm/espintcp.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2#include <net/tcp.h>
   3#include <net/strparser.h>
   4#include <net/xfrm.h>
   5#include <net/esp.h>
   6#include <net/espintcp.h>
   7#include <linux/skmsg.h>
   8#include <net/inet_common.h>
   9#if IS_ENABLED(CONFIG_IPV6)
  10#include <net/ipv6_stubs.h>
  11#endif
  12
  13static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb,
  14                          struct sock *sk)
  15{
  16        if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf ||
  17            !sk_rmem_schedule(sk, skb, skb->truesize)) {
  18                XFRM_INC_STATS(sock_net(sk), LINUX_MIB_XFRMINERROR);
  19                kfree_skb(skb);
  20                return;
  21        }
  22
  23        skb_set_owner_r(skb, sk);
  24
  25        memset(skb->cb, 0, sizeof(skb->cb));
  26        skb_queue_tail(&ctx->ike_queue, skb);
  27        ctx->saved_data_ready(sk);
  28}
  29
  30static void handle_esp(struct sk_buff *skb, struct sock *sk)
  31{
  32        struct tcp_skb_cb *tcp_cb = (struct tcp_skb_cb *)skb->cb;
  33
  34        skb_reset_transport_header(skb);
  35
  36        /* restore IP CB, we need at least IP6CB->nhoff */
  37        memmove(skb->cb, &tcp_cb->header, sizeof(tcp_cb->header));
  38
  39        rcu_read_lock();
  40        skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif);
  41        local_bh_disable();
  42#if IS_ENABLED(CONFIG_IPV6)
  43        if (sk->sk_family == AF_INET6)
  44                ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
  45        else
  46#endif
  47                xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP);
  48        local_bh_enable();
  49        rcu_read_unlock();
  50}
  51
  52static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb)
  53{
  54        struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx,
  55                                                strp);
  56        struct strp_msg *rxm = strp_msg(skb);
  57        int len = rxm->full_len - 2;
  58        u32 nonesp_marker;
  59        int err;
  60
  61        /* keepalive packet? */
  62        if (unlikely(len == 1)) {
  63                u8 data;
  64
  65                err = skb_copy_bits(skb, rxm->offset + 2, &data, 1);
  66                if (err < 0) {
  67                        XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
  68                        kfree_skb(skb);
  69                        return;
  70                }
  71
  72                if (data == 0xff) {
  73                        kfree_skb(skb);
  74                        return;
  75                }
  76        }
  77
  78        /* drop other short messages */
  79        if (unlikely(len <= sizeof(nonesp_marker))) {
  80                XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
  81                kfree_skb(skb);
  82                return;
  83        }
  84
  85        err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker,
  86                            sizeof(nonesp_marker));
  87        if (err < 0) {
  88                XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR);
  89                kfree_skb(skb);
  90                return;
  91        }
  92
  93        /* remove header, leave non-ESP marker/SPI */
  94        if (!__pskb_pull(skb, rxm->offset + 2)) {
  95                XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
  96                kfree_skb(skb);
  97                return;
  98        }
  99
 100        if (pskb_trim(skb, rxm->full_len - 2) != 0) {
 101                XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR);
 102                kfree_skb(skb);
 103                return;
 104        }
 105
 106        if (nonesp_marker == 0)
 107                handle_nonesp(ctx, skb, strp->sk);
 108        else
 109                handle_esp(skb, strp->sk);
 110}
 111
 112static int espintcp_parse(struct strparser *strp, struct sk_buff *skb)
 113{
 114        struct strp_msg *rxm = strp_msg(skb);
 115        __be16 blen;
 116        u16 len;
 117        int err;
 118
 119        if (skb->len < rxm->offset + 2)
 120                return 0;
 121
 122        err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen));
 123        if (err < 0)
 124                return err;
 125
 126        len = be16_to_cpu(blen);
 127        if (len < 2)
 128                return -EINVAL;
 129
 130        return len;
 131}
 132
 133static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
 134                            int nonblock, int flags, int *addr_len)
 135{
 136        struct espintcp_ctx *ctx = espintcp_getctx(sk);
 137        struct sk_buff *skb;
 138        int err = 0;
 139        int copied;
 140        int off = 0;
 141
 142        flags |= nonblock ? MSG_DONTWAIT : 0;
 143
 144        skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err);
 145        if (!skb) {
 146                if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN)
 147                        return 0;
 148                return err;
 149        }
 150
 151        copied = len;
 152        if (copied > skb->len)
 153                copied = skb->len;
 154        else if (copied < skb->len)
 155                msg->msg_flags |= MSG_TRUNC;
 156
 157        err = skb_copy_datagram_msg(skb, 0, msg, copied);
 158        if (unlikely(err)) {
 159                kfree_skb(skb);
 160                return err;
 161        }
 162
 163        if (flags & MSG_TRUNC)
 164                copied = skb->len;
 165        kfree_skb(skb);
 166        return copied;
 167}
 168
 169int espintcp_queue_out(struct sock *sk, struct sk_buff *skb)
 170{
 171        struct espintcp_ctx *ctx = espintcp_getctx(sk);
 172
 173        if (skb_queue_len(&ctx->out_queue) >= netdev_max_backlog)
 174                return -ENOBUFS;
 175
 176        __skb_queue_tail(&ctx->out_queue, skb);
 177
 178        return 0;
 179}
 180EXPORT_SYMBOL_GPL(espintcp_queue_out);
 181
 182/* espintcp length field is 2B and length includes the length field's size */
 183#define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2)
 184
 185static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg,
 186                                   int flags)
 187{
 188        do {
 189                int ret;
 190
 191                ret = skb_send_sock_locked(sk, emsg->skb,
 192                                           emsg->offset, emsg->len);
 193                if (ret < 0)
 194                        return ret;
 195
 196                emsg->len -= ret;
 197                emsg->offset += ret;
 198        } while (emsg->len > 0);
 199
 200        kfree_skb(emsg->skb);
 201        memset(emsg, 0, sizeof(*emsg));
 202
 203        return 0;
 204}
 205
 206static int espintcp_sendskmsg_locked(struct sock *sk,
 207                                     struct espintcp_msg *emsg, int flags)
 208{
 209        struct sk_msg *skmsg = &emsg->skmsg;
 210        struct scatterlist *sg;
 211        int done = 0;
 212        int ret;
 213
 214        flags |= MSG_SENDPAGE_NOTLAST;
 215        sg = &skmsg->sg.data[skmsg->sg.start];
 216        do {
 217                size_t size = sg->length - emsg->offset;
 218                int offset = sg->offset + emsg->offset;
 219                struct page *p;
 220
 221                emsg->offset = 0;
 222
 223                if (sg_is_last(sg))
 224                        flags &= ~MSG_SENDPAGE_NOTLAST;
 225
 226                p = sg_page(sg);
 227retry:
 228                ret = do_tcp_sendpages(sk, p, offset, size, flags);
 229                if (ret < 0) {
 230                        emsg->offset = offset - sg->offset;
 231                        skmsg->sg.start += done;
 232                        return ret;
 233                }
 234
 235                if (ret != size) {
 236                        offset += ret;
 237                        size -= ret;
 238                        goto retry;
 239                }
 240
 241                done++;
 242                put_page(p);
 243                sk_mem_uncharge(sk, sg->length);
 244                sg = sg_next(sg);
 245        } while (sg);
 246
 247        memset(emsg, 0, sizeof(*emsg));
 248
 249        return 0;
 250}
 251
 252static int espintcp_push_msgs(struct sock *sk, int flags)
 253{
 254        struct espintcp_ctx *ctx = espintcp_getctx(sk);
 255        struct espintcp_msg *emsg = &ctx->partial;
 256        int err;
 257
 258        if (!emsg->len)
 259                return 0;
 260
 261        if (ctx->tx_running)
 262                return -EAGAIN;
 263        ctx->tx_running = 1;
 264
 265        if (emsg->skb)
 266                err = espintcp_sendskb_locked(sk, emsg, flags);
 267        else
 268                err = espintcp_sendskmsg_locked(sk, emsg, flags);
 269        if (err == -EAGAIN) {
 270                ctx->tx_running = 0;
 271                return flags & MSG_DONTWAIT ? -EAGAIN : 0;
 272        }
 273        if (!err)
 274                memset(emsg, 0, sizeof(*emsg));
 275
 276        ctx->tx_running = 0;
 277
 278        return err;
 279}
 280
 281int espintcp_push_skb(struct sock *sk, struct sk_buff *skb)
 282{
 283        struct espintcp_ctx *ctx = espintcp_getctx(sk);
 284        struct espintcp_msg *emsg = &ctx->partial;
 285        unsigned int len;
 286        int offset;
 287
 288        if (sk->sk_state != TCP_ESTABLISHED) {
 289                kfree_skb(skb);
 290                return -ECONNRESET;
 291        }
 292
 293        offset = skb_transport_offset(skb);
 294        len = skb->len - offset;
 295
 296        espintcp_push_msgs(sk, 0);
 297
 298        if (emsg->len) {
 299                kfree_skb(skb);
 300                return -ENOBUFS;
 301        }
 302
 303        skb_set_owner_w(skb, sk);
 304
 305        emsg->offset = offset;
 306        emsg->len = len;
 307        emsg->skb = skb;
 308
 309        espintcp_push_msgs(sk, 0);
 310
 311        return 0;
 312}
 313EXPORT_SYMBOL_GPL(espintcp_push_skb);
 314
 315static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 316{
 317        long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 318        struct espintcp_ctx *ctx = espintcp_getctx(sk);
 319        struct espintcp_msg *emsg = &ctx->partial;
 320        struct iov_iter pfx_iter;
 321        struct kvec pfx_iov = {};
 322        size_t msglen = size + 2;
 323        char buf[2] = {0};
 324        int err, end;
 325
 326        if (msg->msg_flags & ~MSG_DONTWAIT)
 327                return -EOPNOTSUPP;
 328
 329        if (size > MAX_ESPINTCP_MSG)
 330                return -EMSGSIZE;
 331
 332        if (msg->msg_controllen)
 333                return -EOPNOTSUPP;
 334
 335        lock_sock(sk);
 336
 337        err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
 338        if (err < 0) {
 339                if (err != -EAGAIN || !(msg->msg_flags & MSG_DONTWAIT))
 340                        err = -ENOBUFS;
 341                goto unlock;
 342        }
 343
 344        sk_msg_init(&emsg->skmsg);
 345        while (1) {
 346                /* only -ENOMEM is possible since we don't coalesce */
 347                err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0);
 348                if (!err)
 349                        break;
 350
 351                err = sk_stream_wait_memory(sk, &timeo);
 352                if (err)
 353                        goto fail;
 354        }
 355
 356        *((__be16 *)buf) = cpu_to_be16(msglen);
 357        pfx_iov.iov_base = buf;
 358        pfx_iov.iov_len = sizeof(buf);
 359        iov_iter_kvec(&pfx_iter, WRITE, &pfx_iov, 1, pfx_iov.iov_len);
 360
 361        err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg,
 362                                       pfx_iov.iov_len);
 363        if (err < 0)
 364                goto fail;
 365
 366        err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size);
 367        if (err < 0)
 368                goto fail;
 369
 370        end = emsg->skmsg.sg.end;
 371        emsg->len = size;
 372        sk_msg_iter_var_prev(end);
 373        sg_mark_end(sk_msg_elem(&emsg->skmsg, end));
 374
 375        tcp_rate_check_app_limited(sk);
 376
 377        err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT);
 378        /* this message could be partially sent, keep it */
 379
 380        release_sock(sk);
 381
 382        return size;
 383
 384fail:
 385        sk_msg_free(sk, &emsg->skmsg);
 386        memset(emsg, 0, sizeof(*emsg));
 387unlock:
 388        release_sock(sk);
 389        return err;
 390}
 391
 392static struct proto espintcp_prot __ro_after_init;
 393static struct proto_ops espintcp_ops __ro_after_init;
 394static struct proto espintcp6_prot;
 395static struct proto_ops espintcp6_ops;
 396static DEFINE_MUTEX(tcpv6_prot_mutex);
 397
 398static void espintcp_data_ready(struct sock *sk)
 399{
 400        struct espintcp_ctx *ctx = espintcp_getctx(sk);
 401
 402        strp_data_ready(&ctx->strp);
 403}
 404
 405static void espintcp_tx_work(struct work_struct *work)
 406{
 407        struct espintcp_ctx *ctx = container_of(work,
 408                                                struct espintcp_ctx, work);
 409        struct sock *sk = ctx->strp.sk;
 410
 411        lock_sock(sk);
 412        if (!ctx->tx_running)
 413                espintcp_push_msgs(sk, 0);
 414        release_sock(sk);
 415}
 416
 417static void espintcp_write_space(struct sock *sk)
 418{
 419        struct espintcp_ctx *ctx = espintcp_getctx(sk);
 420
 421        schedule_work(&ctx->work);
 422        ctx->saved_write_space(sk);
 423}
 424
 425static void espintcp_destruct(struct sock *sk)
 426{
 427        struct espintcp_ctx *ctx = espintcp_getctx(sk);
 428
 429        ctx->saved_destruct(sk);
 430        kfree(ctx);
 431}
 432
 433bool tcp_is_ulp_esp(struct sock *sk)
 434{
 435        return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot;
 436}
 437EXPORT_SYMBOL_GPL(tcp_is_ulp_esp);
 438
 439static void build_protos(struct proto *espintcp_prot,
 440                         struct proto_ops *espintcp_ops,
 441                         const struct proto *orig_prot,
 442                         const struct proto_ops *orig_ops);
 443static int espintcp_init_sk(struct sock *sk)
 444{
 445        struct inet_connection_sock *icsk = inet_csk(sk);
 446        struct strp_callbacks cb = {
 447                .rcv_msg = espintcp_rcv,
 448                .parse_msg = espintcp_parse,
 449        };
 450        struct espintcp_ctx *ctx;
 451        int err;
 452
 453        /* sockmap is not compatible with espintcp */
 454        if (sk->sk_user_data)
 455                return -EBUSY;
 456
 457        ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
 458        if (!ctx)
 459                return -ENOMEM;
 460
 461        err = strp_init(&ctx->strp, sk, &cb);
 462        if (err)
 463                goto free;
 464
 465        __sk_dst_reset(sk);
 466
 467        strp_check_rcv(&ctx->strp);
 468        skb_queue_head_init(&ctx->ike_queue);
 469        skb_queue_head_init(&ctx->out_queue);
 470
 471        if (sk->sk_family == AF_INET) {
 472                sk->sk_prot = &espintcp_prot;
 473                sk->sk_socket->ops = &espintcp_ops;
 474        } else {
 475                mutex_lock(&tcpv6_prot_mutex);
 476                if (!espintcp6_prot.recvmsg)
 477                        build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops);
 478                mutex_unlock(&tcpv6_prot_mutex);
 479
 480                sk->sk_prot = &espintcp6_prot;
 481                sk->sk_socket->ops = &espintcp6_ops;
 482        }
 483        ctx->saved_data_ready = sk->sk_data_ready;
 484        ctx->saved_write_space = sk->sk_write_space;
 485        ctx->saved_destruct = sk->sk_destruct;
 486        sk->sk_data_ready = espintcp_data_ready;
 487        sk->sk_write_space = espintcp_write_space;
 488        sk->sk_destruct = espintcp_destruct;
 489        rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
 490        INIT_WORK(&ctx->work, espintcp_tx_work);
 491
 492        /* avoid using task_frag */
 493        sk->sk_allocation = GFP_ATOMIC;
 494
 495        return 0;
 496
 497free:
 498        kfree(ctx);
 499        return err;
 500}
 501
 502static void espintcp_release(struct sock *sk)
 503{
 504        struct espintcp_ctx *ctx = espintcp_getctx(sk);
 505        struct sk_buff_head queue;
 506        struct sk_buff *skb;
 507
 508        __skb_queue_head_init(&queue);
 509        skb_queue_splice_init(&ctx->out_queue, &queue);
 510
 511        while ((skb = __skb_dequeue(&queue)))
 512                espintcp_push_skb(sk, skb);
 513
 514        tcp_release_cb(sk);
 515}
 516
 517static void espintcp_close(struct sock *sk, long timeout)
 518{
 519        struct espintcp_ctx *ctx = espintcp_getctx(sk);
 520        struct espintcp_msg *emsg = &ctx->partial;
 521
 522        strp_stop(&ctx->strp);
 523
 524        sk->sk_prot = &tcp_prot;
 525        barrier();
 526
 527        cancel_work_sync(&ctx->work);
 528        strp_done(&ctx->strp);
 529
 530        skb_queue_purge(&ctx->out_queue);
 531        skb_queue_purge(&ctx->ike_queue);
 532
 533        if (emsg->len) {
 534                if (emsg->skb)
 535                        kfree_skb(emsg->skb);
 536                else
 537                        sk_msg_free(sk, &emsg->skmsg);
 538        }
 539
 540        tcp_close(sk, timeout);
 541}
 542
 543static __poll_t espintcp_poll(struct file *file, struct socket *sock,
 544                              poll_table *wait)
 545{
 546        __poll_t mask = datagram_poll(file, sock, wait);
 547        struct sock *sk = sock->sk;
 548        struct espintcp_ctx *ctx = espintcp_getctx(sk);
 549
 550        if (!skb_queue_empty(&ctx->ike_queue))
 551                mask |= EPOLLIN | EPOLLRDNORM;
 552
 553        return mask;
 554}
 555
 556static void build_protos(struct proto *espintcp_prot,
 557                         struct proto_ops *espintcp_ops,
 558                         const struct proto *orig_prot,
 559                         const struct proto_ops *orig_ops)
 560{
 561        memcpy(espintcp_prot, orig_prot, sizeof(struct proto));
 562        memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops));
 563        espintcp_prot->sendmsg = espintcp_sendmsg;
 564        espintcp_prot->recvmsg = espintcp_recvmsg;
 565        espintcp_prot->close = espintcp_close;
 566        espintcp_prot->release_cb = espintcp_release;
 567        espintcp_ops->poll = espintcp_poll;
 568}
 569
 570static struct tcp_ulp_ops espintcp_ulp __read_mostly = {
 571        .name = "espintcp",
 572        .owner = THIS_MODULE,
 573        .init = espintcp_init_sk,
 574};
 575
 576void __init espintcp_init(void)
 577{
 578        build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops);
 579
 580        tcp_register_ulp(&espintcp_ulp);
 581}
 582