linux/net/sunrpc/auth_gss/auth_gss.c
<<
>>
Prefs
   1/*
   2 * linux/net/sunrpc/auth_gss/auth_gss.c
   3 *
   4 * RPCSEC_GSS client authentication.
   5 *
   6 *  Copyright (c) 2000 The Regents of the University of Michigan.
   7 *  All rights reserved.
   8 *
   9 *  Dug Song       <dugsong@monkey.org>
  10 *  Andy Adamson   <andros@umich.edu>
  11 *
  12 *  Redistribution and use in source and binary forms, with or without
  13 *  modification, are permitted provided that the following conditions
  14 *  are met:
  15 *
  16 *  1. Redistributions of source code must retain the above copyright
  17 *     notice, this list of conditions and the following disclaimer.
  18 *  2. Redistributions in binary form must reproduce the above copyright
  19 *     notice, this list of conditions and the following disclaimer in the
  20 *     documentation and/or other materials provided with the distribution.
  21 *  3. Neither the name of the University nor the names of its
  22 *     contributors may be used to endorse or promote products derived
  23 *     from this software without specific prior written permission.
  24 *
  25 *  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
  26 *  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
  27 *  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  28 *  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
  29 *  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  30 *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  31 *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
  32 *  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
  33 *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  34 *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  35 *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  36 */
  37
  38
  39#include <linux/module.h>
  40#include <linux/init.h>
  41#include <linux/types.h>
  42#include <linux/slab.h>
  43#include <linux/sched.h>
  44#include <linux/pagemap.h>
  45#include <linux/sunrpc/clnt.h>
  46#include <linux/sunrpc/auth.h>
  47#include <linux/sunrpc/auth_gss.h>
  48#include <linux/sunrpc/svcauth_gss.h>
  49#include <linux/sunrpc/gss_err.h>
  50#include <linux/workqueue.h>
  51#include <linux/sunrpc/rpc_pipe_fs.h>
  52#include <linux/sunrpc/gss_api.h>
  53#include <asm/uaccess.h>
  54
  55static const struct rpc_authops authgss_ops;
  56
  57static const struct rpc_credops gss_credops;
  58static const struct rpc_credops gss_nullops;
  59
  60#define GSS_RETRY_EXPIRED 5
  61static unsigned int gss_expired_cred_retry_delay = GSS_RETRY_EXPIRED;
  62
  63#ifdef RPC_DEBUG
  64# define RPCDBG_FACILITY        RPCDBG_AUTH
  65#endif
  66
  67#define GSS_CRED_SLACK          (RPC_MAX_AUTH_SIZE * 2)
  68/* length of a krb5 verifier (48), plus data added before arguments when
  69 * using integrity (two 4-byte integers): */
  70#define GSS_VERF_SLACK          100
  71
  72struct gss_auth {
  73        struct kref kref;
  74        struct rpc_auth rpc_auth;
  75        struct gss_api_mech *mech;
  76        enum rpc_gss_svc service;
  77        struct rpc_clnt *client;
  78        /*
  79         * There are two upcall pipes; dentry[1], named "gssd", is used
  80         * for the new text-based upcall; dentry[0] is named after the
  81         * mechanism (for example, "krb5") and exists for
  82         * backwards-compatibility with older gssd's.
  83         */
  84        struct dentry *dentry[2];
  85};
  86
  87/* pipe_version >= 0 if and only if someone has a pipe open. */
  88static int pipe_version = -1;
  89static atomic_t pipe_users = ATOMIC_INIT(0);
  90static DEFINE_SPINLOCK(pipe_version_lock);
  91static struct rpc_wait_queue pipe_version_rpc_waitqueue;
  92static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue);
  93
  94static void gss_free_ctx(struct gss_cl_ctx *);
  95static const struct rpc_pipe_ops gss_upcall_ops_v0;
  96static const struct rpc_pipe_ops gss_upcall_ops_v1;
  97
  98static inline struct gss_cl_ctx *
  99gss_get_ctx(struct gss_cl_ctx *ctx)
 100{
 101        atomic_inc(&ctx->count);
 102        return ctx;
 103}
 104
 105static inline void
 106gss_put_ctx(struct gss_cl_ctx *ctx)
 107{
 108        if (atomic_dec_and_test(&ctx->count))
 109                gss_free_ctx(ctx);
 110}
 111
 112/* gss_cred_set_ctx:
 113 * called by gss_upcall_callback and gss_create_upcall in order
 114 * to set the gss context. The actual exchange of an old context
 115 * and a new one is protected by the inode->i_lock.
 116 */
 117static void
 118gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx)
 119{
 120        struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
 121
 122        if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
 123                return;
 124        gss_get_ctx(ctx);
 125        rcu_assign_pointer(gss_cred->gc_ctx, ctx);
 126        set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
 127        smp_mb__before_clear_bit();
 128        clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags);
 129}
 130
 131static const void *
 132simple_get_bytes(const void *p, const void *end, void *res, size_t len)
 133{
 134        const void *q = (const void *)((const char *)p + len);
 135        if (unlikely(q > end || q < p))
 136                return ERR_PTR(-EFAULT);
 137        memcpy(res, p, len);
 138        return q;
 139}
 140
 141static inline const void *
 142simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest)
 143{
 144        const void *q;
 145        unsigned int len;
 146
 147        p = simple_get_bytes(p, end, &len, sizeof(len));
 148        if (IS_ERR(p))
 149                return p;
 150        q = (const void *)((const char *)p + len);
 151        if (unlikely(q > end || q < p))
 152                return ERR_PTR(-EFAULT);
 153        dest->data = kmemdup(p, len, GFP_NOFS);
 154        if (unlikely(dest->data == NULL))
 155                return ERR_PTR(-ENOMEM);
 156        dest->len = len;
 157        return q;
 158}
 159
 160static struct gss_cl_ctx *
 161gss_cred_get_ctx(struct rpc_cred *cred)
 162{
 163        struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
 164        struct gss_cl_ctx *ctx = NULL;
 165
 166        rcu_read_lock();
 167        if (gss_cred->gc_ctx)
 168                ctx = gss_get_ctx(gss_cred->gc_ctx);
 169        rcu_read_unlock();
 170        return ctx;
 171}
 172
 173static struct gss_cl_ctx *
 174gss_alloc_context(void)
 175{
 176        struct gss_cl_ctx *ctx;
 177
 178        ctx = kzalloc(sizeof(*ctx), GFP_NOFS);
 179        if (ctx != NULL) {
 180                ctx->gc_proc = RPC_GSS_PROC_DATA;
 181                ctx->gc_seq = 1;        /* NetApp 6.4R1 doesn't accept seq. no. 0 */
 182                spin_lock_init(&ctx->gc_seq_lock);
 183                atomic_set(&ctx->count,1);
 184        }
 185        return ctx;
 186}
 187
 188#define GSSD_MIN_TIMEOUT (60 * 60)
 189static const void *
 190gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct gss_api_mech *gm)
 191{
 192        const void *q;
 193        unsigned int seclen;
 194        unsigned int timeout;
 195        u32 window_size;
 196        int ret;
 197
 198        /* First unsigned int gives the lifetime (in seconds) of the cred */
 199        p = simple_get_bytes(p, end, &timeout, sizeof(timeout));
 200        if (IS_ERR(p))
 201                goto err;
 202        if (timeout == 0)
 203                timeout = GSSD_MIN_TIMEOUT;
 204        ctx->gc_expiry = jiffies + (unsigned long)timeout * HZ * 3 / 4;
 205        /* Sequence number window. Determines the maximum number of simultaneous requests */
 206        p = simple_get_bytes(p, end, &window_size, sizeof(window_size));
 207        if (IS_ERR(p))
 208                goto err;
 209        ctx->gc_win = window_size;
 210        /* gssd signals an error by passing ctx->gc_win = 0: */
 211        if (ctx->gc_win == 0) {
 212                /*
 213                 * in which case, p points to an error code. Anything other
 214                 * than -EKEYEXPIRED gets converted to -EACCES.
 215                 */
 216                p = simple_get_bytes(p, end, &ret, sizeof(ret));
 217                if (!IS_ERR(p))
 218                        p = (ret == -EKEYEXPIRED) ? ERR_PTR(-EKEYEXPIRED) :
 219                                                    ERR_PTR(-EACCES);
 220                goto err;
 221        }
 222        /* copy the opaque wire context */
 223        p = simple_get_netobj(p, end, &ctx->gc_wire_ctx);
 224        if (IS_ERR(p))
 225                goto err;
 226        /* import the opaque security context */
 227        p  = simple_get_bytes(p, end, &seclen, sizeof(seclen));
 228        if (IS_ERR(p))
 229                goto err;
 230        q = (const void *)((const char *)p + seclen);
 231        if (unlikely(q > end || q < p)) {
 232                p = ERR_PTR(-EFAULT);
 233                goto err;
 234        }
 235        ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, GFP_NOFS);
 236        if (ret < 0) {
 237                p = ERR_PTR(ret);
 238                goto err;
 239        }
 240        return q;
 241err:
 242        dprintk("RPC:       gss_fill_context returning %ld\n", -PTR_ERR(p));
 243        return p;
 244}
 245
 246#define UPCALL_BUF_LEN 128
 247
 248struct gss_upcall_msg {
 249        atomic_t count;
 250        uid_t   uid;
 251        struct rpc_pipe_msg msg;
 252        struct list_head list;
 253        struct gss_auth *auth;
 254        struct rpc_inode *inode;
 255        struct rpc_wait_queue rpc_waitqueue;
 256        wait_queue_head_t waitqueue;
 257        struct gss_cl_ctx *ctx;
 258        char databuf[UPCALL_BUF_LEN];
 259};
 260
 261static int get_pipe_version(void)
 262{
 263        int ret;
 264
 265        spin_lock(&pipe_version_lock);
 266        if (pipe_version >= 0) {
 267                atomic_inc(&pipe_users);
 268                ret = pipe_version;
 269        } else
 270                ret = -EAGAIN;
 271        spin_unlock(&pipe_version_lock);
 272        return ret;
 273}
 274
 275static void put_pipe_version(void)
 276{
 277        if (atomic_dec_and_lock(&pipe_users, &pipe_version_lock)) {
 278                pipe_version = -1;
 279                spin_unlock(&pipe_version_lock);
 280        }
 281}
 282
 283static void
 284gss_release_msg(struct gss_upcall_msg *gss_msg)
 285{
 286        if (!atomic_dec_and_test(&gss_msg->count))
 287                return;
 288        put_pipe_version();
 289        BUG_ON(!list_empty(&gss_msg->list));
 290        if (gss_msg->ctx != NULL)
 291                gss_put_ctx(gss_msg->ctx);
 292        rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue);
 293        kfree(gss_msg);
 294}
 295
 296static struct gss_upcall_msg *
 297__gss_find_upcall(struct rpc_inode *rpci, uid_t uid)
 298{
 299        struct gss_upcall_msg *pos;
 300        list_for_each_entry(pos, &rpci->in_downcall, list) {
 301                if (pos->uid != uid)
 302                        continue;
 303                atomic_inc(&pos->count);
 304                dprintk("RPC:       gss_find_upcall found msg %p\n", pos);
 305                return pos;
 306        }
 307        dprintk("RPC:       gss_find_upcall found nothing\n");
 308        return NULL;
 309}
 310
 311/* Try to add an upcall to the pipefs queue.
 312 * If an upcall owned by our uid already exists, then we return a reference
 313 * to that upcall instead of adding the new upcall.
 314 */
 315static inline struct gss_upcall_msg *
 316gss_add_msg(struct gss_upcall_msg *gss_msg)
 317{
 318        struct rpc_inode *rpci = gss_msg->inode;
 319        struct inode *inode = &rpci->vfs_inode;
 320        struct gss_upcall_msg *old;
 321
 322        spin_lock(&inode->i_lock);
 323        old = __gss_find_upcall(rpci, gss_msg->uid);
 324        if (old == NULL) {
 325                atomic_inc(&gss_msg->count);
 326                list_add(&gss_msg->list, &rpci->in_downcall);
 327        } else
 328                gss_msg = old;
 329        spin_unlock(&inode->i_lock);
 330        return gss_msg;
 331}
 332
 333static void
 334__gss_unhash_msg(struct gss_upcall_msg *gss_msg)
 335{
 336        list_del_init(&gss_msg->list);
 337        rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
 338        wake_up_all(&gss_msg->waitqueue);
 339        atomic_dec(&gss_msg->count);
 340}
 341
 342static void
 343gss_unhash_msg(struct gss_upcall_msg *gss_msg)
 344{
 345        struct inode *inode = &gss_msg->inode->vfs_inode;
 346
 347        if (list_empty(&gss_msg->list))
 348                return;
 349        spin_lock(&inode->i_lock);
 350        if (!list_empty(&gss_msg->list))
 351                __gss_unhash_msg(gss_msg);
 352        spin_unlock(&inode->i_lock);
 353}
 354
 355static void
 356gss_handle_downcall_result(struct gss_cred *gss_cred, struct gss_upcall_msg *gss_msg)
 357{
 358        switch (gss_msg->msg.errno) {
 359        case 0:
 360                if (gss_msg->ctx == NULL)
 361                        break;
 362                clear_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags);
 363                gss_cred_set_ctx(&gss_cred->gc_base, gss_msg->ctx);
 364                break;
 365        case -EKEYEXPIRED:
 366                set_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags);
 367        }
 368        gss_cred->gc_upcall_timestamp = jiffies;
 369        gss_cred->gc_upcall = NULL;
 370        rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno);
 371}
 372
 373static void
 374gss_upcall_callback(struct rpc_task *task)
 375{
 376        struct gss_cred *gss_cred = container_of(task->tk_rqstp->rq_cred,
 377                        struct gss_cred, gc_base);
 378        struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall;
 379        struct inode *inode = &gss_msg->inode->vfs_inode;
 380
 381        spin_lock(&inode->i_lock);
 382        gss_handle_downcall_result(gss_cred, gss_msg);
 383        spin_unlock(&inode->i_lock);
 384        task->tk_status = gss_msg->msg.errno;
 385        gss_release_msg(gss_msg);
 386}
 387
 388static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg)
 389{
 390        gss_msg->msg.data = &gss_msg->uid;
 391        gss_msg->msg.len = sizeof(gss_msg->uid);
 392}
 393
 394static void gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
 395                                struct rpc_clnt *clnt, int machine_cred)
 396{
 397        struct gss_api_mech *mech = gss_msg->auth->mech;
 398        char *p = gss_msg->databuf;
 399        int len = 0;
 400
 401        gss_msg->msg.len = sprintf(gss_msg->databuf, "mech=%s uid=%d ",
 402                                   mech->gm_name,
 403                                   gss_msg->uid);
 404        p += gss_msg->msg.len;
 405        if (clnt->cl_principal) {
 406                len = sprintf(p, "target=%s ", clnt->cl_principal);
 407                p += len;
 408                gss_msg->msg.len += len;
 409        }
 410        if (machine_cred) {
 411                len = sprintf(p, "service=* ");
 412                p += len;
 413                gss_msg->msg.len += len;
 414        } else if (!strcmp(clnt->cl_program->name, "nfs4_cb")) {
 415                len = sprintf(p, "service=nfs ");
 416                p += len;
 417                gss_msg->msg.len += len;
 418        }
 419        if (mech->gm_upcall_enctypes) {
 420                len = sprintf(p, "enctypes=%s ", mech->gm_upcall_enctypes);
 421                p += len;
 422                gss_msg->msg.len += len;
 423        }
 424        len = sprintf(p, "\n");
 425        gss_msg->msg.len += len;
 426
 427        gss_msg->msg.data = gss_msg->databuf;
 428        BUG_ON(gss_msg->msg.len > UPCALL_BUF_LEN);
 429}
 430
 431static void gss_encode_msg(struct gss_upcall_msg *gss_msg,
 432                                struct rpc_clnt *clnt, int machine_cred)
 433{
 434        if (pipe_version == 0)
 435                gss_encode_v0_msg(gss_msg);
 436        else /* pipe_version == 1 */
 437                gss_encode_v1_msg(gss_msg, clnt, machine_cred);
 438}
 439
 440static inline struct gss_upcall_msg *
 441gss_alloc_msg(struct gss_auth *gss_auth, uid_t uid, struct rpc_clnt *clnt,
 442                int machine_cred)
 443{
 444        struct gss_upcall_msg *gss_msg;
 445        int vers;
 446
 447        gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS);
 448        if (gss_msg == NULL)
 449                return ERR_PTR(-ENOMEM);
 450        vers = get_pipe_version();
 451        if (vers < 0) {
 452                kfree(gss_msg);
 453                return ERR_PTR(vers);
 454        }
 455        gss_msg->inode = RPC_I(gss_auth->dentry[vers]->d_inode);
 456        INIT_LIST_HEAD(&gss_msg->list);
 457        rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
 458        init_waitqueue_head(&gss_msg->waitqueue);
 459        atomic_set(&gss_msg->count, 1);
 460        gss_msg->uid = uid;
 461        gss_msg->auth = gss_auth;
 462        gss_encode_msg(gss_msg, clnt, machine_cred);
 463        return gss_msg;
 464}
 465
 466static struct gss_upcall_msg *
 467gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred)
 468{
 469        struct gss_cred *gss_cred = container_of(cred,
 470                        struct gss_cred, gc_base);
 471        struct gss_upcall_msg *gss_new, *gss_msg;
 472        uid_t uid = cred->cr_uid;
 473
 474        gss_new = gss_alloc_msg(gss_auth, uid, clnt, gss_cred->gc_machine_cred);
 475        if (IS_ERR(gss_new))
 476                return gss_new;
 477        gss_msg = gss_add_msg(gss_new);
 478        if (gss_msg == gss_new) {
 479                struct inode *inode = &gss_new->inode->vfs_inode;
 480                int res = rpc_queue_upcall(inode, &gss_new->msg);
 481                if (res) {
 482                        gss_unhash_msg(gss_new);
 483                        gss_msg = ERR_PTR(res);
 484                }
 485        } else
 486                gss_release_msg(gss_new);
 487        return gss_msg;
 488}
 489
 490static void warn_gssd(void)
 491{
 492        static unsigned long ratelimit;
 493        unsigned long now = jiffies;
 494
 495        if (time_after(now, ratelimit)) {
 496                printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n"
 497                                "Please check user daemon is running.\n");
 498                ratelimit = now + 15*HZ;
 499        }
 500}
 501
 502static inline int
 503gss_refresh_upcall(struct rpc_task *task)
 504{
 505        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 506        struct gss_auth *gss_auth = container_of(cred->cr_auth,
 507                        struct gss_auth, rpc_auth);
 508        struct gss_cred *gss_cred = container_of(cred,
 509                        struct gss_cred, gc_base);
 510        struct gss_upcall_msg *gss_msg;
 511        struct inode *inode;
 512        int err = 0;
 513
 514        dprintk("RPC: %5u gss_refresh_upcall for uid %u\n", task->tk_pid,
 515                                                                cred->cr_uid);
 516        gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred);
 517        if (PTR_ERR(gss_msg) == -EAGAIN) {
 518                /* XXX: warning on the first, under the assumption we
 519                 * shouldn't normally hit this case on a refresh. */
 520                warn_gssd();
 521                task->tk_timeout = 15*HZ;
 522                rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL);
 523                return -EAGAIN;
 524        }
 525        if (IS_ERR(gss_msg)) {
 526                err = PTR_ERR(gss_msg);
 527                goto out;
 528        }
 529        inode = &gss_msg->inode->vfs_inode;
 530        spin_lock(&inode->i_lock);
 531        if (gss_cred->gc_upcall != NULL)
 532                rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL);
 533        else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) {
 534                task->tk_timeout = 0;
 535                gss_cred->gc_upcall = gss_msg;
 536                /* gss_upcall_callback will release the reference to gss_upcall_msg */
 537                atomic_inc(&gss_msg->count);
 538                rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback);
 539        } else {
 540                gss_handle_downcall_result(gss_cred, gss_msg);
 541                err = gss_msg->msg.errno;
 542        }
 543        spin_unlock(&inode->i_lock);
 544        gss_release_msg(gss_msg);
 545out:
 546        dprintk("RPC: %5u gss_refresh_upcall for uid %u result %d\n",
 547                        task->tk_pid, cred->cr_uid, err);
 548        return err;
 549}
 550
 551static inline int
 552gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
 553{
 554        struct inode *inode;
 555        struct rpc_cred *cred = &gss_cred->gc_base;
 556        struct gss_upcall_msg *gss_msg;
 557        DEFINE_WAIT(wait);
 558        int err = 0;
 559
 560        dprintk("RPC:       gss_upcall for uid %u\n", cred->cr_uid);
 561retry:
 562        gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred);
 563        if (PTR_ERR(gss_msg) == -EAGAIN) {
 564                err = wait_event_interruptible_timeout(pipe_version_waitqueue,
 565                                pipe_version >= 0, 15*HZ);
 566                if (pipe_version < 0) {
 567                        warn_gssd();
 568                        err = -EACCES;
 569                }
 570                if (err)
 571                        goto out;
 572                goto retry;
 573        }
 574        if (IS_ERR(gss_msg)) {
 575                err = PTR_ERR(gss_msg);
 576                goto out;
 577        }
 578        inode = &gss_msg->inode->vfs_inode;
 579        for (;;) {
 580                prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_INTERRUPTIBLE);
 581                spin_lock(&inode->i_lock);
 582                if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) {
 583                        break;
 584                }
 585                spin_unlock(&inode->i_lock);
 586                if (signalled()) {
 587                        err = -ERESTARTSYS;
 588                        goto out_intr;
 589                }
 590                schedule();
 591        }
 592        if (gss_msg->ctx)
 593                gss_cred_set_ctx(cred, gss_msg->ctx);
 594        else
 595                err = gss_msg->msg.errno;
 596        spin_unlock(&inode->i_lock);
 597out_intr:
 598        finish_wait(&gss_msg->waitqueue, &wait);
 599        gss_release_msg(gss_msg);
 600out:
 601        dprintk("RPC:       gss_create_upcall for uid %u result %d\n",
 602                        cred->cr_uid, err);
 603        return err;
 604}
 605
 606static ssize_t
 607gss_pipe_upcall(struct file *filp, struct rpc_pipe_msg *msg,
 608                char __user *dst, size_t buflen)
 609{
 610        char *data = (char *)msg->data + msg->copied;
 611        size_t mlen = min(msg->len, buflen);
 612        unsigned long left;
 613
 614        left = copy_to_user(dst, data, mlen);
 615        if (left == mlen) {
 616                msg->errno = -EFAULT;
 617                return -EFAULT;
 618        }
 619
 620        mlen -= left;
 621        msg->copied += mlen;
 622        msg->errno = 0;
 623        return mlen;
 624}
 625
 626#define MSG_BUF_MAXSIZE 1024
 627
 628static ssize_t
 629gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
 630{
 631        const void *p, *end;
 632        void *buf;
 633        struct gss_upcall_msg *gss_msg;
 634        struct inode *inode = filp->f_path.dentry->d_inode;
 635        struct gss_cl_ctx *ctx;
 636        uid_t uid;
 637        ssize_t err = -EFBIG;
 638
 639        if (mlen > MSG_BUF_MAXSIZE)
 640                goto out;
 641        err = -ENOMEM;
 642        buf = kmalloc(mlen, GFP_NOFS);
 643        if (!buf)
 644                goto out;
 645
 646        err = -EFAULT;
 647        if (copy_from_user(buf, src, mlen))
 648                goto err;
 649
 650        end = (const void *)((char *)buf + mlen);
 651        p = simple_get_bytes(buf, end, &uid, sizeof(uid));
 652        if (IS_ERR(p)) {
 653                err = PTR_ERR(p);
 654                goto err;
 655        }
 656
 657        err = -ENOMEM;
 658        ctx = gss_alloc_context();
 659        if (ctx == NULL)
 660                goto err;
 661
 662        err = -ENOENT;
 663        /* Find a matching upcall */
 664        spin_lock(&inode->i_lock);
 665        gss_msg = __gss_find_upcall(RPC_I(inode), uid);
 666        if (gss_msg == NULL) {
 667                spin_unlock(&inode->i_lock);
 668                goto err_put_ctx;
 669        }
 670        list_del_init(&gss_msg->list);
 671        spin_unlock(&inode->i_lock);
 672
 673        p = gss_fill_context(p, end, ctx, gss_msg->auth->mech);
 674        if (IS_ERR(p)) {
 675                err = PTR_ERR(p);
 676                switch (err) {
 677                case -EACCES:
 678                case -EKEYEXPIRED:
 679                        gss_msg->msg.errno = err;
 680                        err = mlen;
 681                        break;
 682                case -EFAULT:
 683                case -ENOMEM:
 684                case -EINVAL:
 685                case -ENOSYS:
 686                        gss_msg->msg.errno = -EAGAIN;
 687                        break;
 688                default:
 689                        printk(KERN_CRIT "%s: bad return from "
 690                                "gss_fill_context: %zd\n", __func__, err);
 691                        BUG();
 692                }
 693                goto err_release_msg;
 694        }
 695        gss_msg->ctx = gss_get_ctx(ctx);
 696        err = mlen;
 697
 698err_release_msg:
 699        spin_lock(&inode->i_lock);
 700        __gss_unhash_msg(gss_msg);
 701        spin_unlock(&inode->i_lock);
 702        gss_release_msg(gss_msg);
 703err_put_ctx:
 704        gss_put_ctx(ctx);
 705err:
 706        kfree(buf);
 707out:
 708        dprintk("RPC:       gss_pipe_downcall returning %Zd\n", err);
 709        return err;
 710}
 711
 712static int gss_pipe_open(struct inode *inode, int new_version)
 713{
 714        int ret = 0;
 715
 716        spin_lock(&pipe_version_lock);
 717        if (pipe_version < 0) {
 718                /* First open of any gss pipe determines the version: */
 719                pipe_version = new_version;
 720                rpc_wake_up(&pipe_version_rpc_waitqueue);
 721                wake_up(&pipe_version_waitqueue);
 722        } else if (pipe_version != new_version) {
 723                /* Trying to open a pipe of a different version */
 724                ret = -EBUSY;
 725                goto out;
 726        }
 727        atomic_inc(&pipe_users);
 728out:
 729        spin_unlock(&pipe_version_lock);
 730        return ret;
 731
 732}
 733
 734static int gss_pipe_open_v0(struct inode *inode)
 735{
 736        return gss_pipe_open(inode, 0);
 737}
 738
 739static int gss_pipe_open_v1(struct inode *inode)
 740{
 741        return gss_pipe_open(inode, 1);
 742}
 743
 744static void
 745gss_pipe_release(struct inode *inode)
 746{
 747        struct rpc_inode *rpci = RPC_I(inode);
 748        struct gss_upcall_msg *gss_msg;
 749
 750restart:
 751        spin_lock(&inode->i_lock);
 752        list_for_each_entry(gss_msg, &rpci->in_downcall, list) {
 753
 754                if (!list_empty(&gss_msg->msg.list))
 755                        continue;
 756                gss_msg->msg.errno = -EPIPE;
 757                atomic_inc(&gss_msg->count);
 758                __gss_unhash_msg(gss_msg);
 759                spin_unlock(&inode->i_lock);
 760                gss_release_msg(gss_msg);
 761                goto restart;
 762        }
 763        spin_unlock(&inode->i_lock);
 764
 765        put_pipe_version();
 766}
 767
 768static void
 769gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
 770{
 771        struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg);
 772
 773        if (msg->errno < 0) {
 774                dprintk("RPC:       gss_pipe_destroy_msg releasing msg %p\n",
 775                                gss_msg);
 776                atomic_inc(&gss_msg->count);
 777                gss_unhash_msg(gss_msg);
 778                if (msg->errno == -ETIMEDOUT)
 779                        warn_gssd();
 780                gss_release_msg(gss_msg);
 781        }
 782}
 783
 784/*
 785 * NOTE: we have the opportunity to use different
 786 * parameters based on the input flavor (which must be a pseudoflavor)
 787 */
 788static struct rpc_auth *
 789gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor)
 790{
 791        struct gss_auth *gss_auth;
 792        struct rpc_auth * auth;
 793        int err = -ENOMEM; /* XXX? */
 794
 795        dprintk("RPC:       creating GSS authenticator for client %p\n", clnt);
 796
 797        if (!try_module_get(THIS_MODULE))
 798                return ERR_PTR(err);
 799        if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL)))
 800                goto out_dec;
 801        gss_auth->client = clnt;
 802        err = -EINVAL;
 803        gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor);
 804        if (!gss_auth->mech) {
 805                printk(KERN_WARNING "%s: Pseudoflavor %d not found!\n",
 806                                __func__, flavor);
 807                goto err_free;
 808        }
 809        gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor);
 810        if (gss_auth->service == 0)
 811                goto err_put_mech;
 812        auth = &gss_auth->rpc_auth;
 813        auth->au_cslack = GSS_CRED_SLACK >> 2;
 814        auth->au_rslack = GSS_VERF_SLACK >> 2;
 815        auth->au_ops = &authgss_ops;
 816        auth->au_flavor = flavor;
 817        atomic_set(&auth->au_count, 1);
 818        kref_init(&gss_auth->kref);
 819
 820        /*
 821         * Note: if we created the old pipe first, then someone who
 822         * examined the directory at the right moment might conclude
 823         * that we supported only the old pipe.  So we instead create
 824         * the new pipe first.
 825         */
 826        gss_auth->dentry[1] = rpc_mkpipe(clnt->cl_path.dentry,
 827                                         "gssd",
 828                                         clnt, &gss_upcall_ops_v1,
 829                                         RPC_PIPE_WAIT_FOR_OPEN);
 830        if (IS_ERR(gss_auth->dentry[1])) {
 831                err = PTR_ERR(gss_auth->dentry[1]);
 832                goto err_put_mech;
 833        }
 834
 835        gss_auth->dentry[0] = rpc_mkpipe(clnt->cl_path.dentry,
 836                                         gss_auth->mech->gm_name,
 837                                         clnt, &gss_upcall_ops_v0,
 838                                         RPC_PIPE_WAIT_FOR_OPEN);
 839        if (IS_ERR(gss_auth->dentry[0])) {
 840                err = PTR_ERR(gss_auth->dentry[0]);
 841                goto err_unlink_pipe_1;
 842        }
 843        err = rpcauth_init_credcache(auth);
 844        if (err)
 845                goto err_unlink_pipe_0;
 846
 847        return auth;
 848err_unlink_pipe_0:
 849        rpc_unlink(gss_auth->dentry[0]);
 850err_unlink_pipe_1:
 851        rpc_unlink(gss_auth->dentry[1]);
 852err_put_mech:
 853        gss_mech_put(gss_auth->mech);
 854err_free:
 855        kfree(gss_auth);
 856out_dec:
 857        module_put(THIS_MODULE);
 858        return ERR_PTR(err);
 859}
 860
 861static void
 862gss_free(struct gss_auth *gss_auth)
 863{
 864        rpc_unlink(gss_auth->dentry[1]);
 865        rpc_unlink(gss_auth->dentry[0]);
 866        gss_mech_put(gss_auth->mech);
 867
 868        kfree(gss_auth);
 869        module_put(THIS_MODULE);
 870}
 871
 872static void
 873gss_free_callback(struct kref *kref)
 874{
 875        struct gss_auth *gss_auth = container_of(kref, struct gss_auth, kref);
 876
 877        gss_free(gss_auth);
 878}
 879
 880static void
 881gss_destroy(struct rpc_auth *auth)
 882{
 883        struct gss_auth *gss_auth;
 884
 885        dprintk("RPC:       destroying GSS authenticator %p flavor %d\n",
 886                        auth, auth->au_flavor);
 887
 888        rpcauth_destroy_credcache(auth);
 889
 890        gss_auth = container_of(auth, struct gss_auth, rpc_auth);
 891        kref_put(&gss_auth->kref, gss_free_callback);
 892}
 893
 894/*
 895 * gss_destroying_context will cause the RPCSEC_GSS to send a NULL RPC call
 896 * to the server with the GSS control procedure field set to
 897 * RPC_GSS_PROC_DESTROY. This should normally cause the server to release
 898 * all RPCSEC_GSS state associated with that context.
 899 */
 900static int
 901gss_destroying_context(struct rpc_cred *cred)
 902{
 903        struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
 904        struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
 905        struct rpc_task *task;
 906
 907        if (gss_cred->gc_ctx == NULL ||
 908            test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0)
 909                return 0;
 910
 911        gss_cred->gc_ctx->gc_proc = RPC_GSS_PROC_DESTROY;
 912        cred->cr_ops = &gss_nullops;
 913
 914        /* Take a reference to ensure the cred will be destroyed either
 915         * by the RPC call or by the put_rpccred() below */
 916        get_rpccred(cred);
 917
 918        task = rpc_call_null(gss_auth->client, cred, RPC_TASK_ASYNC|RPC_TASK_SOFT);
 919        if (!IS_ERR(task))
 920                rpc_put_task(task);
 921
 922        put_rpccred(cred);
 923        return 1;
 924}
 925
 926/* gss_destroy_cred (and gss_free_ctx) are used to clean up after failure
 927 * to create a new cred or context, so they check that things have been
 928 * allocated before freeing them. */
 929static void
 930gss_do_free_ctx(struct gss_cl_ctx *ctx)
 931{
 932        dprintk("RPC:       gss_free_ctx\n");
 933
 934        gss_delete_sec_context(&ctx->gc_gss_ctx);
 935        kfree(ctx->gc_wire_ctx.data);
 936        kfree(ctx);
 937}
 938
 939static void
 940gss_free_ctx_callback(struct rcu_head *head)
 941{
 942        struct gss_cl_ctx *ctx = container_of(head, struct gss_cl_ctx, gc_rcu);
 943        gss_do_free_ctx(ctx);
 944}
 945
 946static void
 947gss_free_ctx(struct gss_cl_ctx *ctx)
 948{
 949        call_rcu(&ctx->gc_rcu, gss_free_ctx_callback);
 950}
 951
 952static void
 953gss_free_cred(struct gss_cred *gss_cred)
 954{
 955        dprintk("RPC:       gss_free_cred %p\n", gss_cred);
 956        kfree(gss_cred);
 957}
 958
 959static void
 960gss_free_cred_callback(struct rcu_head *head)
 961{
 962        struct gss_cred *gss_cred = container_of(head, struct gss_cred, gc_base.cr_rcu);
 963        gss_free_cred(gss_cred);
 964}
 965
 966static void
 967gss_destroy_nullcred(struct rpc_cred *cred)
 968{
 969        struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
 970        struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
 971        struct gss_cl_ctx *ctx = gss_cred->gc_ctx;
 972
 973        rcu_assign_pointer(gss_cred->gc_ctx, NULL);
 974        call_rcu(&cred->cr_rcu, gss_free_cred_callback);
 975        if (ctx)
 976                gss_put_ctx(ctx);
 977        kref_put(&gss_auth->kref, gss_free_callback);
 978}
 979
 980static void
 981gss_destroy_cred(struct rpc_cred *cred)
 982{
 983
 984        if (gss_destroying_context(cred))
 985                return;
 986        gss_destroy_nullcred(cred);
 987}
 988
 989/*
 990 * Lookup RPCSEC_GSS cred for the current process
 991 */
 992static struct rpc_cred *
 993gss_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
 994{
 995        return rpcauth_lookup_credcache(auth, acred, flags);
 996}
 997
 998static struct rpc_cred *
 999gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
