linux/net/sunrpc/auth.c
<<
>>
Prefs
   1/*
   2 * linux/net/sunrpc/auth.c
   3 *
   4 * Generic RPC client authentication API.
   5 *
   6 * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
   7 */
   8
   9#include <linux/types.h>
  10#include <linux/sched.h>
  11#include <linux/cred.h>
  12#include <linux/module.h>
  13#include <linux/slab.h>
  14#include <linux/errno.h>
  15#include <linux/hash.h>
  16#include <linux/sunrpc/clnt.h>
  17#include <linux/sunrpc/gss_api.h>
  18#include <linux/spinlock.h>
  19
  20#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
  21# define RPCDBG_FACILITY        RPCDBG_AUTH
  22#endif
  23
  24#define RPC_CREDCACHE_DEFAULT_HASHBITS  (4)
  25struct rpc_cred_cache {
  26        struct hlist_head       *hashtable;
  27        unsigned int            hashbits;
  28        spinlock_t              lock;
  29};
  30
  31static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;
  32
  33static DEFINE_SPINLOCK(rpc_authflavor_lock);
  34static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
  35        &authnull_ops,          /* AUTH_NULL */
  36        &authunix_ops,          /* AUTH_UNIX */
  37        NULL,                   /* others can be loadable modules */
  38};
  39
  40static LIST_HEAD(cred_unused);
  41static unsigned long number_cred_unused;
  42
  43#define MAX_HASHTABLE_BITS (14)
  44static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
  45{
  46        unsigned long num;
  47        unsigned int nbits;
  48        int ret;
  49
  50        if (!val)
  51                goto out_inval;
  52        ret = kstrtoul(val, 0, &num);
  53        if (ret == -EINVAL)
  54                goto out_inval;
  55        nbits = fls(num - 1);
  56        if (nbits > MAX_HASHTABLE_BITS || nbits < 2)
  57                goto out_inval;
  58        *(unsigned int *)kp->arg = nbits;
  59        return 0;
  60out_inval:
  61        return -EINVAL;
  62}
  63
  64static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp)
  65{
  66        unsigned int nbits;
  67
  68        nbits = *(unsigned int *)kp->arg;
  69        return sprintf(buffer, "%u", 1U << nbits);
  70}
  71
  72#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int);
  73
  74static const struct kernel_param_ops param_ops_hashtbl_sz = {
  75        .set = param_set_hashtbl_sz,
  76        .get = param_get_hashtbl_sz,
  77};
  78
  79module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644);
  80MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");
  81
  82static unsigned long auth_max_cred_cachesize = ULONG_MAX;
  83module_param(auth_max_cred_cachesize, ulong, 0644);
  84MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size");
  85
  86static u32
  87pseudoflavor_to_flavor(u32 flavor) {
  88        if (flavor > RPC_AUTH_MAXFLAVOR)
  89                return RPC_AUTH_GSS;
  90        return flavor;
  91}
  92
  93int
  94rpcauth_register(const struct rpc_authops *ops)
  95{
  96        rpc_authflavor_t flavor;
  97        int ret = -EPERM;
  98
  99        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
 100                return -EINVAL;
 101        spin_lock(&rpc_authflavor_lock);
 102        if (auth_flavors[flavor] == NULL) {
 103                auth_flavors[flavor] = ops;
 104                ret = 0;
 105        }
 106        spin_unlock(&rpc_authflavor_lock);
 107        return ret;
 108}
 109EXPORT_SYMBOL_GPL(rpcauth_register);
 110
 111int
 112rpcauth_unregister(const struct rpc_authops *ops)
 113{
 114        rpc_authflavor_t flavor;
 115        int ret = -EPERM;
 116
 117        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
 118                return -EINVAL;
 119        spin_lock(&rpc_authflavor_lock);
 120        if (auth_flavors[flavor] == ops) {
 121                auth_flavors[flavor] = NULL;
 122                ret = 0;
 123        }
 124        spin_unlock(&rpc_authflavor_lock);
 125        return ret;
 126}
 127EXPORT_SYMBOL_GPL(rpcauth_unregister);
 128
 129/**
 130 * rpcauth_get_pseudoflavor - check if security flavor is supported
 131 * @flavor: a security flavor
 132 * @info: a GSS mech OID, quality of protection, and service value
 133 *
 134 * Verifies that an appropriate kernel module is available or already loaded.
 135 * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is
 136 * not supported locally.
 137 */
 138rpc_authflavor_t
 139rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info)
 140{
 141        const struct rpc_authops *ops;
 142        rpc_authflavor_t pseudoflavor;
 143
 144        ops = auth_flavors[flavor];
 145        if (ops == NULL)
 146                request_module("rpc-auth-%u", flavor);
 147        spin_lock(&rpc_authflavor_lock);
 148        ops = auth_flavors[flavor];
 149        if (ops == NULL || !try_module_get(ops->owner)) {
 150                spin_unlock(&rpc_authflavor_lock);
 151                return RPC_AUTH_MAXFLAVOR;
 152        }
 153        spin_unlock(&rpc_authflavor_lock);
 154
 155        pseudoflavor = flavor;
 156        if (ops->info2flavor != NULL)
 157                pseudoflavor = ops->info2flavor(info);
 158
 159        module_put(ops->owner);
 160        return pseudoflavor;
 161}
 162EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor);
 163
 164/**
 165 * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor
 166 * @pseudoflavor: GSS pseudoflavor to match
 167 * @info: rpcsec_gss_info structure to fill in
 168 *
 169 * Returns zero and fills in "info" if pseudoflavor matches a
 170 * supported mechanism.
 171 */
 172int
 173rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info)
 174{
 175        rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor);
 176        const struct rpc_authops *ops;
 177        int result;
 178
 179        if (flavor >= RPC_AUTH_MAXFLAVOR)
 180                return -EINVAL;
 181
 182        ops = auth_flavors[flavor];
 183        if (ops == NULL)
 184                request_module("rpc-auth-%u", flavor);
 185        spin_lock(&rpc_authflavor_lock);
 186        ops = auth_flavors[flavor];
 187        if (ops == NULL || !try_module_get(ops->owner)) {
 188                spin_unlock(&rpc_authflavor_lock);
 189                return -ENOENT;
 190        }
 191        spin_unlock(&rpc_authflavor_lock);
 192
 193        result = -ENOENT;
 194        if (ops->flavor2info != NULL)
 195                result = ops->flavor2info(pseudoflavor, info);
 196
 197        module_put(ops->owner);
 198        return result;
 199}
 200EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo);
 201
 202/**
 203 * rpcauth_list_flavors - discover registered flavors and pseudoflavors
 204 * @array: array to fill in
 205 * @size: size of "array"
 206 *
 207 * Returns the number of array items filled in, or a negative errno.
 208 *
 209 * The returned array is not sorted by any policy.  Callers should not
 210 * rely on the order of the items in the returned array.
 211 */
 212int
 213rpcauth_list_flavors(rpc_authflavor_t *array, int size)
 214{
 215        rpc_authflavor_t flavor;
 216        int result = 0;
 217
 218        spin_lock(&rpc_authflavor_lock);
 219        for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) {
 220                const struct rpc_authops *ops = auth_flavors[flavor];
 221                rpc_authflavor_t pseudos[4];
 222                int i, len;
 223
 224                if (result >= size) {
 225                        result = -ENOMEM;
 226                        break;
 227                }
 228
 229                if (ops == NULL)
 230                        continue;
 231                if (ops->list_pseudoflavors == NULL) {
 232                        array[result++] = ops->au_flavor;
 233                        continue;
 234                }
 235                len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos));
 236                if (len < 0) {
 237                        result = len;
 238                        break;
 239                }
 240                for (i = 0; i < len; i++) {
 241                        if (result >= size) {
 242                                result = -ENOMEM;
 243                                break;
 244                        }
 245                        array[result++] = pseudos[i];
 246                }
 247        }
 248        spin_unlock(&rpc_authflavor_lock);
 249
 250        dprintk("RPC:       %s returns %d\n", __func__, result);
 251        return result;
 252}
 253EXPORT_SYMBOL_GPL(rpcauth_list_flavors);
 254
 255struct rpc_auth *
 256rpcauth_create(struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
 257{
 258        struct rpc_auth         *auth;
 259        const struct rpc_authops *ops;
 260        u32                     flavor = pseudoflavor_to_flavor(args->pseudoflavor);
 261
 262        auth = ERR_PTR(-EINVAL);
 263        if (flavor >= RPC_AUTH_MAXFLAVOR)
 264                goto out;
 265
 266        if ((ops = auth_flavors[flavor]) == NULL)
 267                request_module("rpc-auth-%u", flavor);
 268        spin_lock(&rpc_authflavor_lock);
 269        ops = auth_flavors[flavor];
 270        if (ops == NULL || !try_module_get(ops->owner)) {
 271                spin_unlock(&rpc_authflavor_lock);
 272                goto out;
 273        }
 274        spin_unlock(&rpc_authflavor_lock);
 275        auth = ops->create(args, clnt);
 276        module_put(ops->owner);
 277        if (IS_ERR(auth))
 278                return auth;
 279        if (clnt->cl_auth)
 280                rpcauth_release(clnt->cl_auth);
 281        clnt->cl_auth = auth;
 282
 283out:
 284        return auth;
 285}
 286EXPORT_SYMBOL_GPL(rpcauth_create);
 287
 288void
 289rpcauth_release(struct rpc_auth *auth)
 290{
 291        if (!atomic_dec_and_test(&auth->au_count))
 292                return;
 293        auth->au_ops->destroy(auth);
 294}
 295
 296static DEFINE_SPINLOCK(rpc_credcache_lock);
 297
 298static void
 299rpcauth_unhash_cred_locked(struct rpc_cred *cred)
 300{
 301        hlist_del_rcu(&cred->cr_hash);
 302        smp_mb__before_atomic();
 303        clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
 304}
 305
 306static int
 307rpcauth_unhash_cred(struct rpc_cred *cred)
 308{
 309        spinlock_t *cache_lock;
 310        int ret;
 311
 312        cache_lock = &cred->cr_auth->au_credcache->lock;
 313        spin_lock(cache_lock);
 314        ret = atomic_read(&cred->cr_count) == 0;
 315        if (ret)
 316                rpcauth_unhash_cred_locked(cred);
 317        spin_unlock(cache_lock);
 318        return ret;
 319}
 320
 321/*
 322 * Initialize RPC credential cache
 323 */
 324int
 325rpcauth_init_credcache(struct rpc_auth *auth)
 326{
 327        struct rpc_cred_cache *new;
 328        unsigned int hashsize;
 329
 330        new = kmalloc(sizeof(*new), GFP_KERNEL);
 331        if (!new)
 332                goto out_nocache;
 333        new->hashbits = auth_hashbits;
 334        hashsize = 1U << new->hashbits;
 335        new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL);
 336        if (!new->hashtable)
 337                goto out_nohashtbl;
 338        spin_lock_init(&new->lock);
 339        auth->au_credcache = new;
 340        return 0;
 341out_nohashtbl:
 342        kfree(new);
 343out_nocache:
 344        return -ENOMEM;
 345}
 346EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
 347
 348/*
 349 * Setup a credential key lifetime timeout notification
 350 */
 351int
 352rpcauth_key_timeout_notify(struct rpc_auth *auth, struct rpc_cred *cred)
 353{
 354        if (!cred->cr_auth->au_ops->key_timeout)
 355                return 0;
 356        return cred->cr_auth->au_ops->key_timeout(auth, cred);
 357}
 358EXPORT_SYMBOL_GPL(rpcauth_key_timeout_notify);
 359
 360bool
 361rpcauth_cred_key_to_expire(struct rpc_auth *auth, struct rpc_cred *cred)
 362{
 363        if (auth->au_flags & RPCAUTH_AUTH_NO_CRKEY_TIMEOUT)
 364                return false;
 365        if (!cred->cr_ops->crkey_to_expire)
 366                return false;
 367        return cred->cr_ops->crkey_to_expire(cred);
 368}
 369EXPORT_SYMBOL_GPL(rpcauth_cred_key_to_expire);
 370
 371char *
 372rpcauth_stringify_acceptor(struct rpc_cred *cred)
 373{
 374        if (!cred->cr_ops->crstringify_acceptor)
 375                return NULL;
 376        return cred->cr_ops->crstringify_acceptor(cred);
 377}
 378EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor);
 379
 380/*
 381 * Destroy a list of credentials
 382 */
 383static inline
 384void rpcauth_destroy_credlist(struct list_head *head)
 385{
 386        struct rpc_cred *cred;
 387
 388        while (!list_empty(head)) {
 389                cred = list_entry(head->next, struct rpc_cred, cr_lru);
 390                list_del_init(&cred->cr_lru);
 391                put_rpccred(cred);
 392        }
 393}
 394
 395/*
 396 * Clear the RPC credential cache, and delete those credentials
 397 * that are not referenced.
 398 */
 399void
 400rpcauth_clear_credcache(struct rpc_cred_cache *cache)
 401{
 402        LIST_HEAD(free);
 403        struct hlist_head *head;
 404        struct rpc_cred *cred;
 405        unsigned int hashsize = 1U << cache->hashbits;
 406        int             i;
 407
 408        spin_lock(&rpc_credcache_lock);
 409        spin_lock(&cache->lock);
 410        for (i = 0; i < hashsize; i++) {
 411                head = &cache->hashtable[i];
 412                while (!hlist_empty(head)) {
 413                        cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
 414                        get_rpccred(cred);
 415                        if (!list_empty(&cred->cr_lru)) {
 416                                list_del(&cred->cr_lru);
 417                                number_cred_unused--;
 418                        }
 419                        list_add_tail(&cred->cr_lru, &free);
 420                        rpcauth_unhash_cred_locked(cred);
 421                }
 422        }
 423        spin_unlock(&cache->lock);
 424        spin_unlock(&rpc_credcache_lock);
 425        rpcauth_destroy_credlist(&free);
 426}
 427
 428/*
 429 * Destroy the RPC credential cache
 430 */
 431void
 432rpcauth_destroy_credcache(struct rpc_auth *auth)
 433{
 434        struct rpc_cred_cache *cache = auth->au_credcache;
 435
 436        if (cache) {
 437                auth->au_credcache = NULL;
 438                rpcauth_clear_credcache(cache);
 439                kfree(cache->hashtable);
 440                kfree(cache);
 441        }
 442}
 443EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);
 444
 445
 446#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ)
 447
 448/*
 449 * Remove stale credentials. Avoid sleeping inside the loop.
 450 */
 451static long
 452rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
 453{
 454        spinlock_t *cache_lock;
 455        struct rpc_cred *cred, *next;
 456        unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM;
 457        long freed = 0;
 458
 459        list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) {
 460
 461                if (nr_to_scan-- == 0)
 462                        break;
 463                /*
 464                 * Enforce a 60 second garbage collection moratorium
 465                 * Note that the cred_unused list must be time-ordered.
 466                 */
 467                if (time_in_range(cred->cr_expire, expired, jiffies) &&
 468                    test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) {
 469                        freed = SHRINK_STOP;
 470                        break;
 471                }
 472
 473                list_del_init(&cred->cr_lru);
 474                number_cred_unused--;
 475                freed++;
 476                if (atomic_read(&cred->cr_count) != 0)
 477                        continue;
 478
 479                cache_lock = &cred->cr_auth->au_credcache->lock;
 480                spin_lock(cache_lock);
 481                if (atomic_read(&cred->cr_count) == 0) {
 482                        get_rpccred(cred);
 483                        list_add_tail(&cred->cr_lru, free);
 484                        rpcauth_unhash_cred_locked(cred);
 485                }
 486                spin_unlock(cache_lock);
 487        }
 488        return freed;
 489}
 490
 491static unsigned long
 492rpcauth_cache_do_shrink(int nr_to_scan)
 493{
 494        LIST_HEAD(free);
 495        unsigned long freed;
 496
 497        spin_lock(&rpc_credcache_lock);
 498        freed = rpcauth_prune_expired(&free, nr_to_scan);
 499        spin_unlock(&rpc_credcache_lock);
 500        rpcauth_destroy_credlist(&free);
 501
 502        return freed;
 503}
 504
 505/*
 506 * Run memory cache shrinker.
 507 */
 508static unsigned long
 509rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
 510
 511{
 512        if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL)
 513                return SHRINK_STOP;
 514
 515        /* nothing left, don't come back */
 516        if (list_empty(&cred_unused))
 517                return SHRINK_STOP;
 518
 519        return rpcauth_cache_do_shrink(sc->nr_to_scan);
 520}
 521
 522static unsigned long
 523rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc)
 524
 525{
 526        return number_cred_unused * sysctl_vfs_cache_pressure / 100;
 527}
 528
 529static void
 530rpcauth_cache_enforce_limit(void)
 531{
 532        unsigned long diff;
 533        unsigned int nr_to_scan;
 534
 535        if (number_cred_unused <= auth_max_cred_cachesize)
 536                return;
 537        diff = number_cred_unused - auth_max_cred_cachesize;
 538        nr_to_scan = 100;
 539        if (diff < nr_to_scan)
 540                nr_to_scan = diff;
 541        rpcauth_cache_do_shrink(nr_to_scan);
 542}
 543
 544/*
 545 * Look up a process' credentials in the authentication cache
 546 */
 547struct rpc_cred *
 548rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
 549                int flags, gfp_t gfp)
 550{
 551        LIST_HEAD(free);
 552        struct rpc_cred_cache *cache = auth->au_credcache;
 553        struct rpc_cred *cred = NULL,
 554                        *entry, *new;
 555        unsigned int nr;
 556
 557        nr = auth->au_ops->hash_cred(acred, cache->hashbits);
 558
 559        rcu_read_lock();
 560        hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
 561                if (!entry->cr_ops->crmatch(acred, entry, flags))
 562                        continue;
 563                if (flags & RPCAUTH_LOOKUP_RCU) {
 564                        if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) &&
 565                            !test_bit(RPCAUTH_CRED_NEW, &entry->cr_flags))
 566                                cred = entry;
 567                        break;
 568                }
 569                spin_lock(&cache->lock);
 570                if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) == 0) {
 571                        spin_unlock(&cache->lock);
 572                        continue;
 573                }
 574                cred = get_rpccred(entry);
 575                spin_unlock(&cache->lock);
 576                break;
 577        }
 578        rcu_read_unlock();
 579
 580        if (cred != NULL)
 581                goto found;
 582
 583        if (flags & RPCAUTH_LOOKUP_RCU)
 584                return ERR_PTR(-ECHILD);
 585
 586        new = auth->au_ops->crcreate(auth, acred, flags, gfp);
 587        if (IS_ERR(new)) {
 588                cred = new;
 589                goto out;
 590        }
 591
 592        spin_lock(&cache->lock);
 593        hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {
 594                if (!entry->cr_ops->crmatch(acred, entry, flags))
 595                        continue;
 596                cred = get_rpccred(entry);
 597                break;
 598        }
 599        if (cred == NULL) {
 600                cred = new;
 601                set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
 602                hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
 603        } else
 604                list_add_tail(&new->cr_lru, &free);
 605        spin_unlock(&cache->lock);
 606        rpcauth_cache_enforce_limit();
 607found:
 608        if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
 609            cred->cr_ops->cr_init != NULL &&
 610            !(flags & RPCAUTH_LOOKUP_NEW)) {
 611                int res = cred->cr_ops->cr_init(auth, cred);
 612                if (res < 0) {
 613                        put_rpccred(cred);
 614                        cred = ERR_PTR(res);
 615                }
 616        }
 617        rpcauth_destroy_credlist(&free);
 618out:
 619        return cred;
 620}
 621EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache);
 622
 623struct rpc_cred *
 624rpcauth_lookupcred(struct rpc_auth *auth, int flags)
 625{
 626        struct auth_cred acred;
 627        struct rpc_cred *ret;
 628        const struct cred *cred = current_cred();
 629
 630        dprintk("RPC:       looking up %s cred\n",
 631                auth->au_ops->au_name);
 632
 633        memset(&acred, 0, sizeof(acred));
 634        acred.uid = cred->fsuid;
 635        acred.gid = cred->fsgid;
 636        acred.group_info = cred->group_info;
 637        ret = auth->au_ops->lookup_cred(auth, &acred, flags);
 638        return ret;
 639}
 640EXPORT_SYMBOL_GPL(rpcauth_lookupcred);
 641
 642void
 643rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
 644                  struct rpc_auth *auth, const struct rpc_credops *ops)
 645{
 646        INIT_HLIST_NODE(&cred->cr_hash);
 647        INIT_LIST_HEAD(&cred->cr_lru);
 648        atomic_set(&cred->cr_count, 1);
 649        cred->cr_auth = auth;
 650        cred->cr_ops = ops;
 651        cred->cr_expire = jiffies;
 652        cred->cr_uid = acred->uid;
 653}
 654EXPORT_SYMBOL_GPL(rpcauth_init_cred);
 655
 656struct rpc_cred *
 657rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags)
 658{
 659        dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid,
 660                        cred->cr_auth->au_ops->au_name, cred);
 661        return get_rpccred(cred);
 662}
 663EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred);
 664
 665static struct rpc_cred *
 666rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
 667{
 668        struct rpc_auth *auth = task->tk_client->cl_auth;
 669        struct auth_cred acred = {
 670                .uid = GLOBAL_ROOT_UID,
 671                .gid = GLOBAL_ROOT_GID,
 672        };
 673
 674        dprintk("RPC: %5u looking up %s cred\n",
 675                task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
 676        return auth->au_ops->lookup_cred(auth, &acred, lookupflags);
 677}
 678
 679static struct rpc_cred *
 680rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
 681{
 682        struct rpc_auth *auth = task->tk_client->cl_auth;
 683
 684        dprintk("RPC: %5u looking up %s cred\n",
 685                task->tk_pid, auth->au_ops->au_name);
 686        return rpcauth_lookupcred(auth, lookupflags);
 687}
 688
 689static int
 690rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags)
 691{
 692        struct rpc_rqst *req = task->tk_rqstp;
 693        struct rpc_cred *new;
 694        int lookupflags = 0;
 695
 696        if (flags & RPC_TASK_ASYNC)
 697                lookupflags |= RPCAUTH_LOOKUP_NEW;
 698        if (cred != NULL)
 699                new = cred->cr_ops->crbind(task, cred, lookupflags);
 700        else if (flags & RPC_TASK_ROOTCREDS)
 701                new = rpcauth_bind_root_cred(task, lookupflags);
 702        else
 703                new = rpcauth_bind_new_cred(task, lookupflags);
 704        if (IS_ERR(new))
 705                return PTR_ERR(new);
 706        put_rpccred(req->rq_cred);
 707        req->rq_cred = new;
 708        return 0;
 709}
 710
 711void
 712put_rpccred(struct rpc_cred *cred)
 713{
 714        if (cred == NULL)
 715                return;
 716        /* Fast path for unhashed credentials */
 717        if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) == 0) {
 718                if (atomic_dec_and_test(&cred->cr_count))
 719                        cred->cr_ops->crdestroy(cred);
 720                return;
 721        }
 722
 723        if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
 724                return;
 725        if (!list_empty(&cred->cr_lru)) {
 726                number_cred_unused--;
 727                list_del_init(&cred->cr_lru);
 728        }
 729        if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) {
 730                if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) {
 731                        cred->cr_expire = jiffies;
 732                        list_add_tail(&cred->cr_lru, &cred_unused);
 733                        number_cred_unused++;
 734                        goto out_nodestroy;
 735                }
 736                if (!rpcauth_unhash_cred(cred)) {
 737                        /* We were hashed and someone looked us up... */
 738                        goto out_nodestroy;
 739                }
 740        }
 741        spin_unlock(&rpc_credcache_lock);
 742        cred->cr_ops->crdestroy(cred);
 743        return;
 744out_nodestroy:
 745        spin_unlock(&rpc_credcache_lock);
 746}
 747EXPORT_SYMBOL_GPL(put_rpccred);
 748
 749__be32 *
 750rpcauth_marshcred(struct rpc_task *task, __be32 *p)
 751{
 752        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 753
 754        dprintk("RPC: %5u marshaling %s cred %p\n",
 755                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 756
 757        return cred->cr_ops->crmarshal(task, p);
 758}
 759
 760__be32 *
 761rpcauth_checkverf(struct rpc_task *task, __be32 *p)
 762{
 763        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 764
 765        dprintk("RPC: %5u validating %s cred %p\n",
 766                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 767
 768        return cred->cr_ops->crvalidate(task, p);
 769}
 770
 771static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
 772                                   __be32 *data, void *obj)
 773{
 774        struct xdr_stream xdr;
 775
 776        xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data);
 777        encode(rqstp, &xdr, obj);
 778}
 779
 780int
 781rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp,
 782                __be32 *data, void *obj)
 783{
 784        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 785
 786        dprintk("RPC: %5u using %s cred %p to wrap rpc data\n",
 787                        task->tk_pid, cred->cr_ops->cr_name, cred);
 788        if (cred->cr_ops->crwrap_req)
 789                return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);
 790        /* By default, we encode the arguments normally. */
 791        rpcauth_wrap_req_encode(encode, rqstp, data, obj);
 792        return 0;
 793}
 794
 795static int
 796rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
 797                          __be32 *data, void *obj)
 798{
 799        struct xdr_stream xdr;
 800
 801        xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data);
 802        return decode(rqstp, &xdr, obj);
 803}
 804
 805int
 806rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp,
 807                __be32 *data, void *obj)
 808{
 809        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 810
 811        dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n",
 812                        task->tk_pid, cred->cr_ops->cr_name, cred);
 813        if (cred->cr_ops->crunwrap_resp)
 814                return cred->cr_ops->crunwrap_resp(task, decode, rqstp,
 815                                                   data, obj);
 816        /* By default, we decode the arguments normally. */
 817        return rpcauth_unwrap_req_decode(decode, rqstp, data, obj);
 818}
 819
 820int
 821rpcauth_refreshcred(struct rpc_task *task)
 822{
 823        struct rpc_cred *cred;
 824        int err;
 825
 826        cred = task->tk_rqstp->rq_cred;
 827        if (cred == NULL) {
 828                err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags);
 829                if (err < 0)
 830                        goto out;
 831                cred = task->tk_rqstp->rq_cred;
 832        }
 833        dprintk("RPC: %5u refreshing %s cred %p\n",
 834                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 835
 836        err = cred->cr_ops->crrefresh(task);
 837out:
 838        if (err < 0)
 839                task->tk_status = err;
 840        return err;
 841}
 842
 843void
 844rpcauth_invalcred(struct rpc_task *task)
 845{
 846        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 847
 848        dprintk("RPC: %5u invalidating %s cred %p\n",
 849                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 850        if (cred)
 851                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
 852}
 853
 854int
 855rpcauth_uptodatecred(struct rpc_task *task)
 856{
 857        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 858
 859        return cred == NULL ||
 860                test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
 861}
 862
 863static struct shrinker rpc_cred_shrinker = {
 864        .count_objects = rpcauth_cache_shrink_count,
 865        .scan_objects = rpcauth_cache_shrink_scan,
 866        .seeks = DEFAULT_SEEKS,
 867};
 868
 869int __init rpcauth_init_module(void)
 870{
 871        int err;
 872
 873        err = rpc_init_authunix();
 874        if (err < 0)
 875                goto out1;
 876        err = rpc_init_generic_auth();
 877        if (err < 0)
 878                goto out2;
 879        err = register_shrinker(&rpc_cred_shrinker);
 880        if (err < 0)
 881                goto out3;
 882        return 0;
 883out3:
 884        rpc_destroy_generic_auth();
 885out2:
 886        rpc_destroy_authunix();
 887out1:
 888        return err;
 889}
 890
 891void rpcauth_remove_module(void)
 892{
 893        rpc_destroy_authunix();
 894        rpc_destroy_generic_auth();
 895        unregister_shrinker(&rpc_cred_shrinker);
 896}
 897