linux/net/sunrpc/auth.c
<<
>>
Prefs
   1/*
   2 * linux/net/sunrpc/auth.c
   3 *
   4 * Generic RPC client authentication API.
   5 *
   6 * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
   7 */
   8
   9#include <linux/types.h>
  10#include <linux/sched.h>
  11#include <linux/module.h>
  12#include <linux/slab.h>
  13#include <linux/errno.h>
  14#include <linux/hash.h>
  15#include <linux/sunrpc/clnt.h>
  16#include <linux/sunrpc/gss_api.h>
  17#include <linux/spinlock.h>
  18
  19#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
  20# define RPCDBG_FACILITY        RPCDBG_AUTH
  21#endif
  22
  23#define RPC_CREDCACHE_DEFAULT_HASHBITS  (4)
  24struct rpc_cred_cache {
  25        struct hlist_head       *hashtable;
  26        unsigned int            hashbits;
  27        spinlock_t              lock;
  28};
  29
  30static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;
  31
  32static DEFINE_SPINLOCK(rpc_authflavor_lock);
  33static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
  34        &authnull_ops,          /* AUTH_NULL */
  35        &authunix_ops,          /* AUTH_UNIX */
  36        NULL,                   /* others can be loadable modules */
  37};
  38
  39static LIST_HEAD(cred_unused);
  40static unsigned long number_cred_unused;
  41
  42#define MAX_HASHTABLE_BITS (14)
  43static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
  44{
  45        unsigned long num;
  46        unsigned int nbits;
  47        int ret;
  48
  49        if (!val)
  50                goto out_inval;
  51        ret = kstrtoul(val, 0, &num);
  52        if (ret == -EINVAL)
  53                goto out_inval;
  54        nbits = fls(num);
  55        if (num > (1U << nbits))
  56                nbits++;
  57        if (nbits > MAX_HASHTABLE_BITS || nbits < 2)
  58                goto out_inval;
  59        *(unsigned int *)kp->arg = nbits;
  60        return 0;
  61out_inval:
  62        return -EINVAL;
  63}
  64
  65static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp)
  66{
  67        unsigned int nbits;
  68
  69        nbits = *(unsigned int *)kp->arg;
  70        return sprintf(buffer, "%u", 1U << nbits);
  71}
  72
  73#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int);
  74
  75static struct kernel_param_ops param_ops_hashtbl_sz = {
  76        .set = param_set_hashtbl_sz,
  77        .get = param_get_hashtbl_sz,
  78};
  79
  80module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644);
  81MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");
  82
  83static unsigned long auth_max_cred_cachesize = ULONG_MAX;
  84module_param(auth_max_cred_cachesize, ulong, 0644);
  85MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size");
  86
  87static u32
  88pseudoflavor_to_flavor(u32 flavor) {
  89        if (flavor > RPC_AUTH_MAXFLAVOR)
  90                return RPC_AUTH_GSS;
  91        return flavor;
  92}
  93
  94int
  95rpcauth_register(const struct rpc_authops *ops)
  96{
  97        rpc_authflavor_t flavor;
  98        int ret = -EPERM;
  99
 100        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
 101                return -EINVAL;
 102        spin_lock(&rpc_authflavor_lock);
 103        if (auth_flavors[flavor] == NULL) {
 104                auth_flavors[flavor] = ops;
 105                ret = 0;
 106        }
 107        spin_unlock(&rpc_authflavor_lock);
 108        return ret;
 109}
 110EXPORT_SYMBOL_GPL(rpcauth_register);
 111
 112int
 113rpcauth_unregister(const struct rpc_authops *ops)
 114{
 115        rpc_authflavor_t flavor;
 116        int ret = -EPERM;
 117
 118        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
 119                return -EINVAL;
 120        spin_lock(&rpc_authflavor_lock);
 121        if (auth_flavors[flavor] == ops) {
 122                auth_flavors[flavor] = NULL;
 123                ret = 0;
 124        }
 125        spin_unlock(&rpc_authflavor_lock);
 126        return ret;
 127}
 128EXPORT_SYMBOL_GPL(rpcauth_unregister);
 129
 130/**
 131 * rpcauth_get_pseudoflavor - check if security flavor is supported
 132 * @flavor: a security flavor
 133 * @info: a GSS mech OID, quality of protection, and service value
 134 *
 135 * Verifies that an appropriate kernel module is available or already loaded.
 136 * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is
 137 * not supported locally.
 138 */
 139rpc_authflavor_t
 140rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info)
 141{
 142        const struct rpc_authops *ops;
 143        rpc_authflavor_t pseudoflavor;
 144
 145        ops = auth_flavors[flavor];
 146        if (ops == NULL)
 147                request_module("rpc-auth-%u", flavor);
 148        spin_lock(&rpc_authflavor_lock);
 149        ops = auth_flavors[flavor];
 150        if (ops == NULL || !try_module_get(ops->owner)) {
 151                spin_unlock(&rpc_authflavor_lock);
 152                return RPC_AUTH_MAXFLAVOR;
 153        }
 154        spin_unlock(&rpc_authflavor_lock);
 155
 156        pseudoflavor = flavor;
 157        if (ops->info2flavor != NULL)
 158                pseudoflavor = ops->info2flavor(info);
 159
 160        module_put(ops->owner);
 161        return pseudoflavor;
 162}
 163EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor);
 164
 165/**
 166 * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor
 167 * @pseudoflavor: GSS pseudoflavor to match
 168 * @info: rpcsec_gss_info structure to fill in
 169 *
 170 * Returns zero and fills in "info" if pseudoflavor matches a
 171 * supported mechanism.
 172 */
 173int
 174rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info)
 175{
 176        rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor);
 177        const struct rpc_authops *ops;
 178        int result;
 179
 180        if (flavor >= RPC_AUTH_MAXFLAVOR)
 181                return -EINVAL;
 182
 183        ops = auth_flavors[flavor];
 184        if (ops == NULL)
 185                request_module("rpc-auth-%u", flavor);
 186        spin_lock(&rpc_authflavor_lock);
 187        ops = auth_flavors[flavor];
 188        if (ops == NULL || !try_module_get(ops->owner)) {
 189                spin_unlock(&rpc_authflavor_lock);
 190                return -ENOENT;
 191        }
 192        spin_unlock(&rpc_authflavor_lock);
 193
 194        result = -ENOENT;
 195        if (ops->flavor2info != NULL)
 196                result = ops->flavor2info(pseudoflavor, info);
 197
 198        module_put(ops->owner);
 199        return result;
 200}
 201EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo);
 202
 203/**
 204 * rpcauth_list_flavors - discover registered flavors and pseudoflavors
 205 * @array: array to fill in
 206 * @size: size of "array"
 207 *
 208 * Returns the number of array items filled in, or a negative errno.
 209 *
 210 * The returned array is not sorted by any policy.  Callers should not
 211 * rely on the order of the items in the returned array.
 212 */
 213int
 214rpcauth_list_flavors(rpc_authflavor_t *array, int size)
 215{
 216        rpc_authflavor_t flavor;
 217        int result = 0;
 218
 219        spin_lock(&rpc_authflavor_lock);
 220        for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) {
 221                const struct rpc_authops *ops = auth_flavors[flavor];
 222                rpc_authflavor_t pseudos[4];
 223                int i, len;
 224
 225                if (result >= size) {
 226                        result = -ENOMEM;
 227                        break;
 228                }
 229
 230                if (ops == NULL)
 231                        continue;
 232                if (ops->list_pseudoflavors == NULL) {
 233                        array[result++] = ops->au_flavor;
 234                        continue;
 235                }
 236                len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos));
 237                if (len < 0) {
 238                        result = len;
 239                        break;
 240                }
 241                for (i = 0; i < len; i++) {
 242                        if (result >= size) {
 243                                result = -ENOMEM;
 244                                break;
 245                        }
 246                        array[result++] = pseudos[i];
 247                }
 248        }
 249        spin_unlock(&rpc_authflavor_lock);
 250
 251        dprintk("RPC:       %s returns %d\n", __func__, result);
 252        return result;
 253}
 254EXPORT_SYMBOL_GPL(rpcauth_list_flavors);
 255
 256struct rpc_auth *
 257rpcauth_create(struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
 258{
 259        struct rpc_auth         *auth;
 260        const struct rpc_authops *ops;
 261        u32                     flavor = pseudoflavor_to_flavor(args->pseudoflavor);
 262
 263        auth = ERR_PTR(-EINVAL);
 264        if (flavor >= RPC_AUTH_MAXFLAVOR)
 265                goto out;
 266
 267        if ((ops = auth_flavors[flavor]) == NULL)
 268                request_module("rpc-auth-%u", flavor);
 269        spin_lock(&rpc_authflavor_lock);
 270        ops = auth_flavors[flavor];
 271        if (ops == NULL || !try_module_get(ops->owner)) {
 272                spin_unlock(&rpc_authflavor_lock);
 273                goto out;
 274        }
 275        spin_unlock(&rpc_authflavor_lock);
 276        auth = ops->create(args, clnt);
 277        module_put(ops->owner);
 278        if (IS_ERR(auth))
 279                return auth;
 280        if (clnt->cl_auth)
 281                rpcauth_release(clnt->cl_auth);
 282        clnt->cl_auth = auth;
 283
 284out:
 285        return auth;
 286}
 287EXPORT_SYMBOL_GPL(rpcauth_create);
 288
 289void
 290rpcauth_release(struct rpc_auth *auth)
 291{
 292        if (!atomic_dec_and_test(&auth->au_count))
 293                return;
 294        auth->au_ops->destroy(auth);
 295}
 296
 297static DEFINE_SPINLOCK(rpc_credcache_lock);
 298
 299static void
 300rpcauth_unhash_cred_locked(struct rpc_cred *cred)
 301{
 302        hlist_del_rcu(&cred->cr_hash);
 303        smp_mb__before_atomic();
 304        clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
 305}
 306
 307static int
 308rpcauth_unhash_cred(struct rpc_cred *cred)
 309{
 310        spinlock_t *cache_lock;
 311        int ret;
 312
 313        cache_lock = &cred->cr_auth->au_credcache->lock;
 314        spin_lock(cache_lock);
 315        ret = atomic_read(&cred->cr_count) == 0;
 316        if (ret)
 317                rpcauth_unhash_cred_locked(cred);
 318        spin_unlock(cache_lock);
 319        return ret;
 320}
 321
 322/*
 323 * Initialize RPC credential cache
 324 */
 325int
 326rpcauth_init_credcache(struct rpc_auth *auth)
 327{
 328        struct rpc_cred_cache *new;
 329        unsigned int hashsize;
 330
 331        new = kmalloc(sizeof(*new), GFP_KERNEL);
 332        if (!new)
 333                goto out_nocache;
 334        new->hashbits = auth_hashbits;
 335        hashsize = 1U << new->hashbits;
 336        new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL);
 337        if (!new->hashtable)
 338                goto out_nohashtbl;
 339        spin_lock_init(&new->lock);
 340        auth->au_credcache = new;
 341        return 0;
 342out_nohashtbl:
 343        kfree(new);
 344out_nocache:
 345        return -ENOMEM;
 346}
 347EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
 348
 349/*
 350 * Setup a credential key lifetime timeout notification
 351 */
 352int
 353rpcauth_key_timeout_notify(struct rpc_auth *auth, struct rpc_cred *cred)
 354{
 355        if (!cred->cr_auth->au_ops->key_timeout)
 356                return 0;
 357        return cred->cr_auth->au_ops->key_timeout(auth, cred);
 358}
 359EXPORT_SYMBOL_GPL(rpcauth_key_timeout_notify);
 360
 361bool
 362rpcauth_cred_key_to_expire(struct rpc_cred *cred)
 363{
 364        if (!cred->cr_ops->crkey_to_expire)
 365                return false;
 366        return cred->cr_ops->crkey_to_expire(cred);
 367}
 368EXPORT_SYMBOL_GPL(rpcauth_cred_key_to_expire);
 369
 370char *
 371rpcauth_stringify_acceptor(struct rpc_cred *cred)
 372{
 373        if (!cred->cr_ops->crstringify_acceptor)
 374                return NULL;
 375        return cred->cr_ops->crstringify_acceptor(cred);
 376}
 377EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor);
 378
 379/*
 380 * Destroy a list of credentials
 381 */
 382static inline
 383void rpcauth_destroy_credlist(struct list_head *head)
 384{
 385        struct rpc_cred *cred;
 386
 387        while (!list_empty(head)) {
 388                cred = list_entry(head->next, struct rpc_cred, cr_lru);
 389                list_del_init(&cred->cr_lru);
 390                put_rpccred(cred);
 391        }
 392}
 393
 394/*
 395 * Clear the RPC credential cache, and delete those credentials
 396 * that are not referenced.
 397 */
 398void
 399rpcauth_clear_credcache(struct rpc_cred_cache *cache)
 400{
 401        LIST_HEAD(free);
 402        struct hlist_head *head;
 403        struct rpc_cred *cred;
 404        unsigned int hashsize = 1U << cache->hashbits;
 405        int             i;
 406
 407        spin_lock(&rpc_credcache_lock);
 408        spin_lock(&cache->lock);
 409        for (i = 0; i < hashsize; i++) {
 410                head = &cache->hashtable[i];
 411                while (!hlist_empty(head)) {
 412                        cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
 413                        get_rpccred(cred);
 414                        if (!list_empty(&cred->cr_lru)) {
 415                                list_del(&cred->cr_lru);
 416                                number_cred_unused--;
 417                        }
 418                        list_add_tail(&cred->cr_lru, &free);
 419                        rpcauth_unhash_cred_locked(cred);
 420                }
 421        }
 422        spin_unlock(&cache->lock);
 423        spin_unlock(&rpc_credcache_lock);
 424        rpcauth_destroy_credlist(&free);
 425}
 426
 427/*
 428 * Destroy the RPC credential cache
 429 */
 430void
 431rpcauth_destroy_credcache(struct rpc_auth *auth)
 432{
 433        struct rpc_cred_cache *cache = auth->au_credcache;
 434
 435        if (cache) {
 436                auth->au_credcache = NULL;
 437                rpcauth_clear_credcache(cache);
 438                kfree(cache->hashtable);
 439                kfree(cache);
 440        }
 441}
 442EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);
 443
 444
 445#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ)
 446
 447/*
 448 * Remove stale credentials. Avoid sleeping inside the loop.
 449 */
 450static long
 451rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
 452{
 453        spinlock_t *cache_lock;
 454        struct rpc_cred *cred, *next;
 455        unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM;
 456        long freed = 0;
 457
 458        list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) {
 459
 460                if (nr_to_scan-- == 0)
 461                        break;
 462                /*
 463                 * Enforce a 60 second garbage collection moratorium
 464                 * Note that the cred_unused list must be time-ordered.
 465                 */
 466                if (time_in_range(cred->cr_expire, expired, jiffies) &&
 467                    test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0)
 468                        break;
 469
 470                list_del_init(&cred->cr_lru);
 471                number_cred_unused--;
 472                freed++;
 473                if (atomic_read(&cred->cr_count) != 0)
 474                        continue;
 475
 476                cache_lock = &cred->cr_auth->au_credcache->lock;
 477                spin_lock(cache_lock);
 478                if (atomic_read(&cred->cr_count) == 0) {
 479                        get_rpccred(cred);
 480                        list_add_tail(&cred->cr_lru, free);
 481                        rpcauth_unhash_cred_locked(cred);
 482                }
 483                spin_unlock(cache_lock);
 484        }
 485        return freed;
 486}
 487
 488static unsigned long
 489rpcauth_cache_do_shrink(int nr_to_scan)
 490{
 491        LIST_HEAD(free);
 492        unsigned long freed;
 493
 494        spin_lock(&rpc_credcache_lock);
 495        freed = rpcauth_prune_expired(&free, nr_to_scan);
 496        spin_unlock(&rpc_credcache_lock);
 497        rpcauth_destroy_credlist(&free);
 498
 499        return freed;
 500}
 501
 502/*
 503 * Run memory cache shrinker.
 504 */
 505static unsigned long
 506rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
 507
 508{
 509        if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL)
 510                return SHRINK_STOP;
 511
 512        /* nothing left, don't come back */
 513        if (list_empty(&cred_unused))
 514                return SHRINK_STOP;
 515
 516        return rpcauth_cache_do_shrink(sc->nr_to_scan);
 517}
 518
 519static unsigned long
 520rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc)
 521
 522{
 523        return (number_cred_unused / 100) * sysctl_vfs_cache_pressure;
 524}
 525
 526static void
 527rpcauth_cache_enforce_limit(void)
 528{
 529        unsigned long diff;
 530        unsigned int nr_to_scan;
 531
 532        if (number_cred_unused <= auth_max_cred_cachesize)
 533                return;
 534        diff = number_cred_unused - auth_max_cred_cachesize;
 535        nr_to_scan = 100;
 536        if (diff < nr_to_scan)
 537                nr_to_scan = diff;
 538        rpcauth_cache_do_shrink(nr_to_scan);
 539}
 540
 541/*
 542 * Look up a process' credentials in the authentication cache
 543 */
 544struct rpc_cred *
 545rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
 546                int flags)
 547{
 548        LIST_HEAD(free);
 549        struct rpc_cred_cache *cache = auth->au_credcache;
 550        struct rpc_cred *cred = NULL,
 551                        *entry, *new;
 552        unsigned int nr;
 553
 554        nr = hash_long(from_kuid(&init_user_ns, acred->uid), cache->hashbits);
 555
 556        rcu_read_lock();
 557        hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
 558                if (!entry->cr_ops->crmatch(acred, entry, flags))
 559                        continue;
 560                if (flags & RPCAUTH_LOOKUP_RCU) {
 561                        if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) &&
 562                            !test_bit(RPCAUTH_CRED_NEW, &entry->cr_flags))
 563                                cred = entry;
 564                        break;
 565                }
 566                spin_lock(&cache->lock);
 567                if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) == 0) {
 568                        spin_unlock(&cache->lock);
 569                        continue;
 570                }
 571                cred = get_rpccred(entry);
 572                spin_unlock(&cache->lock);
 573                break;
 574        }
 575        rcu_read_unlock();
 576
 577        if (cred != NULL)
 578                goto found;
 579
 580        if (flags & RPCAUTH_LOOKUP_RCU)
 581                return ERR_PTR(-ECHILD);
 582
 583        new = auth->au_ops->crcreate(auth, acred, flags);
 584        if (IS_ERR(new)) {
 585                cred = new;
 586                goto out;
 587        }
 588
 589        spin_lock(&cache->lock);
 590        hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {
 591                if (!entry->cr_ops->crmatch(acred, entry, flags))
 592                        continue;
 593                cred = get_rpccred(entry);
 594                break;
 595        }
 596        if (cred == NULL) {
 597                cred = new;
 598                set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
 599                hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
 600        } else
 601                list_add_tail(&new->cr_lru, &free);
 602        spin_unlock(&cache->lock);
 603        rpcauth_cache_enforce_limit();
 604found:
 605        if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
 606            cred->cr_ops->cr_init != NULL &&
 607            !(flags & RPCAUTH_LOOKUP_NEW)) {
 608                int res = cred->cr_ops->cr_init(auth, cred);
 609                if (res < 0) {
 610                        put_rpccred(cred);
 611                        cred = ERR_PTR(res);
 612                }
 613        }
 614        rpcauth_destroy_credlist(&free);
 615out:
 616        return cred;
 617}
 618EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache);
 619
 620struct rpc_cred *
 621rpcauth_lookupcred(struct rpc_auth *auth, int flags)
 622{
 623        struct auth_cred acred;
 624        struct rpc_cred *ret;
 625        const struct cred *cred = current_cred();
 626
 627        dprintk("RPC:       looking up %s cred\n",
 628                auth->au_ops->au_name);
 629
 630        memset(&acred, 0, sizeof(acred));
 631        acred.uid = cred->fsuid;
 632        acred.gid = cred->fsgid;
 633        acred.group_info = cred->group_info;
 634        ret = auth->au_ops->lookup_cred(auth, &acred, flags);
 635        return ret;
 636}
 637EXPORT_SYMBOL_GPL(rpcauth_lookupcred);
 638
 639void
 640rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
 641                  struct rpc_auth *auth, const struct rpc_credops *ops)
 642{
 643        INIT_HLIST_NODE(&cred->cr_hash);
 644        INIT_LIST_HEAD(&cred->cr_lru);
 645        atomic_set(&cred->cr_count, 1);
 646        cred->cr_auth = auth;
 647        cred->cr_ops = ops;
 648        cred->cr_expire = jiffies;
 649#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
 650        cred->cr_magic = RPCAUTH_CRED_MAGIC;
 651#endif
 652        cred->cr_uid = acred->uid;
 653}
 654EXPORT_SYMBOL_GPL(rpcauth_init_cred);
 655
 656struct rpc_cred *
 657rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred, int lookupflags)
 658{
 659        dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid,
 660                        cred->cr_auth->au_ops->au_name, cred);
 661        return get_rpccred(cred);
 662}
 663EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred);
 664
 665static struct rpc_cred *
 666rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
 667{
 668        struct rpc_auth *auth = task->tk_client->cl_auth;
 669        struct auth_cred acred = {
 670                .uid = GLOBAL_ROOT_UID,
 671                .gid = GLOBAL_ROOT_GID,
 672        };
 673
 674        dprintk("RPC: %5u looking up %s cred\n",
 675                task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
 676        return auth->au_ops->lookup_cred(auth, &acred, lookupflags);
 677}
 678
 679static struct rpc_cred *
 680rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
 681{
 682        struct rpc_auth *auth = task->tk_client->cl_auth;
 683
 684        dprintk("RPC: %5u looking up %s cred\n",
 685                task->tk_pid, auth->au_ops->au_name);
 686        return rpcauth_lookupcred(auth, lookupflags);
 687}
 688
 689static int
 690rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags)
 691{
 692        struct rpc_rqst *req = task->tk_rqstp;
 693        struct rpc_cred *new;
 694        int lookupflags = 0;
 695
 696        if (flags & RPC_TASK_ASYNC)
 697                lookupflags |= RPCAUTH_LOOKUP_NEW;
 698        if (cred != NULL)
 699                new = cred->cr_ops->crbind(task, cred, lookupflags);
 700        else if (flags & RPC_TASK_ROOTCREDS)
 701                new = rpcauth_bind_root_cred(task, lookupflags);
 702        else
 703                new = rpcauth_bind_new_cred(task, lookupflags);
 704        if (IS_ERR(new))
 705                return PTR_ERR(new);
 706        if (req->rq_cred != NULL)
 707                put_rpccred(req->rq_cred);
 708        req->rq_cred = new;
 709        return 0;
 710}
 711
 712void
 713put_rpccred(struct rpc_cred *cred)
 714{
 715        /* Fast path for unhashed credentials */
 716        if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) == 0) {
 717                if (atomic_dec_and_test(&cred->cr_count))
 718                        cred->cr_ops->crdestroy(cred);
 719                return;
 720        }
 721
 722        if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
 723                return;
 724        if (!list_empty(&cred->cr_lru)) {
 725                number_cred_unused--;
 726                list_del_init(&cred->cr_lru);
 727        }
 728        if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) {
 729                if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) {
 730                        cred->cr_expire = jiffies;
 731                        list_add_tail(&cred->cr_lru, &cred_unused);
 732                        number_cred_unused++;
 733                        goto out_nodestroy;
 734                }
 735                if (!rpcauth_unhash_cred(cred)) {
 736                        /* We were hashed and someone looked us up... */
 737                        goto out_nodestroy;
 738                }
 739        }
 740        spin_unlock(&rpc_credcache_lock);
 741        cred->cr_ops->crdestroy(cred);
 742        return;
 743out_nodestroy:
 744        spin_unlock(&rpc_credcache_lock);
 745}
 746EXPORT_SYMBOL_GPL(put_rpccred);
 747
 748__be32 *
 749rpcauth_marshcred(struct rpc_task *task, __be32 *p)
 750{
 751        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 752
 753        dprintk("RPC: %5u marshaling %s cred %p\n",
 754                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 755
 756        return cred->cr_ops->crmarshal(task, p);
 757}
 758
 759__be32 *
 760rpcauth_checkverf(struct rpc_task *task, __be32 *p)
 761{
 762        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 763
 764        dprintk("RPC: %5u validating %s cred %p\n",
 765                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 766
 767        return cred->cr_ops->crvalidate(task, p);
 768}
 769
 770static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
 771                                   __be32 *data, void *obj)
 772{
 773        struct xdr_stream xdr;
 774
 775        xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data);
 776        encode(rqstp, &xdr, obj);
 777}
 778
 779int
 780rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp,
 781                __be32 *data, void *obj)
 782{
 783        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 784
 785        dprintk("RPC: %5u using %s cred %p to wrap rpc data\n",
 786                        task->tk_pid, cred->cr_ops->cr_name, cred);
 787        if (cred->cr_ops->crwrap_req)
 788                return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);
 789        /* By default, we encode the arguments normally. */
 790        rpcauth_wrap_req_encode(encode, rqstp, data, obj);
 791        return 0;
 792}
 793
 794static int
 795rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
 796                          __be32 *data, void *obj)
 797{
 798        struct xdr_stream xdr;
 799
 800        xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data);
 801        return decode(rqstp, &xdr, obj);
 802}
 803
 804int
 805rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp,
 806                __be32 *data, void *obj)
 807{
 808        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 809
 810        dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n",
 811                        task->tk_pid, cred->cr_ops->cr_name, cred);
 812        if (cred->cr_ops->crunwrap_resp)
 813                return cred->cr_ops->crunwrap_resp(task, decode, rqstp,
 814                                                   data, obj);
 815        /* By default, we decode the arguments normally. */
 816        return rpcauth_unwrap_req_decode(decode, rqstp, data, obj);
 817}
 818
 819int
 820rpcauth_refreshcred(struct rpc_task *task)
 821{
 822        struct rpc_cred *cred;
 823        int err;
 824
 825        cred = task->tk_rqstp->rq_cred;
 826        if (cred == NULL) {
 827                err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags);
 828                if (err < 0)
 829                        goto out;
 830                cred = task->tk_rqstp->rq_cred;
 831        }
 832        dprintk("RPC: %5u refreshing %s cred %p\n",
 833                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 834
 835        err = cred->cr_ops->crrefresh(task);
 836out:
 837        if (err < 0)
 838                task->tk_status = err;
 839        return err;
 840}
 841
 842void
 843rpcauth_invalcred(struct rpc_task *task)
 844{
 845        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 846
 847        dprintk("RPC: %5u invalidating %s cred %p\n",
 848                task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
 849        if (cred)
 850                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
 851}
 852
 853int
 854rpcauth_uptodatecred(struct rpc_task *task)
 855{
 856        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 857
 858        return cred == NULL ||
 859                test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
 860}
 861
 862static struct shrinker rpc_cred_shrinker = {
 863        .count_objects = rpcauth_cache_shrink_count,
 864        .scan_objects = rpcauth_cache_shrink_scan,
 865        .seeks = DEFAULT_SEEKS,
 866};
 867
 868int __init rpcauth_init_module(void)
 869{
 870        int err;
 871
 872        err = rpc_init_authunix();
 873        if (err < 0)
 874                goto out1;
 875        err = rpc_init_generic_auth();
 876        if (err < 0)
 877                goto out2;
 878        register_shrinker(&rpc_cred_shrinker);
 879        return 0;
 880out2:
 881        rpc_destroy_authunix();
 882out1:
 883        return err;
 884}
 885
 886void rpcauth_remove_module(void)
 887{
 888        rpc_destroy_authunix();
 889        rpc_destroy_generic_auth();
 890        unregister_shrinker(&rpc_cred_shrinker);
 891}
 892