linux/net/sunrpc/auth.c
<<
>>
Prefs
   1/*
   2 * linux/net/sunrpc/auth.c
   3 *
   4 * Generic RPC client authentication API.
   5 *
   6 * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
   7 */
   8
   9#include <linux/types.h>
  10#include <linux/sched.h>
  11#include <linux/module.h>
  12#include <linux/slab.h>
  13#include <linux/errno.h>
  14#include <linux/hash.h>
  15#include <linux/sunrpc/clnt.h>
  16#include <linux/sunrpc/gss_api.h>
  17#include <linux/spinlock.h>
  18
  19#ifdef RPC_DEBUG
  20# define RPCDBG_FACILITY        RPCDBG_AUTH
  21#endif
  22
  23#define RPC_CREDCACHE_DEFAULT_HASHBITS  (4)
  24struct rpc_cred_cache {
  25        struct hlist_head       *hashtable;
  26        unsigned int            hashbits;
  27        spinlock_t              lock;
  28};
  29
  30static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;
  31
  32static DEFINE_SPINLOCK(rpc_authflavor_lock);
  33static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
  34        &authnull_ops,          /* AUTH_NULL */
  35        &authunix_ops,          /* AUTH_UNIX */
  36        NULL,                   /* others can be loadable modules */
  37};
  38
  39static LIST_HEAD(cred_unused);
  40static unsigned long number_cred_unused;
  41
  42#define MAX_HASHTABLE_BITS (14)
  43static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
  44{
  45        unsigned long num;
  46        unsigned int nbits;
  47        int ret;
  48
  49        if (!val)
  50                goto out_inval;
  51        ret = strict_strtoul(val, 0, &num);
  52        if (ret == -EINVAL)
  53                goto out_inval;
  54        nbits = fls(num);
  55        if (num > (1U << nbits))
  56                nbits++;
  57        if (nbits > MAX_HASHTABLE_BITS || nbits < 2)
  58                goto out_inval;
  59        *(unsigned int *)kp->arg = nbits;
  60        return 0;
  61out_inval:
  62        return -EINVAL;
  63}
  64
  65static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp)
  66{
  67        unsigned int nbits;
  68
  69        nbits = *(unsigned int *)kp->arg;
  70        return sprintf(buffer, "%u", 1U << nbits);
  71}
  72
  73#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int);
  74
  75static struct kernel_param_ops param_ops_hashtbl_sz = {
  76        .set = param_set_hashtbl_sz,
  77        .get = param_get_hashtbl_sz,
  78};
  79
  80module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644);
  81MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");
  82
  83static u32
  84pseudoflavor_to_flavor(u32 flavor) {
  85        if (flavor > RPC_AUTH_MAXFLAVOR)
  86                return RPC_AUTH_GSS;
  87        return flavor;
  88}
  89
  90int
  91rpcauth_register(const struct rpc_authops *ops)
  92{
  93        rpc_authflavor_t flavor;
  94        int ret = -EPERM;
  95
  96        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
  97                return -EINVAL;
  98        spin_lock(&rpc_authflavor_lock);
  99        if (auth_flavors[flavor] == NULL) {
 100                auth_flavors[flavor] = ops;
 101                ret = 0;
 102        }
 103        spin_unlock(&rpc_authflavor_lock);
 104        return ret;
 105}
 106EXPORT_SYMBOL_GPL(rpcauth_register);
 107
 108int
 109rpcauth_unregister(const struct rpc_authops *ops)
 110{
 111        rpc_authflavor_t flavor;
 112        int ret = -EPERM;
 113
 114        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
 115                return -EINVAL;
 116        spin_lock(&rpc_authflavor_lock);
 117        if (auth_flavors[flavor] == ops) {
 118                auth_flavors[flavor] = NULL;
 119                ret = 0;
 120        }
 121        spin_unlock(&rpc_authflavor_lock);
 122        return ret;
 123}
 124EXPORT_SYMBOL_GPL(rpcauth_unregister);
 125
 126/**
 127 * rpcauth_get_pseudoflavor - check if security flavor is supported
 128 * @flavor: a security flavor
 129 * @info: a GSS mech OID, quality of protection, and service value
 130 *
 131 * Verifies that an appropriate kernel module is available or already loaded.
 132 * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is
 133 * not supported locally.
 134 */
 135rpc_authflavor_t
 136rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info)
 137{
 138        const struct rpc_authops *ops;
 139        rpc_authflavor_t pseudoflavor;
 140
 141        ops = auth_flavors[flavor];
 142        if (ops == NULL)
 143                request_module("rpc-auth-%u", flavor);
 144        spin_lock(&rpc_authflavor_lock);
 145        ops = auth_flavors[flavor];
 146        if (ops == NULL || !try_module_get(ops->owner)) {
 147                spin_unlock(&rpc_authflavor_lock);
 148                return RPC_AUTH_MAXFLAVOR;
 149        }
 150        spin_unlock(&rpc_authflavor_lock);
 151
 152        pseudoflavor = flavor;
 153        if (ops->info2flavor != NULL)
 154                pseudoflavor = ops->info2flavor(info);
 155
 156        module_put(ops->owner);
 157        return pseudoflavor;
 158}
 159EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor);
 160
 161/**
 162 * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor
 163 * @pseudoflavor: GSS pseudoflavor to match
 164 * @info: rpcsec_gss_info structure to fill in
 165 *
 166 * Returns zero and fills in "info" if pseudoflavor matches a
 167 * supported mechanism.
 168 */
 169int
 170rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info)
 171{
 172        rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor);
 173        const struct rpc_authops *ops;
 174        int result;
 175
 176        if (flavor >= RPC_AUTH_MAXFLAVOR)
 177                return -EINVAL;
 178
 179        ops = auth_flavors[flavor];
 180        if (ops == NULL)
 181                request_module("rpc-auth-%u", flavor);
 182        spin_lock(&rpc_authflavor_lock);
 183        ops = auth_flavors[flavor];
 184        if (ops == NULL || !try_module_get(ops->owner)) {
 185                spin_unlock(&rpc_authflavor_lock);
 186                return -ENOENT;
 187        }
 188        spin_unlock(&rpc_authflavor_lock);
 189
 190        result = -ENOENT;
 191        if (ops->flavor2info != NULL)
 192                result = ops->flavor2info(pseudoflavor, info);
 193
 194        module_put(ops->owner);
 195        return result;
 196}
 197EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo);
 198
 199/**
 200 * rpcauth_list_flavors - discover registered flavors and pseudoflavors
 201 * @array: array to fill in
 202 * @size: size of "array"
 203 *
 204 * Returns the number of array items filled in, or a negative errno.
 205 *
 206 * The returned array is not sorted by any policy.  Callers should not
 207 * rely on the order of the items in the returned array.
 208 */
 209int
 210rpcauth_list_flavors(rpc_authflavor_t *array, int size)
 211{
 212        rpc_authflavor_t flavor;
 213        int result = 0;
 214
 215        spin_lock(&rpc_authflavor_lock);
 216        for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) {
 217                const struct rpc_authops *ops = auth_flavors[flavor];
 218                rpc_authflavor_t pseudos[4];
 219                int i, len;
 220
 221                if (result >= size) {
 222                        result = -ENOMEM;
 223                        break;
 224                }
 225
 226                if (ops == NULL)
 227                        continue;
 228                if (ops->list_pseudoflavors == NULL) {
 229                        array[result++] = ops->au_flavor;
 230                        continue;
 231                }
 232                len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos));
 233                if (len < 0) {
 234                        result = len;
 235                        break;
 236                }
 237                for (i = 0; i < len; i++) {
 238                        if (result >= size) {
 239                                result = -ENOMEM;
 240                                break;
 241                        }
 242                        array[result++] = pseudos[i];
 243                }
 244        }
 245        spin_unlock(&rpc_authflavor_lock);
 246
 247        dprintk("RPC:       %s returns %d\n", __func__, result);
 248        return result;
 249}
 250EXPORT_SYMBOL_GPL(rpcauth_list_flavors);
 251
 252struct rpc_auth *
 253rpcauth_create(struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
 254{
 255        struct rpc_auth         *auth;
 256        const struct rpc_authops *ops;
 257        u32                     flavor = pseudoflavor_to_flavor(args->pseudoflavor);
 258
 259        auth = ERR_PTR(-EINVAL);
 260        if (flavor >= RPC_AUTH_MAXFLAVOR)
 261                goto out;
 262
 263        if ((ops = auth_flavors[flavor]) == NULL)
 264                request_module("rpc-auth-%u", flavor);
 265        spin_lock(&rpc_authflavor_lock);
 266        ops = auth_flavors[flavor];
 267        if (ops == NULL || !try_module_get(ops->owner)) {
 268                spin_unlock(&rpc_authflavor_lock);
 269                goto out;
 270        }
 271        spin_unlock(&rpc_authflavor_lock);
 272        auth = ops->create(args, clnt);
 273        module_put(ops->owner);
 274        if (IS_ERR(auth))
 275                return auth;
 276        if (clnt->cl_auth)
 277                rpcauth_release(clnt->cl_auth);
 278        clnt->cl_auth = auth;
 279
 280out:
 281        return auth;
 282}
 283EXPORT_SYMBOL_GPL(rpcauth_create);
 284
 285void
 286rpcauth_release(struct rpc_auth *auth)
 287{
 288        if (!atomic_dec_and_test(&auth->au_count))
 289                return;
 290        auth->au_ops->destroy(auth);
 291}
 292
 293static DEFINE_SPINLOCK(rpc_credcache_lock);
 294
 295static void
 296rpcauth_unhash_cred_locked(struct rpc_cred *cred)
 297{
 298        hlist_del_rcu(&cred->cr_hash);
 299        smp_mb__before_clear_bit();
 300        clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
 301}
 302
 303static int
 304rpcauth_unhash_cred(struct rpc_cred *cred)
 305{
 306        spinlock_t *cache_lock;
 307        int ret;
 308
 309        cache_lock = &cred->cr_auth->au_credcache->lock;
 310        spin_lock(cache_lock);
 311        ret = atomic_read(&cred->cr_count) == 0;
 312        if (ret)
 313                rpcauth_unhash_cred_locked(cred);
 314        spin_unlock(cache_lock);
 315        return ret;
 316}
 317
 318/*
 319 * Initialize RPC credential cache
 320 */
 321int
 322rpcauth_init_credcache(struct rpc_auth *auth)
 323{
 324        struct rpc_cred_cache *new;
 325        unsigned int hashsize;
 326
 327        new = kmalloc(sizeof(*new), GFP_KERNEL);
 328        if (!new)
 329                goto out_nocache;
 330        new->hashbits = auth_hashbits;
 331        hashsize = 1U << new->hashbits;
 332        new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL);
 333        if (!new->hashtable)
 334                goto out_nohashtbl;
 335        spin_lock_init(&new->lock);
 336        auth->au_credcache = new;
 337        return 0;
 338out_nohashtbl:
 339        kfree(new);
 340out_nocache:
 341        return -ENOMEM;
 342}
 343EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
 344
 345/*
 346 * Setup a credential key lifetime timeout notification
 347 */
 348int
 349rpcauth_key_timeout_notify(struct rpc_auth *auth, struct rpc_cred *cred)
 350{
 351        if (!cred->cr_auth->au_ops->key_timeout)
 352                return 0;
 353        return cred->cr_auth->au_ops->key_timeout(auth, cred);
 354}
 355EXPORT_SYMBOL_GPL(rpcauth_key_timeout_notify);
 356
 357bool
 358rpcauth_cred_key_to_expire(struct rpc_cred *cred)
 359{
 360        if (!cred->cr_ops->crkey_to_expire)
 361                return false;
 362        return cred->cr_ops->crkey_to_expire(cred);
 363}
 364EXPORT_SYMBOL_GPL(rpcauth_cred_key_to_expire);
 365
 366/*
 367 * Destroy a list of credentials
 368 */
 369static inline
 370void rpcauth_destroy_credlist(struct list_head *head)
 371{
 372        struct rpc_cred *cred;
 373
 374        while (!list_empty(head)) {
 375                cred = list_entry(head->next, struct rpc_cred, cr_lru);
 376                list_del_init(&cred->cr_lru);
 377                put_rpccred(cred);
 378        }
 379}
 380
 381/*
 382 * Clear the RPC credential cache, and delete those credentials
 383 * that are not referenced.
 384 */
 385void
 386rpcauth_clear_credcache(struct rpc_cred_cache *cache)
 387{
 388        LIST_HEAD(free);
 389        struct hlist_head *head;
 390        struct rpc_cred *cred;
 391        unsigned int hashsize = 1U << cache->hashbits;
 392        int             i;
 393
 394        spin_lock(&rpc_credcache_lock);
 395        spin_lock(&cache->lock);
 396        for (i = 0; i < hashsize; i++) {
 397                head = &cache->hashtable[i];
 398                while (!hlist_empty(head)) {
 399                        cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
 400                        get_rpccred(cred);
 401                        if (!list_empty(&cred->cr_lru)) {
 402                                list_del(&cred->cr_lru);
 403                                number_cred_unused--;
 404                        }
 405                        list_add_tail(&cred->cr_lru, &free);
 406                        rpcauth_unhash_cred_locked(cred);
 407                }
 408        }
 409        spin_unlock(&cache->lock);
 410        spin_unlock(&rpc_credcache_lock);
 411        rpcauth_destroy_credlist(&free);
 412}
 413
 414/*
 415 * Destroy the RPC credential cache
 416 */
 417void
 418rpcauth_destroy_credcache(struct rpc_auth *auth)
 419{
 420        struct rpc_cred_cache *cache = auth->au_credcache;
 421
 422        if (cache) {
 423                auth->au_credcache = NULL;
 424                rpcauth_clear_credcache(cache);
 425                kfree(cache->hashtable);
 426                kfree(cache);
 427        }
 428}
 429EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);
 430
 431
 432#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ)
 433
 434/*
 435 * Remove stale credentials. Avoid sleeping inside the loop.
 436 */
 437static long
 438rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
 439{
 440        spinlock_t *cache_lock;
 441        struct rpc_cred *cred, *next;
 442        unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM;
 443        long freed = 0;
 444
 445        list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) {
 446
 447                if (nr_to_scan-- == 0)
 448                        break;
 449                /*
 450                 * Enforce a 60 second garbage collection moratorium
 451                 * Note that the cred_unused list must be time-ordered.
 452                 */
 453                if (time_in_range(cred->cr_expire, expired, jiffies) &&
 454                    test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0)
 455                        break;
 456
 457                list_del_init(&cred->cr_lru);
 458                number_cred_unused--;
 459                freed++;
 460                if (atomic_read(&cred->cr_count) != 0)
 461                        continue;
 462
 463                cache_lock = &cred->cr_auth->au_credcache->lock;
 464                spin_lock(cache_lock);
 465                if (atomic_read(&cred->cr_count) == 0) {
 466                        get_rpccred(cred);
 467                        list_add_tail(&cred->cr_lru, free);
 468                        rpcauth_unhash_cred_locked(cred);
 469                }
 470                spin_unlock(cache_lock);
 471        }
 472        return freed;
 473}
 474
 475/*
 476 * Run memory cache shrinker.
 477 */
 478static unsigned long
 479rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
 480
 481{
 482        LIST_HEAD(free);
 483        unsigned long freed;
 484
 485        if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL)
 486                return SHRINK_STOP;
 487
 488        /* nothing left, don't come back */
 489        if (list_empty(&cred_unused))
 490                return SHRINK_STOP;
 491
 492        spin_lock(&rpc_credcache_lock);
 493        freed = rpcauth_prune_expired(&free, sc->nr_to_scan);
 494        spin_unlock(&rpc_credcache_lock);
 495        rpcauth_destroy_credlist(&free);
 496
 497        return freed;
 498}
 499
 500static unsigned long
 501rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc)
 502
 503{
 504        return (number_cred_unused / 100) * sysctl_vfs_cache_pressure;
 505}
 506
 507/*
 508 * Look up a process' credentials in the authentication cache
 509 */
 510struct rpc_cred *
 511rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
 512                int flags)
 513{
 514        LIST_HEAD(free);
 515        struct rpc_cred_cache *cache = auth->au_credcache;
 516        struct rpc_cred *cred = NULL,
 517                        *entry, *new;
 518        unsigned int nr;
 519
 520        nr = hash_long(from_kuid(&init_user_ns, acred->uid), cache->hashbits);
 521
 522        rcu_read_lock();
 523        hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
 524                if (!entry->cr_ops->crmatch(acred, entry, flags))
 525                        continue;
 526                spin_lock(&cache->lock);
 527                if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) == 0) {
 528                        spin_unlock(&cache->lock);
 529                        continue;
 530                }
 531                cred = get_rpccred(entry);
 532                spin_unlock(&cache->lock);
 533                break;
 534        }
 535        rcu_read_unlock();
 536
 537        if (cred != NULL)
 538                goto found;
 539
 540        new = auth->au_ops->crcreate(auth, acred, flags);
 541        if (IS_ERR(new)) {
 542                cred = new;
 543                goto out;
 544        }
 545
 546        spin_lock(&cache->lock);
 547        hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {
 548                if (!entry->cr_ops->crmatch(acred, entry, flags))
 549                        continue;
 550                cred = get_rpccred(entry);
 551                break;
 552        }
 553        if (cred == NULL) {
 554                cred = new;
 555                set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
 556                hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
 557        } else
 558                list_add_tail(&new->cr_lru, &free);
 559        spin_unlock(&cache->lock);
 560found:
 561        if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
 562            cred->cr_ops->cr_init != NULL &&
 563            !(flags & RPCAUTH_LOOKUP_NEW)) {
 564                int res = cred->cr_ops->cr_init(auth, cred);
 565                if (res < 0) {
 566                        put_rpccred(cred);
 567                        cred = ERR_PTR(res);
 568                }
 569        }
 570        rpcauth_destroy_credlist(&free);
 571out:
 572        return cred;
 573}
 574EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache);
 575
 576struct rpc_cred *
 577rpcauth_lookupcred(struct rpc_auth *auth, int flags)
 578{
 579        struct auth_cred acred;
 580        struct rpc_cred *ret;
 581        const struct cred *cred = current_cred();
 582
 583        dprintk("RPC:       looking up %s cred\n",
 584                auth->au_ops->au_name);
 585
 586        memset(&acred, 0, sizeof(acred));
 587        acred.uid = cred->fsuid;
 588        acred.gid = cred->fsgid;
 589        acred.group_info = get_group_info(((struct cred *)cred)->group_info);
 590
 591        ret = auth->au_ops->lookup_cred(auth, &acred, flags);
 592        put_group_info(acred.group_info);
 593        return ret;
 594}
 595
 596void
 597rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
 598                  struct rpc_auth *auth, const struct rpc_credops *ops)
 599{
 600        INIT_HLIST_NODE(&cred->cr_hash);
 601        INIT_LIST_HEAD(&cred->cr_lru);
 602        atomic_set(&cred->cr_count, 1);
 603        cred->cr_auth = auth;
 604        cred->cr_ops = ops;
 605        cred->cr_expire = jiffies;
 606#ifdef RPC_DEBUG
 607        cred->cr_magic = RPCAUTH_CRED_MAGIC;
 608#endif
 609        cred->cr_uid = acred->uid;
 610}
 611EXPORT_SYMBOL_GPL(rpcauth_init_cred);
 612
 613struct rpc_cred *
 614rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags)
 615{
 616        dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid,
 617                        cred->cr_auth->au_ops->au_name, cred);
 618        return get_rpccred(cred);
 619}
 620EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred);
 621
 622static struct rpc_cred *
 623rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
 624{
 625        struct rpc_auth *auth = task->tk_client->cl_auth;
 626        struct auth_cred acred = {
 627                .uid = GLOBAL_ROOT_UID,
 628                .gid = GLOBAL_ROOT_GID,
 629        };
 630
 631        dprintk("RPC: %5u looking up %s cred\n",
 632                task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
 633        return auth->au_ops->lookup_cred(auth, &acred, lookupflags);
 634}
 635
 636static struct rpc_cred *
 637rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
 638{
 639        struct rpc_auth *auth = task->tk_client->cl_auth;
 640
 641        dprintk("RPC: %5u looking up %s cred\n",
 642                task->tk_pid, auth->au_ops->au_name);
 643        return rpcauth_lookupcred(auth, lookupflags);
 644}
 645
 646static int
 647rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags)
 648{
 649        struct rpc_rqst *req = task->tk_rqstp;
 650        struct rpc_cred *new;
 651        int lookupflags = 0;
 652
 653        if (flags & RPC_TASK_ASYNC)
 654                lookupflags |= RPCAUTH_LOOKUP_NEW;
 655        if (cred != NULL)
 656                new = cred->cr_ops->crbind(task, cred, lookupflags);
 657        else if (flags & RPC_TASK_ROOTCREDS)
 658                new = rpcauth_bind_root_cred(task, lookupflags);
 659        else
 660                new = rpcauth_bind_new_cred(task, lookupflags);
 661        if (IS_ERR(new))
 662                return PTR_ERR(new);
 663        if (req->rq_cred != NULL)
 664                put_rpccred(req->rq_cred);
 665        req->rq_cred = new;
 666        return 0;
 667}
 668
 669void
 670put_rpccred(struct rpc_cred *cred)
 671{
 672        /* Fast path for unhashed credentials */
 673        if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) == 0) {
 674                if (atomic_dec_and_test(&cred->cr_count))
 675                        cred->cr_ops->crdestroy(cred);
 676                return;
 677        }
 678
 679        if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
 680                return;
 681        if (!list_empty(&cred->cr_lru)) {
 682                number_cred_unused--;
 683                list_del_init(&cred->cr_lru);
 684        }
 685        if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) {
 686                if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) {
 687                        cred->cr_expire = jiffies;
 688                        list_add_tail(&cred->cr_lru, &cred_unused);
 689                        number_cred_unused++;
 690                        goto out_nodestroy;
 691                }
 692                if (!rpcauth_unhash_cred(cred)) {
 693                        /* We were hashed and someone looked us up... */
 694                        goto out_nodestroy;
 695                }
 696        }
 697        spin_unlock(&rpc_credcache_lock);
 698        cred->cr_ops->crdestroy(cred);
 699        return;
 700out_nodestroy:
 701        spin_unlock(&rpc_credcache_lock);
 702}
 703EXPORT_SYMBOL_GPL(put_rpccred);
 704
 705__be32 *
 706rpcauth_marshcred(struct rpc_task *task, __be32 *p)
 707{
 708        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 709
 710        dprintk("RPC: %5u marshaling %s cred %p\n",
 711                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 712
 713        return cred->cr_ops->crmarshal(task, p);
 714}
 715
 716__be32 *
 717rpcauth_checkverf(struct rpc_task *task, __be32 *p)
 718{
 719        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 720
 721        dprintk("RPC: %5u validating %s cred %p\n",
 722                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 723
 724        return cred->cr_ops->crvalidate(task, p);
 725}
 726
 727static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
 728                                   __be32 *data, void *obj)
 729{
 730        struct xdr_stream xdr;
 731
 732        xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data);
 733        encode(rqstp, &xdr, obj);
 734}
 735
 736int
 737rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp,
 738                __be32 *data, void *obj)
 739{
 740        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 741
 742        dprintk("RPC: %5u using %s cred %p to wrap rpc data\n",
 743                        task->tk_pid, cred->cr_ops->cr_name, cred);
 744        if (cred->cr_ops->crwrap_req)
 745                return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);
 746        /* By default, we encode the arguments normally. */
 747        rpcauth_wrap_req_encode(encode, rqstp, data, obj);
 748        return 0;
 749}
 750
 751static int
 752rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
 753                          __be32 *data, void *obj)
 754{
 755        struct xdr_stream xdr;
 756
 757        xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data);
 758        return decode(rqstp, &xdr, obj);
 759}
 760
 761int
 762rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp,
 763                __be32 *data, void *obj)
 764{
 765        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 766
 767        dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n",
 768                        task->tk_pid, cred->cr_ops->cr_name, cred);
 769        if (cred->cr_ops->crunwrap_resp)
 770                return cred->cr_ops->crunwrap_resp(task, decode, rqstp,
 771                                                   data, obj);
 772        /* By default, we decode the arguments normally. */
 773        return rpcauth_unwrap_req_decode(decode, rqstp, data, obj);
 774}
 775
 776int
 777rpcauth_refreshcred(struct rpc_task *task)
 778{
 779        struct rpc_cred *cred;
 780        int err;
 781
 782        cred = task->tk_rqstp->rq_cred;
 783        if (cred == NULL) {
 784                err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags);
 785                if (err < 0)
 786                        goto out;
 787                cred = task->tk_rqstp->rq_cred;
 788        }
 789        dprintk("RPC: %5u refreshing %s cred %p\n",
 790                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 791
 792        err = cred->cr_ops->crrefresh(task);
 793out:
 794        if (err < 0)
 795                task->tk_status = err;
 796        return err;
 797}
 798
 799void
 800rpcauth_invalcred(struct rpc_task *task)
 801{
 802        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 803
 804        dprintk("RPC: %5u invalidating %s cred %p\n",
 805                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 806        if (cred)
 807                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
 808}
 809
 810int
 811rpcauth_uptodatecred(struct rpc_task *task)
 812{
 813        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 814
 815        return cred == NULL ||
 816                test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
 817}
 818
 819static struct shrinker rpc_cred_shrinker = {
 820        .count_objects = rpcauth_cache_shrink_count,
 821        .scan_objects = rpcauth_cache_shrink_scan,
 822        .seeks = DEFAULT_SEEKS,
 823};
 824
 825int __init rpcauth_init_module(void)
 826{
 827        int err;
 828
 829        err = rpc_init_authunix();
 830        if (err < 0)
 831                goto out1;
 832        err = rpc_init_generic_auth();
 833        if (err < 0)
 834                goto out2;
 835        register_shrinker(&rpc_cred_shrinker);
 836        return 0;
 837out2:
 838        rpc_destroy_authunix();
 839out1:
 840        return err;
 841}
 842
 843void rpcauth_remove_module(void)
 844{
 845        rpc_destroy_authunix();
 846        rpc_destroy_generic_auth();
 847        unregister_shrinker(&rpc_cred_shrinker);
 848}
 849