linux/net/mctp/af_mctp.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/*
   3 * Management Component Transport Protocol (MCTP)
   4 *
   5 * Copyright (c) 2021 Code Construct
   6 * Copyright (c) 2021 Google
   7 */
   8
   9#include <linux/if_arp.h>
  10#include <linux/net.h>
  11#include <linux/mctp.h>
  12#include <linux/module.h>
  13#include <linux/socket.h>
  14
  15#include <net/mctp.h>
  16#include <net/mctpdevice.h>
  17#include <net/sock.h>
  18
  19/* socket implementation */
  20
  21static int mctp_release(struct socket *sock)
  22{
  23        struct sock *sk = sock->sk;
  24
  25        if (sk) {
  26                sock->sk = NULL;
  27                sk->sk_prot->close(sk, 0);
  28        }
  29
  30        return 0;
  31}
  32
  33static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen)
  34{
  35        struct sock *sk = sock->sk;
  36        struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
  37        struct sockaddr_mctp *smctp;
  38        int rc;
  39
  40        if (addrlen < sizeof(*smctp))
  41                return -EINVAL;
  42
  43        if (addr->sa_family != AF_MCTP)
  44                return -EAFNOSUPPORT;
  45
  46        if (!capable(CAP_NET_BIND_SERVICE))
  47                return -EACCES;
  48
  49        /* it's a valid sockaddr for MCTP, cast and do protocol checks */
  50        smctp = (struct sockaddr_mctp *)addr;
  51
  52        lock_sock(sk);
  53
  54        /* TODO: allow rebind */
  55        if (sk_hashed(sk)) {
  56                rc = -EADDRINUSE;
  57                goto out_release;
  58        }
  59        msk->bind_net = smctp->smctp_network;
  60        msk->bind_addr = smctp->smctp_addr.s_addr;
  61        msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */
  62
  63        rc = sk->sk_prot->hash(sk);
  64
  65out_release:
  66        release_sock(sk);
  67
  68        return rc;
  69}
  70
  71static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
  72{
  73        DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
  74        const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
  75        int rc, addrlen = msg->msg_namelen;
  76        struct sock *sk = sock->sk;
  77        struct mctp_skb_cb *cb;
  78        struct mctp_route *rt;
  79        struct sk_buff *skb;
  80
  81        if (addr) {
  82                if (addrlen < sizeof(struct sockaddr_mctp))
  83                        return -EINVAL;
  84                if (addr->smctp_family != AF_MCTP)
  85                        return -EINVAL;
  86                if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER))
  87                        return -EINVAL;
  88
  89        } else {
  90                /* TODO: connect()ed sockets */
  91                return -EDESTADDRREQ;
  92        }
  93
  94        if (!capable(CAP_NET_RAW))
  95                return -EACCES;
  96
  97        if (addr->smctp_network == MCTP_NET_ANY)
  98                addr->smctp_network = mctp_default_net(sock_net(sk));
  99
 100        rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
 101                               addr->smctp_addr.s_addr);
 102        if (!rt)
 103                return -EHOSTUNREACH;
 104
 105        skb = sock_alloc_send_skb(sk, hlen + 1 + len,
 106                                  msg->msg_flags & MSG_DONTWAIT, &rc);
 107        if (!skb)
 108                return rc;
 109
 110        skb_reserve(skb, hlen);
 111
 112        /* set type as fist byte in payload */
 113        *(u8 *)skb_put(skb, 1) = addr->smctp_type;
 114
 115        rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len);
 116        if (rc < 0) {
 117                kfree_skb(skb);
 118                return rc;
 119        }
 120
 121        /* set up cb */
 122        cb = __mctp_cb(skb);
 123        cb->net = addr->smctp_network;
 124
 125        rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
 126                               addr->smctp_tag);
 127
 128        return rc ? : len;
 129}
 130
 131static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 132                        int flags)
 133{
 134        DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
 135        struct sock *sk = sock->sk;
 136        struct sk_buff *skb;
 137        size_t msglen;
 138        u8 type;
 139        int rc;
 140
 141        if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK))
 142                return -EOPNOTSUPP;
 143
 144        skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &rc);
 145        if (!skb)
 146                return rc;
 147
 148        if (!skb->len) {
 149                rc = 0;
 150                goto out_free;
 151        }
 152
 153        /* extract message type, remove from data */
 154        type = *((u8 *)skb->data);
 155        msglen = skb->len - 1;
 156
 157        if (len < msglen)
 158                msg->msg_flags |= MSG_TRUNC;
 159        else
 160                len = msglen;
 161
 162        rc = skb_copy_datagram_msg(skb, 1, msg, len);
 163        if (rc < 0)
 164                goto out_free;
 165
 166        sock_recv_ts_and_drops(msg, sk, skb);
 167
 168        if (addr) {
 169                struct mctp_skb_cb *cb = mctp_cb(skb);
 170                /* TODO: expand mctp_skb_cb for header fields? */
 171                struct mctp_hdr *hdr = mctp_hdr(skb);
 172
 173                addr = msg->msg_name;
 174                addr->smctp_family = AF_MCTP;
 175                addr->smctp_network = cb->net;
 176                addr->smctp_addr.s_addr = hdr->src;
 177                addr->smctp_type = type;
 178                addr->smctp_tag = hdr->flags_seq_tag &
 179                                        (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
 180                msg->msg_namelen = sizeof(*addr);
 181        }
 182
 183        rc = len;
 184
 185        if (flags & MSG_TRUNC)
 186                rc = msglen;
 187
 188out_free:
 189        skb_free_datagram(sk, skb);
 190        return rc;
 191}
 192
 193static int mctp_setsockopt(struct socket *sock, int level, int optname,
 194                           sockptr_t optval, unsigned int optlen)
 195{
 196        return -EINVAL;
 197}
 198
 199static int mctp_getsockopt(struct socket *sock, int level, int optname,
 200                           char __user *optval, int __user *optlen)
 201{
 202        return -EINVAL;
 203}
 204
 205static const struct proto_ops mctp_dgram_ops = {
 206        .family         = PF_MCTP,
 207        .release        = mctp_release,
 208        .bind           = mctp_bind,
 209        .connect        = sock_no_connect,
 210        .socketpair     = sock_no_socketpair,
 211        .accept         = sock_no_accept,
 212        .getname        = sock_no_getname,
 213        .poll           = datagram_poll,
 214        .ioctl          = sock_no_ioctl,
 215        .gettstamp      = sock_gettstamp,
 216        .listen         = sock_no_listen,
 217        .shutdown       = sock_no_shutdown,
 218        .setsockopt     = mctp_setsockopt,
 219        .getsockopt     = mctp_getsockopt,
 220        .sendmsg        = mctp_sendmsg,
 221        .recvmsg        = mctp_recvmsg,
 222        .mmap           = sock_no_mmap,
 223        .sendpage       = sock_no_sendpage,
 224};
 225
 226static int mctp_sk_init(struct sock *sk)
 227{
 228        struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
 229
 230        INIT_HLIST_HEAD(&msk->keys);
 231        return 0;
 232}
 233
 234static void mctp_sk_close(struct sock *sk, long timeout)
 235{
 236        sk_common_release(sk);
 237}
 238
 239static int mctp_sk_hash(struct sock *sk)
 240{
 241        struct net *net = sock_net(sk);
 242
 243        mutex_lock(&net->mctp.bind_lock);
 244        sk_add_node_rcu(sk, &net->mctp.binds);
 245        mutex_unlock(&net->mctp.bind_lock);
 246
 247        return 0;
 248}
 249
 250static void mctp_sk_unhash(struct sock *sk)
 251{
 252        struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
 253        struct net *net = sock_net(sk);
 254        struct mctp_sk_key *key;
 255        struct hlist_node *tmp;
 256        unsigned long flags;
 257
 258        /* remove from any type-based binds */
 259        mutex_lock(&net->mctp.bind_lock);
 260        sk_del_node_init_rcu(sk);
 261        mutex_unlock(&net->mctp.bind_lock);
 262
 263        /* remove tag allocations */
 264        spin_lock_irqsave(&net->mctp.keys_lock, flags);
 265        hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
 266                hlist_del_rcu(&key->sklist);
 267                hlist_del_rcu(&key->hlist);
 268
 269                spin_lock(&key->reasm_lock);
 270                if (key->reasm_head)
 271                        kfree_skb(key->reasm_head);
 272                key->reasm_head = NULL;
 273                key->reasm_dead = true;
 274                spin_unlock(&key->reasm_lock);
 275
 276                kfree_rcu(key, rcu);
 277        }
 278        spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
 279
 280        synchronize_rcu();
 281}
 282
 283static struct proto mctp_proto = {
 284        .name           = "MCTP",
 285        .owner          = THIS_MODULE,
 286        .obj_size       = sizeof(struct mctp_sock),
 287        .init           = mctp_sk_init,
 288        .close          = mctp_sk_close,
 289        .hash           = mctp_sk_hash,
 290        .unhash         = mctp_sk_unhash,
 291};
 292
 293static int mctp_pf_create(struct net *net, struct socket *sock,
 294                          int protocol, int kern)
 295{
 296        const struct proto_ops *ops;
 297        struct proto *proto;
 298        struct sock *sk;
 299        int rc;
 300
 301        if (protocol)
 302                return -EPROTONOSUPPORT;
 303
 304        /* only datagram sockets are supported */
 305        if (sock->type != SOCK_DGRAM)
 306                return -ESOCKTNOSUPPORT;
 307
 308        proto = &mctp_proto;
 309        ops = &mctp_dgram_ops;
 310
 311        sock->state = SS_UNCONNECTED;
 312        sock->ops = ops;
 313
 314        sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern);
 315        if (!sk)
 316                return -ENOMEM;
 317
 318        sock_init_data(sock, sk);
 319
 320        rc = 0;
 321        if (sk->sk_prot->init)
 322                rc = sk->sk_prot->init(sk);
 323
 324        if (rc)
 325                goto err_sk_put;
 326
 327        return 0;
 328
 329err_sk_put:
 330        sock_orphan(sk);
 331        sock_put(sk);
 332        return rc;
 333}
 334
 335static struct net_proto_family mctp_pf = {
 336        .family = PF_MCTP,
 337        .create = mctp_pf_create,
 338        .owner = THIS_MODULE,
 339};
 340
 341static __init int mctp_init(void)
 342{
 343        int rc;
 344
 345        /* ensure our uapi tag definitions match the header format */
 346        BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO);
 347        BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK);
 348
 349        pr_info("mctp: management component transport protocol core\n");
 350
 351        rc = sock_register(&mctp_pf);
 352        if (rc)
 353                return rc;
 354
 355        rc = proto_register(&mctp_proto, 0);
 356        if (rc)
 357                goto err_unreg_sock;
 358
 359        rc = mctp_routes_init();
 360        if (rc)
 361                goto err_unreg_proto;
 362
 363        rc = mctp_neigh_init();
 364        if (rc)
 365                goto err_unreg_proto;
 366
 367        mctp_device_init();
 368
 369        return 0;
 370
 371err_unreg_proto:
 372        proto_unregister(&mctp_proto);
 373err_unreg_sock:
 374        sock_unregister(PF_MCTP);
 375
 376        return rc;
 377}
 378
 379static __exit void mctp_exit(void)
 380{
 381        mctp_device_exit();
 382        mctp_neigh_exit();
 383        mctp_routes_exit();
 384        proto_unregister(&mctp_proto);
 385        sock_unregister(PF_MCTP);
 386}
 387
 388module_init(mctp_init);
 389module_exit(mctp_exit);
 390
 391MODULE_DESCRIPTION("MCTP core");
 392MODULE_LICENSE("GPL v2");
 393MODULE_AUTHOR("Jeremy Kerr <jk@codeconstruct.com.au>");
 394
 395MODULE_ALIAS_NETPROTO(PF_MCTP);
 396