linux/net/sunrpc/auth_gss/auth_gss.c
<<
>>
Prefs
   1/*
   2 * linux/net/sunrpc/auth_gss/auth_gss.c
   3 *
   4 * RPCSEC_GSS client authentication.
   5 *
   6 *  Copyright (c) 2000 The Regents of the University of Michigan.
   7 *  All rights reserved.
   8 *
   9 *  Dug Song       <dugsong@monkey.org>
  10 *  Andy Adamson   <andros@umich.edu>
  11 *
  12 *  Redistribution and use in source and binary forms, with or without
  13 *  modification, are permitted provided that the following conditions
  14 *  are met:
  15 *
  16 *  1. Redistributions of source code must retain the above copyright
  17 *     notice, this list of conditions and the following disclaimer.
  18 *  2. Redistributions in binary form must reproduce the above copyright
  19 *     notice, this list of conditions and the following disclaimer in the
  20 *     documentation and/or other materials provided with the distribution.
  21 *  3. Neither the name of the University nor the names of its
  22 *     contributors may be used to endorse or promote products derived
  23 *     from this software without specific prior written permission.
  24 *
  25 *  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
  26 *  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
  27 *  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  28 *  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
  29 *  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  30 *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  31 *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
  32 *  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
  33 *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  34 *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  35 *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  36 */
  37
  38
  39#include <linux/module.h>
  40#include <linux/init.h>
  41#include <linux/types.h>
  42#include <linux/slab.h>
  43#include <linux/sched.h>
  44#include <linux/pagemap.h>
  45#include <linux/sunrpc/clnt.h>
  46#include <linux/sunrpc/auth.h>
  47#include <linux/sunrpc/auth_gss.h>
  48#include <linux/sunrpc/svcauth_gss.h>
  49#include <linux/sunrpc/gss_err.h>
  50#include <linux/workqueue.h>
  51#include <linux/sunrpc/rpc_pipe_fs.h>
  52#include <linux/sunrpc/gss_api.h>
  53#include <asm/uaccess.h>
  54
  55static const struct rpc_authops authgss_ops;
  56
  57static const struct rpc_credops gss_credops;
  58static const struct rpc_credops gss_nullops;
  59
  60#define GSS_RETRY_EXPIRED 5
  61static unsigned int gss_expired_cred_retry_delay = GSS_RETRY_EXPIRED;
  62
  63#ifdef RPC_DEBUG
  64# define RPCDBG_FACILITY        RPCDBG_AUTH
  65#endif
  66
  67#define GSS_CRED_SLACK          (RPC_MAX_AUTH_SIZE * 2)
  68/* length of a krb5 verifier (48), plus data added before arguments when
  69 * using integrity (two 4-byte integers): */
  70#define GSS_VERF_SLACK          100
  71
  72struct gss_auth {
  73        struct kref kref;
  74        struct rpc_auth rpc_auth;
  75        struct gss_api_mech *mech;
  76        enum rpc_gss_svc service;
  77        struct rpc_clnt *client;
  78        /*
  79         * There are two upcall pipes; dentry[1], named "gssd", is used
  80         * for the new text-based upcall; dentry[0] is named after the
  81         * mechanism (for example, "krb5") and exists for
  82         * backwards-compatibility with older gssd's.
  83         */
  84        struct rpc_pipe *pipe[2];
  85};
  86
  87/* pipe_version >= 0 if and only if someone has a pipe open. */
  88static int pipe_version = -1;
  89static atomic_t pipe_users = ATOMIC_INIT(0);
  90static DEFINE_SPINLOCK(pipe_version_lock);
  91static struct rpc_wait_queue pipe_version_rpc_waitqueue;
  92static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue);
  93
  94static void gss_free_ctx(struct gss_cl_ctx *);
  95static const struct rpc_pipe_ops gss_upcall_ops_v0;
  96static const struct rpc_pipe_ops gss_upcall_ops_v1;
  97
  98static inline struct gss_cl_ctx *
  99gss_get_ctx(struct gss_cl_ctx *ctx)
 100{
 101        atomic_inc(&ctx->count);
 102        return ctx;
 103}
 104
 105static inline void
 106gss_put_ctx(struct gss_cl_ctx *ctx)
 107{
 108        if (atomic_dec_and_test(&ctx->count))
 109                gss_free_ctx(ctx);
 110}
 111
 112/* gss_cred_set_ctx:
 113 * called by gss_upcall_callback and gss_create_upcall in order
 114 * to set the gss context. The actual exchange of an old context
 115 * and a new one is protected by the pipe->lock.
 116 */
 117static void
 118gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx)
 119{
 120        struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
 121
 122        if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
 123                return;
 124        gss_get_ctx(ctx);
 125        rcu_assign_pointer(gss_cred->gc_ctx, ctx);
 126        set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
 127        smp_mb__before_clear_bit();
 128        clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags);
 129}
 130
 131static const void *
 132simple_get_bytes(const void *p, const void *end, void *res, size_t len)
 133{
 134        const void *q = (const void *)((const char *)p + len);
 135        if (unlikely(q > end || q < p))
 136                return ERR_PTR(-EFAULT);
 137        memcpy(res, p, len);
 138        return q;
 139}
 140
 141static inline const void *
 142simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest)
 143{
 144        const void *q;
 145        unsigned int len;
 146
 147        p = simple_get_bytes(p, end, &len, sizeof(len));
 148        if (IS_ERR(p))
 149                return p;
 150        q = (const void *)((const char *)p + len);
 151        if (unlikely(q > end || q < p))
 152                return ERR_PTR(-EFAULT);
 153        dest->data = kmemdup(p, len, GFP_NOFS);
 154        if (unlikely(dest->data == NULL))
 155                return ERR_PTR(-ENOMEM);
 156        dest->len = len;
 157        return q;
 158}
 159
 160static struct gss_cl_ctx *
 161gss_cred_get_ctx(struct rpc_cred *cred)
 162{
 163        struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
 164        struct gss_cl_ctx *ctx = NULL;
 165
 166        rcu_read_lock();
 167        if (gss_cred->gc_ctx)
 168                ctx = gss_get_ctx(gss_cred->gc_ctx);
 169        rcu_read_unlock();
 170        return ctx;
 171}
 172
 173static struct gss_cl_ctx *
 174gss_alloc_context(void)
 175{
 176        struct gss_cl_ctx *ctx;
 177
 178        ctx = kzalloc(sizeof(*ctx), GFP_NOFS);
 179        if (ctx != NULL) {
 180                ctx->gc_proc = RPC_GSS_PROC_DATA;
 181                ctx->gc_seq = 1;        /* NetApp 6.4R1 doesn't accept seq. no. 0 */
 182                spin_lock_init(&ctx->gc_seq_lock);
 183                atomic_set(&ctx->count,1);
 184        }
 185        return ctx;
 186}
 187
 188#define GSSD_MIN_TIMEOUT (60 * 60)
 189static const void *
 190gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct gss_api_mech *gm)
 191{
 192        const void *q;
 193        unsigned int seclen;
 194        unsigned int timeout;
 195        u32 window_size;
 196        int ret;
 197
 198        /* First unsigned int gives the lifetime (in seconds) of the cred */
 199        p = simple_get_bytes(p, end, &timeout, sizeof(timeout));
 200        if (IS_ERR(p))
 201                goto err;
 202        if (timeout == 0)
 203                timeout = GSSD_MIN_TIMEOUT;
 204        ctx->gc_expiry = jiffies + (unsigned long)timeout * HZ * 3 / 4;
 205        /* Sequence number window. Determines the maximum number of simultaneous requests */
 206        p = simple_get_bytes(p, end, &window_size, sizeof(window_size));
 207        if (IS_ERR(p))
 208                goto err;
 209        ctx->gc_win = window_size;
 210        /* gssd signals an error by passing ctx->gc_win = 0: */
 211        if (ctx->gc_win == 0) {
 212                /*
 213                 * in which case, p points to an error code. Anything other
 214                 * than -EKEYEXPIRED gets converted to -EACCES.
 215                 */
 216                p = simple_get_bytes(p, end, &ret, sizeof(ret));
 217                if (!IS_ERR(p))
 218                        p = (ret == -EKEYEXPIRED) ? ERR_PTR(-EKEYEXPIRED) :
 219                                                    ERR_PTR(-EACCES);
 220                goto err;
 221        }
 222        /* copy the opaque wire context */
 223        p = simple_get_netobj(p, end, &ctx->gc_wire_ctx);
 224        if (IS_ERR(p))
 225                goto err;
 226        /* import the opaque security context */
 227        p  = simple_get_bytes(p, end, &seclen, sizeof(seclen));
 228        if (IS_ERR(p))
 229                goto err;
 230        q = (const void *)((const char *)p + seclen);
 231        if (unlikely(q > end || q < p)) {
 232                p = ERR_PTR(-EFAULT);
 233                goto err;
 234        }
 235        ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, GFP_NOFS);
 236        if (ret < 0) {
 237                p = ERR_PTR(ret);
 238                goto err;
 239        }
 240        return q;
 241err:
 242        dprintk("RPC:       gss_fill_context returning %ld\n", -PTR_ERR(p));
 243        return p;
 244}
 245
 246#define UPCALL_BUF_LEN 128
 247
 248struct gss_upcall_msg {
 249        atomic_t count;
 250        uid_t   uid;
 251        struct rpc_pipe_msg msg;
 252        struct list_head list;
 253        struct gss_auth *auth;
 254        struct rpc_pipe *pipe;
 255        struct rpc_wait_queue rpc_waitqueue;
 256        wait_queue_head_t waitqueue;
 257        struct gss_cl_ctx *ctx;
 258        char databuf[UPCALL_BUF_LEN];
 259};
 260
 261static int get_pipe_version(void)
 262{
 263        int ret;
 264
 265        spin_lock(&pipe_version_lock);
 266        if (pipe_version >= 0) {
 267                atomic_inc(&pipe_users);
 268                ret = pipe_version;
 269        } else
 270                ret = -EAGAIN;
 271        spin_unlock(&pipe_version_lock);
 272        return ret;
 273}
 274
 275static void put_pipe_version(void)
 276{
 277        if (atomic_dec_and_lock(&pipe_users, &pipe_version_lock)) {
 278                pipe_version = -1;
 279                spin_unlock(&pipe_version_lock);
 280        }
 281}
 282
 283static void
 284gss_release_msg(struct gss_upcall_msg *gss_msg)
 285{
 286        if (!atomic_dec_and_test(&gss_msg->count))
 287                return;
 288        put_pipe_version();
 289        BUG_ON(!list_empty(&gss_msg->list));
 290        if (gss_msg->ctx != NULL)
 291                gss_put_ctx(gss_msg->ctx);
 292        rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue);
 293        kfree(gss_msg);
 294}
 295
 296static struct gss_upcall_msg *
 297__gss_find_upcall(struct rpc_pipe *pipe, uid_t uid)
 298{
 299        struct gss_upcall_msg *pos;
 300        list_for_each_entry(pos, &pipe->in_downcall, list) {
 301                if (pos->uid != uid)
 302                        continue;
 303                atomic_inc(&pos->count);
 304                dprintk("RPC:       gss_find_upcall found msg %p\n", pos);
 305                return pos;
 306        }
 307        dprintk("RPC:       gss_find_upcall found nothing\n");
 308        return NULL;
 309}
 310
 311/* Try to add an upcall to the pipefs queue.
 312 * If an upcall owned by our uid already exists, then we return a reference
 313 * to that upcall instead of adding the new upcall.
 314 */
 315static inline struct gss_upcall_msg *
 316gss_add_msg(struct gss_upcall_msg *gss_msg)
 317{
 318        struct rpc_pipe *pipe = gss_msg->pipe;
 319        struct gss_upcall_msg *old;
 320
 321        spin_lock(&pipe->lock);
 322        old = __gss_find_upcall(pipe, gss_msg->uid);
 323        if (old == NULL) {
 324                atomic_inc(&gss_msg->count);
 325                list_add(&gss_msg->list, &pipe->in_downcall);
 326        } else
 327                gss_msg = old;
 328        spin_unlock(&pipe->lock);
 329        return gss_msg;
 330}
 331
 332static void
 333__gss_unhash_msg(struct gss_upcall_msg *gss_msg)
 334{
 335        list_del_init(&gss_msg->list);
 336        rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
 337        wake_up_all(&gss_msg->waitqueue);
 338        atomic_dec(&gss_msg->count);
 339}
 340
 341static void
 342gss_unhash_msg(struct gss_upcall_msg *gss_msg)
 343{
 344        struct rpc_pipe *pipe = gss_msg->pipe;
 345
 346        if (list_empty(&gss_msg->list))
 347                return;
 348        spin_lock(&pipe->lock);
 349        if (!list_empty(&gss_msg->list))
 350                __gss_unhash_msg(gss_msg);
 351        spin_unlock(&pipe->lock);
 352}
 353
 354static void
 355gss_handle_downcall_result(struct gss_cred *gss_cred, struct gss_upcall_msg *gss_msg)
 356{
 357        switch (gss_msg->msg.errno) {
 358        case 0:
 359                if (gss_msg->ctx == NULL)
 360                        break;
 361                clear_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags);
 362                gss_cred_set_ctx(&gss_cred->gc_base, gss_msg->ctx);
 363                break;
 364        case -EKEYEXPIRED:
 365                set_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags);
 366        }
 367        gss_cred->gc_upcall_timestamp = jiffies;
 368        gss_cred->gc_upcall = NULL;
 369        rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
 370}
 371
 372static void
 373gss_upcall_callback(struct rpc_task *task)
 374{
 375        struct gss_cred *gss_cred = container_of(task->tk_rqstp->rq_cred,
 376                        struct gss_cred, gc_base);
 377        struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall;
 378        struct rpc_pipe *pipe = gss_msg->pipe;
 379
 380        spin_lock(&pipe->lock);
 381        gss_handle_downcall_result(gss_cred, gss_msg);
 382        spin_unlock(&pipe->lock);
 383        task->tk_status = gss_msg->msg.errno;
 384        gss_release_msg(gss_msg);
 385}
 386
 387static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg)
 388{
 389        gss_msg->msg.data = &gss_msg->uid;
 390        gss_msg->msg.len = sizeof(gss_msg->uid);
 391}
 392
 393static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
 394                                struct rpc_clnt *clnt,
 395                                const char *service_name)
 396{
 397        struct gss_api_mech *mech = gss_msg->auth->mech;
 398        char *p = gss_msg->databuf;
 399        int len = 0;
 400
 401        gss_msg->msg.len = sprintf(gss_msg->databuf, "mech=%s uid=%d ",
 402                                   mech->gm_name,
 403                                   gss_msg->uid);
 404        p += gss_msg->msg.len;
 405        if (clnt->cl_principal) {
 406                len = sprintf(p, "target=%s ", clnt->cl_principal);
 407                p += len;
 408                gss_msg->msg.len += len;
 409        }
 410        if (service_name != NULL) {
 411                len = sprintf(p, "service=%s ", service_name);
 412                p += len;
 413                gss_msg->msg.len += len;
 414        }
 415        if (mech->gm_upcall_enctypes) {
 416                len = sprintf(p, "enctypes=%s ", mech->gm_upcall_enctypes);
 417                p += len;
 418                gss_msg->msg.len += len;
 419        }
 420        len = sprintf(p, "\n");
 421        gss_msg->msg.len += len;
 422
 423        gss_msg->msg.data = gss_msg->databuf;
 424        BUG_ON(gss_msg->msg.len > UPCALL_BUF_LEN);
 425}
 426
 427static void gss_encode_msg(struct gss_upcall_msg *gss_msg,
 428                                struct rpc_clnt *clnt,
 429                                const char *service_name)
 430{
 431        if (pipe_version == 0)
 432                gss_encode_v0_msg(gss_msg);
 433        else /* pipe_version == 1 */
 434                gss_encode_v1_msg(gss_msg, clnt, service_name);
 435}
 436
 437static struct gss_upcall_msg *
 438gss_alloc_msg(struct gss_auth *gss_auth, struct rpc_clnt *clnt,
 439                uid_t uid, const char *service_name)
 440{
 441        struct gss_upcall_msg *gss_msg;
 442        int vers;
 443
 444        gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS);
 445        if (gss_msg == NULL)
 446                return ERR_PTR(-ENOMEM);
 447        vers = get_pipe_version();
 448        if (vers < 0) {
 449                kfree(gss_msg);
 450                return ERR_PTR(vers);
 451        }
 452        gss_msg->pipe = gss_auth->pipe[vers];
 453        INIT_LIST_HEAD(&gss_msg->list);
 454        rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
 455        init_waitqueue_head(&gss_msg->waitqueue);
 456        atomic_set(&gss_msg->count, 1);
 457        gss_msg->uid = uid;
 458        gss_msg->auth = gss_auth;
 459        gss_encode_msg(gss_msg, clnt, service_name);
 460        return gss_msg;
 461}
 462
 463static struct gss_upcall_msg *
 464gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred)
 465{
 466        struct gss_cred *gss_cred = container_of(cred,
 467                        struct gss_cred, gc_base);
 468        struct gss_upcall_msg *gss_new, *gss_msg;
 469        uid_t uid = cred->cr_uid;
 470
 471        gss_new = gss_alloc_msg(gss_auth, clnt, uid, gss_cred->gc_principal);
 472        if (IS_ERR(gss_new))
 473                return gss_new;
 474        gss_msg = gss_add_msg(gss_new);
 475        if (gss_msg == gss_new) {
 476                int res = rpc_queue_upcall(gss_new->pipe, &gss_new->msg);
 477                if (res) {
 478                        gss_unhash_msg(gss_new);
 479                        gss_msg = ERR_PTR(res);
 480                }
 481        } else
 482                gss_release_msg(gss_new);
 483        return gss_msg;
 484}
 485
 486static void warn_gssd(void)
 487{
 488        static unsigned long ratelimit;
 489        unsigned long now = jiffies;
 490
 491        if (time_after(now, ratelimit)) {
 492                printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n"
 493                                "Please check user daemon is running.\n");
 494                ratelimit = now + 15*HZ;
 495        }
 496}
 497
 498static inline int
 499gss_refresh_upcall(struct rpc_task *task)
 500{
 501        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 502        struct gss_auth *gss_auth = container_of(cred->cr_auth,
 503                        struct gss_auth, rpc_auth);
 504        struct gss_cred *gss_cred = container_of(cred,
 505                        struct gss_cred, gc_base);
 506        struct gss_upcall_msg *gss_msg;
 507        struct rpc_pipe *pipe;
 508        int err = 0;
 509
 510        dprintk("RPC: %5u gss_refresh_upcall for uid %u\n", task->tk_pid,
 511                                                                cred->cr_uid);
 512        gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred);
 513        if (PTR_ERR(gss_msg) == -EAGAIN) {
 514                /* XXX: warning on the first, under the assumption we
 515                 * shouldn't normally hit this case on a refresh. */
 516                warn_gssd();
 517                task->tk_timeout = 15*HZ;
 518                rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL);
 519                return -EAGAIN;
 520        }
 521        if (IS_ERR(gss_msg)) {
 522                err = PTR_ERR(gss_msg);
 523                goto out;
 524        }
 525        pipe = gss_msg->pipe;
 526        spin_lock(&pipe->lock);
 527        if (gss_cred->gc_upcall != NULL)
 528                rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL);
 529        else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) {
 530                task->tk_timeout = 0;
 531                gss_cred->gc_upcall = gss_msg;
 532                /* gss_upcall_callback will release the reference to gss_upcall_msg */
 533                atomic_inc(&gss_msg->count);
 534                rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback);
 535        } else {
 536                gss_handle_downcall_result(gss_cred, gss_msg);
 537                err = gss_msg->msg.errno;
 538        }
 539        spin_unlock(&pipe->lock);
 540        gss_release_msg(gss_msg);
 541out:
 542        dprintk("RPC: %5u gss_refresh_upcall for uid %u result %d\n",
 543                        task->tk_pid, cred->cr_uid, err);
 544        return err;
 545}
 546
 547static inline int
 548gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
 549{
 550        struct rpc_pipe *pipe;
 551        struct rpc_cred *cred = &gss_cred->gc_base;
 552        struct gss_upcall_msg *gss_msg;
 553        DEFINE_WAIT(wait);
 554        int err = 0;
 555
 556        dprintk("RPC:       gss_upcall for uid %u\n", cred->cr_uid);
 557retry:
 558        gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred);
 559        if (PTR_ERR(gss_msg) == -EAGAIN) {
 560                err = wait_event_interruptible_timeout(pipe_version_waitqueue,
 561                                pipe_version >= 0, 15*HZ);
 562                if (pipe_version < 0) {
 563                        warn_gssd();
 564                        err = -EACCES;
 565                }
 566                if (err)
 567                        goto out;
 568                goto retry;
 569        }
 570        if (IS_ERR(gss_msg)) {
 571                err = PTR_ERR(gss_msg);
 572                goto out;
 573        }
 574        pipe = gss_msg->pipe;
 575        for (;;) {
 576                prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_KILLABLE);
 577                spin_lock(&pipe->lock);
 578                if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) {
 579                        break;
 580                }
 581                spin_unlock(&pipe->lock);
 582                if (fatal_signal_pending(current)) {
 583                        err = -ERESTARTSYS;
 584                        goto out_intr;
 585                }
 586                schedule();
 587        }
 588        if (gss_msg->ctx)
 589                gss_cred_set_ctx(cred, gss_msg->ctx);
 590        else
 591                err = gss_msg->msg.errno;
 592        spin_unlock(&pipe->lock);
 593out_intr:
 594        finish_wait(&gss_msg->waitqueue, &wait);
 595        gss_release_msg(gss_msg);
 596out:
 597        dprintk("RPC:       gss_create_upcall for uid %u result %d\n",
 598                        cred->cr_uid, err);
 599        return err;
 600}
 601
 602#define MSG_BUF_MAXSIZE 1024
 603
 604static ssize_t
 605gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
 606{
 607        const void *p, *end;
 608        void *buf;
 609        struct gss_upcall_msg *gss_msg;
 610        struct rpc_pipe *pipe = RPC_I(filp->f_dentry->d_inode)->pipe;
 611        struct gss_cl_ctx *ctx;
 612        uid_t uid;
 613        ssize_t err = -EFBIG;
 614
 615        if (mlen > MSG_BUF_MAXSIZE)
 616                goto out;
 617        err = -ENOMEM;
 618        buf = kmalloc(mlen, GFP_NOFS);
 619        if (!buf)
 620                goto out;
 621
 622        err = -EFAULT;
 623        if (copy_from_user(buf, src, mlen))
 624                goto err;
 625
 626        end = (const void *)((char *)buf + mlen);
 627        p = simple_get_bytes(buf, end, &uid, sizeof(uid));
 628        if (IS_ERR(p)) {
 629                err = PTR_ERR(p);
 630                goto err;
 631        }
 632
 633        err = -ENOMEM;
 634        ctx = gss_alloc_context();
 635        if (ctx == NULL)
 636                goto err;
 637
 638        err = -ENOENT;
 639        /* Find a matching upcall */
 640        spin_lock(&pipe->lock);
 641        gss_msg = __gss_find_upcall(pipe, uid);
 642        if (gss_msg == NULL) {
 643                spin_unlock(&pipe->lock);
 644                goto err_put_ctx;
 645        }
 646        list_del_init(&gss_msg->list);
 647        spin_unlock(&pipe->lock);
 648
 649        p = gss_fill_context(p, end, ctx, gss_msg->auth->mech);
 650        if (IS_ERR(p)) {
 651                err = PTR_ERR(p);
 652                switch (err) {
 653                case -EACCES:
 654                case -EKEYEXPIRED:
 655                        gss_msg->msg.errno = err;
 656                        err = mlen;
 657                        break;
 658                case -EFAULT:
 659                case -ENOMEM:
 660                case -EINVAL:
 661                case -ENOSYS:
 662                        gss_msg->msg.errno = -EAGAIN;
 663                        break;
 664                default:
 665                        printk(KERN_CRIT "%s: bad return from "
 666                                "gss_fill_context: %zd\n", __func__, err);
 667                        BUG();
 668                }
 669                goto err_release_msg;
 670        }
 671        gss_msg->ctx = gss_get_ctx(ctx);
 672        err = mlen;
 673
 674err_release_msg:
 675        spin_lock(&pipe->lock);
 676        __gss_unhash_msg(gss_msg);
 677        spin_unlock(&pipe->lock);
 678        gss_release_msg(gss_msg);
 679err_put_ctx:
 680        gss_put_ctx(ctx);
 681err:
 682        kfree(buf);
 683out:
 684        dprintk("RPC:       gss_pipe_downcall returning %Zd\n", err);
 685        return err;
 686}
 687
 688static int gss_pipe_open(struct inode *inode, int new_version)
 689{
 690        int ret = 0;
 691
 692        spin_lock(&pipe_version_lock);
 693        if (pipe_version < 0) {
 694                /* First open of any gss pipe determines the version: */
 695                pipe_version = new_version;
 696                rpc_wake_up(&pipe_version_rpc_waitqueue);
 697                wake_up(&pipe_version_waitqueue);
 698        } else if (pipe_version != new_version) {
 699                /* Trying to open a pipe of a different version */
 700                ret = -EBUSY;
 701                goto out;
 702        }
 703        atomic_inc(&pipe_users);
 704out:
 705        spin_unlock(&pipe_version_lock);
 706        return ret;
 707
 708}
 709
 710static int gss_pipe_open_v0(struct inode *inode)
 711{
 712        return gss_pipe_open(inode, 0);
 713}
 714
 715static int gss_pipe_open_v1(struct inode *inode)
 716{
 717        return gss_pipe_open(inode, 1);
 718}
 719
 720static void
 721gss_pipe_release(struct inode *inode)
 722{
 723        struct rpc_pipe *pipe = RPC_I(inode)->pipe;
 724        struct gss_upcall_msg *gss_msg;
 725
 726restart:
 727        spin_lock(&pipe->lock);
 728        list_for_each_entry(gss_msg, &pipe->in_downcall, list) {
 729
 730                if (!list_empty(&gss_msg->msg.list))
 731                        continue;
 732                gss_msg->msg.errno = -EPIPE;
 733                atomic_inc(&gss_msg->count);
 734                __gss_unhash_msg(gss_msg);
 735                spin_unlock(&pipe->lock);
 736                gss_release_msg(gss_msg);
 737                goto restart;
 738        }
 739        spin_unlock(&pipe->lock);
 740
 741        put_pipe_version();
 742}
 743
 744static void
 745gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
 746{
 747        struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg);
 748
 749        if (msg->errno < 0) {
 750                dprintk("RPC:       gss_pipe_destroy_msg releasing msg %p\n",
 751                                gss_msg);
 752                atomic_inc(&gss_msg->count);
 753                gss_unhash_msg(gss_msg);
 754                if (msg->errno == -ETIMEDOUT)
 755                        warn_gssd();
 756                gss_release_msg(gss_msg);
 757        }
 758}
 759
 760static void gss_pipes_dentries_destroy(struct rpc_auth *auth)
 761{
 762        struct gss_auth *gss_auth;
 763
 764        gss_auth = container_of(auth, struct gss_auth, rpc_auth);
 765        if (gss_auth->pipe[0]->dentry)
 766                rpc_unlink(gss_auth->pipe[0]->dentry);
 767        if (gss_auth->pipe[1]->dentry)
 768                rpc_unlink(gss_auth->pipe[1]->dentry);
 769}
 770
 771static int gss_pipes_dentries_create(struct rpc_auth *auth)
 772{
 773        int err;
 774        struct gss_auth *gss_auth;
 775        struct rpc_clnt *clnt;
 776
 777        gss_auth = container_of(auth, struct gss_auth, rpc_auth);
 778        clnt = gss_auth->client;
 779
 780        gss_auth->pipe[1]->dentry = rpc_mkpipe_dentry(clnt->cl_dentry,
 781                                                      "gssd",
 782                                                      clnt, gss_auth->pipe[1]);
 783        if (IS_ERR(gss_auth->pipe[1]->dentry))
 784                return PTR_ERR(gss_auth->pipe[1]->dentry);
 785        gss_auth->pipe[0]->dentry = rpc_mkpipe_dentry(clnt->cl_dentry,
 786                                                      gss_auth->mech->gm_name,
 787                                                      clnt, gss_auth->pipe[0]);
 788        if (IS_ERR(gss_auth->pipe[0]->dentry)) {
 789                err = PTR_ERR(gss_auth->pipe[0]->dentry);
 790                goto err_unlink_pipe_1;
 791        }
 792        return 0;
 793
 794err_unlink_pipe_1:
 795        rpc_unlink(gss_auth->pipe[1]->dentry);
 796        return err;
 797}
 798
 799static void gss_pipes_dentries_destroy_net(struct rpc_clnt *clnt,
 800                                           struct rpc_auth *auth)
 801{
 802        struct net *net = rpc_net_ns(clnt);
 803        struct super_block *sb;
 804
 805        sb = rpc_get_sb_net(net);
 806        if (sb) {
 807                if (clnt->cl_dentry)
 808                        gss_pipes_dentries_destroy(auth);
 809                rpc_put_sb_net(net);
 810        }
 811}
 812
 813static int gss_pipes_dentries_create_net(struct rpc_clnt *clnt,
 814                                         struct rpc_auth *auth)
 815{
 816        struct net *net = rpc_net_ns(clnt);
 817        struct super_block *sb;
 818        int err = 0;
 819
 820        sb = rpc_get_sb_net(net);
 821        if (sb) {
 822                if (clnt->cl_dentry)
 823                        err = gss_pipes_dentries_create(auth);
 824                rpc_put_sb_net(net);
 825        }
 826        return err;
 827}
 828
 829/*
 830 * NOTE: we have the opportunity to use different
 831 * parameters based on the input flavor (which must be a pseudoflavor)
 832 */
 833static struct rpc_auth *
 834gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor)
 835{
 836        struct gss_auth *gss_auth;
 837        struct rpc_auth * auth;
 838        int err = -ENOMEM; /* XXX? */
 839
 840        dprintk("RPC:       creating GSS authenticator for client %p\n", clnt);
 841
 842        if (!try_module_get(THIS_MODULE))
 843                return ERR_PTR(err);
 844        if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL)))
 845                goto out_dec;
 846        gss_auth->client = clnt;
 847        err = -EINVAL;
 848        gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor);
 849        if (!gss_auth->mech) {
 850                printk(KERN_WARNING "%s: Pseudoflavor %d not found!\n",
 851                                __func__, flavor);
 852                goto err_free;
 853        }
 854        gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor);
 855        if (gss_auth->service == 0)
 856                goto err_put_mech;
 857        auth = &gss_auth->rpc_auth;
 858        auth->au_cslack = GSS_CRED_SLACK >> 2;
 859        auth->au_rslack = GSS_VERF_SLACK >> 2;
 860        auth->au_ops = &authgss_ops;
 861        auth->au_flavor = flavor;
 862        atomic_set(&auth->au_count, 1);
 863        kref_init(&gss_auth->kref);
 864
 865        /*
 866         * Note: if we created the old pipe first, then someone who
 867         * examined the directory at the right moment might conclude
 868         * that we supported only the old pipe.  So we instead create
 869         * the new pipe first.
 870         */
 871        gss_auth->pipe[1] = rpc_mkpipe_data(&gss_upcall_ops_v1,
 872                                            RPC_PIPE_WAIT_FOR_OPEN);
 873        if (IS_ERR(gss_auth->pipe[1])) {
 874                err = PTR_ERR(gss_auth->pipe[1]);
 875                goto err_put_mech;
 876        }
 877
 878        gss_auth->pipe[0] = rpc_mkpipe_data(&gss_upcall_ops_v0,
 879                                            RPC_PIPE_WAIT_FOR_OPEN);
 880        if (IS_ERR(gss_auth->pipe[0])) {
 881                err = PTR_ERR(gss_auth->pipe[0]);
 882                goto err_destroy_pipe_1;
 883        }
 884        err = gss_pipes_dentries_create_net(clnt, auth);
 885        if (err)
 886                goto err_destroy_pipe_0;
 887        err = rpcauth_init_credcache(auth);
 888        if (err)
 889                goto err_unlink_pipes;
 890
 891        return auth;
 892err_unlink_pipes:
 893        gss_pipes_dentries_destroy_net(clnt, auth);
 894err_destroy_pipe_0:
 895        rpc_destroy_pipe_data(gss_auth->pipe[0]);
 896err_destroy_pipe_1:
 897        rpc_destroy_pipe_data(gss_auth->pipe[1]);
 898err_put_mech:
 899        gss_mech_put(gss_auth->mech);
 900err_free:
 901        kfree(gss_auth);
 902out_dec:
 903        module_put(THIS_MODULE);
 904        return ERR_PTR(err);
 905}
 906
 907static void
 908gss_free(struct gss_auth *gss_auth)
 909{
 910        gss_pipes_dentries_destroy_net(gss_auth->client, &gss_auth->rpc_auth);
 911        rpc_destroy_pipe_data(gss_auth->pipe[0]);
 912        rpc_destroy_pipe_data(gss_auth->pipe[1]);
 913        gss_mech_put(gss_auth->mech);
 914
 915        kfree(gss_auth);
 916        module_put(THIS_MODULE);
 917}
 918
 919static void
 920gss_free_callback(struct kref *kref)
 921{
 922        struct gss_auth *gss_auth = container_of(kref, struct gss_auth, kref);
 923
 924        gss_free(gss_auth);
 925}
 926
 927static void
 928gss_destroy(struct rpc_auth *auth)
 929{
 930        struct gss_auth *gss_auth;
 931
 932        dprintk("RPC:       destroying GSS authenticator %p flavor %d\n",
 933                        auth, auth->au_flavor);
 934
 935        rpcauth_destroy_credcache(auth);
 936
 937        gss_auth = container_of(auth, struct gss_auth, rpc_auth);
 938        kref_put(&gss_auth->kref, gss_free_callback);
 939}
 940
 941/*
 942 * gss_destroying_context will cause the RPCSEC_GSS to send a NULL RPC call
 943 * to the server with the GSS control procedure field set to
 944 * RPC_GSS_PROC_DESTROY. This should normally cause the server to release
 945 * all RPCSEC_GSS state associated with that context.
 946 */
 947static int
 948gss_destroying_context(struct rpc_cred *cred)
 949{
 950        struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
 951        struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
 952        struct rpc_task *task;
 953
 954        if (gss_cred->gc_ctx == NULL ||
 955            test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0)
 956                return 0;
 957
 958        gss_cred->gc_ctx->gc_proc = RPC_GSS_PROC_DESTROY;
 959        cred->cr_ops = &gss_nullops;
 960
 961        /* Take a reference to ensure the cred will be destroyed either
 962         * by the RPC call or by the put_rpccred() below */
 963        get_rpccred(cred);
 964
 965        task = rpc_call_null(gss_auth->client, cred, RPC_TASK_ASYNC|RPC_TASK_SOFT);
 966        if (!IS_ERR(task))
 967                rpc_put_task(task);
 968
 969        put_rpccred(cred);
 970        return 1;
 971}
 972
 973/* gss_destroy_cred (and gss_free_ctx) are used to clean up after failure
 974 * to create a new cred or context, so they check that things have been
 975 * allocated before freeing them. */
 976static void
 977gss_do_free_ctx(struct gss_cl_ctx *ctx)
 978{
 979        dprintk("RPC:       gss_free_ctx\n");
 980
 981        gss_delete_sec_context(&ctx->gc_gss_ctx);
 982        kfree(ctx->gc_wire_ctx.data);
 983        kfree(ctx);
 984}
 985
 986static void
 987gss_free_ctx_callback(struct rcu_head *head)
 988{
 989        struct gss_cl_ctx *ctx = container_of(head, struct gss_cl_ctx, gc_rcu);
 990        gss_do_free_ctx(ctx);
 991}
 992
 993static void
 994gss_free_ctx(struct gss_cl_ctx *ctx)
 995{
 996        call_rcu(&ctx->gc_rcu, gss_free_ctx_callback);
 997}
 998
 999static void
