linux/net/strparser/strparser.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * Stream Parser
   4 *
   5 * Copyright (c) 2016 Tom Herbert <tom@herbertland.com>
   6 */
   7
   8#include <linux/bpf.h>
   9#include <linux/errno.h>
  10#include <linux/errqueue.h>
  11#include <linux/file.h>
  12#include <linux/in.h>
  13#include <linux/kernel.h>
  14#include <linux/export.h>
  15#include <linux/init.h>
  16#include <linux/net.h>
  17#include <linux/netdevice.h>
  18#include <linux/poll.h>
  19#include <linux/rculist.h>
  20#include <linux/skbuff.h>
  21#include <linux/socket.h>
  22#include <linux/uaccess.h>
  23#include <linux/workqueue.h>
  24#include <net/strparser.h>
  25#include <net/netns/generic.h>
  26#include <net/sock.h>
  27
  28static struct workqueue_struct *strp_wq;
  29
  30struct _strp_msg {
  31        /* Internal cb structure. struct strp_msg must be first for passing
  32         * to upper layer.
  33         */
  34        struct strp_msg strp;
  35        int accum_len;
  36};
  37
  38static inline struct _strp_msg *_strp_msg(struct sk_buff *skb)
  39{
  40        return (struct _strp_msg *)((void *)skb->cb +
  41                offsetof(struct qdisc_skb_cb, data));
  42}
  43
  44/* Lower lock held */
  45static void strp_abort_strp(struct strparser *strp, int err)
  46{
  47        /* Unrecoverable error in receive */
  48
  49        cancel_delayed_work(&strp->msg_timer_work);
  50
  51        if (strp->stopped)
  52                return;
  53
  54        strp->stopped = 1;
  55
  56        if (strp->sk) {
  57                struct sock *sk = strp->sk;
  58
  59                /* Report an error on the lower socket */
  60                sk->sk_err = -err;
  61                sk_error_report(sk);
  62        }
  63}
  64
  65static void strp_start_timer(struct strparser *strp, long timeo)
  66{
  67        if (timeo && timeo != LONG_MAX)
  68                mod_delayed_work(strp_wq, &strp->msg_timer_work, timeo);
  69}
  70
  71/* Lower lock held */
  72static void strp_parser_err(struct strparser *strp, int err,
  73                            read_descriptor_t *desc)
  74{
  75        desc->error = err;
  76        kfree_skb(strp->skb_head);
  77        strp->skb_head = NULL;
  78        strp->cb.abort_parser(strp, err);
  79}
  80
  81static inline int strp_peek_len(struct strparser *strp)
  82{
  83        if (strp->sk) {
  84                struct socket *sock = strp->sk->sk_socket;
  85
  86                return sock->ops->peek_len(sock);
  87        }
  88
  89        /* If we don't have an associated socket there's nothing to peek.
  90         * Return int max to avoid stopping the strparser.
  91         */
  92
  93        return INT_MAX;
  94}
  95
  96/* Lower socket lock held */
  97static int __strp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb,
  98                       unsigned int orig_offset, size_t orig_len,
  99                       size_t max_msg_size, long timeo)
 100{
 101        struct strparser *strp = (struct strparser *)desc->arg.data;
 102        struct _strp_msg *stm;
 103        struct sk_buff *head, *skb;
 104        size_t eaten = 0, cand_len;
 105        ssize_t extra;
 106        int err;
 107        bool cloned_orig = false;
 108
 109        if (strp->paused)
 110                return 0;
 111
 112        head = strp->skb_head;
 113        if (head) {
 114                /* Message already in progress */
 115                if (unlikely(orig_offset)) {
 116                        /* Getting data with a non-zero offset when a message is
 117                         * in progress is not expected. If it does happen, we
 118                         * need to clone and pull since we can't deal with
 119                         * offsets in the skbs for a message expect in the head.
 120                         */
 121                        orig_skb = skb_clone(orig_skb, GFP_ATOMIC);
 122                        if (!orig_skb) {
 123                                STRP_STATS_INCR(strp->stats.mem_fail);
 124                                desc->error = -ENOMEM;
 125                                return 0;
 126                        }
 127                        if (!pskb_pull(orig_skb, orig_offset)) {
 128                                STRP_STATS_INCR(strp->stats.mem_fail);
 129                                kfree_skb(orig_skb);
 130                                desc->error = -ENOMEM;
 131                                return 0;
 132                        }
 133                        cloned_orig = true;
 134                        orig_offset = 0;
 135                }
 136
 137                if (!strp->skb_nextp) {
 138                        /* We are going to append to the frags_list of head.
 139                         * Need to unshare the frag_list.
 140                         */
 141                        err = skb_unclone(head, GFP_ATOMIC);
 142                        if (err) {
 143                                STRP_STATS_INCR(strp->stats.mem_fail);
 144                                desc->error = err;
 145                                return 0;
 146                        }
 147
 148                        if (unlikely(skb_shinfo(head)->frag_list)) {
 149                                /* We can't append to an sk_buff that already
 150                                 * has a frag_list. We create a new head, point
 151                                 * the frag_list of that to the old head, and
 152                                 * then are able to use the old head->next for
 153                                 * appending to the message.
 154                                 */
 155                                if (WARN_ON(head->next)) {
 156                                        desc->error = -EINVAL;
 157                                        return 0;
 158                                }
 159
 160                                skb = alloc_skb_for_msg(head);
 161                                if (!skb) {
 162                                        STRP_STATS_INCR(strp->stats.mem_fail);
 163                                        desc->error = -ENOMEM;
 164                                        return 0;
 165                                }
 166
 167                                strp->skb_nextp = &head->next;
 168                                strp->skb_head = skb;
 169                                head = skb;
 170                        } else {
 171                                strp->skb_nextp =
 172                                    &skb_shinfo(head)->frag_list;
 173                        }
 174                }
 175        }
 176
 177        while (eaten < orig_len) {
 178                /* Always clone since we will consume something */
 179                skb = skb_clone(orig_skb, GFP_ATOMIC);
 180                if (!skb) {
 181                        STRP_STATS_INCR(strp->stats.mem_fail);
 182                        desc->error = -ENOMEM;
 183                        break;
 184                }
 185
 186                cand_len = orig_len - eaten;
 187
 188                head = strp->skb_head;
 189                if (!head) {
 190                        head = skb;
 191                        strp->skb_head = head;
 192                        /* Will set skb_nextp on next packet if needed */
 193                        strp->skb_nextp = NULL;
 194                        stm = _strp_msg(head);
 195                        memset(stm, 0, sizeof(*stm));
 196                        stm->strp.offset = orig_offset + eaten;
 197                } else {
 198                        /* Unclone if we are appending to an skb that we
 199                         * already share a frag_list with.
 200                         */
 201                        if (skb_has_frag_list(skb)) {
 202                                err = skb_unclone(skb, GFP_ATOMIC);
 203                                if (err) {
 204                                        STRP_STATS_INCR(strp->stats.mem_fail);
 205                                        desc->error = err;
 206                                        break;
 207                                }
 208                        }
 209
 210                        stm = _strp_msg(head);
 211                        *strp->skb_nextp = skb;
 212                        strp->skb_nextp = &skb->next;
 213                        head->data_len += skb->len;
 214                        head->len += skb->len;
 215                        head->truesize += skb->truesize;
 216                }
 217
 218                if (!stm->strp.full_len) {
 219                        ssize_t len;
 220
 221                        len = (*strp->cb.parse_msg)(strp, head);
 222
 223                        if (!len) {
 224                                /* Need more header to determine length */
 225                                if (!stm->accum_len) {
 226                                        /* Start RX timer for new message */
 227                                        strp_start_timer(strp, timeo);
 228                                }
 229                                stm->accum_len += cand_len;
 230                                eaten += cand_len;
 231                                STRP_STATS_INCR(strp->stats.need_more_hdr);
 232                                WARN_ON(eaten != orig_len);
 233                                break;
 234                        } else if (len < 0) {
 235                                if (len == -ESTRPIPE && stm->accum_len) {
 236                                        len = -ENODATA;
 237                                        strp->unrecov_intr = 1;
 238                                } else {
 239                                        strp->interrupted = 1;
 240                                }
 241                                strp_parser_err(strp, len, desc);
 242                                break;
 243                        } else if (len > max_msg_size) {
 244                                /* Message length exceeds maximum allowed */
 245                                STRP_STATS_INCR(strp->stats.msg_too_big);
 246                                strp_parser_err(strp, -EMSGSIZE, desc);
 247                                break;
 248                        } else if (len <= (ssize_t)head->len -
 249                                          skb->len - stm->strp.offset) {
 250                                /* Length must be into new skb (and also
 251                                 * greater than zero)
 252                                 */
 253                                STRP_STATS_INCR(strp->stats.bad_hdr_len);
 254                                strp_parser_err(strp, -EPROTO, desc);
 255                                break;
 256                        }
 257
 258                        stm->strp.full_len = len;
 259                }
 260
 261                extra = (ssize_t)(stm->accum_len + cand_len) -
 262                        stm->strp.full_len;
 263
 264                if (extra < 0) {
 265                        /* Message not complete yet. */
 266                        if (stm->strp.full_len - stm->accum_len >
 267                            strp_peek_len(strp)) {
 268                                /* Don't have the whole message in the socket
 269                                 * buffer. Set strp->need_bytes to wait for
 270                                 * the rest of the message. Also, set "early
 271                                 * eaten" since we've already buffered the skb
 272                                 * but don't consume yet per strp_read_sock.
 273                                 */
 274
 275                                if (!stm->accum_len) {
 276                                        /* Start RX timer for new message */
 277                                        strp_start_timer(strp, timeo);
 278                                }
 279
 280                                stm->accum_len += cand_len;
 281                                eaten += cand_len;
 282                                strp->need_bytes = stm->strp.full_len -
 283                                                       stm->accum_len;
 284                                STRP_STATS_ADD(strp->stats.bytes, cand_len);
 285                                desc->count = 0; /* Stop reading socket */
 286                                break;
 287                        }
 288                        stm->accum_len += cand_len;
 289                        eaten += cand_len;
 290                        WARN_ON(eaten != orig_len);
 291                        break;
 292                }
 293
 294                /* Positive extra indicates more bytes than needed for the
 295                 * message
 296                 */
 297
 298                WARN_ON(extra > cand_len);
 299
 300                eaten += (cand_len - extra);
 301
 302                /* Hurray, we have a new message! */
 303                cancel_delayed_work(&strp->msg_timer_work);
 304                strp->skb_head = NULL;
 305                strp->need_bytes = 0;
 306                STRP_STATS_INCR(strp->stats.msgs);
 307
 308                /* Give skb to upper layer */
 309                strp->cb.rcv_msg(strp, head);
 310
 311                if (unlikely(strp->paused)) {
 312                        /* Upper layer paused strp */
 313                        break;
 314                }
 315        }
 316
 317        if (cloned_orig)
 318                kfree_skb(orig_skb);
 319
 320        STRP_STATS_ADD(strp->stats.bytes, eaten);
 321
 322        return eaten;
 323}
 324
 325int strp_process(struct strparser *strp, struct sk_buff *orig_skb,
 326                 unsigned int orig_offset, size_t orig_len,
 327                 size_t max_msg_size, long timeo)
 328{
 329        read_descriptor_t desc; /* Dummy arg to strp_recv */
 330
 331        desc.arg.data = strp;
 332
 333        return __strp_recv(&desc, orig_skb, orig_offset, orig_len,
 334                           max_msg_size, timeo);
 335}
 336EXPORT_SYMBOL_GPL(strp_process);
 337
 338static int strp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb,
 339                     unsigned int orig_offset, size_t orig_len)
 340{
 341        struct strparser *strp = (struct strparser *)desc->arg.data;
 342
 343        return __strp_recv(desc, orig_skb, orig_offset, orig_len,
 344                           strp->sk->sk_rcvbuf, strp->sk->sk_rcvtimeo);
 345}
 346
 347static int default_read_sock_done(struct strparser *strp, int err)
 348{
 349        return err;
 350}
 351
 352/* Called with lock held on lower socket */
 353static int strp_read_sock(struct strparser *strp)
 354{
 355        struct socket *sock = strp->sk->sk_socket;
 356        read_descriptor_t desc;
 357
 358        if (unlikely(!sock || !sock->ops || !sock->ops->read_sock))
 359                return -EBUSY;
 360
 361        desc.arg.data = strp;
 362        desc.error = 0;
 363        desc.count = 1; /* give more than one skb per call */
 364
 365        /* sk should be locked here, so okay to do read_sock */
 366        sock->ops->read_sock(strp->sk, &desc, strp_recv);
 367
 368        desc.error = strp->cb.read_sock_done(strp, desc.error);
 369
 370        return desc.error;
 371}
 372
 373/* Lower sock lock held */
 374void strp_data_ready(struct strparser *strp)
 375{
 376        if (unlikely(strp->stopped) || strp->paused)
 377                return;
 378
 379        /* This check is needed to synchronize with do_strp_work.
 380         * do_strp_work acquires a process lock (lock_sock) whereas
 381         * the lock held here is bh_lock_sock. The two locks can be
 382         * held by different threads at the same time, but bh_lock_sock
 383         * allows a thread in BH context to safely check if the process
 384         * lock is held. In this case, if the lock is held, queue work.
 385         */
 386        if (sock_owned_by_user_nocheck(strp->sk)) {
 387                queue_work(strp_wq, &strp->work);
 388                return;
 389        }
 390
 391        if (strp->need_bytes) {
 392                if (strp_peek_len(strp) < strp->need_bytes)
 393                        return;
 394        }
 395
 396        if (strp_read_sock(strp) == -ENOMEM)
 397                queue_work(strp_wq, &strp->work);
 398}
 399EXPORT_SYMBOL_GPL(strp_data_ready);
 400
 401static void do_strp_work(struct strparser *strp)
 402{
 403        /* We need the read lock to synchronize with strp_data_ready. We
 404         * need the socket lock for calling strp_read_sock.
 405         */
 406        strp->cb.lock(strp);
 407
 408        if (unlikely(strp->stopped))
 409                goto out;
 410
 411        if (strp->paused)
 412                goto out;
 413
 414        if (strp_read_sock(strp) == -ENOMEM)
 415                queue_work(strp_wq, &strp->work);
 416
 417out:
 418        strp->cb.unlock(strp);
 419}
 420
 421static void strp_work(struct work_struct *w)
 422{
 423        do_strp_work(container_of(w, struct strparser, work));
 424}
 425
 426static void strp_msg_timeout(struct work_struct *w)
 427{
 428        struct strparser *strp = container_of(w, struct strparser,
 429                                              msg_timer_work.work);
 430
 431        /* Message assembly timed out */
 432        STRP_STATS_INCR(strp->stats.msg_timeouts);
 433        strp->cb.lock(strp);
 434        strp->cb.abort_parser(strp, -ETIMEDOUT);
 435        strp->cb.unlock(strp);
 436}
 437
 438static void strp_sock_lock(struct strparser *strp)
 439{
 440        lock_sock(strp->sk);
 441}
 442
 443static void strp_sock_unlock(struct strparser *strp)
 444{
 445        release_sock(strp->sk);
 446}
 447
 448int strp_init(struct strparser *strp, struct sock *sk,
 449              const struct strp_callbacks *cb)
 450{
 451
 452        if (!cb || !cb->rcv_msg || !cb->parse_msg)
 453                return -EINVAL;
 454
 455        /* The sk (sock) arg determines the mode of the stream parser.
 456         *
 457         * If the sock is set then the strparser is in receive callback mode.
 458         * The upper layer calls strp_data_ready to kick receive processing
 459         * and strparser calls the read_sock function on the socket to
 460         * get packets.
 461         *
 462         * If the sock is not set then the strparser is in general mode.
 463         * The upper layer calls strp_process for each skb to be parsed.
 464         */
 465
 466        if (!sk) {
 467                if (!cb->lock || !cb->unlock)
 468                        return -EINVAL;
 469        }
 470
 471        memset(strp, 0, sizeof(*strp));
 472
 473        strp->sk = sk;
 474
 475        strp->cb.lock = cb->lock ? : strp_sock_lock;
 476        strp->cb.unlock = cb->unlock ? : strp_sock_unlock;
 477        strp->cb.rcv_msg = cb->rcv_msg;
 478        strp->cb.parse_msg = cb->parse_msg;
 479        strp->cb.read_sock_done = cb->read_sock_done ? : default_read_sock_done;
 480        strp->cb.abort_parser = cb->abort_parser ? : strp_abort_strp;
 481
 482        INIT_DELAYED_WORK(&strp->msg_timer_work, strp_msg_timeout);
 483        INIT_WORK(&strp->work, strp_work);
 484
 485        return 0;
 486}
 487EXPORT_SYMBOL_GPL(strp_init);
 488
 489/* Sock process lock held (lock_sock) */
 490void __strp_unpause(struct strparser *strp)
 491{
 492        strp->paused = 0;
 493
 494        if (strp->need_bytes) {
 495                if (strp_peek_len(strp) < strp->need_bytes)
 496                        return;
 497        }
 498        strp_read_sock(strp);
 499}
 500EXPORT_SYMBOL_GPL(__strp_unpause);
 501
 502void strp_unpause(struct strparser *strp)
 503{
 504        strp->paused = 0;
 505
 506        /* Sync setting paused with RX work */
 507        smp_mb();
 508
 509        queue_work(strp_wq, &strp->work);
 510}
 511EXPORT_SYMBOL_GPL(strp_unpause);
 512
 513/* strp must already be stopped so that strp_recv will no longer be called.
 514 * Note that strp_done is not called with the lower socket held.
 515 */
 516void strp_done(struct strparser *strp)
 517{
 518        WARN_ON(!strp->stopped);
 519
 520        cancel_delayed_work_sync(&strp->msg_timer_work);
 521        cancel_work_sync(&strp->work);
 522
 523        if (strp->skb_head) {
 524                kfree_skb(strp->skb_head);
 525                strp->skb_head = NULL;
 526        }
 527}
 528EXPORT_SYMBOL_GPL(strp_done);
 529
 530void strp_stop(struct strparser *strp)
 531{
 532        strp->stopped = 1;
 533}
 534EXPORT_SYMBOL_GPL(strp_stop);
 535
 536void strp_check_rcv(struct strparser *strp)
 537{
 538        queue_work(strp_wq, &strp->work);
 539}
 540EXPORT_SYMBOL_GPL(strp_check_rcv);
 541
 542static int __init strp_dev_init(void)
 543{
 544        strp_wq = create_singlethread_workqueue("kstrp");
 545        if (unlikely(!strp_wq))
 546                return -ENOMEM;
 547
 548        return 0;
 549}
 550device_initcall(strp_dev_init);
 551