linux/fs/ksmbd/transport_tcp.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-or-later
   2/*
   3 *   Copyright (C) 2016 Namjae Jeon <linkinjeon@kernel.org>
   4 *   Copyright (C) 2018 Samsung Electronics Co., Ltd.
   5 */
   6
   7#include <linux/freezer.h>
   8
   9#include "smb_common.h"
  10#include "server.h"
  11#include "auth.h"
  12#include "connection.h"
  13#include "transport_tcp.h"
  14
  15#define IFACE_STATE_DOWN                BIT(0)
  16#define IFACE_STATE_CONFIGURED          BIT(1)
  17
  18struct interface {
  19        struct task_struct      *ksmbd_kthread;
  20        struct socket           *ksmbd_socket;
  21        struct list_head        entry;
  22        char                    *name;
  23        struct mutex            sock_release_lock;
  24        int                     state;
  25};
  26
  27static LIST_HEAD(iface_list);
  28
  29static int bind_additional_ifaces;
  30
  31struct tcp_transport {
  32        struct ksmbd_transport          transport;
  33        struct socket                   *sock;
  34        struct kvec                     *iov;
  35        unsigned int                    nr_iov;
  36};
  37
  38static struct ksmbd_transport_ops ksmbd_tcp_transport_ops;
  39
  40static void tcp_stop_kthread(struct task_struct *kthread);
  41static struct interface *alloc_iface(char *ifname);
  42
  43#define KSMBD_TRANS(t)  (&(t)->transport)
  44#define TCP_TRANS(t)    ((struct tcp_transport *)container_of(t, \
  45                                struct tcp_transport, transport))
  46
  47static inline void ksmbd_tcp_nodelay(struct socket *sock)
  48{
  49        tcp_sock_set_nodelay(sock->sk);
  50}
  51
  52static inline void ksmbd_tcp_reuseaddr(struct socket *sock)
  53{
  54        sock_set_reuseaddr(sock->sk);
  55}
  56
  57static inline void ksmbd_tcp_rcv_timeout(struct socket *sock, s64 secs)
  58{
  59        lock_sock(sock->sk);
  60        if (secs && secs < MAX_SCHEDULE_TIMEOUT / HZ - 1)
  61                sock->sk->sk_rcvtimeo = secs * HZ;
  62        else
  63                sock->sk->sk_rcvtimeo = MAX_SCHEDULE_TIMEOUT;
  64        release_sock(sock->sk);
  65}
  66
  67static inline void ksmbd_tcp_snd_timeout(struct socket *sock, s64 secs)
  68{
  69        sock_set_sndtimeo(sock->sk, secs);
  70}
  71
  72static struct tcp_transport *alloc_transport(struct socket *client_sk)
  73{
  74        struct tcp_transport *t;
  75        struct ksmbd_conn *conn;
  76
  77        t = kzalloc(sizeof(*t), GFP_KERNEL);
  78        if (!t)
  79                return NULL;
  80        t->sock = client_sk;
  81
  82        conn = ksmbd_conn_alloc();
  83        if (!conn) {
  84                kfree(t);
  85                return NULL;
  86        }
  87
  88        conn->transport = KSMBD_TRANS(t);
  89        KSMBD_TRANS(t)->conn = conn;
  90        KSMBD_TRANS(t)->ops = &ksmbd_tcp_transport_ops;
  91        return t;
  92}
  93
  94static void free_transport(struct tcp_transport *t)
  95{
  96        kernel_sock_shutdown(t->sock, SHUT_RDWR);
  97        sock_release(t->sock);
  98        t->sock = NULL;
  99
 100        ksmbd_conn_free(KSMBD_TRANS(t)->conn);
 101        kfree(t->iov);
 102        kfree(t);
 103}
 104
 105/**
 106 * kvec_array_init() - initialize a IO vector segment
 107 * @new:        IO vector to be initialized
 108 * @iov:        base IO vector
 109 * @nr_segs:    number of segments in base iov
 110 * @bytes:      total iovec length so far for read
 111 *
 112 * Return:      Number of IO segments
 113 */
 114static unsigned int kvec_array_init(struct kvec *new, struct kvec *iov,
 115                                    unsigned int nr_segs, size_t bytes)
 116{
 117        size_t base = 0;
 118
 119        while (bytes || !iov->iov_len) {
 120                int copy = min(bytes, iov->iov_len);
 121
 122                bytes -= copy;
 123                base += copy;
 124                if (iov->iov_len == base) {
 125                        iov++;
 126                        nr_segs--;
 127                        base = 0;
 128                }
 129        }
 130
 131        memcpy(new, iov, sizeof(*iov) * nr_segs);
 132        new->iov_base += base;
 133        new->iov_len -= base;
 134        return nr_segs;
 135}
 136
 137/**
 138 * get_conn_iovec() - get connection iovec for reading from socket
 139 * @t:          TCP transport instance
 140 * @nr_segs:    number of segments in iov
 141 *
 142 * Return:      return existing or newly allocate iovec
 143 */
 144static struct kvec *get_conn_iovec(struct tcp_transport *t, unsigned int nr_segs)
 145{
 146        struct kvec *new_iov;
 147
 148        if (t->iov && nr_segs <= t->nr_iov)
 149                return t->iov;
 150
 151        /* not big enough -- allocate a new one and release the old */
 152        new_iov = kmalloc_array(nr_segs, sizeof(*new_iov), GFP_KERNEL);
 153        if (new_iov) {
 154                kfree(t->iov);
 155                t->iov = new_iov;
 156                t->nr_iov = nr_segs;
 157        }
 158        return new_iov;
 159}
 160
 161static unsigned short ksmbd_tcp_get_port(const struct sockaddr *sa)
 162{
 163        switch (sa->sa_family) {
 164        case AF_INET:
 165                return ntohs(((struct sockaddr_in *)sa)->sin_port);
 166        case AF_INET6:
 167                return ntohs(((struct sockaddr_in6 *)sa)->sin6_port);
 168        }
 169        return 0;
 170}
 171
 172/**
 173 * ksmbd_tcp_new_connection() - create a new tcp session on mount
 174 * @client_sk:  socket associated with new connection
 175 *
 176 * whenever a new connection is requested, create a conn thread
 177 * (session thread) to handle new incoming smb requests from the connection
 178 *
 179 * Return:      0 on success, otherwise error
 180 */
 181static int ksmbd_tcp_new_connection(struct socket *client_sk)
 182{
 183        struct sockaddr *csin;
 184        int rc = 0;
 185        struct tcp_transport *t;
 186
 187        t = alloc_transport(client_sk);
 188        if (!t)
 189                return -ENOMEM;
 190
 191        csin = KSMBD_TCP_PEER_SOCKADDR(KSMBD_TRANS(t)->conn);
 192        if (kernel_getpeername(client_sk, csin) < 0) {
 193                pr_err("client ip resolution failed\n");
 194                rc = -EINVAL;
 195                goto out_error;
 196        }
 197
 198        KSMBD_TRANS(t)->handler = kthread_run(ksmbd_conn_handler_loop,
 199                                              KSMBD_TRANS(t)->conn,
 200                                              "ksmbd:%u",
 201                                              ksmbd_tcp_get_port(csin));
 202        if (IS_ERR(KSMBD_TRANS(t)->handler)) {
 203                pr_err("cannot start conn thread\n");
 204                rc = PTR_ERR(KSMBD_TRANS(t)->handler);
 205                free_transport(t);
 206        }
 207        return rc;
 208
 209out_error:
 210        free_transport(t);
 211        return rc;
 212}
 213
 214/**
 215 * ksmbd_kthread_fn() - listen to new SMB connections and callback server
 216 * @p:          arguments to forker thread
 217 *
 218 * Return:      0 on success, error number otherwise
 219 */
 220static int ksmbd_kthread_fn(void *p)
 221{
 222        struct socket *client_sk = NULL;
 223        struct interface *iface = (struct interface *)p;
 224        int ret;
 225
 226        while (!kthread_should_stop()) {
 227                mutex_lock(&iface->sock_release_lock);
 228                if (!iface->ksmbd_socket) {
 229                        mutex_unlock(&iface->sock_release_lock);
 230                        break;
 231                }
 232                ret = kernel_accept(iface->ksmbd_socket, &client_sk,
 233                                    O_NONBLOCK);
 234                mutex_unlock(&iface->sock_release_lock);
 235                if (ret) {
 236                        if (ret == -EAGAIN)
 237                                /* check for new connections every 100 msecs */
 238                                schedule_timeout_interruptible(HZ / 10);
 239                        continue;
 240                }
 241
 242                ksmbd_debug(CONN, "connect success: accepted new connection\n");
 243                client_sk->sk->sk_rcvtimeo = KSMBD_TCP_RECV_TIMEOUT;
 244                client_sk->sk->sk_sndtimeo = KSMBD_TCP_SEND_TIMEOUT;
 245
 246                ksmbd_tcp_new_connection(client_sk);
 247        }
 248
 249        ksmbd_debug(CONN, "releasing socket\n");
 250        return 0;
 251}
 252
 253/**
 254 * ksmbd_tcp_run_kthread() - start forker thread
 255 * @iface: pointer to struct interface
 256 *
 257 * start forker thread(ksmbd/0) at module init time to listen
 258 * on port 445 for new SMB connection requests. It creates per connection
 259 * server threads(ksmbd/x)
 260 *
 261 * Return:      0 on success or error number
 262 */
 263static int ksmbd_tcp_run_kthread(struct interface *iface)
 264{
 265        int rc;
 266        struct task_struct *kthread;
 267
 268        kthread = kthread_run(ksmbd_kthread_fn, (void *)iface, "ksmbd-%s",
 269                              iface->name);
 270        if (IS_ERR(kthread)) {
 271                rc = PTR_ERR(kthread);
 272                return rc;
 273        }
 274        iface->ksmbd_kthread = kthread;
 275
 276        return 0;
 277}
 278
 279/**
 280 * ksmbd_tcp_readv() - read data from socket in given iovec
 281 * @t:          TCP transport instance
 282 * @iov_orig:   base IO vector
 283 * @nr_segs:    number of segments in base iov
 284 * @to_read:    number of bytes to read from socket
 285 *
 286 * Return:      on success return number of bytes read from socket,
 287 *              otherwise return error number
 288 */
 289static int ksmbd_tcp_readv(struct tcp_transport *t, struct kvec *iov_orig,
 290                           unsigned int nr_segs, unsigned int to_read)
 291{
 292        int length = 0;
 293        int total_read;
 294        unsigned int segs;
 295        struct msghdr ksmbd_msg;
 296        struct kvec *iov;
 297        struct ksmbd_conn *conn = KSMBD_TRANS(t)->conn;
 298
 299        iov = get_conn_iovec(t, nr_segs);
 300        if (!iov)
 301                return -ENOMEM;
 302
 303        ksmbd_msg.msg_control = NULL;
 304        ksmbd_msg.msg_controllen = 0;
 305
 306        for (total_read = 0; to_read; total_read += length, to_read -= length) {
 307                try_to_freeze();
 308
 309                if (!ksmbd_conn_alive(conn)) {
 310                        total_read = -ESHUTDOWN;
 311                        break;
 312                }
 313                segs = kvec_array_init(iov, iov_orig, nr_segs, total_read);
 314
 315                length = kernel_recvmsg(t->sock, &ksmbd_msg,
 316                                        iov, segs, to_read, 0);
 317
 318                if (length == -EINTR) {
 319                        total_read = -ESHUTDOWN;
 320                        break;
 321                } else if (conn->status == KSMBD_SESS_NEED_RECONNECT) {
 322                        total_read = -EAGAIN;
 323                        break;
 324                } else if (length == -ERESTARTSYS || length == -EAGAIN) {
 325                        usleep_range(1000, 2000);
 326                        length = 0;
 327                        continue;
 328                } else if (length <= 0) {
 329                        total_read = -EAGAIN;
 330                        break;
 331                }
 332        }
 333        return total_read;
 334}
 335
 336/**
 337 * ksmbd_tcp_read() - read data from socket in given buffer
 338 * @t:          TCP transport instance
 339 * @buf:        buffer to store read data from socket
 340 * @to_read:    number of bytes to read from socket
 341 *
 342 * Return:      on success return number of bytes read from socket,
 343 *              otherwise return error number
 344 */
 345static int ksmbd_tcp_read(struct ksmbd_transport *t, char *buf, unsigned int to_read)
 346{
 347        struct kvec iov;
 348
 349        iov.iov_base = buf;
 350        iov.iov_len = to_read;
 351
 352        return ksmbd_tcp_readv(TCP_TRANS(t), &iov, 1, to_read);
 353}
 354
 355static int ksmbd_tcp_writev(struct ksmbd_transport *t, struct kvec *iov,
 356                            int nvecs, int size, bool need_invalidate,
 357                            unsigned int remote_key)
 358
 359{
 360        struct msghdr smb_msg = {.msg_flags = MSG_NOSIGNAL};
 361
 362        return kernel_sendmsg(TCP_TRANS(t)->sock, &smb_msg, iov, nvecs, size);
 363}
 364
 365static void ksmbd_tcp_disconnect(struct ksmbd_transport *t)
 366{
 367        free_transport(TCP_TRANS(t));
 368}
 369
 370static void tcp_destroy_socket(struct socket *ksmbd_socket)
 371{
 372        int ret;
 373
 374        if (!ksmbd_socket)
 375                return;
 376
 377        /* set zero to timeout */
 378        ksmbd_tcp_rcv_timeout(ksmbd_socket, 0);
 379        ksmbd_tcp_snd_timeout(ksmbd_socket, 0);
 380
 381        ret = kernel_sock_shutdown(ksmbd_socket, SHUT_RDWR);
 382        if (ret)
 383                pr_err("Failed to shutdown socket: %d\n", ret);
 384        sock_release(ksmbd_socket);
 385}
 386
 387/**
 388 * create_socket - create socket for ksmbd/0
 389 *
 390 * Return:      0 on success, error number otherwise
 391 */
 392static int create_socket(struct interface *iface)
 393{
 394        int ret;
 395        struct sockaddr_in6 sin6;
 396        struct sockaddr_in sin;
 397        struct socket *ksmbd_socket;
 398        bool ipv4 = false;
 399
 400        ret = sock_create(PF_INET6, SOCK_STREAM, IPPROTO_TCP, &ksmbd_socket);
 401        if (ret) {
 402                pr_err("Can't create socket for ipv6, try ipv4: %d\n", ret);
 403                ret = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP,
 404                                  &ksmbd_socket);
 405                if (ret) {
 406                        pr_err("Can't create socket for ipv4: %d\n", ret);
 407                        goto out_clear;
 408                }
 409
 410                sin.sin_family = PF_INET;
 411                sin.sin_addr.s_addr = htonl(INADDR_ANY);
 412                sin.sin_port = htons(server_conf.tcp_port);
 413                ipv4 = true;
 414        } else {
 415                sin6.sin6_family = PF_INET6;
 416                sin6.sin6_addr = in6addr_any;
 417                sin6.sin6_port = htons(server_conf.tcp_port);
 418        }
 419
 420        ksmbd_tcp_nodelay(ksmbd_socket);
 421        ksmbd_tcp_reuseaddr(ksmbd_socket);
 422
 423        ret = sock_setsockopt(ksmbd_socket,
 424                              SOL_SOCKET,
 425                              SO_BINDTODEVICE,
 426                              KERNEL_SOCKPTR(iface->name),
 427                              strlen(iface->name));
 428        if (ret != -ENODEV && ret < 0) {
 429                pr_err("Failed to set SO_BINDTODEVICE: %d\n", ret);
 430                goto out_error;
 431        }
 432
 433        if (ipv4)
 434                ret = kernel_bind(ksmbd_socket, (struct sockaddr *)&sin,
 435                                  sizeof(sin));
 436        else
 437                ret = kernel_bind(ksmbd_socket, (struct sockaddr *)&sin6,
 438                                  sizeof(sin6));
 439        if (ret) {
 440                pr_err("Failed to bind socket: %d\n", ret);
 441                goto out_error;
 442        }
 443
 444        ksmbd_socket->sk->sk_rcvtimeo = KSMBD_TCP_RECV_TIMEOUT;
 445        ksmbd_socket->sk->sk_sndtimeo = KSMBD_TCP_SEND_TIMEOUT;
 446
 447        ret = kernel_listen(ksmbd_socket, KSMBD_SOCKET_BACKLOG);
 448        if (ret) {
 449                pr_err("Port listen() error: %d\n", ret);
 450                goto out_error;
 451        }
 452
 453        iface->ksmbd_socket = ksmbd_socket;
 454        ret = ksmbd_tcp_run_kthread(iface);
 455        if (ret) {
 456                pr_err("Can't start ksmbd main kthread: %d\n", ret);
 457                goto out_error;
 458        }
 459        iface->state = IFACE_STATE_CONFIGURED;
 460
 461        return 0;
 462
 463out_error:
 464        tcp_destroy_socket(ksmbd_socket);
 465out_clear:
 466        iface->ksmbd_socket = NULL;
 467        return ret;
 468}
 469
 470static int ksmbd_netdev_event(struct notifier_block *nb, unsigned long event,
 471                              void *ptr)
 472{
 473        struct net_device *netdev = netdev_notifier_info_to_dev(ptr);
 474        struct interface *iface;
 475        int ret, found = 0;
 476
 477        switch (event) {
 478        case NETDEV_UP:
 479                if (netif_is_bridge_port(netdev))
 480                        return NOTIFY_OK;
 481
 482                list_for_each_entry(iface, &iface_list, entry) {
 483                        if (!strcmp(iface->name, netdev->name)) {
 484                                found = 1;
 485                                if (iface->state != IFACE_STATE_DOWN)
 486                                        break;
 487                                ret = create_socket(iface);
 488                                if (ret)
 489                                        return NOTIFY_OK;
 490                                break;
 491                        }
 492                }
 493                if (!found && bind_additional_ifaces) {
 494                        iface = alloc_iface(kstrdup(netdev->name, GFP_KERNEL));
 495                        if (!iface)
 496                                return NOTIFY_OK;
 497                        ret = create_socket(iface);
 498                        if (ret)
 499                                break;
 500                }
 501                break;
 502        case NETDEV_DOWN:
 503                list_for_each_entry(iface, &iface_list, entry) {
 504                        if (!strcmp(iface->name, netdev->name) &&
 505                            iface->state == IFACE_STATE_CONFIGURED) {
 506                                tcp_stop_kthread(iface->ksmbd_kthread);
 507                                iface->ksmbd_kthread = NULL;
 508                                mutex_lock(&iface->sock_release_lock);
 509                                tcp_destroy_socket(iface->ksmbd_socket);
 510                                iface->ksmbd_socket = NULL;
 511                                mutex_unlock(&iface->sock_release_lock);
 512
 513                                iface->state = IFACE_STATE_DOWN;
 514                                break;
 515                        }
 516                }
 517                break;
 518        }
 519
 520        return NOTIFY_DONE;
 521}
 522
 523static struct notifier_block ksmbd_netdev_notifier = {
 524        .notifier_call = ksmbd_netdev_event,
 525};
 526
 527int ksmbd_tcp_init(void)
 528{
 529        register_netdevice_notifier(&ksmbd_netdev_notifier);
 530
 531        return 0;
 532}
 533
 534static void tcp_stop_kthread(struct task_struct *kthread)
 535{
 536        int ret;
 537
 538        if (!kthread)
 539                return;
 540
 541        ret = kthread_stop(kthread);
 542        if (ret)
 543                pr_err("failed to stop forker thread\n");
 544}
 545
 546void ksmbd_tcp_destroy(void)
 547{
 548        struct interface *iface, *tmp;
 549
 550        unregister_netdevice_notifier(&ksmbd_netdev_notifier);
 551
 552        list_for_each_entry_safe(iface, tmp, &iface_list, entry) {
 553                list_del(&iface->entry);
 554                kfree(iface->name);
 555                kfree(iface);
 556        }
 557}
 558
 559static struct interface *alloc_iface(char *ifname)
 560{
 561        struct interface *iface;
 562
 563        if (!ifname)
 564                return NULL;
 565
 566        iface = kzalloc(sizeof(struct interface), GFP_KERNEL);
 567        if (!iface) {
 568                kfree(ifname);
 569                return NULL;
 570        }
 571
 572        iface->name = ifname;
 573        iface->state = IFACE_STATE_DOWN;
 574        list_add(&iface->entry, &iface_list);
 575        mutex_init(&iface->sock_release_lock);
 576        return iface;
 577}
 578
 579int ksmbd_tcp_set_interfaces(char *ifc_list, int ifc_list_sz)
 580{
 581        int sz = 0;
 582
 583        if (!ifc_list_sz) {
 584                struct net_device *netdev;
 585
 586                rtnl_lock();
 587                for_each_netdev(&init_net, netdev) {
 588                        if (netif_is_bridge_port(netdev))
 589                                continue;
 590                        if (!alloc_iface(kstrdup(netdev->name, GFP_KERNEL)))
 591                                return -ENOMEM;
 592                }
 593                rtnl_unlock();
 594                bind_additional_ifaces = 1;
 595                return 0;
 596        }
 597
 598        while (ifc_list_sz > 0) {
 599                if (!alloc_iface(kstrdup(ifc_list, GFP_KERNEL)))
 600                        return -ENOMEM;
 601
 602                sz = strlen(ifc_list);
 603                if (!sz)
 604                        break;
 605
 606                ifc_list += sz + 1;
 607                ifc_list_sz -= (sz + 1);
 608        }
 609
 610        bind_additional_ifaces = 0;
 611
 612        return 0;
 613}
 614
 615static struct ksmbd_transport_ops ksmbd_tcp_transport_ops = {
 616        .read           = ksmbd_tcp_read,
 617        .writev         = ksmbd_tcp_writev,
 618        .disconnect     = ksmbd_tcp_disconnect,
 619};
 620