linux/net/mctp/af_mctp.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/*
   3 * Management Component Transport Protocol (MCTP)
   4 *
   5 * Copyright (c) 2021 Code Construct
   6 * Copyright (c) 2021 Google
   7 */
   8
   9#include <linux/if_arp.h>
  10#include <linux/net.h>
  11#include <linux/mctp.h>
  12#include <linux/module.h>
  13#include <linux/socket.h>
  14
  15#include <net/mctp.h>
  16#include <net/mctpdevice.h>
  17#include <net/sock.h>
  18
  19#define CREATE_TRACE_POINTS
  20#include <trace/events/mctp.h>
  21
  22/* socket implementation */
  23
  24static int mctp_release(struct socket *sock)
  25{
  26        struct sock *sk = sock->sk;
  27
  28        if (sk) {
  29                sock->sk = NULL;
  30                sk->sk_prot->close(sk, 0);
  31        }
  32
  33        return 0;
  34}
  35
  36/* Generic sockaddr checks, padding checks only so far */
  37static bool mctp_sockaddr_is_ok(const struct sockaddr_mctp *addr)
  38{
  39        return !addr->__smctp_pad0 && !addr->__smctp_pad1;
  40}
  41
  42static bool mctp_sockaddr_ext_is_ok(const struct sockaddr_mctp_ext *addr)
  43{
  44        return !addr->__smctp_pad0[0] &&
  45               !addr->__smctp_pad0[1] &&
  46               !addr->__smctp_pad0[2];
  47}
  48
  49static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
  50{
  51        struct sock *sk = sock->sk;
  52        struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
  53        struct sockaddr_mctp *smctp;
  54        int rc;
  55
  56        if (addrlen < sizeof(*smctp))
  57                return -EINVAL;
  58
  59        if (addr->sa_family != AF_MCTP)
  60                return -EAFNOSUPPORT;
  61
  62        if (!capable(CAP_NET_BIND_SERVICE))
  63                return -EACCES;
  64
  65        /* it's a valid sockaddr for MCTP, cast and do protocol checks */
  66        smctp = (struct sockaddr_mctp *)addr;
  67
  68        if (!mctp_sockaddr_is_ok(smctp))
  69                return -EINVAL;
  70
  71        lock_sock(sk);
  72
  73        /* TODO: allow rebind */
  74        if (sk_hashed(sk)) {
  75                rc = -EADDRINUSE;
  76                goto out_release;
  77        }
  78        msk->bind_net = smctp->smctp_network;
  79        msk->bind_addr = smctp->smctp_addr.s_addr;
  80        msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */
  81
  82        rc = sk->sk_prot->hash(sk);
  83
  84out_release:
  85        release_sock(sk);
  86
  87        return rc;
  88}
  89
  90static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
  91{
  92        DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
  93        const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
  94        int rc, addrlen = msg->msg_namelen;
  95        struct sock *sk = sock->sk;
  96        struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
  97        struct mctp_skb_cb *cb;
  98        struct mctp_route *rt;
  99        struct sk_buff *skb;
 100
 101        if (addr) {
 102                if (addrlen < sizeof(struct sockaddr_mctp))
 103                        return -EINVAL;
 104                if (addr->smctp_family != AF_MCTP)
 105                        return -EINVAL;
 106                if (!mctp_sockaddr_is_ok(addr))
 107                        return -EINVAL;
 108                if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER))
 109                        return -EINVAL;
 110
 111        } else {
 112                /* TODO: connect()ed sockets */
 113                return -EDESTADDRREQ;
 114        }
 115
 116        if (!capable(CAP_NET_RAW))
 117                return -EACCES;
 118
 119        if (addr->smctp_network == MCTP_NET_ANY)
 120                addr->smctp_network = mctp_default_net(sock_net(sk));
 121
 122        skb = sock_alloc_send_skb(sk, hlen + 1 + len,
 123                                  msg->msg_flags & MSG_DONTWAIT, &rc);
 124        if (!skb)
 125                return rc;
 126
 127        skb_reserve(skb, hlen);
 128
 129        /* set type as fist byte in payload */
 130        *(u8 *)skb_put(skb, 1) = addr->smctp_type;
 131
 132        rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len);
 133        if (rc < 0)
 134                goto err_free;
 135
 136        /* set up cb */
 137        cb = __mctp_cb(skb);
 138        cb->net = addr->smctp_network;
 139
 140        /* direct addressing */
 141        if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
 142                DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
 143                                 extaddr, msg->msg_name);
 144
 145                if (!mctp_sockaddr_ext_is_ok(extaddr) ||
 146                    extaddr->smctp_halen > sizeof(cb->haddr)) {
 147                        rc = -EINVAL;
 148                        goto err_free;
 149                }
 150
 151                cb->ifindex = extaddr->smctp_ifindex;
 152                cb->halen = extaddr->smctp_halen;
 153                memcpy(cb->haddr, extaddr->smctp_haddr, cb->halen);
 154
 155                rt = NULL;
 156        } else {
 157                rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
 158                                       addr->smctp_addr.s_addr);
 159                if (!rt) {
 160                        rc = -EHOSTUNREACH;
 161                        goto err_free;
 162                }
 163        }
 164
 165        rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
 166                               addr->smctp_tag);
 167
 168        return rc ? : len;
 169
 170err_free:
 171        kfree_skb(skb);
 172        return rc;
 173}
 174
 175static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 176                        int flags)
 177{
 178        DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
 179        struct sock *sk = sock->sk;
 180        struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
 181        struct sk_buff *skb;
 182        size_t msglen;
 183        u8 type;
 184        int rc;
 185
 186        if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK))
 187                return -EOPNOTSUPP;
 188
 189        skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &rc);
 190        if (!skb)
 191                return rc;
 192
 193        if (!skb->len) {
 194                rc = 0;
 195                goto out_free;
 196        }
 197
 198        /* extract message type, remove from data */
 199        type = *((u8 *)skb->data);
 200        msglen = skb->len - 1;
 201
 202        if (len < msglen)
 203                msg->msg_flags |= MSG_TRUNC;
 204        else
 205                len = msglen;
 206
 207        rc = skb_copy_datagram_msg(skb, 1, msg, len);
 208        if (rc < 0)
 209                goto out_free;
 210
 211        sock_recv_ts_and_drops(msg, sk, skb);
 212
 213        if (addr) {
 214                struct mctp_skb_cb *cb = mctp_cb(skb);
 215                /* TODO: expand mctp_skb_cb for header fields? */
 216                struct mctp_hdr *hdr = mctp_hdr(skb);
 217
 218                addr = msg->msg_name;
 219                addr->smctp_family = AF_MCTP;
 220                addr->__smctp_pad0 = 0;
 221                addr->smctp_network = cb->net;
 222                addr->smctp_addr.s_addr = hdr->src;
 223                addr->smctp_type = type;
 224                addr->smctp_tag = hdr->flags_seq_tag &
 225                                        (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
 226                addr->__smctp_pad1 = 0;
 227                msg->msg_namelen = sizeof(*addr);
 228
 229                if (msk->addr_ext) {
 230                        DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, ae,
 231                                         msg->msg_name);
 232                        msg->msg_namelen = sizeof(*ae);
 233                        ae->smctp_ifindex = cb->ifindex;
 234                        ae->smctp_halen = cb->halen;
 235                        memset(ae->__smctp_pad0, 0x0, sizeof(ae->__smctp_pad0));
 236                        memset(ae->smctp_haddr, 0x0, sizeof(ae->smctp_haddr));
 237                        memcpy(ae->smctp_haddr, cb->haddr, cb->halen);
 238                }
 239        }
 240
 241        rc = len;
 242
 243        if (flags & MSG_TRUNC)
 244                rc = msglen;
 245
 246out_free:
 247        skb_free_datagram(sk, skb);
 248        return rc;
 249}
 250
 251static int mctp_setsockopt(struct socket *sock, int level, int optname,
 252                           sockptr_t optval, unsigned int optlen)
 253{
 254        struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
 255        int val;
 256
 257        if (level != SOL_MCTP)
 258                return -EINVAL;
 259
 260        if (optname == MCTP_OPT_ADDR_EXT) {
 261                if (optlen != sizeof(int))
 262                        return -EINVAL;
 263                if (copy_from_sockptr(&val, optval, sizeof(int)))
 264                        return -EFAULT;
 265                msk->addr_ext = val;
 266                return 0;
 267        }
 268
 269        return -ENOPROTOOPT;
 270}
 271
 272static int mctp_getsockopt(struct socket *sock, int level, int optname,
 273                           char __user *optval, int __user *optlen)
 274{
 275        struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
 276        int len, val;
 277
 278        if (level != SOL_MCTP)
 279                return -EINVAL;
 280
 281        if (get_user(len, optlen))
 282                return -EFAULT;
 283
 284        if (optname == MCTP_OPT_ADDR_EXT) {
 285                if (len != sizeof(int))
 286                        return -EINVAL;
 287                val = !!msk->addr_ext;
 288                if (copy_to_user(optval, &val, len))
 289                        return -EFAULT;
 290                return 0;
 291        }
 292
 293        return -EINVAL;
 294}
 295
 296static const struct proto_ops mctp_dgram_ops = {
 297        .family         = PF_MCTP,
 298        .release        = mctp_release,
 299        .bind           = mctp_bind,
 300        .connect        = sock_no_connect,
 301        .socketpair     = sock_no_socketpair,
 302        .accept         = sock_no_accept,
 303        .getname        = sock_no_getname,
 304        .poll           = datagram_poll,
 305        .ioctl          = sock_no_ioctl,
 306        .gettstamp      = sock_gettstamp,
 307        .listen         = sock_no_listen,
 308        .shutdown       = sock_no_shutdown,
 309        .setsockopt     = mctp_setsockopt,
 310        .getsockopt     = mctp_getsockopt,
 311        .sendmsg        = mctp_sendmsg,
 312        .recvmsg        = mctp_recvmsg,
 313        .mmap           = sock_no_mmap,
 314        .sendpage       = sock_no_sendpage,
 315};
 316
 317static void mctp_sk_expire_keys(struct timer_list *timer)
 318{
 319        struct mctp_sock *msk = container_of(timer, struct mctp_sock,
 320                                             key_expiry);
 321        struct net *net = sock_net(&msk->sk);
 322        unsigned long next_expiry, flags;
 323        struct mctp_sk_key *key;
 324        struct hlist_node *tmp;
 325        bool next_expiry_valid = false;
 326
 327        spin_lock_irqsave(&net->mctp.keys_lock, flags);
 328
 329        hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
 330                spin_lock(&key->lock);
 331
 332                if (!time_after_eq(key->expiry, jiffies)) {
 333                        trace_mctp_key_release(key, MCTP_TRACE_KEY_TIMEOUT);
 334                        key->valid = false;
 335                        hlist_del_rcu(&key->hlist);
 336                        hlist_del_rcu(&key->sklist);
 337                        spin_unlock(&key->lock);
 338                        mctp_key_unref(key);
 339                        continue;
 340                }
 341
 342                if (next_expiry_valid) {
 343                        if (time_before(key->expiry, next_expiry))
 344                                next_expiry = key->expiry;
 345                } else {
 346                        next_expiry = key->expiry;
 347                        next_expiry_valid = true;
 348                }
 349                spin_unlock(&key->lock);
 350        }
 351
 352        spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
 353
 354        if (next_expiry_valid)
 355                mod_timer(timer, next_expiry);
 356}
 357
 358static int mctp_sk_init(struct sock *sk)
 359{
 360        struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
 361
 362        INIT_HLIST_HEAD(&msk->keys);
 363        timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0);
 364        return 0;
 365}
 366
 367static void mctp_sk_close(struct sock *sk, long timeout)
 368{
 369        struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
 370
 371        del_timer_sync(&msk->key_expiry);
 372        sk_common_release(sk);
 373}
 374
 375static int mctp_sk_hash(struct sock *sk)
 376{
 377        struct net *net = sock_net(sk);
 378
 379        mutex_lock(&net->mctp.bind_lock);
 380        sk_add_node_rcu(sk, &net->mctp.binds);
 381        mutex_unlock(&net->mctp.bind_lock);
 382
 383        return 0;
 384}
 385
 386static void mctp_sk_unhash(struct sock *sk)
 387{
 388        struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
 389        struct net *net = sock_net(sk);
 390        struct mctp_sk_key *key;
 391        struct hlist_node *tmp;
 392        unsigned long flags;
 393
 394        /* remove from any type-based binds */
 395        mutex_lock(&net->mctp.bind_lock);
 396        sk_del_node_init_rcu(sk);
 397        mutex_unlock(&net->mctp.bind_lock);
 398
 399        /* remove tag allocations */
 400        spin_lock_irqsave(&net->mctp.keys_lock, flags);
 401        hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
 402                hlist_del(&key->sklist);
 403                hlist_del(&key->hlist);
 404
 405                trace_mctp_key_release(key, MCTP_TRACE_KEY_CLOSED);
 406
 407                spin_lock(&key->lock);
 408                if (key->reasm_head)
 409                        kfree_skb(key->reasm_head);
 410                key->reasm_head = NULL;
 411                key->reasm_dead = true;
 412                key->valid = false;
 413                spin_unlock(&key->lock);
 414
 415                /* key is no longer on the lookup lists, unref */
 416                mctp_key_unref(key);
 417        }
 418        spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
 419}
 420
 421static struct proto mctp_proto = {
 422        .name           = "MCTP",
 423        .owner          = THIS_MODULE,
 424        .obj_size       = sizeof(struct mctp_sock),
 425        .init           = mctp_sk_init,
 426        .close          = mctp_sk_close,
 427        .hash           = mctp_sk_hash,
 428        .unhash         = mctp_sk_unhash,
 429};
 430
 431static int mctp_pf_create(struct net *net, struct socket *sock,
 432                          int protocol, int kern)
 433{
 434        const struct proto_ops *ops;
 435        struct proto *proto;
 436        struct sock *sk;
 437        int rc;
 438
 439        if (protocol)
 440                return -EPROTONOSUPPORT;
 441
 442        /* only datagram sockets are supported */
 443        if (sock->type != SOCK_DGRAM)
 444                return -ESOCKTNOSUPPORT;
 445
 446        proto = &mctp_proto;
 447        ops = &mctp_dgram_ops;
 448
 449        sock->state = SS_UNCONNECTED;
 450        sock->ops = ops;
 451
 452        sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern);
 453        if (!sk)
 454                return -ENOMEM;
 455
 456        sock_init_data(sock, sk);
 457
 458        rc = 0;
 459        if (sk->sk_prot->init)
 460                rc = sk->sk_prot->init(sk);
 461
 462        if (rc)
 463                goto err_sk_put;
 464
 465        return 0;
 466
 467err_sk_put:
 468        sock_orphan(sk);
 469        sock_put(sk);
 470        return rc;
 471}
 472
 473static struct net_proto_family mctp_pf = {
 474        .family = PF_MCTP,
 475        .create = mctp_pf_create,
 476        .owner = THIS_MODULE,
 477};
 478
 479static __init int mctp_init(void)
 480{
 481        int rc;
 482
 483        /* ensure our uapi tag definitions match the header format */
 484        BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO);
 485        BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK);
 486
 487        pr_info("mctp: management component transport protocol core\n");
 488
 489        rc = sock_register(&mctp_pf);
 490        if (rc)
 491                return rc;
 492
 493        rc = proto_register(&mctp_proto, 0);
 494        if (rc)
 495                goto err_unreg_sock;
 496
 497        rc = mctp_routes_init();
 498        if (rc)
 499                goto err_unreg_proto;
 500
 501        rc = mctp_neigh_init();
 502        if (rc)
 503                goto err_unreg_proto;
 504
 505        mctp_device_init();
 506
 507        return 0;
 508
 509err_unreg_proto:
 510        proto_unregister(&mctp_proto);
 511err_unreg_sock:
 512        sock_unregister(PF_MCTP);
 513
 514        return rc;
 515}
 516
 517static __exit void mctp_exit(void)
 518{
 519        mctp_device_exit();
 520        mctp_neigh_exit();
 521        mctp_routes_exit();
 522        proto_unregister(&mctp_proto);
 523        sock_unregister(PF_MCTP);
 524}
 525
 526subsys_initcall(mctp_init);
 527module_exit(mctp_exit);
 528
 529MODULE_DESCRIPTION("MCTP core");
 530MODULE_LICENSE("GPL v2");
 531MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>");
 532
 533MODULE_ALIAS_NETPROTO(PF_MCTP);
 534