1000{
1001        struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
1002        struct gss_cred *cred = NULL;
1003        int err = -ENOMEM;
1004
1005        dprintk("RPC:       gss_create_cred for uid %d, flavor %d\n",
1006                acred->uid, auth->au_flavor);
1007
1008        if (!(cred = kzalloc(sizeof(*cred), GFP_NOFS)))
1009                goto out_err;
1010
1011        rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops);
1012        /*
1013         * Note: in order to force a call to call_refresh(), we deliberately
1014         * fail to flag the credential as RPCAUTH_CRED_UPTODATE.
1015         */
1016        cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW;
1017        cred->gc_service = gss_auth->service;
1018        cred->gc_machine_cred = acred->machine_cred;
1019        kref_get(&gss_auth->kref);
1020        return &cred->gc_base;
1021
1022out_err:
1023        dprintk("RPC:       gss_create_cred failed with error %d\n", err);
1024        return ERR_PTR(err);
1025}
1026
1027static int
1028gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred)
1029{
1030        struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth);
1031        struct gss_cred *gss_cred = container_of(cred,struct gss_cred, gc_base);
1032        int err;
1033
1034        do {
1035                err = gss_create_upcall(gss_auth, gss_cred);
1036        } while (err == -EAGAIN);
1037        return err;
1038}
1039
1040static int
1041gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags)
1042{
1043        struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base);
1044
1045        if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags))
1046                goto out;
1047        /* Don't match with creds that have expired. */
1048        if (time_after(jiffies, gss_cred->gc_ctx->gc_expiry))
1049                return 0;
1050        if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags))
1051                return 0;
1052out:
1053        if (acred->machine_cred != gss_cred->gc_machine_cred)
1054                return 0;
1055        return rc->cr_uid == acred->uid;
1056}
1057
1058/*
1059* Marshal credentials.
1060* Maybe we should keep a cached credential for performance reasons.
1061*/
1062static __be32 *
1063gss_marshal(struct rpc_task *task, __be32 *p)
1064{
1065        struct rpc_rqst *req = task->tk_rqstp;
1066        struct rpc_cred *cred = req->rq_cred;
1067        struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1068                                                 gc_base);
1069        struct gss_cl_ctx       *ctx = gss_cred_get_ctx(cred);
1070        __be32          *cred_len;
1071        u32             maj_stat = 0;
1072        struct xdr_netobj mic;
1073        struct kvec     iov;
1074        struct xdr_buf  verf_buf;
1075
1076        dprintk("RPC: %5u gss_marshal\n", task->tk_pid);
1077
1078        *p++ = htonl(RPC_AUTH_GSS);
1079        cred_len = p++;
1080
1081        spin_lock(&ctx->gc_seq_lock);
1082        req->rq_seqno = ctx->gc_seq++;
1083        spin_unlock(&ctx->gc_seq_lock);
1084
1085        *p++ = htonl((u32) RPC_GSS_VERSION);
1086        *p++ = htonl((u32) ctx->gc_proc);
1087        *p++ = htonl((u32) req->rq_seqno);
1088        *p++ = htonl((u32) gss_cred->gc_service);
1089        p = xdr_encode_netobj(p, &ctx->gc_wire_ctx);
1090        *cred_len = htonl((p - (cred_len + 1)) << 2);
1091
1092        /* We compute the checksum for the verifier over the xdr-encoded bytes
1093         * starting with the xid and ending at the end of the credential: */
1094        iov.iov_base = xprt_skip_transport_header(task->tk_xprt,
1095                                        req->rq_snd_buf.head[0].iov_base);
1096        iov.iov_len = (u8 *)p - (u8 *)iov.iov_base;
1097        xdr_buf_from_iov(&iov, &verf_buf);
1098
1099        /* set verifier flavor*/
1100        *p++ = htonl(RPC_AUTH_GSS);
1101
1102        mic.data = (u8 *)(p + 1);
1103        maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1104        if (maj_stat == GSS_S_CONTEXT_EXPIRED) {
1105                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1106        } else if (maj_stat != 0) {
1107                printk("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat);
1108                goto out_put_ctx;
1109        }
1110        p = xdr_encode_opaque(p, NULL, mic.len);
1111        gss_put_ctx(ctx);
1112        return p;
1113out_put_ctx:
1114        gss_put_ctx(ctx);
1115        return NULL;
1116}
1117
1118static int gss_renew_cred(struct rpc_task *task)
1119{
1120        struct rpc_cred *oldcred = task->tk_rqstp->rq_cred;
1121        struct gss_cred *gss_cred = container_of(oldcred,
1122                                                 struct gss_cred,
1123                                                 gc_base);
1124        struct rpc_auth *auth = oldcred->cr_auth;
1125        struct auth_cred acred = {
1126                .uid = oldcred->cr_uid,
1127                .machine_cred = gss_cred->gc_machine_cred,
1128        };
1129        struct rpc_cred *new;
1130
1131        new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW);
1132        if (IS_ERR(new))
1133                return PTR_ERR(new);
1134        task->tk_rqstp->rq_cred = new;
1135        put_rpccred(oldcred);
1136        return 0;
1137}
1138
1139static int gss_cred_is_negative_entry(struct rpc_cred *cred)
1140{
1141        if (test_bit(RPCAUTH_CRED_NEGATIVE, &cred->cr_flags)) {
1142                unsigned long now = jiffies;
1143                unsigned long begin, expire;
1144                struct gss_cred *gss_cred; 
1145
1146                gss_cred = container_of(cred, struct gss_cred, gc_base);
1147                begin = gss_cred->gc_upcall_timestamp;
1148                expire = begin + gss_expired_cred_retry_delay * HZ;
1149
1150                if (time_in_range_open(now, begin, expire))
1151                        return 1;
1152        }
1153        return 0;
1154}
1155
1156/*
1157* Refresh credentials. XXX - finish
1158*/
1159static int
1160gss_refresh(struct rpc_task *task)
1161{
1162        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1163        int ret = 0;
1164
1165        if (gss_cred_is_negative_entry(cred))
1166                return -EKEYEXPIRED;
1167
1168        if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
1169                        !test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) {
1170                ret = gss_renew_cred(task);
1171                if (ret < 0)
1172                        goto out;
1173                cred = task->tk_rqstp->rq_cred;
1174        }
1175
1176        if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags))
1177                ret = gss_refresh_upcall(task);
1178out:
1179        return ret;
1180}
1181
1182/* Dummy refresh routine: used only when destroying the context */
1183static int
1184gss_refresh_null(struct rpc_task *task)
1185{
1186        return -EACCES;
1187}
1188
1189static __be32 *
1190gss_validate(struct rpc_task *task, __be32 *p)
1191{
1192        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1193        struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1194        __be32          seq;
1195        struct kvec     iov;
1196        struct xdr_buf  verf_buf;
1197        struct xdr_netobj mic;
1198        u32             flav,len;
1199        u32             maj_stat;
1200
1201        dprintk("RPC: %5u gss_validate\n", task->tk_pid);
1202
1203        flav = ntohl(*p++);
1204        if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE)
1205                goto out_bad;
1206        if (flav != RPC_AUTH_GSS)
1207                goto out_bad;
1208        seq = htonl(task->tk_rqstp->rq_seqno);
1209        iov.iov_base = &seq;
1210        iov.iov_len = sizeof(seq);
1211        xdr_buf_from_iov(&iov, &verf_buf);
1212        mic.data = (u8 *)p;
1213        mic.len = len;
1214
1215        maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
1216        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1217                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1218        if (maj_stat) {
1219                dprintk("RPC: %5u gss_validate: gss_verify_mic returned "
1220                                "error 0x%08x\n", task->tk_pid, maj_stat);
1221                goto out_bad;
1222        }
1223        /* We leave it to unwrap to calculate au_rslack. For now we just
1224         * calculate the length of the verifier: */
1225        cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2;
1226        gss_put_ctx(ctx);
1227        dprintk("RPC: %5u gss_validate: gss_verify_mic succeeded.\n",
1228                        task->tk_pid);
1229        return p + XDR_QUADLEN(len);
1230out_bad:
1231        gss_put_ctx(ctx);
1232        dprintk("RPC: %5u gss_validate failed.\n", task->tk_pid);
1233        return NULL;
1234}
1235
1236static void gss_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
1237                                __be32 *p, void *obj)
1238{
1239        struct xdr_stream xdr;
1240
1241        xdr_init_encode(&xdr, &rqstp->rq_snd_buf, p);
1242        encode(rqstp, &xdr, obj);
1243}
1244
1245static inline int
1246gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1247                   kxdreproc_t encode, struct rpc_rqst *rqstp,
1248                   __be32 *p, void *obj)
1249{
1250        struct xdr_buf  *snd_buf = &rqstp->rq_snd_buf;
1251        struct xdr_buf  integ_buf;
1252        __be32          *integ_len = NULL;
1253        struct xdr_netobj mic;
1254        u32             offset;
1255        __be32          *q;
1256        struct kvec     *iov;
1257        u32             maj_stat = 0;
1258        int             status = -EIO;
1259
1260        integ_len = p++;
1261        offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1262        *p++ = htonl(rqstp->rq_seqno);
1263
1264        gss_wrap_req_encode(encode, rqstp, p, obj);
1265
1266        if (xdr_buf_subsegment(snd_buf, &integ_buf,
1267                                offset, snd_buf->len - offset))
1268                return status;
1269        *integ_len = htonl(integ_buf.len);
1270
1271        /* guess whether we're in the head or the tail: */
1272        if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1273                iov = snd_buf->tail;
1274        else
1275                iov = snd_buf->head;
1276        p = iov->iov_base + iov->iov_len;
1277        mic.data = (u8 *)(p + 1);
1278
1279        maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1280        status = -EIO; /* XXX? */
1281        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1282                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1283        else if (maj_stat)
1284                return status;
1285        q = xdr_encode_opaque(p, NULL, mic.len);
1286
1287        offset = (u8 *)q - (u8 *)p;
1288        iov->iov_len += offset;
1289        snd_buf->len += offset;
1290        return 0;
1291}
1292
1293static void
1294priv_release_snd_buf(struct rpc_rqst *rqstp)
1295{
1296        int i;
1297
1298        for (i=0; i < rqstp->rq_enc_pages_num; i++)
1299                __free_page(rqstp->rq_enc_pages[i]);
1300        kfree(rqstp->rq_enc_pages);
1301}
1302
1303static int
1304alloc_enc_pages(struct rpc_rqst *rqstp)
1305{
1306        struct xdr_buf *snd_buf = &rqstp->rq_snd_buf;
1307        int first, last, i;
1308
1309        if (snd_buf->page_len == 0) {
1310                rqstp->rq_enc_pages_num = 0;
1311                return 0;
1312        }
1313
1314        first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1315        last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_CACHE_SHIFT;
1316        rqstp->rq_enc_pages_num = last - first + 1 + 1;
1317        rqstp->rq_enc_pages
1318                = kmalloc(rqstp->rq_enc_pages_num * sizeof(struct page *),
1319                                GFP_NOFS);
1320        if (!rqstp->rq_enc_pages)
1321                goto out;
1322        for (i=0; i < rqstp->rq_enc_pages_num; i++) {
1323                rqstp->rq_enc_pages[i] = alloc_page(GFP_NOFS);
1324                if (rqstp->rq_enc_pages[i] == NULL)
1325                        goto out_free;
1326        }
1327        rqstp->rq_release_snd_buf = priv_release_snd_buf;
1328        return 0;
1329out_free:
1330        rqstp->rq_enc_pages_num = i;
1331        priv_release_snd_buf(rqstp);
1332out:
1333        return -EAGAIN;
1334}
1335
1336static inline int
1337gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1338                  kxdreproc_t encode, struct rpc_rqst *rqstp,
1339                  __be32 *p, void *obj)
1340{
1341        struct xdr_buf  *snd_buf = &rqstp->rq_snd_buf;
1342        u32             offset;
1343        u32             maj_stat;
1344        int             status;
1345        __be32          *opaque_len;
1346        struct page     **inpages;
1347        int             first;
1348        int             pad;
1349        struct kvec     *iov;
1350        char            *tmp;
1351
1352        opaque_len = p++;
1353        offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
1354        *p++ = htonl(rqstp->rq_seqno);
1355
1356        gss_wrap_req_encode(encode, rqstp, p, obj);
1357
1358        status = alloc_enc_pages(rqstp);
1359        if (status)
1360                return status;
1361        first = snd_buf->page_base >> PAGE_CACHE_SHIFT;
1362        inpages = snd_buf->pages + first;
1363        snd_buf->pages = rqstp->rq_enc_pages;
1364        snd_buf->page_base -= first << PAGE_CACHE_SHIFT;
1365        /*
1366         * Give the tail its own page, in case we need extra space in the
1367         * head when wrapping:
1368         *
1369         * call_allocate() allocates twice the slack space required
1370         * by the authentication flavor to rq_callsize.
1371         * For GSS, slack is GSS_CRED_SLACK.
1372         */
1373        if (snd_buf->page_len || snd_buf->tail[0].iov_len) {
1374                tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]);
1375                memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len);
1376                snd_buf->tail[0].iov_base = tmp;
1377        }
1378        maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages);
1379        /* slack space should prevent this ever happening: */
1380        BUG_ON(snd_buf->len > snd_buf->buflen);
1381        status = -EIO;
1382        /* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was
1383         * done anyway, so it's safe to put the request on the wire: */
1384        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1385                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1386        else if (maj_stat)
1387                return status;
1388
1389        *opaque_len = htonl(snd_buf->len - offset);
1390        /* guess whether we're in the head or the tail: */
1391        if (snd_buf->page_len || snd_buf->tail[0].iov_len)
1392                iov = snd_buf->tail;
1393        else
1394                iov = snd_buf->head;
1395        p = iov->iov_base + iov->iov_len;
1396        pad = 3 - ((snd_buf->len - offset - 1) & 3);
1397        memset(p, 0, pad);
1398        iov->iov_len += pad;
1399        snd_buf->len += pad;
1400
1401        return 0;
1402}
1403
1404static int
1405gss_wrap_req(struct rpc_task *task,
1406             kxdreproc_t encode, void *rqstp, __be32 *p, void *obj)
1407{
1408        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1409        struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1410                        gc_base);
1411        struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1412        int             status = -EIO;
1413
1414        dprintk("RPC: %5u gss_wrap_req\n", task->tk_pid);
1415        if (ctx->gc_proc != RPC_GSS_PROC_DATA) {
1416                /* The spec seems a little ambiguous here, but I think that not
1417                 * wrapping context destruction requests makes the most sense.
1418                 */
1419                gss_wrap_req_encode(encode, rqstp, p, obj);
1420                status = 0;
1421                goto out;
1422        }
1423        switch (gss_cred->gc_service) {
1424                case RPC_GSS_SVC_NONE:
1425                        gss_wrap_req_encode(encode, rqstp, p, obj);
1426                        status = 0;
1427                        break;
1428                case RPC_GSS_SVC_INTEGRITY:
1429                        status = gss_wrap_req_integ(cred, ctx, encode,
1430                                                                rqstp, p, obj);
1431                        break;
1432                case RPC_GSS_SVC_PRIVACY:
1433                        status = gss_wrap_req_priv(cred, ctx, encode,
1434                                        rqstp, p, obj);
1435                        break;
1436        }
1437out:
1438        gss_put_ctx(ctx);
1439        dprintk("RPC: %5u gss_wrap_req returning %d\n", task->tk_pid, status);
1440        return status;
1441}
1442
1443static inline int
1444gss_unwrap_resp_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1445                struct rpc_rqst *rqstp, __be32 **p)
1446{
1447        struct xdr_buf  *rcv_buf = &rqstp->rq_rcv_buf;
1448        struct xdr_buf integ_buf;
1449        struct xdr_netobj mic;
1450        u32 data_offset, mic_offset;
1451        u32 integ_len;
1452        u32 maj_stat;
1453        int status = -EIO;
1454
1455        integ_len = ntohl(*(*p)++);
1456        if (integ_len & 3)
1457                return status;
1458        data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1459        mic_offset = integ_len + data_offset;
1460        if (mic_offset > rcv_buf->len)
1461                return status;
1462        if (ntohl(*(*p)++) != rqstp->rq_seqno)
1463                return status;
1464
1465        if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset,
1466                                mic_offset - data_offset))
1467                return status;
1468
1469        if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset))
1470                return status;
1471
1472        maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
1473        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1474                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1475        if (maj_stat != GSS_S_COMPLETE)
1476                return status;
1477        return 0;
1478}
1479
1480static inline int
1481gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
1482                struct rpc_rqst *rqstp, __be32 **p)
1483{
1484        struct xdr_buf  *rcv_buf = &rqstp->rq_rcv_buf;
1485        u32 offset;
1486        u32 opaque_len;
1487        u32 maj_stat;
1488        int status = -EIO;
1489
1490        opaque_len = ntohl(*(*p)++);
1491        offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
1492        if (offset + opaque_len > rcv_buf->len)
1493                return status;
1494        /* remove padding: */
1495        rcv_buf->len = offset + opaque_len;
1496
1497        maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf);
1498        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
1499                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
1500        if (maj_stat != GSS_S_COMPLETE)
1501                return status;
1502        if (ntohl(*(*p)++) != rqstp->rq_seqno)
1503                return status;
1504
1505        return 0;
1506}
1507
1508static int
1509gss_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
1510                      __be32 *p, void *obj)
1511{
1512        struct xdr_stream xdr;
1513
1514        xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, p);
1515        return decode(rqstp, &xdr, obj);
1516}
1517
1518static int
1519gss_unwrap_resp(struct rpc_task *task,
1520                kxdrdproc_t decode, void *rqstp, __be32 *p, void *obj)
1521{
1522        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
1523        struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
1524                        gc_base);
1525        struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
1526        __be32          *savedp = p;
1527        struct kvec     *head = ((struct rpc_rqst *)rqstp)->rq_rcv_buf.head;
1528        int             savedlen = head->iov_len;
1529        int             status = -EIO;
1530
1531        if (ctx->gc_proc != RPC_GSS_PROC_DATA)
1532                goto out_decode;
1533        switch (gss_cred->gc_service) {
1534                case RPC_GSS_SVC_NONE:
1535                        break;
1536                case RPC_GSS_SVC_INTEGRITY:
1537                        status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p);
1538                        if (status)
1539                                goto out;
1540                        break;
1541                case RPC_GSS_SVC_PRIVACY:
1542                        status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p);
1543                        if (status)
1544                                goto out;
1545                        break;
1546        }
1547        /* take into account extra slack for integrity and privacy cases: */
1548        cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp)
1549                                                + (savedlen - head->iov_len);
1550out_decode:
1551        status = gss_unwrap_req_decode(decode, rqstp, p, obj);
1552out:
1553        gss_put_ctx(ctx);
1554        dprintk("RPC: %5u gss_unwrap_resp returning %d\n", task->tk_pid,
1555                        status);
1556        return status;
1557}
1558
1559static const struct rpc_authops authgss_ops = {
1560        .owner          = THIS_MODULE,
1561        .au_flavor      = RPC_AUTH_GSS,
1562        .au_name        = "RPCSEC_GSS",
1563        .create         = gss_create,
1564        .destroy        = gss_destroy,
1565        .lookup_cred    = gss_lookup_cred,
1566        .crcreate       = gss_create_cred
1567};
1568
1569static const struct rpc_credops gss_credops = {
1570        .cr_name        = "AUTH_GSS",
1571        .crdestroy      = gss_destroy_cred,
1572        .cr_init        = gss_cred_init,
1573        .crbind         = rpcauth_generic_bind_cred,
1574        .crmatch        = gss_match,
1575        .crmarshal      = gss_marshal,
1576        .crrefresh      = gss_refresh,
1577        .crvalidate     = gss_validate,
1578        .crwrap_req     = gss_wrap_req,
1579        .crunwrap_resp  = gss_unwrap_resp,
1580};
1581
1582static const struct rpc_credops gss_nullops = {
1583        .cr_name        = "AUTH_GSS",
1584        .crdestroy      = gss_destroy_nullcred,
1585        .crbind         = rpcauth_generic_bind_cred,
1586        .crmatch        = gss_match,
1587        .crmarshal      = gss_marshal,
1588        .crrefresh      = gss_refresh_null,
1589        .crvalidate     = gss_validate,
1590        .crwrap_req     = gss_wrap_req,
1591        .crunwrap_resp  = gss_unwrap_resp,
1592};
1593
1594static const struct rpc_pipe_ops gss_upcall_ops_v0 = {
1595        .upcall         = gss_pipe_upcall,
1596        .downcall       = gss_pipe_downcall,
1597        .destroy_msg    = gss_pipe_destroy_msg,
1598        .open_pipe      = gss_pipe_open_v0,
1599        .release_pipe   = gss_pipe_release,
1600};
1601
1602static const struct rpc_pipe_ops gss_upcall_ops_v1 = {
1603        .upcall         = gss_pipe_upcall,
1604        .downcall       = gss_pipe_downcall,
1605        .destroy_msg    = gss_pipe_destroy_msg,
1606        .open_pipe      = gss_pipe_open_v1,
1607        .release_pipe   = gss_pipe_release,
1608};
1609
1610/*
1611 * Initialize RPCSEC_GSS module
1612 */
1613static int __init init_rpcsec_gss(void)
1614{
1615        int err = 0;
1616
1617        err = rpcauth_register(&authgss_ops);
1618        if (err)
1619                goto out;
1620        err = gss_svc_init();
1621        if (err)
1622                goto out_unregister;
1623        rpc_init_wait_queue(&pipe_version_rpc_waitqueue, "gss pipe version");
1624        return 0;
1625out_unregister:
1626        rpcauth_unregister(&authgss_ops);
1627out:
1628        return err;
1629}
1630
1631static void __exit exit_rpcsec_gss(void)
1632{
1633        gss_svc_shutdown();
1634        rpcauth_unregister(&authgss_ops);
1635        rcu_barrier(); /* Wait for completion of call_rcu()'s */
1636}
1637
1638MODULE_LICENSE("GPL");
1639module_param_named(expired_cred_retry_delay,
1640                   gss_expired_cred_retry_delay,
1641                   uint, 0644);
1642MODULE_PARM_DESC(expired_cred_retry_delay, "Timeout (in seconds) until "
1643                "the RPC engine retries an expired credential");
1644
1645module_init(init_rpcsec_gss)
1646module_exit(exit_rpcsec_gss)
1647