1000gss_free_cred(struct gss_cred *gss_cred)
1001{
1002        dprintk("RPC:       gss_free_cred %p\n", gss_cred);
1003        kfree(gss_cred);
1004}
1005
1006static void
1007gss_free_cred_callback(struct rcu_head *head)
1008{
1009        struct gss_cred *gss_cred = container_of(head, struct gss_cred, gc_base.cr_rcu);
1010        gss_free_cred(gss_cred);
1011}
1012
1013static void
1014gss_destroy_nullcred(struct rpc_cred *cred)
1015{
1016        struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
1017        struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
1018        struct gss_cl_ctx *ctx = gss_cred->gc_ctx;
1019
1020        RCU_INIT_POINTER(gss_cred->gc_ctx, NULL);
1021        call_rcu(&cred->cr_rcu, gss_free_cred_callback);
1022        if (ctx)
1023                gss_put_ctx(ctx);
1024        kref_put(&gss_auth->kref, gss_free_callback);
1025}
1026
1027static void
1028gss_destroy_cred(struct rpc_cred *cred)
1029{
1030
1031        if (gss_destroying_context(cred))
1032                return;
1033        gss_destroy_nullcred(cred);
1034}
1035
1036/*
1037 * Lookup RPCSEC_GSS cred for the current process
1038 */
1039static struct rpc_cred *
1040gss_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
1041{
1042        return rpcauth_lookup_credcache(auth, acred, flags);
1043}
1044
1045static struct rpc_cred *
1046gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
1047{
1048        struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
1049        struct gss_cred *cred = NULL;
1050        int err = -ENOMEM;
1051
1052        dprintk("RPC:       gss_create_cred for uid %d, flavor %d\n",
1053                acred->uid, auth->au_flavor);
1054
1055        if (!(cred = kzalloc(sizeof(*cred), GFP_NOFS)))
1056                goto out_err;
1057
1058        rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops);
1059        /*
1060         * Note: in order to force a call to call_refresh(), we deliberately
1061         * fail to flag the credential as RPCAUTH_CRED_UPTODATE.
1062         */
1063        cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW;
1064        cred->gc_service = gss_auth->service;
1065        cred->gc_principal = NULL;
1066        if (acred->machine_cred)
1067                cred->gc_principal = acred->principal;
1068        kref_get(&gss_auth->kref);
1069        return &cred->gc_base;
1070
1071out_err:
1072        dprintk("RPC:       gss_create_cred failed with error %d\n", err);
1073        return ERR_PTR(err);
1074}
1075
1076static int
1077gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred)
1078{
1079        struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
1080        struct gss_cred *gss_cred = container_of(cred,struct gss_cred, gc_base);
1081        int err;
1082
1083        do {
1084                err = gss_create_upcall(gss_auth, gss_cred);
1085        } while (err == -EAGAIN);
1086        return err;
1087}
1088
1089static int
1090gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags)
1091{
1092        struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base);
1093
1094        if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags))
1095                goto out;
1096        /* Don't match with creds that have expired. */
1097        if (time_after(jiffies, gss_cred->gc_ctx->gc_expiry))
1098                return 0;
1099        if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags))
1100                return 0;
1101out:
1102        if (acred->principal != NULL) {
1103                if (gss_cred->gc_principal == NULL)
1104                        return 0;
1105                return strcmp(acred->principal, gss_cred->gc_principal) == 0;
1106        }
1107        if (gss_cred->gc_principal != NULL)
1108                return 0;
1109        return rc->cr_uid == acred->uid;
1110}
1111
1112/*
1113* Marshal credentials.
1114* Maybe we should keep a cached credential for performance reasons.
1115*/
1116static __be32 *
1117gss_marshal(struct rpc_task *task, __be32 *p)
1118{
1119        struct rpc_rqst *req = task->tk_rqstp;
1120        struct rpc_cred *cred = req->rq_cred;
1121        struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1122                                                 gc_base);
1123        struct gss_cl_ctx       *ctx = gss_cred_get_ctx(cred);
1124        __be32          *cred_len;
1125        u32             maj_stat = 0;
1126        struct xdr_netobj mic;
1127        struct kvec     iov;
1128        struct xdr_buf  verf_buf;
1129
1130        dprintk("RPC: %5u gss_marshal\n", task->tk_pid);
1131
1132        *p++ = htonl(RPC_AUTH_GSS);
1133        cred_len = p++;
1134
1135        spin_lock(&ctx->gc_seq_lock);
1136        req->rq_seqno = ctx->gc_seq++;
1137        spin_unlock(&ctx->gc_seq_lock);
1138
1139        *p++ = htonl((u32) RPC_GSS_VERSION);
1140        *p++ = htonl((u32) ctx->gc_proc);
1141        *p++ = htonl((u32) req->rq_seqno);
1142        *p++ = htonl((u32) gss_cred->gc_service);
1143        p = xdr_encode_netobj(p, &ctx->gc_wire_ctx);
1144        *cred_len = htonl((p - (cred_len + 1)) << 2);
1145
1146        /* We compute the checksum for the verifier over the xdr-encoded bytes
1147         * starting with the xid and ending at the end of the credential: */
1148        iov.iov_base = xprt_skip_transport_header(task->tk_xprt,
1149                                        req->rq_snd_buf.head[0].iov_base);
1150        iov.iov_len = (u8 *)p - (u8 *)iov.iov_base;
1151        xdr_buf_from_iov(&iov, &verf_buf);
1152
1153        /* set verifier flavor*/
1154        *p++ = htonl(RPC_AUTH_GSS);
1155
1156        mic.data = (u8 *)(p + 1);
1157        maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1158        if (maj_stat == GSS_S_CONTEXT_EXPIRED) {
1159                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1160        } else if (maj_stat != 0) {
1161                printk("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat);
1162                goto out_put_ctx;
1163        }
1164        p = xdr_encode_opaque(p, NULL, mic.len);
1165        gss_put_ctx(ctx);
1166        return p;
1167out_put_ctx:
1168        gss_put_ctx(ctx);
1169        return NULL;
1170}
1171
1172static int gss_renew_cred(struct rpc_task *task)
1173{
1174        struct rpc_cred *oldcred = task->tk_rqstp->rq_cred;
1175        struct gss_cred *gss_cred = container_of(oldcred,
1176                                                 struct gss_cred,
1177                                                 gc_base);
1178        struct rpc_auth *auth = oldcred->cr_auth;
1179        struct auth_cred acred = {
1180                .uid = oldcred->cr_uid,
1181                .principal = gss_cred->gc_principal,
1182                .machine_cred = (gss_cred->gc_principal != NULL ? 1 : 0),
1183        };
1184        struct rpc_cred *new;
1185
1186        new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW);
1187        if (IS_ERR(new))
1188                return PTR_ERR(new);
1189        task->tk_rqstp->rq_cred = new;
1190        put_rpccred(oldcred);
1191        return 0;
1192}
1193
1194static int gss_cred_is_negative_entry(struct rpc_cred *cred)
1195{
1196        if (test_bit(RPCAUTH_CRED_NEGATIVE, &cred->cr_flags)) {
1197                unsigned long now = jiffies;
1198                unsigned long begin, expire;
1199                struct gss_cred *gss_cred; 
1200
1201                gss_cred = container_of(cred, struct gss_cred, gc_base);
1202                begin = gss_cred->gc_upcall_timestamp;
1203                expire = begin + gss_expired_cred_retry_delay * HZ;
1204
1205                if (time_in_range_open(now, begin, expire))
1206                        return 1;
1207        }
1208        return 0;
1209}
1210
1211/*
1212* Refresh credentials. XXX - finish
1213*/
1214static int
1215gss_refresh(struct rpc_task *task)
1216{
1217        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1218        int ret = 0;
1219
1220        if (gss_cred_is_negative_entry(cred))
1221                return -EKEYEXPIRED;
1222
1223        if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
1224                        !test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) {
1225                ret = gss_renew_cred(task);
1226                if (ret < 0)
1227                        goto out;
1228                cred = task->tk_rqstp->rq_cred;
1229        }
1230
1231        if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
1232                ret = gss_refresh_upcall(task);
1233out:
1234        return ret;
1235}
1236
1237/* Dummy refresh routine: used only when destroying the context */
1238static int
1239gss_refresh_null(struct rpc_task *task)
1240{
1241        return -EACCES;
1242}
1243
1244static __be32 *
1245gss_validate(struct rpc_task *task, __be32 *p)
1246{
1247        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1248        struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1249        __be32          seq;
1250        struct kvec     iov;
1251        struct xdr_buf  verf_buf;
1252        struct xdr_netobj mic;
1253        u32             flav,len;
1254        u32             maj_stat;
1255
1256        dprintk("RPC: %5u gss_validate\n", task->tk_pid);
1257
1258        flav = ntohl(*p++);
1259        if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE)
1260                goto out_bad;
1261        if (flav != RPC_AUTH_GSS)
1262                goto out_bad;
1263        seq = htonl(task->tk_rqstp->rq_seqno);
1264        iov.iov_base = &seq;
1265        iov.iov_len = sizeof(seq);
1266        xdr_buf_from_iov(&iov, &verf_buf);
1267        mic.data = (u8 *)p;
1268        mic.len = len;
1269
1270        maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1271        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1272                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1273        if (maj_stat) {
1274                dprintk("RPC: %5u gss_validate: gss_verify_mic returned "
1275                                "error 0x%08x\n", task->tk_pid, maj_stat);
1276                goto out_bad;
1277        }
1278        /* We leave it to unwrap to calculate au_rslack. For now we just
1279         * calculate the length of the verifier: */
1280        cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2;
1281        gss_put_ctx(ctx);
1282        dprintk("RPC: %5u gss_validate: gss_verify_mic succeeded.\n",
1283                        task->tk_pid);
1284        return p + XDR_QUADLEN(len);
1285out_bad:
1286        gss_put_ctx(ctx);
1287        dprintk("RPC: %5u gss_validate failed.\n", task->tk_pid);
1288        return NULL;
1289}
1290
1291static void gss_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
1292                                __be32 *p, void *obj)
1293{
1294        struct xdr_stream xdr;
1295
1296        xdr_init_encode(&xdr, &rqstp->rq_snd_buf, p);
1297        encode(rqstp, &xdr, obj);
1298}
1299
1300static inline int
1301gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1302                   kxdreproc_t encode, struct rpc_rqst *rqstp,
1303                   __be32 *p, void *obj)
1304{
1305        struct xdr_buf  *snd_buf = &rqstp->rq_snd_buf;
1306        struct xdr_buf  integ_buf;
1307        __be32          *integ_len = NULL;
1308        struct xdr_netobj mic;
1309        u32             offset;
1310        __be32          *q;
1311        struct kvec     *iov;
1312        u32             maj_stat = 0;
1313        int             status = -EIO;
1314
1315        integ_len = p++;
1316        offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1317        *p++ = htonl(rqstp->rq_seqno);
1318
1319        gss_wrap_req_encode(encode, rqstp, p, obj);
1320
1321        if (xdr_buf_subsegment(snd_buf, &integ_buf,
1322                                offset, snd_buf->len - offset))
1323                return status;
1324        *integ_len = htonl(integ_buf.len);
1325
1326        /* guess whether we're in the head or the tail: */
1327        if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1328                iov = snd_buf->tail;
1329        else
1330                iov = snd_buf->head;
1331        p = iov->iov_base + iov->iov_len;
1332        mic.data = (u8 *)(p + 1);
1333
1334        maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1335        status = -EIO; /* XXX? */
1336        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1337                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1338        else if (maj_stat)
1339                return status;
1340        q = xdr_encode_opaque(p, NULL, mic.len);
1341
1342        offset = (u8 *)q - (u8 *)p;
1343        iov->iov_len += offset;
1344        snd_buf->len += offset;
1345        return 0;
1346}
1347
1348static void
1349priv_release_snd_buf(struct rpc_rqst *rqstp)
1350{
1351        int i;
1352
1353        for (i=0; i < rqstp->rq_enc_pages_num; i++)
1354                __free_page(rqstp->rq_enc_pages[i]);
1355        kfree(rqstp->rq_enc_pages);
1356}
1357
1358static int
1359alloc_enc_pages(struct rpc_rqst *rqstp)
1360{
1361        struct xdr_buf *snd_buf = &rqstp->rq_snd_buf;
1362        int first, last, i;
1363
1364        if (snd_buf->page_len == 0) {
1365                rqstp->rq_enc_pages_num = 0;
1366                return 0;
1367        }
1368
1369        first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1370        last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_CACHE_SHIFT;
1371        rqstp->rq_enc_pages_num = last - first + 1 + 1;
1372        rqstp->rq_enc_pages
1373                = kmalloc(rqstp->rq_enc_pages_num * sizeof(struct page *),
1374                                GFP_NOFS);
1375        if (!rqstp->rq_enc_pages)
1376                goto out;
1377        for (i=0; i < rqstp->rq_enc_pages_num; i++) {
1378                rqstp->rq_enc_pages[i] = alloc_page(GFP_NOFS);
1379                if (rqstp->rq_enc_pages[i] == NULL)
1380                        goto out_free;
1381        }
1382        rqstp->rq_release_snd_buf = priv_release_snd_buf;
1383        return 0;
1384out_free:
1385        rqstp->rq_enc_pages_num = i;
1386        priv_release_snd_buf(rqstp);
1387out:
1388        return -EAGAIN;
1389}
1390
1391static inline int
1392gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1393                  kxdreproc_t encode, struct rpc_rqst *rqstp,
1394                  __be32 *p, void *obj)
1395{
1396        struct xdr_buf  *snd_buf = &rqstp->rq_snd_buf;
1397        u32             offset;
1398        u32             maj_stat;
1399        int             status;
1400        __be32          *opaque_len;
1401        struct page     **inpages;
1402        int             first;
1403        int             pad;
1404        struct kvec     *iov;
1405        char            *tmp;
1406
1407        opaque_len = p++;
1408        offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1409        *p++ = htonl(rqstp->rq_seqno);
1410
1411        gss_wrap_req_encode(encode, rqstp, p, obj);
1412
1413        status = alloc_enc_pages(rqstp);
1414        if (status)
1415                return status;
1416        first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1417        inpages = snd_buf->pages + first;
1418        snd_buf->pages = rqstp->rq_enc_pages;
1419        snd_buf->page_base -= first << PAGE_CACHE_SHIFT;
1420        /*
1421         * Give the tail its own page, in case we need extra space in the
1422         * head when wrapping:
1423         *
1424         * call_allocate() allocates twice the slack space required
1425         * by the authentication flavor to rq_callsize.
1426         * For GSS, slack is GSS_CRED_SLACK.
1427         */
1428        if (snd_buf->page_len || snd_buf->tail[0].iov_len) {
1429                tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]);
1430                memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len);
1431                snd_buf->tail[0].iov_base = tmp;
1432        }
1433        maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages);
1434        /* slack space should prevent this ever happening: */
1435        BUG_ON(snd_buf->len > snd_buf->buflen);
1436        status = -EIO;
1437        /* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was
1438         * done anyway, so it's safe to put the request on the wire: */
1439        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1440                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1441        else if (maj_stat)
1442                return status;
1443
1444        *opaque_len = htonl(snd_buf->len - offset);
1445        /* guess whether we're in the head or the tail: */
1446        if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1447                iov = snd_buf->tail;
1448        else
1449                iov = snd_buf->head;
1450        p = iov->iov_base + iov->iov_len;
1451        pad = 3 - ((snd_buf->len - offset - 1) & 3);
1452        memset(p, 0, pad);
1453        iov->iov_len += pad;
1454        snd_buf->len += pad;
1455
1456        return 0;
1457}
1458
1459static int
1460gss_wrap_req(struct rpc_task *task,
1461             kxdreproc_t encode, void *rqstp, __be32 *p, void *obj)
1462{
1463        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1464        struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1465                        gc_base);
1466        struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1467        int             status = -EIO;
1468
1469        dprintk("RPC: %5u gss_wrap_req\n", task->tk_pid);
1470        if (ctx->gc_proc != RPC_GSS_PROC_DATA) {
1471                /* The spec seems a little ambiguous here, but I think that not
1472                 * wrapping context destruction requests makes the most sense.
1473                 */
1474                gss_wrap_req_encode(encode, rqstp, p, obj);
1475                status = 0;
1476                goto out;
1477        }
1478        switch (gss_cred->gc_service) {
1479        case RPC_GSS_SVC_NONE:
1480                gss_wrap_req_encode(encode, rqstp, p, obj);
1481                status = 0;
1482                break;
1483        case RPC_GSS_SVC_INTEGRITY:
1484                status = gss_wrap_req_integ(cred, ctx, encode, rqstp, p, obj);
1485                break;
1486        case RPC_GSS_SVC_PRIVACY:
1487                status = gss_wrap_req_priv(cred, ctx, encode, rqstp, p, obj);
1488                break;
1489        }
1490out:
1491        gss_put_ctx(ctx);
1492        dprintk("RPC: %5u gss_wrap_req returning %d\n", task->tk_pid, status);
1493        return status;
1494}
1495
1496static inline int
1497gss_unwrap_resp_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1498                struct rpc_rqst *rqstp, __be32 **p)
1499{
1500        struct xdr_buf  *rcv_buf = &rqstp->rq_rcv_buf;
1501        struct xdr_buf integ_buf;
1502        struct xdr_netobj mic;
1503        u32 data_offset, mic_offset;
1504        u32 integ_len;
1505        u32 maj_stat;
1506        int status = -EIO;
1507
1508        integ_len = ntohl(*(*p)++);
1509        if (integ_len & 3)
1510                return status;
1511        data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1512        mic_offset = integ_len + data_offset;
1513        if (mic_offset > rcv_buf->len)
1514                return status;
1515        if (ntohl(*(*p)++) != rqstp->rq_seqno)
1516                return status;
1517
1518        if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset,
1519                                mic_offset - data_offset))
1520                return status;
1521
1522        if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset))
1523                return status;
1524
1525        maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1526        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1527                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1528        if (maj_stat != GSS_S_COMPLETE)
1529                return status;
1530        return 0;
1531}
1532
1533static inline int
1534gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1535                struct rpc_rqst *rqstp, __be32 **p)
1536{
1537        struct xdr_buf  *rcv_buf = &rqstp->rq_rcv_buf;
1538        u32 offset;
1539        u32 opaque_len;
1540        u32 maj_stat;
1541        int status = -EIO;
1542
1543        opaque_len = ntohl(*(*p)++);
1544        offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1545        if (offset + opaque_len > rcv_buf->len)
1546                return status;
1547        /* remove padding: */
1548        rcv_buf->len = offset + opaque_len;
1549
1550        maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf);
1551        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1552                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1553        if (maj_stat != GSS_S_COMPLETE)
1554                return status;
1555        if (ntohl(*(*p)++) != rqstp->rq_seqno)
1556                return status;
1557
1558        return 0;
1559}
1560
1561static int
1562gss_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
1563                      __be32 *p, void *obj)
1564{
1565        struct xdr_stream xdr;
1566
1567        xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, p);
1568        return decode(rqstp, &xdr, obj);
1569}
1570
1571static int
1572gss_unwrap_resp(struct rpc_task *task,
1573                kxdrdproc_t decode, void *rqstp, __be32 *p, void *obj)
1574{
1575        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1576        struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1577                        gc_base);
1578        struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1579        __be32          *savedp = p;
1580        struct kvec     *head = ((struct rpc_rqst *)rqstp)->rq_rcv_buf.head;
1581        int             savedlen = head->iov_len;
1582        int             status = -EIO;
1583
1584        if (ctx->gc_proc != RPC_GSS_PROC_DATA)
1585                goto out_decode;
1586        switch (gss_cred->gc_service) {
1587        case RPC_GSS_SVC_NONE:
1588                break;
1589        case RPC_GSS_SVC_INTEGRITY:
1590                status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p);
1591                if (status)
1592                        goto out;
1593                break;
1594        case RPC_GSS_SVC_PRIVACY:
1595                status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p);
1596                if (status)
1597                        goto out;
1598                break;
1599        }
1600        /* take into account extra slack for integrity and privacy cases: */
1601        cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp)
1602                                                + (savedlen - head->iov_len);
1603out_decode:
1604        status = gss_unwrap_req_decode(decode, rqstp, p, obj);
1605out:
1606        gss_put_ctx(ctx);
1607        dprintk("RPC: %5u gss_unwrap_resp returning %d\n", task->tk_pid,
1608                        status);
1609        return status;
1610}
1611
1612static const struct rpc_authops authgss_ops = {
1613        .owner          = THIS_MODULE,
1614        .au_flavor      = RPC_AUTH_GSS,
1615        .au_name        = "RPCSEC_GSS",
1616        .create         = gss_create,
1617        .destroy        = gss_destroy,
1618        .lookup_cred    = gss_lookup_cred,
1619        .crcreate       = gss_create_cred,
1620        .pipes_create   = gss_pipes_dentries_create,
1621        .pipes_destroy  = gss_pipes_dentries_destroy,
1622};
1623
1624static const struct rpc_credops gss_credops = {
1625        .cr_name        = "AUTH_GSS",
1626        .crdestroy      = gss_destroy_cred,
1627        .cr_init        = gss_cred_init,
1628        .crbind         = rpcauth_generic_bind_cred,
1629        .crmatch        = gss_match,
1630        .crmarshal      = gss_marshal,
1631        .crrefresh      = gss_refresh,
1632        .crvalidate     = gss_validate,
1633        .crwrap_req     = gss_wrap_req,
1634        .crunwrap_resp  = gss_unwrap_resp,
1635};
1636
1637static const struct rpc_credops gss_nullops = {
1638        .cr_name        = "AUTH_GSS",
1639        .crdestroy      = gss_destroy_nullcred,
1640        .crbind         = rpcauth_generic_bind_cred,
1641        .crmatch        = gss_match,
1642        .crmarshal      = gss_marshal,
1643        .crrefresh      = gss_refresh_null,
1644        .crvalidate     = gss_validate,
1645        .crwrap_req     = gss_wrap_req,
1646        .crunwrap_resp  = gss_unwrap_resp,
1647};
1648
1649static const struct rpc_pipe_ops gss_upcall_ops_v0 = {
1650        .upcall         = rpc_pipe_generic_upcall,
1651        .downcall       = gss_pipe_downcall,
1652        .destroy_msg    = gss_pipe_destroy_msg,
1653        .open_pipe      = gss_pipe_open_v0,
1654        .release_pipe   = gss_pipe_release,
1655};
1656
1657static const struct rpc_pipe_ops gss_upcall_ops_v1 = {
1658        .upcall         = rpc_pipe_generic_upcall,
1659        .downcall       = gss_pipe_downcall,
1660        .destroy_msg    = gss_pipe_destroy_msg,
1661        .open_pipe      = gss_pipe_open_v1,
1662        .release_pipe   = gss_pipe_release,
1663};
1664
1665static __net_init int rpcsec_gss_init_net(struct net *net)
1666{
1667        return gss_svc_init_net(net);
1668}
1669
1670static __net_exit void rpcsec_gss_exit_net(struct net *net)
1671{
1672        gss_svc_shutdown_net(net);
1673}
1674
1675static struct pernet_operations rpcsec_gss_net_ops = {
1676        .init = rpcsec_gss_init_net,
1677        .exit = rpcsec_gss_exit_net,
1678};
1679
1680/*
1681 * Initialize RPCSEC_GSS module
1682 */
1683static int __init init_rpcsec_gss(void)
1684{
1685        int err = 0;
1686
1687        err = rpcauth_register(&authgss_ops);
1688        if (err)
1689                goto out;
1690        err = gss_svc_init();
1691        if (err)
1692                goto out_unregister;
1693        err = register_pernet_subsys(&rpcsec_gss_net_ops);
1694        if (err)
1695                goto out_svc_exit;
1696        rpc_init_wait_queue(&pipe_version_rpc_waitqueue, "gss pipe version");
1697        return 0;
1698out_svc_exit:
1699        gss_svc_shutdown();
1700out_unregister:
1701        rpcauth_unregister(&authgss_ops);
1702out:
1703        return err;
1704}
1705
1706static void __exit exit_rpcsec_gss(void)
1707{
1708        unregister_pernet_subsys(&rpcsec_gss_net_ops);
1709        gss_svc_shutdown();
1710        rpcauth_unregister(&authgss_ops);
1711        rcu_barrier(); /* Wait for completion of call_rcu()'s */
1712}
1713
1714MODULE_LICENSE("GPL");
1715module_param_named(expired_cred_retry_delay,
1716                   gss_expired_cred_retry_delay,
1717                   uint, 0644);
1718MODULE_PARM_DESC(expired_cred_retry_delay, "Timeout (in seconds) until "
1719                "the RPC engine retries an expired credential");
1720
1721module_init(init_rpcsec_gss)
1722module_exit(exit_rpcsec_gss)
1723