linux/net/sunrpc/auth.c
<<
>>
Prefs
   1/*
   2 * linux/net/sunrpc/auth.c
   3 *
   4 * Generic RPC client authentication API.
   5 *
   6 * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
   7 */
   8
   9#include <linux/types.h>
  10#include <linux/sched.h>
  11#include <linux/module.h>
  12#include <linux/slab.h>
  13#include <linux/errno.h>
  14#include <linux/hash.h>
  15#include <linux/sunrpc/clnt.h>
  16#include <linux/sunrpc/gss_api.h>
  17#include <linux/spinlock.h>
  18
  19#ifdef RPC_DEBUG
  20# define RPCDBG_FACILITY        RPCDBG_AUTH
  21#endif
  22
  23#define RPC_CREDCACHE_DEFAULT_HASHBITS  (4)
  24struct rpc_cred_cache {
  25        struct hlist_head       *hashtable;
  26        unsigned int            hashbits;
  27        spinlock_t              lock;
  28};
  29
  30static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;
  31
  32static DEFINE_SPINLOCK(rpc_authflavor_lock);
  33static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
  34        &authnull_ops,          /* AUTH_NULL */
  35        &authunix_ops,          /* AUTH_UNIX */
  36        NULL,                   /* others can be loadable modules */
  37};
  38
  39static LIST_HEAD(cred_unused);
  40static unsigned long number_cred_unused;
  41
  42#define MAX_HASHTABLE_BITS (14)
  43static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
  44{
  45        unsigned long num;
  46        unsigned int nbits;
  47        int ret;
  48
  49        if (!val)
  50                goto out_inval;
  51        ret = strict_strtoul(val, 0, &num);
  52        if (ret == -EINVAL)
  53                goto out_inval;
  54        nbits = fls(num);
  55        if (num > (1U << nbits))
  56                nbits++;
  57        if (nbits > MAX_HASHTABLE_BITS || nbits < 2)
  58                goto out_inval;
  59        *(unsigned int *)kp->arg = nbits;
  60        return 0;
  61out_inval:
  62        return -EINVAL;
  63}
  64
  65static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp)
  66{
  67        unsigned int nbits;
  68
  69        nbits = *(unsigned int *)kp->arg;
  70        return sprintf(buffer, "%u", 1U << nbits);
  71}
  72
  73#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int);
  74
  75static struct kernel_param_ops param_ops_hashtbl_sz = {
  76        .set = param_set_hashtbl_sz,
  77        .get = param_get_hashtbl_sz,
  78};
  79
  80module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644);
  81MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");
  82
  83static u32
  84pseudoflavor_to_flavor(u32 flavor) {
  85        if (flavor > RPC_AUTH_MAXFLAVOR)
  86                return RPC_AUTH_GSS;
  87        return flavor;
  88}
  89
  90int
  91rpcauth_register(const struct rpc_authops *ops)
  92{
  93        rpc_authflavor_t flavor;
  94        int ret = -EPERM;
  95
  96        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
  97                return -EINVAL;
  98        spin_lock(&rpc_authflavor_lock);
  99        if (auth_flavors[flavor] == NULL) {
 100                auth_flavors[flavor] = ops;
 101                ret = 0;
 102        }
 103        spin_unlock(&rpc_authflavor_lock);
 104        return ret;
 105}
 106EXPORT_SYMBOL_GPL(rpcauth_register);
 107
 108int
 109rpcauth_unregister(const struct rpc_authops *ops)
 110{
 111        rpc_authflavor_t flavor;
 112        int ret = -EPERM;
 113
 114        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
 115                return -EINVAL;
 116        spin_lock(&rpc_authflavor_lock);
 117        if (auth_flavors[flavor] == ops) {
 118                auth_flavors[flavor] = NULL;
 119                ret = 0;
 120        }
 121        spin_unlock(&rpc_authflavor_lock);
 122        return ret;
 123}
 124EXPORT_SYMBOL_GPL(rpcauth_unregister);
 125
 126/**
 127 * rpcauth_get_pseudoflavor - check if security flavor is supported
 128 * @flavor: a security flavor
 129 * @info: a GSS mech OID, quality of protection, and service value
 130 *
 131 * Verifies that an appropriate kernel module is available or already loaded.
 132 * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is
 133 * not supported locally.
 134 */
 135rpc_authflavor_t
 136rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info)
 137{
 138        const struct rpc_authops *ops;
 139        rpc_authflavor_t pseudoflavor;
 140
 141        ops = auth_flavors[flavor];
 142        if (ops == NULL)
 143                request_module("rpc-auth-%u", flavor);
 144        spin_lock(&rpc_authflavor_lock);
 145        ops = auth_flavors[flavor];
 146        if (ops == NULL || !try_module_get(ops->owner)) {
 147                spin_unlock(&rpc_authflavor_lock);
 148                return RPC_AUTH_MAXFLAVOR;
 149        }
 150        spin_unlock(&rpc_authflavor_lock);
 151
 152        pseudoflavor = flavor;
 153        if (ops->info2flavor != NULL)
 154                pseudoflavor = ops->info2flavor(info);
 155
 156        module_put(ops->owner);
 157        return pseudoflavor;
 158}
 159EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor);
 160
 161/**
 162 * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor
 163 * @pseudoflavor: GSS pseudoflavor to match
 164 * @info: rpcsec_gss_info structure to fill in
 165 *
 166 * Returns zero and fills in "info" if pseudoflavor matches a
 167 * supported mechanism.
 168 */
 169int
 170rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info)
 171{
 172        rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor);
 173        const struct rpc_authops *ops;
 174        int result;
 175
 176        if (flavor >= RPC_AUTH_MAXFLAVOR)
 177                return -EINVAL;
 178
 179        ops = auth_flavors[flavor];
 180        if (ops == NULL)
 181                request_module("rpc-auth-%u", flavor);
 182        spin_lock(&rpc_authflavor_lock);
 183        ops = auth_flavors[flavor];
 184        if (ops == NULL || !try_module_get(ops->owner)) {
 185                spin_unlock(&rpc_authflavor_lock);
 186                return -ENOENT;
 187        }
 188        spin_unlock(&rpc_authflavor_lock);
 189
 190        result = -ENOENT;
 191        if (ops->flavor2info != NULL)
 192                result = ops->flavor2info(pseudoflavor, info);
 193
 194        module_put(ops->owner);
 195        return result;
 196}
 197EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo);
 198
 199/**
 200 * rpcauth_list_flavors - discover registered flavors and pseudoflavors
 201 * @array: array to fill in
 202 * @size: size of "array"
 203 *
 204 * Returns the number of array items filled in, or a negative errno.
 205 *
 206 * The returned array is not sorted by any policy.  Callers should not
 207 * rely on the order of the items in the returned array.
 208 */
 209int
 210rpcauth_list_flavors(rpc_authflavor_t *array, int size)
 211{
 212        rpc_authflavor_t flavor;
 213        int result = 0;
 214
 215        spin_lock(&rpc_authflavor_lock);
 216        for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) {
 217                const struct rpc_authops *ops = auth_flavors[flavor];
 218                rpc_authflavor_t pseudos[4];
 219                int i, len;
 220
 221                if (result >= size) {
 222                        result = -ENOMEM;
 223                        break;
 224                }
 225
 226                if (ops == NULL)
 227                        continue;
 228                if (ops->list_pseudoflavors == NULL) {
 229                        array[result++] = ops->au_flavor;
 230                        continue;
 231                }
 232                len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos));
 233                if (len < 0) {
 234                        result = len;
 235                        break;
 236                }
 237                for (i = 0; i < len; i++) {
 238                        if (result >= size) {
 239                                result = -ENOMEM;
 240                                break;
 241                        }
 242                        array[result++] = pseudos[i];
 243                }
 244        }
 245        spin_unlock(&rpc_authflavor_lock);
 246
 247        dprintk("RPC:       %s returns %d\n", __func__, result);
 248        return result;
 249}
 250EXPORT_SYMBOL_GPL(rpcauth_list_flavors);
 251
 252struct rpc_auth *
 253rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt)
 254{
 255        struct rpc_auth         *auth;
 256        const struct rpc_authops *ops;
 257        u32                     flavor = pseudoflavor_to_flavor(pseudoflavor);
 258
 259        auth = ERR_PTR(-EINVAL);
 260        if (flavor >= RPC_AUTH_MAXFLAVOR)
 261                goto out;
 262
 263        if ((ops = auth_flavors[flavor]) == NULL)
 264                request_module("rpc-auth-%u", flavor);
 265        spin_lock(&rpc_authflavor_lock);
 266        ops = auth_flavors[flavor];
 267        if (ops == NULL || !try_module_get(ops->owner)) {
 268                spin_unlock(&rpc_authflavor_lock);
 269                goto out;
 270        }
 271        spin_unlock(&rpc_authflavor_lock);
 272        auth = ops->create(clnt, pseudoflavor);
 273        module_put(ops->owner);
 274        if (IS_ERR(auth))
 275                return auth;
 276        if (clnt->cl_auth)
 277                rpcauth_release(clnt->cl_auth);
 278        clnt->cl_auth = auth;
 279
 280out:
 281        return auth;
 282}
 283EXPORT_SYMBOL_GPL(rpcauth_create);
 284
 285void
 286rpcauth_release(struct rpc_auth *auth)
 287{
 288        if (!atomic_dec_and_test(&auth->au_count))
 289                return;
 290        auth->au_ops->destroy(auth);
 291}
 292
 293static DEFINE_SPINLOCK(rpc_credcache_lock);
 294
 295static void
 296rpcauth_unhash_cred_locked(struct rpc_cred *cred)
 297{
 298        hlist_del_rcu(&cred->cr_hash);
 299        smp_mb__before_clear_bit();
 300        clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
 301}
 302
 303static int
 304rpcauth_unhash_cred(struct rpc_cred *cred)
 305{
 306        spinlock_t *cache_lock;
 307        int ret;
 308
 309        cache_lock = &cred->cr_auth->au_credcache->lock;
 310        spin_lock(cache_lock);
 311        ret = atomic_read(&cred->cr_count) == 0;
 312        if (ret)
 313                rpcauth_unhash_cred_locked(cred);
 314        spin_unlock(cache_lock);
 315        return ret;
 316}
 317
 318/*
 319 * Initialize RPC credential cache
 320 */
 321int
 322rpcauth_init_credcache(struct rpc_auth *auth)
 323{
 324        struct rpc_cred_cache *new;
 325        unsigned int hashsize;
 326
 327        new = kmalloc(sizeof(*new), GFP_KERNEL);
 328        if (!new)
 329                goto out_nocache;
 330        new->hashbits = auth_hashbits;
 331        hashsize = 1U << new->hashbits;
 332        new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL);
 333        if (!new->hashtable)
 334                goto out_nohashtbl;
 335        spin_lock_init(&new->lock);
 336        auth->au_credcache = new;
 337        return 0;
 338out_nohashtbl:
 339        kfree(new);
 340out_nocache:
 341        return -ENOMEM;
 342}
 343EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
 344
 345/*
 346 * Destroy a list of credentials
 347 */
 348static inline
 349void rpcauth_destroy_credlist(struct list_head *head)
 350{
 351        struct rpc_cred *cred;
 352
 353        while (!list_empty(head)) {
 354                cred = list_entry(head->next, struct rpc_cred, cr_lru);
 355                list_del_init(&cred->cr_lru);
 356                put_rpccred(cred);
 357        }
 358}
 359
 360/*
 361 * Clear the RPC credential cache, and delete those credentials
 362 * that are not referenced.
 363 */
 364void
 365rpcauth_clear_credcache(struct rpc_cred_cache *cache)
 366{
 367        LIST_HEAD(free);
 368        struct hlist_head *head;
 369        struct rpc_cred *cred;
 370        unsigned int hashsize = 1U << cache->hashbits;
 371        int             i;
 372
 373        spin_lock(&rpc_credcache_lock);
 374        spin_lock(&cache->lock);
 375        for (i = 0; i < hashsize; i++) {
 376                head = &cache->hashtable[i];
 377                while (!hlist_empty(head)) {
 378                        cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
 379                        get_rpccred(cred);
 380                        if (!list_empty(&cred->cr_lru)) {
 381                                list_del(&cred->cr_lru);
 382                                number_cred_unused--;
 383                        }
 384                        list_add_tail(&cred->cr_lru, &free);
 385                        rpcauth_unhash_cred_locked(cred);
 386                }
 387        }
 388        spin_unlock(&cache->lock);
 389        spin_unlock(&rpc_credcache_lock);
 390        rpcauth_destroy_credlist(&free);
 391}
 392
 393/*
 394 * Destroy the RPC credential cache
 395 */
 396void
 397rpcauth_destroy_credcache(struct rpc_auth *auth)
 398{
 399        struct rpc_cred_cache *cache = auth->au_credcache;
 400
 401        if (cache) {
 402                auth->au_credcache = NULL;
 403                rpcauth_clear_credcache(cache);
 404                kfree(cache->hashtable);
 405                kfree(cache);
 406        }
 407}
 408EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);
 409
 410
 411#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ)
 412
 413/*
 414 * Remove stale credentials. Avoid sleeping inside the loop.
 415 */
 416static int
 417rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
 418{
 419        spinlock_t *cache_lock;
 420        struct rpc_cred *cred, *next;
 421        unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM;
 422
 423        list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) {
 424
 425                if (nr_to_scan-- == 0)
 426                        break;
 427                /*
 428                 * Enforce a 60 second garbage collection moratorium
 429                 * Note that the cred_unused list must be time-ordered.
 430                 */
 431                if (time_in_range(cred->cr_expire, expired, jiffies) &&
 432                    test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0)
 433                        return 0;
 434
 435                list_del_init(&cred->cr_lru);
 436                number_cred_unused--;
 437                if (atomic_read(&cred->cr_count) != 0)
 438                        continue;
 439
 440                cache_lock = &cred->cr_auth->au_credcache->lock;
 441                spin_lock(cache_lock);
 442                if (atomic_read(&cred->cr_count) == 0) {
 443                        get_rpccred(cred);
 444                        list_add_tail(&cred->cr_lru, free);
 445                        rpcauth_unhash_cred_locked(cred);
 446                }
 447                spin_unlock(cache_lock);
 448        }
 449        return (number_cred_unused / 100) * sysctl_vfs_cache_pressure;
 450}
 451
 452/*
 453 * Run memory cache shrinker.
 454 */
 455static int
 456rpcauth_cache_shrinker(struct shrinker *shrink, struct shrink_control *sc)
 457{
 458        LIST_HEAD(free);
 459        int res;
 460        int nr_to_scan = sc->nr_to_scan;
 461        gfp_t gfp_mask = sc->gfp_mask;
 462
 463        if ((gfp_mask & GFP_KERNEL) != GFP_KERNEL)
 464                return (nr_to_scan == 0) ? 0 : -1;
 465        if (list_empty(&cred_unused))
 466                return 0;
 467        spin_lock(&rpc_credcache_lock);
 468        res = rpcauth_prune_expired(&free, nr_to_scan);
 469        spin_unlock(&rpc_credcache_lock);
 470        rpcauth_destroy_credlist(&free);
 471        return res;
 472}
 473
 474/*
 475 * Look up a process' credentials in the authentication cache
 476 */
 477struct rpc_cred *
 478rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
 479                int flags)
 480{
 481        LIST_HEAD(free);
 482        struct rpc_cred_cache *cache = auth->au_credcache;
 483        struct rpc_cred *cred = NULL,
 484                        *entry, *new;
 485        unsigned int nr;
 486
 487        nr = hash_long(from_kuid(&init_user_ns, acred->uid), cache->hashbits);
 488
 489        rcu_read_lock();
 490        hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
 491                if (!entry->cr_ops->crmatch(acred, entry, flags))
 492                        continue;
 493                spin_lock(&cache->lock);
 494                if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) == 0) {
 495                        spin_unlock(&cache->lock);
 496                        continue;
 497                }
 498                cred = get_rpccred(entry);
 499                spin_unlock(&cache->lock);
 500                break;
 501        }
 502        rcu_read_unlock();
 503
 504        if (cred != NULL)
 505                goto found;
 506
 507        new = auth->au_ops->crcreate(auth, acred, flags);
 508        if (IS_ERR(new)) {
 509                cred = new;
 510                goto out;
 511        }
 512
 513        spin_lock(&cache->lock);
 514        hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {
 515                if (!entry->cr_ops->crmatch(acred, entry, flags))
 516                        continue;
 517                cred = get_rpccred(entry);
 518                break;
 519        }
 520        if (cred == NULL) {
 521                cred = new;
 522                set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
 523                hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
 524        } else
 525                list_add_tail(&new->cr_lru, &free);
 526        spin_unlock(&cache->lock);
 527found:
 528        if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
 529            cred->cr_ops->cr_init != NULL &&
 530            !(flags & RPCAUTH_LOOKUP_NEW)) {
 531                int res = cred->cr_ops->cr_init(auth, cred);
 532                if (res < 0) {
 533                        put_rpccred(cred);
 534                        cred = ERR_PTR(res);
 535                }
 536        }
 537        rpcauth_destroy_credlist(&free);
 538out:
 539        return cred;
 540}
 541EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache);
 542
 543struct rpc_cred *
 544rpcauth_lookupcred(struct rpc_auth *auth, int flags)
 545{
 546        struct auth_cred acred;
 547        struct rpc_cred *ret;
 548        const struct cred *cred = current_cred();
 549
 550        dprintk("RPC:       looking up %s cred\n",
 551                auth->au_ops->au_name);
 552
 553        memset(&acred, 0, sizeof(acred));
 554        acred.uid = cred->fsuid;
 555        acred.gid = cred->fsgid;
 556        acred.group_info = get_group_info(((struct cred *)cred)->group_info);
 557
 558        ret = auth->au_ops->lookup_cred(auth, &acred, flags);
 559        put_group_info(acred.group_info);
 560        return ret;
 561}
 562
 563void
 564rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
 565                  struct rpc_auth *auth, const struct rpc_credops *ops)
 566{
 567        INIT_HLIST_NODE(&cred->cr_hash);
 568        INIT_LIST_HEAD(&cred->cr_lru);
 569        atomic_set(&cred->cr_count, 1);
 570        cred->cr_auth = auth;
 571        cred->cr_ops = ops;
 572        cred->cr_expire = jiffies;
 573#ifdef RPC_DEBUG
 574        cred->cr_magic = RPCAUTH_CRED_MAGIC;
 575#endif
 576        cred->cr_uid = acred->uid;
 577}
 578EXPORT_SYMBOL_GPL(rpcauth_init_cred);
 579
 580struct rpc_cred *
 581rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags)
 582{
 583        dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid,
 584                        cred->cr_auth->au_ops->au_name, cred);
 585        return get_rpccred(cred);
 586}
 587EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred);
 588
 589static struct rpc_cred *
 590rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
 591{
 592        struct rpc_auth *auth = task->tk_client->cl_auth;
 593        struct auth_cred acred = {
 594                .uid = GLOBAL_ROOT_UID,
 595                .gid = GLOBAL_ROOT_GID,
 596        };
 597
 598        dprintk("RPC: %5u looking up %s cred\n",
 599                task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
 600        return auth->au_ops->lookup_cred(auth, &acred, lookupflags);
 601}
 602
 603static struct rpc_cred *
 604rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
 605{
 606        struct rpc_auth *auth = task->tk_client->cl_auth;
 607
 608        dprintk("RPC: %5u looking up %s cred\n",
 609                task->tk_pid, auth->au_ops->au_name);
 610        return rpcauth_lookupcred(auth, lookupflags);
 611}
 612
 613static int
 614rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags)
 615{
 616        struct rpc_rqst *req = task->tk_rqstp;
 617        struct rpc_cred *new;
 618        int lookupflags = 0;
 619
 620        if (flags & RPC_TASK_ASYNC)
 621                lookupflags |= RPCAUTH_LOOKUP_NEW;
 622        if (cred != NULL)
 623                new = cred->cr_ops->crbind(task, cred, lookupflags);
 624        else if (flags & RPC_TASK_ROOTCREDS)
 625                new = rpcauth_bind_root_cred(task, lookupflags);
 626        else
 627                new = rpcauth_bind_new_cred(task, lookupflags);
 628        if (IS_ERR(new))
 629                return PTR_ERR(new);
 630        if (req->rq_cred != NULL)
 631                put_rpccred(req->rq_cred);
 632        req->rq_cred = new;
 633        return 0;
 634}
 635
 636void
 637put_rpccred(struct rpc_cred *cred)
 638{
 639        /* Fast path for unhashed credentials */
 640        if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) == 0) {
 641                if (atomic_dec_and_test(&cred->cr_count))
 642                        cred->cr_ops->crdestroy(cred);
 643                return;
 644        }
 645
 646        if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
 647                return;
 648        if (!list_empty(&cred->cr_lru)) {
 649                number_cred_unused--;
 650                list_del_init(&cred->cr_lru);
 651        }
 652        if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) {
 653                if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) {
 654                        cred->cr_expire = jiffies;
 655                        list_add_tail(&cred->cr_lru, &cred_unused);
 656                        number_cred_unused++;
 657                        goto out_nodestroy;
 658                }
 659                if (!rpcauth_unhash_cred(cred)) {
 660                        /* We were hashed and someone looked us up... */
 661                        goto out_nodestroy;
 662                }
 663        }
 664        spin_unlock(&rpc_credcache_lock);
 665        cred->cr_ops->crdestroy(cred);
 666        return;
 667out_nodestroy:
 668        spin_unlock(&rpc_credcache_lock);
 669}
 670EXPORT_SYMBOL_GPL(put_rpccred);
 671
 672__be32 *
 673rpcauth_marshcred(struct rpc_task *task, __be32 *p)
 674{
 675        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 676
 677        dprintk("RPC: %5u marshaling %s cred %p\n",
 678                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 679
 680        return cred->cr_ops->crmarshal(task, p);
 681}
 682
 683__be32 *
 684rpcauth_checkverf(struct rpc_task *task, __be32 *p)
 685{
 686        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 687
 688        dprintk("RPC: %5u validating %s cred %p\n",
 689                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 690
 691        return cred->cr_ops->crvalidate(task, p);
 692}
 693
 694static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
 695                                   __be32 *data, void *obj)
 696{
 697        struct xdr_stream xdr;
 698
 699        xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data);
 700        encode(rqstp, &xdr, obj);
 701}
 702
 703int
 704rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp,
 705                __be32 *data, void *obj)
 706{
 707        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 708
 709        dprintk("RPC: %5u using %s cred %p to wrap rpc data\n",
 710                        task->tk_pid, cred->cr_ops->cr_name, cred);
 711        if (cred->cr_ops->crwrap_req)
 712                return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);
 713        /* By default, we encode the arguments normally. */
 714        rpcauth_wrap_req_encode(encode, rqstp, data, obj);
 715        return 0;
 716}
 717
 718static int
 719rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
 720                          __be32 *data, void *obj)
 721{
 722        struct xdr_stream xdr;
 723
 724        xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data);
 725        return decode(rqstp, &xdr, obj);
 726}
 727
 728int
 729rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp,
 730                __be32 *data, void *obj)
 731{
 732        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 733
 734        dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n",
 735                        task->tk_pid, cred->cr_ops->cr_name, cred);
 736        if (cred->cr_ops->crunwrap_resp)
 737                return cred->cr_ops->crunwrap_resp(task, decode, rqstp,
 738                                                   data, obj);
 739        /* By default, we decode the arguments normally. */
 740        return rpcauth_unwrap_req_decode(decode, rqstp, data, obj);
 741}
 742
 743int
 744rpcauth_refreshcred(struct rpc_task *task)
 745{
 746        struct rpc_cred *cred;
 747        int err;
 748
 749        cred = task->tk_rqstp->rq_cred;
 750        if (cred == NULL) {
 751                err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags);
 752                if (err < 0)
 753                        goto out;
 754                cred = task->tk_rqstp->rq_cred;
 755        }
 756        dprintk("RPC: %5u refreshing %s cred %p\n",
 757                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 758
 759        err = cred->cr_ops->crrefresh(task);
 760out:
 761        if (err < 0)
 762                task->tk_status = err;
 763        return err;
 764}
 765
 766void
 767rpcauth_invalcred(struct rpc_task *task)
 768{
 769        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 770
 771        dprintk("RPC: %5u invalidating %s cred %p\n",
 772                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 773        if (cred)
 774                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
 775}
 776
 777int
 778rpcauth_uptodatecred(struct rpc_task *task)
 779{
 780        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 781
 782        return cred == NULL ||
 783                test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
 784}
 785
 786static struct shrinker rpc_cred_shrinker = {
 787        .shrink = rpcauth_cache_shrinker,
 788        .seeks = DEFAULT_SEEKS,
 789};
 790
 791int __init rpcauth_init_module(void)
 792{
 793        int err;
 794
 795        err = rpc_init_authunix();
 796        if (err < 0)
 797                goto out1;
 798        err = rpc_init_generic_auth();
 799        if (err < 0)
 800                goto out2;
 801        register_shrinker(&rpc_cred_shrinker);
 802        return 0;
 803out2:
 804        rpc_destroy_authunix();
 805out1:
 806        return err;
 807}
 808
 809void rpcauth_remove_module(void)
 810{
 811        rpc_destroy_authunix();
 812        rpc_destroy_generic_auth();
 813        unregister_shrinker(&rpc_cred_shrinker);
 814}
 815