linux/net/sunrpc/auth.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * linux/net/sunrpc/auth.c
   4 *
   5 * Generic RPC client authentication API.
   6 *
   7 * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
   8 */
   9
  10#include <linux/types.h>
  11#include <linux/sched.h>
  12#include <linux/cred.h>
  13#include <linux/module.h>
  14#include <linux/slab.h>
  15#include <linux/errno.h>
  16#include <linux/hash.h>
  17#include <linux/sunrpc/clnt.h>
  18#include <linux/sunrpc/gss_api.h>
  19#include <linux/spinlock.h>
  20
  21#include <trace/events/sunrpc.h>
  22
  23#define RPC_CREDCACHE_DEFAULT_HASHBITS  (4)
  24struct rpc_cred_cache {
  25        struct hlist_head       *hashtable;
  26        unsigned int            hashbits;
  27        spinlock_t              lock;
  28};
  29
  30static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;
  31
  32static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
  33        [RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops,
  34        [RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops,
  35        NULL,                   /* others can be loadable modules */
  36};
  37
  38static LIST_HEAD(cred_unused);
  39static unsigned long number_cred_unused;
  40
  41static struct cred machine_cred = {
  42        .usage = ATOMIC_INIT(1),
  43#ifdef CONFIG_DEBUG_CREDENTIALS
  44        .magic = CRED_MAGIC,
  45#endif
  46};
  47
  48/*
  49 * Return the machine_cred pointer to be used whenever
  50 * the a generic machine credential is needed.
  51 */
  52const struct cred *rpc_machine_cred(void)
  53{
  54        return &machine_cred;
  55}
  56EXPORT_SYMBOL_GPL(rpc_machine_cred);
  57
  58#define MAX_HASHTABLE_BITS (14)
  59static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
  60{
  61        unsigned long num;
  62        unsigned int nbits;
  63        int ret;
  64
  65        if (!val)
  66                goto out_inval;
  67        ret = kstrtoul(val, 0, &num);
  68        if (ret)
  69                goto out_inval;
  70        nbits = fls(num - 1);
  71        if (nbits > MAX_HASHTABLE_BITS || nbits < 2)
  72                goto out_inval;
  73        *(unsigned int *)kp->arg = nbits;
  74        return 0;
  75out_inval:
  76        return -EINVAL;
  77}
  78
  79static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp)
  80{
  81        unsigned int nbits;
  82
  83        nbits = *(unsigned int *)kp->arg;
  84        return sprintf(buffer, "%u", 1U << nbits);
  85}
  86
  87#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int);
  88
  89static const struct kernel_param_ops param_ops_hashtbl_sz = {
  90        .set = param_set_hashtbl_sz,
  91        .get = param_get_hashtbl_sz,
  92};
  93
  94module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644);
  95MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");
  96
  97static unsigned long auth_max_cred_cachesize = ULONG_MAX;
  98module_param(auth_max_cred_cachesize, ulong, 0644);
  99MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size");
 100
 101static u32
 102pseudoflavor_to_flavor(u32 flavor) {
 103        if (flavor > RPC_AUTH_MAXFLAVOR)
 104                return RPC_AUTH_GSS;
 105        return flavor;
 106}
 107
 108int
 109rpcauth_register(const struct rpc_authops *ops)
 110{
 111        const struct rpc_authops *old;
 112        rpc_authflavor_t flavor;
 113
 114        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
 115                return -EINVAL;
 116        old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], NULL, ops);
 117        if (old == NULL || old == ops)
 118                return 0;
 119        return -EPERM;
 120}
 121EXPORT_SYMBOL_GPL(rpcauth_register);
 122
 123int
 124rpcauth_unregister(const struct rpc_authops *ops)
 125{
 126        const struct rpc_authops *old;
 127        rpc_authflavor_t flavor;
 128
 129        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
 130                return -EINVAL;
 131
 132        old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], ops, NULL);
 133        if (old == ops || old == NULL)
 134                return 0;
 135        return -EPERM;
 136}
 137EXPORT_SYMBOL_GPL(rpcauth_unregister);
 138
 139static const struct rpc_authops *
 140rpcauth_get_authops(rpc_authflavor_t flavor)
 141{
 142        const struct rpc_authops *ops;
 143
 144        if (flavor >= RPC_AUTH_MAXFLAVOR)
 145                return NULL;
 146
 147        rcu_read_lock();
 148        ops = rcu_dereference(auth_flavors[flavor]);
 149        if (ops == NULL) {
 150                rcu_read_unlock();
 151                request_module("rpc-auth-%u", flavor);
 152                rcu_read_lock();
 153                ops = rcu_dereference(auth_flavors[flavor]);
 154                if (ops == NULL)
 155                        goto out;
 156        }
 157        if (!try_module_get(ops->owner))
 158                ops = NULL;
 159out:
 160        rcu_read_unlock();
 161        return ops;
 162}
 163
 164static void
 165rpcauth_put_authops(const struct rpc_authops *ops)
 166{
 167        module_put(ops->owner);
 168}
 169
 170/**
 171 * rpcauth_get_pseudoflavor - check if security flavor is supported
 172 * @flavor: a security flavor
 173 * @info: a GSS mech OID, quality of protection, and service value
 174 *
 175 * Verifies that an appropriate kernel module is available or already loaded.
 176 * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is
 177 * not supported locally.
 178 */
 179rpc_authflavor_t
 180rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info)
 181{
 182        const struct rpc_authops *ops = rpcauth_get_authops(flavor);
 183        rpc_authflavor_t pseudoflavor;
 184
 185        if (!ops)
 186                return RPC_AUTH_MAXFLAVOR;
 187        pseudoflavor = flavor;
 188        if (ops->info2flavor != NULL)
 189                pseudoflavor = ops->info2flavor(info);
 190
 191        rpcauth_put_authops(ops);
 192        return pseudoflavor;
 193}
 194EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor);
 195
 196/**
 197 * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor
 198 * @pseudoflavor: GSS pseudoflavor to match
 199 * @info: rpcsec_gss_info structure to fill in
 200 *
 201 * Returns zero and fills in "info" if pseudoflavor matches a
 202 * supported mechanism.
 203 */
 204int
 205rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info)
 206{
 207        rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor);
 208        const struct rpc_authops *ops;
 209        int result;
 210
 211        ops = rpcauth_get_authops(flavor);
 212        if (ops == NULL)
 213                return -ENOENT;
 214
 215        result = -ENOENT;
 216        if (ops->flavor2info != NULL)
 217                result = ops->flavor2info(pseudoflavor, info);
 218
 219        rpcauth_put_authops(ops);
 220        return result;
 221}
 222EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo);
 223
 224/**
 225 * rpcauth_list_flavors - discover registered flavors and pseudoflavors
 226 * @array: array to fill in
 227 * @size: size of "array"
 228 *
 229 * Returns the number of array items filled in, or a negative errno.
 230 *
 231 * The returned array is not sorted by any policy.  Callers should not
 232 * rely on the order of the items in the returned array.
 233 */
 234int
 235rpcauth_list_flavors(rpc_authflavor_t *array, int size)
 236{
 237        const struct rpc_authops *ops;
 238        rpc_authflavor_t flavor, pseudos[4];
 239        int i, len, result = 0;
 240
 241        rcu_read_lock();
 242        for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) {
 243                ops = rcu_dereference(auth_flavors[flavor]);
 244                if (result >= size) {
 245                        result = -ENOMEM;
 246                        break;
 247                }
 248
 249                if (ops == NULL)
 250                        continue;
 251                if (ops->list_pseudoflavors == NULL) {
 252                        array[result++] = ops->au_flavor;
 253                        continue;
 254                }
 255                len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos));
 256                if (len < 0) {
 257                        result = len;
 258                        break;
 259                }
 260                for (i = 0; i < len; i++) {
 261                        if (result >= size) {
 262                                result = -ENOMEM;
 263                                break;
 264                        }
 265                        array[result++] = pseudos[i];
 266                }
 267        }
 268        rcu_read_unlock();
 269        return result;
 270}
 271EXPORT_SYMBOL_GPL(rpcauth_list_flavors);
 272
 273struct rpc_auth *
 274rpcauth_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
 275{
 276        struct rpc_auth *auth = ERR_PTR(-EINVAL);
 277        const struct rpc_authops *ops;
 278        u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor);
 279
 280        ops = rpcauth_get_authops(flavor);
 281        if (ops == NULL)
 282                goto out;
 283
 284        auth = ops->create(args, clnt);
 285
 286        rpcauth_put_authops(ops);
 287        if (IS_ERR(auth))
 288                return auth;
 289        if (clnt->cl_auth)
 290                rpcauth_release(clnt->cl_auth);
 291        clnt->cl_auth = auth;
 292
 293out:
 294        return auth;
 295}
 296EXPORT_SYMBOL_GPL(rpcauth_create);
 297
 298void
 299rpcauth_release(struct rpc_auth *auth)
 300{
 301        if (!refcount_dec_and_test(&auth->au_count))
 302                return;
 303        auth->au_ops->destroy(auth);
 304}
 305
 306static DEFINE_SPINLOCK(rpc_credcache_lock);
 307
 308/*
 309 * On success, the caller is responsible for freeing the reference
 310 * held by the hashtable
 311 */
 312static bool
 313rpcauth_unhash_cred_locked(struct rpc_cred *cred)
 314{
 315        if (!test_and_clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
 316                return false;
 317        hlist_del_rcu(&cred->cr_hash);
 318        return true;
 319}
 320
 321static bool
 322rpcauth_unhash_cred(struct rpc_cred *cred)
 323{
 324        spinlock_t *cache_lock;
 325        bool ret;
 326
 327        if (!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
 328                return false;
 329        cache_lock = &cred->cr_auth->au_credcache->lock;
 330        spin_lock(cache_lock);
 331        ret = rpcauth_unhash_cred_locked(cred);
 332        spin_unlock(cache_lock);
 333        return ret;
 334}
 335
 336/*
 337 * Initialize RPC credential cache
 338 */
 339int
 340rpcauth_init_credcache(struct rpc_auth *auth)
 341{
 342        struct rpc_cred_cache *new;
 343        unsigned int hashsize;
 344
 345        new = kmalloc(sizeof(*new), GFP_KERNEL);
 346        if (!new)
 347                goto out_nocache;
 348        new->hashbits = auth_hashbits;
 349        hashsize = 1U << new->hashbits;
 350        new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL);
 351        if (!new->hashtable)
 352                goto out_nohashtbl;
 353        spin_lock_init(&new->lock);
 354        auth->au_credcache = new;
 355        return 0;
 356out_nohashtbl:
 357        kfree(new);
 358out_nocache:
 359        return -ENOMEM;
 360}
 361EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
 362
 363char *
 364rpcauth_stringify_acceptor(struct rpc_cred *cred)
 365{
 366        if (!cred->cr_ops->crstringify_acceptor)
 367                return NULL;
 368        return cred->cr_ops->crstringify_acceptor(cred);
 369}
 370EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor);
 371
 372/*
 373 * Destroy a list of credentials
 374 */
 375static inline
 376void rpcauth_destroy_credlist(struct list_head *head)
 377{
 378        struct rpc_cred *cred;
 379
 380        while (!list_empty(head)) {
 381                cred = list_entry(head->next, struct rpc_cred, cr_lru);
 382                list_del_init(&cred->cr_lru);
 383                put_rpccred(cred);
 384        }
 385}
 386
 387static void
 388rpcauth_lru_add_locked(struct rpc_cred *cred)
 389{
 390        if (!list_empty(&cred->cr_lru))
 391                return;
 392        number_cred_unused++;
 393        list_add_tail(&cred->cr_lru, &cred_unused);
 394}
 395
 396static void
 397rpcauth_lru_add(struct rpc_cred *cred)
 398{
 399        if (!list_empty(&cred->cr_lru))
 400                return;
 401        spin_lock(&rpc_credcache_lock);
 402        rpcauth_lru_add_locked(cred);
 403        spin_unlock(&rpc_credcache_lock);
 404}
 405
 406static void
 407rpcauth_lru_remove_locked(struct rpc_cred *cred)
 408{
 409        if (list_empty(&cred->cr_lru))
 410                return;
 411        number_cred_unused--;
 412        list_del_init(&cred->cr_lru);
 413}
 414
 415static void
 416rpcauth_lru_remove(struct rpc_cred *cred)
 417{
 418        if (list_empty(&cred->cr_lru))
 419                return;
 420        spin_lock(&rpc_credcache_lock);
 421        rpcauth_lru_remove_locked(cred);
 422        spin_unlock(&rpc_credcache_lock);
 423}
 424
 425/*
 426 * Clear the RPC credential cache, and delete those credentials
 427 * that are not referenced.
 428 */
 429void
 430rpcauth_clear_credcache(struct rpc_cred_cache *cache)
 431{
 432        LIST_HEAD(free);
 433        struct hlist_head *head;
 434        struct rpc_cred *cred;
 435        unsigned int hashsize = 1U << cache->hashbits;
 436        int             i;
 437
 438        spin_lock(&rpc_credcache_lock);
 439        spin_lock(&cache->lock);
 440        for (i = 0; i < hashsize; i++) {
 441                head = &cache->hashtable[i];
 442                while (!hlist_empty(head)) {
 443                        cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
 444                        rpcauth_unhash_cred_locked(cred);
 445                        /* Note: We now hold a reference to cred */
 446                        rpcauth_lru_remove_locked(cred);
 447                        list_add_tail(&cred->cr_lru, &free);
 448                }
 449        }
 450        spin_unlock(&cache->lock);
 451        spin_unlock(&rpc_credcache_lock);
 452        rpcauth_destroy_credlist(&free);
 453}
 454
 455/*
 456 * Destroy the RPC credential cache
 457 */
 458void
 459rpcauth_destroy_credcache(struct rpc_auth *auth)
 460{
 461        struct rpc_cred_cache *cache = auth->au_credcache;
 462
 463        if (cache) {
 464                auth->au_credcache = NULL;
 465                rpcauth_clear_credcache(cache);
 466                kfree(cache->hashtable);
 467                kfree(cache);
 468        }
 469}
 470EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);
 471
 472
 473#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ)
 474
 475/*
 476 * Remove stale credentials. Avoid sleeping inside the loop.
 477 */
 478static long
 479rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
 480{
 481        struct rpc_cred *cred, *next;
 482        unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM;
 483        long freed = 0;
 484
 485        list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) {
 486
 487                if (nr_to_scan-- == 0)
 488                        break;
 489                if (refcount_read(&cred->cr_count) > 1) {
 490                        rpcauth_lru_remove_locked(cred);
 491                        continue;
 492                }
 493                /*
 494                 * Enforce a 60 second garbage collection moratorium
 495                 * Note that the cred_unused list must be time-ordered.
 496                 */
 497                if (!time_in_range(cred->cr_expire, expired, jiffies))
 498                        continue;
 499                if (!rpcauth_unhash_cred(cred))
 500                        continue;
 501
 502                rpcauth_lru_remove_locked(cred);
 503                freed++;
 504                list_add_tail(&cred->cr_lru, free);
 505        }
 506        return freed ? freed : SHRINK_STOP;
 507}
 508
 509static unsigned long
 510rpcauth_cache_do_shrink(int nr_to_scan)
 511{
 512        LIST_HEAD(free);
 513        unsigned long freed;
 514
 515        spin_lock(&rpc_credcache_lock);
 516        freed = rpcauth_prune_expired(&free, nr_to_scan);
 517        spin_unlock(&rpc_credcache_lock);
 518        rpcauth_destroy_credlist(&free);
 519
 520        return freed;
 521}
 522
 523/*
 524 * Run memory cache shrinker.
 525 */
 526static unsigned long
 527rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
 528
 529{
 530        if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL)
 531                return SHRINK_STOP;
 532
 533        /* nothing left, don't come back */
 534        if (list_empty(&cred_unused))
 535                return SHRINK_STOP;
 536
 537        return rpcauth_cache_do_shrink(sc->nr_to_scan);
 538}
 539
 540static unsigned long
 541rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc)
 542
 543{
 544        return number_cred_unused * sysctl_vfs_cache_pressure / 100;
 545}
 546
 547static void
 548rpcauth_cache_enforce_limit(void)
 549{
 550        unsigned long diff;
 551        unsigned int nr_to_scan;
 552
 553        if (number_cred_unused <= auth_max_cred_cachesize)
 554                return;
 555        diff = number_cred_unused - auth_max_cred_cachesize;
 556        nr_to_scan = 100;
 557        if (diff < nr_to_scan)
 558                nr_to_scan = diff;
 559        rpcauth_cache_do_shrink(nr_to_scan);
 560}
 561
 562/*
 563 * Look up a process' credentials in the authentication cache
 564 */
 565struct rpc_cred *
 566rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
 567                int flags, gfp_t gfp)
 568{
 569        LIST_HEAD(free);
 570        struct rpc_cred_cache *cache = auth->au_credcache;
 571        struct rpc_cred *cred = NULL,
 572                        *entry, *new;
 573        unsigned int nr;
 574
 575        nr = auth->au_ops->hash_cred(acred, cache->hashbits);
 576
 577        rcu_read_lock();
 578        hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
 579                if (!entry->cr_ops->crmatch(acred, entry, flags))
 580                        continue;
 581                cred = get_rpccred(entry);
 582                if (cred)
 583                        break;
 584        }
 585        rcu_read_unlock();
 586
 587        if (cred != NULL)
 588                goto found;
 589
 590        new = auth->au_ops->crcreate(auth, acred, flags, gfp);
 591        if (IS_ERR(new)) {
 592                cred = new;
 593                goto out;
 594        }
 595
 596        spin_lock(&cache->lock);
 597        hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {
 598                if (!entry->cr_ops->crmatch(acred, entry, flags))
 599                        continue;
 600                cred = get_rpccred(entry);
 601                if (cred)
 602                        break;
 603        }
 604        if (cred == NULL) {
 605                cred = new;
 606                set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
 607                refcount_inc(&cred->cr_count);
 608                hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
 609        } else
 610                list_add_tail(&new->cr_lru, &free);
 611        spin_unlock(&cache->lock);
 612        rpcauth_cache_enforce_limit();
 613found:
 614        if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
 615            cred->cr_ops->cr_init != NULL &&
 616            !(flags & RPCAUTH_LOOKUP_NEW)) {
 617                int res = cred->cr_ops->cr_init(auth, cred);
 618                if (res < 0) {
 619                        put_rpccred(cred);
 620                        cred = ERR_PTR(res);
 621                }
 622        }
 623        rpcauth_destroy_credlist(&free);
 624out:
 625        return cred;
 626}
 627EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache);
 628
 629struct rpc_cred *
 630rpcauth_lookupcred(struct rpc_auth *auth, int flags)
 631{
 632        struct auth_cred acred;
 633        struct rpc_cred *ret;
 634        const struct cred *cred = current_cred();
 635
 636        memset(&acred, 0, sizeof(acred));
 637        acred.cred = cred;
 638        ret = auth->au_ops->lookup_cred(auth, &acred, flags);
 639        return ret;
 640}
 641EXPORT_SYMBOL_GPL(rpcauth_lookupcred);
 642
 643void
 644rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
 645                  struct rpc_auth *auth, const struct rpc_credops *ops)
 646{
 647        INIT_HLIST_NODE(&cred->cr_hash);
 648        INIT_LIST_HEAD(&cred->cr_lru);
 649        refcount_set(&cred->cr_count, 1);
 650        cred->cr_auth = auth;
 651        cred->cr_flags = 0;
 652        cred->cr_ops = ops;
 653        cred->cr_expire = jiffies;
 654        cred->cr_cred = get_cred(acred->cred);
 655}
 656EXPORT_SYMBOL_GPL(rpcauth_init_cred);
 657
 658static struct rpc_cred *
 659rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
 660{
 661        struct rpc_auth *auth = task->tk_client->cl_auth;
 662        struct auth_cred acred = {
 663                .cred = get_task_cred(&init_task),
 664        };
 665        struct rpc_cred *ret;
 666
 667        ret = auth->au_ops->lookup_cred(auth, &acred, lookupflags);
 668        put_cred(acred.cred);
 669        return ret;
 670}
 671
 672static struct rpc_cred *
 673rpcauth_bind_machine_cred(struct rpc_task *task, int lookupflags)
 674{
 675        struct rpc_auth *auth = task->tk_client->cl_auth;
 676        struct auth_cred acred = {
 677                .principal = task->tk_client->cl_principal,
 678                .cred = init_task.cred,
 679        };
 680
 681        if (!acred.principal)
 682                return NULL;
 683        return auth->au_ops->lookup_cred(auth, &acred, lookupflags);
 684}
 685
 686static struct rpc_cred *
 687rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
 688{
 689        struct rpc_auth *auth = task->tk_client->cl_auth;
 690
 691        return rpcauth_lookupcred(auth, lookupflags);
 692}
 693
 694static int
 695rpcauth_bindcred(struct rpc_task *task, const struct cred *cred, int flags)
 696{
 697        struct rpc_rqst *req = task->tk_rqstp;
 698        struct rpc_cred *new = NULL;
 699        int lookupflags = 0;
 700        struct rpc_auth *auth = task->tk_client->cl_auth;
 701        struct auth_cred acred = {
 702                .cred = cred,
 703        };
 704
 705        if (flags & RPC_TASK_ASYNC)
 706                lookupflags |= RPCAUTH_LOOKUP_NEW;
 707        if (task->tk_op_cred)
 708                /* Task must use exactly this rpc_cred */
 709                new = get_rpccred(task->tk_op_cred);
 710        else if (cred != NULL && cred != &machine_cred)
 711                new = auth->au_ops->lookup_cred(auth, &acred, lookupflags);
 712        else if (cred == &machine_cred)
 713                new = rpcauth_bind_machine_cred(task, lookupflags);
 714
 715        /* If machine cred couldn't be bound, try a root cred */
 716        if (new)
 717                ;
 718        else if (cred == &machine_cred || (flags & RPC_TASK_ROOTCREDS))
 719                new = rpcauth_bind_root_cred(task, lookupflags);
 720        else if (flags & RPC_TASK_NULLCREDS)
 721                new = authnull_ops.lookup_cred(NULL, NULL, 0);
 722        else
 723                new = rpcauth_bind_new_cred(task, lookupflags);
 724        if (IS_ERR(new))
 725                return PTR_ERR(new);
 726        put_rpccred(req->rq_cred);
 727        req->rq_cred = new;
 728        return 0;
 729}
 730
 731void
 732put_rpccred(struct rpc_cred *cred)
 733{
 734        if (cred == NULL)
 735                return;
 736        rcu_read_lock();
 737        if (refcount_dec_and_test(&cred->cr_count))
 738                goto destroy;
 739        if (refcount_read(&cred->cr_count) != 1 ||
 740            !test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
 741                goto out;
 742        if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) {
 743                cred->cr_expire = jiffies;
 744                rpcauth_lru_add(cred);
 745                /* Race breaker */
 746                if (unlikely(!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)))
 747                        rpcauth_lru_remove(cred);
 748        } else if (rpcauth_unhash_cred(cred)) {
 749                rpcauth_lru_remove(cred);
 750                if (refcount_dec_and_test(&cred->cr_count))
 751                        goto destroy;
 752        }
 753out:
 754        rcu_read_unlock();
 755        return;
 756destroy:
 757        rcu_read_unlock();
 758        cred->cr_ops->crdestroy(cred);
 759}
 760EXPORT_SYMBOL_GPL(put_rpccred);
 761
 762/**
 763 * rpcauth_marshcred - Append RPC credential to end of @xdr
 764 * @task: controlling RPC task
 765 * @xdr: xdr_stream containing initial portion of RPC Call header
 766 *
 767 * On success, an appropriate verifier is added to @xdr, @xdr is
 768 * updated to point past the verifier, and zero is returned.
 769 * Otherwise, @xdr is in an undefined state and a negative errno
 770 * is returned.
 771 */
 772int rpcauth_marshcred(struct rpc_task *task, struct xdr_stream *xdr)
 773{
 774        const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops;
 775
 776        return ops->crmarshal(task, xdr);
 777}
 778
 779/**
 780 * rpcauth_wrap_req_encode - XDR encode the RPC procedure
 781 * @task: controlling RPC task
 782 * @xdr: stream where on-the-wire bytes are to be marshalled
 783 *
 784 * On success, @xdr contains the encoded and wrapped message.
 785 * Otherwise, @xdr is in an undefined state.
 786 */
 787int rpcauth_wrap_req_encode(struct rpc_task *task, struct xdr_stream *xdr)
 788{
 789        kxdreproc_t encode = task->tk_msg.rpc_proc->p_encode;
 790
 791        encode(task->tk_rqstp, xdr, task->tk_msg.rpc_argp);
 792        return 0;
 793}
 794EXPORT_SYMBOL_GPL(rpcauth_wrap_req_encode);
 795
 796/**
 797 * rpcauth_wrap_req - XDR encode and wrap the RPC procedure
 798 * @task: controlling RPC task
 799 * @xdr: stream where on-the-wire bytes are to be marshalled
 800 *
 801 * On success, @xdr contains the encoded and wrapped message,
 802 * and zero is returned. Otherwise, @xdr is in an undefined
 803 * state and a negative errno is returned.
 804 */
 805int rpcauth_wrap_req(struct rpc_task *task, struct xdr_stream *xdr)
 806{
 807        const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops;
 808
 809        return ops->crwrap_req(task, xdr);
 810}
 811
 812/**
 813 * rpcauth_checkverf - Validate verifier in RPC Reply header
 814 * @task: controlling RPC task
 815 * @xdr: xdr_stream containing RPC Reply header
 816 *
 817 * On success, @xdr is updated to point past the verifier and
 818 * zero is returned. Otherwise, @xdr is in an undefined state
 819 * and a negative errno is returned.
 820 */
 821int
 822rpcauth_checkverf(struct rpc_task *task, struct xdr_stream *xdr)
 823{
 824        const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops;
 825
 826        return ops->crvalidate(task, xdr);
 827}
 828
 829/**
 830 * rpcauth_unwrap_resp_decode - Invoke XDR decode function
 831 * @task: controlling RPC task
 832 * @xdr: stream where the Reply message resides
 833 *
 834 * Returns zero on success; otherwise a negative errno is returned.
 835 */
 836int
 837rpcauth_unwrap_resp_decode(struct rpc_task *task, struct xdr_stream *xdr)
 838{
 839        kxdrdproc_t decode = task->tk_msg.rpc_proc->p_decode;
 840
 841        return decode(task->tk_rqstp, xdr, task->tk_msg.rpc_resp);
 842}
 843EXPORT_SYMBOL_GPL(rpcauth_unwrap_resp_decode);
 844
 845/**
 846 * rpcauth_unwrap_resp - Invoke unwrap and decode function for the cred
 847 * @task: controlling RPC task
 848 * @xdr: stream where the Reply message resides
 849 *
 850 * Returns zero on success; otherwise a negative errno is returned.
 851 */
 852int
 853rpcauth_unwrap_resp(struct rpc_task *task, struct xdr_stream *xdr)
 854{
 855        const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops;
 856
 857        return ops->crunwrap_resp(task, xdr);
 858}
 859
 860bool
 861rpcauth_xmit_need_reencode(struct rpc_task *task)
 862{
 863        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 864
 865        if (!cred || !cred->cr_ops->crneed_reencode)
 866                return false;
 867        return cred->cr_ops->crneed_reencode(task);
 868}
 869
 870int
 871rpcauth_refreshcred(struct rpc_task *task)
 872{
 873        struct rpc_cred *cred;
 874        int err;
 875
 876        cred = task->tk_rqstp->rq_cred;
 877        if (cred == NULL) {
 878                err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags);
 879                if (err < 0)
 880                        goto out;
 881                cred = task->tk_rqstp->rq_cred;
 882        }
 883
 884        err = cred->cr_ops->crrefresh(task);
 885out:
 886        if (err < 0)
 887                task->tk_status = err;
 888        return err;
 889}
 890
 891void
 892rpcauth_invalcred(struct rpc_task *task)
 893{
 894        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 895
 896        if (cred)
 897                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
 898}
 899
 900int
 901rpcauth_uptodatecred(struct rpc_task *task)
 902{
 903        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 904
 905        return cred == NULL ||
 906                test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
 907}
 908
 909static struct shrinker rpc_cred_shrinker = {
 910        .count_objects = rpcauth_cache_shrink_count,
 911        .scan_objects = rpcauth_cache_shrink_scan,
 912        .seeks = DEFAULT_SEEKS,
 913};
 914
 915int __init rpcauth_init_module(void)
 916{
 917        int err;
 918
 919        err = rpc_init_authunix();
 920        if (err < 0)
 921                goto out1;
 922        err = register_shrinker(&rpc_cred_shrinker);
 923        if (err < 0)
 924                goto out2;
 925        return 0;
 926out2:
 927        rpc_destroy_authunix();
 928out1:
 929        return err;
 930}
 931
 932void rpcauth_remove_module(void)
 933{
 934        rpc_destroy_authunix();
 935        unregister_shrinker(&rpc_cred_shrinker);
 936}
 937