linux/net/sunrpc/auth.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * linux/net/sunrpc/auth.c
   4 *
   5 * Generic RPC client authentication API.
   6 *
   7 * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
   8 */
   9
  10#include <linux/types.h>
  11#include <linux/sched.h>
  12#include <linux/cred.h>
  13#include <linux/module.h>
  14#include <linux/slab.h>
  15#include <linux/errno.h>
  16#include <linux/hash.h>
  17#include <linux/sunrpc/clnt.h>
  18#include <linux/sunrpc/gss_api.h>
  19#include <linux/spinlock.h>
  20
  21#include <trace/events/sunrpc.h>
  22
  23#define RPC_CREDCACHE_DEFAULT_HASHBITS  (4)
  24struct rpc_cred_cache {
  25        struct hlist_head       *hashtable;
  26        unsigned int            hashbits;
  27        spinlock_t              lock;
  28};
  29
  30static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;
  31
  32static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
  33        [RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops,
  34        [RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops,
  35        NULL,                   /* others can be loadable modules */
  36};
  37
  38static LIST_HEAD(cred_unused);
  39static unsigned long number_cred_unused;
  40
  41static struct cred machine_cred = {
  42        .usage = ATOMIC_INIT(1),
  43#ifdef CONFIG_DEBUG_CREDENTIALS
  44        .magic = CRED_MAGIC,
  45#endif
  46};
  47
  48/*
  49 * Return the machine_cred pointer to be used whenever
  50 * the a generic machine credential is needed.
  51 */
  52const struct cred *rpc_machine_cred(void)
  53{
  54        return &machine_cred;
  55}
  56EXPORT_SYMBOL_GPL(rpc_machine_cred);
  57
  58#define MAX_HASHTABLE_BITS (14)
  59static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
  60{
  61        unsigned long num;
  62        unsigned int nbits;
  63        int ret;
  64
  65        if (!val)
  66                goto out_inval;
  67        ret = kstrtoul(val, 0, &num);
  68        if (ret)
  69                goto out_inval;
  70        nbits = fls(num - 1);
  71        if (nbits > MAX_HASHTABLE_BITS || nbits < 2)
  72                goto out_inval;
  73        *(unsigned int *)kp->arg = nbits;
  74        return 0;
  75out_inval:
  76        return -EINVAL;
  77}
  78
  79static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp)
  80{
  81        unsigned int nbits;
  82
  83        nbits = *(unsigned int *)kp->arg;
  84        return sprintf(buffer, "%u\n", 1U << nbits);
  85}
  86
  87#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int);
  88
  89static const struct kernel_param_ops param_ops_hashtbl_sz = {
  90        .set = param_set_hashtbl_sz,
  91        .get = param_get_hashtbl_sz,
  92};
  93
  94module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644);
  95MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");
  96
  97static unsigned long auth_max_cred_cachesize = ULONG_MAX;
  98module_param(auth_max_cred_cachesize, ulong, 0644);
  99MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size");
 100
 101static u32
 102pseudoflavor_to_flavor(u32 flavor) {
 103        if (flavor > RPC_AUTH_MAXFLAVOR)
 104                return RPC_AUTH_GSS;
 105        return flavor;
 106}
 107
 108int
 109rpcauth_register(const struct rpc_authops *ops)
 110{
 111        const struct rpc_authops *old;
 112        rpc_authflavor_t flavor;
 113
 114        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
 115                return -EINVAL;
 116        old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], NULL, ops);
 117        if (old == NULL || old == ops)
 118                return 0;
 119        return -EPERM;
 120}
 121EXPORT_SYMBOL_GPL(rpcauth_register);
 122
 123int
 124rpcauth_unregister(const struct rpc_authops *ops)
 125{
 126        const struct rpc_authops *old;
 127        rpc_authflavor_t flavor;
 128
 129        if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
 130                return -EINVAL;
 131
 132        old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], ops, NULL);
 133        if (old == ops || old == NULL)
 134                return 0;
 135        return -EPERM;
 136}
 137EXPORT_SYMBOL_GPL(rpcauth_unregister);
 138
 139static const struct rpc_authops *
 140rpcauth_get_authops(rpc_authflavor_t flavor)
 141{
 142        const struct rpc_authops *ops;
 143
 144        if (flavor >= RPC_AUTH_MAXFLAVOR)
 145                return NULL;
 146
 147        rcu_read_lock();
 148        ops = rcu_dereference(auth_flavors[flavor]);
 149        if (ops == NULL) {
 150                rcu_read_unlock();
 151                request_module("rpc-auth-%u", flavor);
 152                rcu_read_lock();
 153                ops = rcu_dereference(auth_flavors[flavor]);
 154                if (ops == NULL)
 155                        goto out;
 156        }
 157        if (!try_module_get(ops->owner))
 158                ops = NULL;
 159out:
 160        rcu_read_unlock();
 161        return ops;
 162}
 163
 164static void
 165rpcauth_put_authops(const struct rpc_authops *ops)
 166{
 167        module_put(ops->owner);
 168}
 169
 170/**
 171 * rpcauth_get_pseudoflavor - check if security flavor is supported
 172 * @flavor: a security flavor
 173 * @info: a GSS mech OID, quality of protection, and service value
 174 *
 175 * Verifies that an appropriate kernel module is available or already loaded.
 176 * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is
 177 * not supported locally.
 178 */
 179rpc_authflavor_t
 180rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info)
 181{
 182        const struct rpc_authops *ops = rpcauth_get_authops(flavor);
 183        rpc_authflavor_t pseudoflavor;
 184
 185        if (!ops)
 186                return RPC_AUTH_MAXFLAVOR;
 187        pseudoflavor = flavor;
 188        if (ops->info2flavor != NULL)
 189                pseudoflavor = ops->info2flavor(info);
 190
 191        rpcauth_put_authops(ops);
 192        return pseudoflavor;
 193}
 194EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor);
 195
 196/**
 197 * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor
 198 * @pseudoflavor: GSS pseudoflavor to match
 199 * @info: rpcsec_gss_info structure to fill in
 200 *
 201 * Returns zero and fills in "info" if pseudoflavor matches a
 202 * supported mechanism.
 203 */
 204int
 205rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info)
 206{
 207        rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor);
 208        const struct rpc_authops *ops;
 209        int result;
 210
 211        ops = rpcauth_get_authops(flavor);
 212        if (ops == NULL)
 213                return -ENOENT;
 214
 215        result = -ENOENT;
 216        if (ops->flavor2info != NULL)
 217                result = ops->flavor2info(pseudoflavor, info);
 218
 219        rpcauth_put_authops(ops);
 220        return result;
 221}
 222EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo);
 223
 224struct rpc_auth *
 225rpcauth_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
 226{
 227        struct rpc_auth *auth = ERR_PTR(-EINVAL);
 228        const struct rpc_authops *ops;
 229        u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor);
 230
 231        ops = rpcauth_get_authops(flavor);
 232        if (ops == NULL)
 233                goto out;
 234
 235        auth = ops->create(args, clnt);
 236
 237        rpcauth_put_authops(ops);
 238        if (IS_ERR(auth))
 239                return auth;
 240        if (clnt->cl_auth)
 241                rpcauth_release(clnt->cl_auth);
 242        clnt->cl_auth = auth;
 243
 244out:
 245        return auth;
 246}
 247EXPORT_SYMBOL_GPL(rpcauth_create);
 248
 249void
 250rpcauth_release(struct rpc_auth *auth)
 251{
 252        if (!refcount_dec_and_test(&auth->au_count))
 253                return;
 254        auth->au_ops->destroy(auth);
 255}
 256
 257static DEFINE_SPINLOCK(rpc_credcache_lock);
 258
 259/*
 260 * On success, the caller is responsible for freeing the reference
 261 * held by the hashtable
 262 */
 263static bool
 264rpcauth_unhash_cred_locked(struct rpc_cred *cred)
 265{
 266        if (!test_and_clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
 267                return false;
 268        hlist_del_rcu(&cred->cr_hash);
 269        return true;
 270}
 271
 272static bool
 273rpcauth_unhash_cred(struct rpc_cred *cred)
 274{
 275        spinlock_t *cache_lock;
 276        bool ret;
 277
 278        if (!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
 279                return false;
 280        cache_lock = &cred->cr_auth->au_credcache->lock;
 281        spin_lock(cache_lock);
 282        ret = rpcauth_unhash_cred_locked(cred);
 283        spin_unlock(cache_lock);
 284        return ret;
 285}
 286
 287/*
 288 * Initialize RPC credential cache
 289 */
 290int
 291rpcauth_init_credcache(struct rpc_auth *auth)
 292{
 293        struct rpc_cred_cache *new;
 294        unsigned int hashsize;
 295
 296        new = kmalloc(sizeof(*new), GFP_KERNEL);
 297        if (!new)
 298                goto out_nocache;
 299        new->hashbits = auth_hashbits;
 300        hashsize = 1U << new->hashbits;
 301        new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL);
 302        if (!new->hashtable)
 303                goto out_nohashtbl;
 304        spin_lock_init(&new->lock);
 305        auth->au_credcache = new;
 306        return 0;
 307out_nohashtbl:
 308        kfree(new);
 309out_nocache:
 310        return -ENOMEM;
 311}
 312EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
 313
 314char *
 315rpcauth_stringify_acceptor(struct rpc_cred *cred)
 316{
 317        if (!cred->cr_ops->crstringify_acceptor)
 318                return NULL;
 319        return cred->cr_ops->crstringify_acceptor(cred);
 320}
 321EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor);
 322
 323/*
 324 * Destroy a list of credentials
 325 */
 326static inline
 327void rpcauth_destroy_credlist(struct list_head *head)
 328{
 329        struct rpc_cred *cred;
 330
 331        while (!list_empty(head)) {
 332                cred = list_entry(head->next, struct rpc_cred, cr_lru);
 333                list_del_init(&cred->cr_lru);
 334                put_rpccred(cred);
 335        }
 336}
 337
 338static void
 339rpcauth_lru_add_locked(struct rpc_cred *cred)
 340{
 341        if (!list_empty(&cred->cr_lru))
 342                return;
 343        number_cred_unused++;
 344        list_add_tail(&cred->cr_lru, &cred_unused);
 345}
 346
 347static void
 348rpcauth_lru_add(struct rpc_cred *cred)
 349{
 350        if (!list_empty(&cred->cr_lru))
 351                return;
 352        spin_lock(&rpc_credcache_lock);
 353        rpcauth_lru_add_locked(cred);
 354        spin_unlock(&rpc_credcache_lock);
 355}
 356
 357static void
 358rpcauth_lru_remove_locked(struct rpc_cred *cred)
 359{
 360        if (list_empty(&cred->cr_lru))
 361                return;
 362        number_cred_unused--;
 363        list_del_init(&cred->cr_lru);
 364}
 365
 366static void
 367rpcauth_lru_remove(struct rpc_cred *cred)
 368{
 369        if (list_empty(&cred->cr_lru))
 370                return;
 371        spin_lock(&rpc_credcache_lock);
 372        rpcauth_lru_remove_locked(cred);
 373        spin_unlock(&rpc_credcache_lock);
 374}
 375
 376/*
 377 * Clear the RPC credential cache, and delete those credentials
 378 * that are not referenced.
 379 */
 380void
 381rpcauth_clear_credcache(struct rpc_cred_cache *cache)
 382{
 383        LIST_HEAD(free);
 384        struct hlist_head *head;
 385        struct rpc_cred *cred;
 386        unsigned int hashsize = 1U << cache->hashbits;
 387        int             i;
 388
 389        spin_lock(&rpc_credcache_lock);
 390        spin_lock(&cache->lock);
 391        for (i = 0; i < hashsize; i++) {
 392                head = &cache->hashtable[i];
 393                while (!hlist_empty(head)) {
 394                        cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
 395                        rpcauth_unhash_cred_locked(cred);
 396                        /* Note: We now hold a reference to cred */
 397                        rpcauth_lru_remove_locked(cred);
 398                        list_add_tail(&cred->cr_lru, &free);
 399                }
 400        }
 401        spin_unlock(&cache->lock);
 402        spin_unlock(&rpc_credcache_lock);
 403        rpcauth_destroy_credlist(&free);
 404}
 405
 406/*
 407 * Destroy the RPC credential cache
 408 */
 409void
 410rpcauth_destroy_credcache(struct rpc_auth *auth)
 411{
 412        struct rpc_cred_cache *cache = auth->au_credcache;
 413
 414        if (cache) {
 415                auth->au_credcache = NULL;
 416                rpcauth_clear_credcache(cache);
 417                kfree(cache->hashtable);
 418                kfree(cache);
 419        }
 420}
 421EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);
 422
 423
 424#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ)
 425
 426/*
 427 * Remove stale credentials. Avoid sleeping inside the loop.
 428 */
 429static long
 430rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
 431{
 432        struct rpc_cred *cred, *next;
 433        unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM;
 434        long freed = 0;
 435
 436        list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) {
 437
 438                if (nr_to_scan-- == 0)
 439                        break;
 440                if (refcount_read(&cred->cr_count) > 1) {
 441                        rpcauth_lru_remove_locked(cred);
 442                        continue;
 443                }
 444                /*
 445                 * Enforce a 60 second garbage collection moratorium
 446                 * Note that the cred_unused list must be time-ordered.
 447                 */
 448                if (!time_in_range(cred->cr_expire, expired, jiffies))
 449                        continue;
 450                if (!rpcauth_unhash_cred(cred))
 451                        continue;
 452
 453                rpcauth_lru_remove_locked(cred);
 454                freed++;
 455                list_add_tail(&cred->cr_lru, free);
 456        }
 457        return freed ? freed : SHRINK_STOP;
 458}
 459
 460static unsigned long
 461rpcauth_cache_do_shrink(int nr_to_scan)
 462{
 463        LIST_HEAD(free);
 464        unsigned long freed;
 465
 466        spin_lock(&rpc_credcache_lock);
 467        freed = rpcauth_prune_expired(&free, nr_to_scan);
 468        spin_unlock(&rpc_credcache_lock);
 469        rpcauth_destroy_credlist(&free);
 470
 471        return freed;
 472}
 473
 474/*
 475 * Run memory cache shrinker.
 476 */
 477static unsigned long
 478rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
 479
 480{
 481        if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL)
 482                return SHRINK_STOP;
 483
 484        /* nothing left, don't come back */
 485        if (list_empty(&cred_unused))
 486                return SHRINK_STOP;
 487
 488        return rpcauth_cache_do_shrink(sc->nr_to_scan);
 489}
 490
 491static unsigned long
 492rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc)
 493
 494{
 495        return number_cred_unused * sysctl_vfs_cache_pressure / 100;
 496}
 497
 498static void
 499rpcauth_cache_enforce_limit(void)
 500{
 501        unsigned long diff;
 502        unsigned int nr_to_scan;
 503
 504        if (number_cred_unused <= auth_max_cred_cachesize)
 505                return;
 506        diff = number_cred_unused - auth_max_cred_cachesize;
 507        nr_to_scan = 100;
 508        if (diff < nr_to_scan)
 509                nr_to_scan = diff;
 510        rpcauth_cache_do_shrink(nr_to_scan);
 511}
 512
 513/*
 514 * Look up a process' credentials in the authentication cache
 515 */
 516struct rpc_cred *
 517rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
 518                int flags, gfp_t gfp)
 519{
 520        LIST_HEAD(free);
 521        struct rpc_cred_cache *cache = auth->au_credcache;
 522        struct rpc_cred *cred = NULL,
 523                        *entry, *new;
 524        unsigned int nr;
 525
 526        nr = auth->au_ops->hash_cred(acred, cache->hashbits);
 527
 528        rcu_read_lock();
 529        hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
 530                if (!entry->cr_ops->crmatch(acred, entry, flags))
 531                        continue;
 532                cred = get_rpccred(entry);
 533                if (cred)
 534                        break;
 535        }
 536        rcu_read_unlock();
 537
 538        if (cred != NULL)
 539                goto found;
 540
 541        new = auth->au_ops->crcreate(auth, acred, flags, gfp);
 542        if (IS_ERR(new)) {
 543                cred = new;
 544                goto out;
 545        }
 546
 547        spin_lock(&cache->lock);
 548        hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {
 549                if (!entry->cr_ops->crmatch(acred, entry, flags))
 550                        continue;
 551                cred = get_rpccred(entry);
 552                if (cred)
 553                        break;
 554        }
 555        if (cred == NULL) {
 556                cred = new;
 557                set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
 558                refcount_inc(&cred->cr_count);
 559                hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
 560        } else
 561                list_add_tail(&new->cr_lru, &free);
 562        spin_unlock(&cache->lock);
 563        rpcauth_cache_enforce_limit();
 564found:
 565        if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
 566            cred->cr_ops->cr_init != NULL &&
 567            !(flags & RPCAUTH_LOOKUP_NEW)) {
 568                int res = cred->cr_ops->cr_init(auth, cred);
 569                if (res < 0) {
 570                        put_rpccred(cred);
 571                        cred = ERR_PTR(res);
 572                }
 573        }
 574        rpcauth_destroy_credlist(&free);
 575out:
 576        return cred;
 577}
 578EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache);
 579
 580struct rpc_cred *
 581rpcauth_lookupcred(struct rpc_auth *auth, int flags)
 582{
 583        struct auth_cred acred;
 584        struct rpc_cred *ret;
 585        const struct cred *cred = current_cred();
 586
 587        memset(&acred, 0, sizeof(acred));
 588        acred.cred = cred;
 589        ret = auth->au_ops->lookup_cred(auth, &acred, flags);
 590        return ret;
 591}
 592EXPORT_SYMBOL_GPL(rpcauth_lookupcred);
 593
 594void
 595rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
 596                  struct rpc_auth *auth, const struct rpc_credops *ops)
 597{
 598        INIT_HLIST_NODE(&cred->cr_hash);
 599        INIT_LIST_HEAD(&cred->cr_lru);
 600        refcount_set(&cred->cr_count, 1);
 601        cred->cr_auth = auth;
 602        cred->cr_flags = 0;
 603        cred->cr_ops = ops;
 604        cred->cr_expire = jiffies;
 605        cred->cr_cred = get_cred(acred->cred);
 606}
 607EXPORT_SYMBOL_GPL(rpcauth_init_cred);
 608
 609static struct rpc_cred *
 610rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
 611{
 612        struct rpc_auth *auth = task->tk_client->cl_auth;
 613        struct auth_cred acred = {
 614                .cred = get_task_cred(&init_task),
 615        };
 616        struct rpc_cred *ret;
 617
 618        if (RPC_IS_ASYNC(task))
 619                lookupflags |= RPCAUTH_LOOKUP_ASYNC;
 620        ret = auth->au_ops->lookup_cred(auth, &acred, lookupflags);
 621        put_cred(acred.cred);
 622        return ret;
 623}
 624
 625static struct rpc_cred *
 626rpcauth_bind_machine_cred(struct rpc_task *task, int lookupflags)
 627{
 628        struct rpc_auth *auth = task->tk_client->cl_auth;
 629        struct auth_cred acred = {
 630                .principal = task->tk_client->cl_principal,
 631                .cred = init_task.cred,
 632        };
 633
 634        if (!acred.principal)
 635                return NULL;
 636        if (RPC_IS_ASYNC(task))
 637                lookupflags |= RPCAUTH_LOOKUP_ASYNC;
 638        return auth->au_ops->lookup_cred(auth, &acred, lookupflags);
 639}
 640
 641static struct rpc_cred *
 642rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
 643{
 644        struct rpc_auth *auth = task->tk_client->cl_auth;
 645
 646        return rpcauth_lookupcred(auth, lookupflags);
 647}
 648
 649static int
 650rpcauth_bindcred(struct rpc_task *task, const struct cred *cred, int flags)
 651{
 652        struct rpc_rqst *req = task->tk_rqstp;
 653        struct rpc_cred *new = NULL;
 654        int lookupflags = 0;
 655        struct rpc_auth *auth = task->tk_client->cl_auth;
 656        struct auth_cred acred = {
 657                .cred = cred,
 658        };
 659
 660        if (flags & RPC_TASK_ASYNC)
 661                lookupflags |= RPCAUTH_LOOKUP_NEW | RPCAUTH_LOOKUP_ASYNC;
 662        if (task->tk_op_cred)
 663                /* Task must use exactly this rpc_cred */
 664                new = get_rpccred(task->tk_op_cred);
 665        else if (cred != NULL && cred != &machine_cred)
 666                new = auth->au_ops->lookup_cred(auth, &acred, lookupflags);
 667        else if (cred == &machine_cred)
 668                new = rpcauth_bind_machine_cred(task, lookupflags);
 669
 670        /* If machine cred couldn't be bound, try a root cred */
 671        if (new)
 672                ;
 673        else if (cred == &machine_cred)
 674                new = rpcauth_bind_root_cred(task, lookupflags);
 675        else if (flags & RPC_TASK_NULLCREDS)
 676                new = authnull_ops.lookup_cred(NULL, NULL, 0);
 677        else
 678                new = rpcauth_bind_new_cred(task, lookupflags);
 679        if (IS_ERR(new))
 680                return PTR_ERR(new);
 681        put_rpccred(req->rq_cred);
 682        req->rq_cred = new;
 683        return 0;
 684}
 685
 686void
 687put_rpccred(struct rpc_cred *cred)
 688{
 689        if (cred == NULL)
 690                return;
 691        rcu_read_lock();
 692        if (refcount_dec_and_test(&cred->cr_count))
 693                goto destroy;
 694        if (refcount_read(&cred->cr_count) != 1 ||
 695            !test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
 696                goto out;
 697        if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) {
 698                cred->cr_expire = jiffies;
 699                rpcauth_lru_add(cred);
 700                /* Race breaker */
 701                if (unlikely(!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)))
 702                        rpcauth_lru_remove(cred);
 703        } else if (rpcauth_unhash_cred(cred)) {
 704                rpcauth_lru_remove(cred);
 705                if (refcount_dec_and_test(&cred->cr_count))
 706                        goto destroy;
 707        }
 708out:
 709        rcu_read_unlock();
 710        return;
 711destroy:
 712        rcu_read_unlock();
 713        cred->cr_ops->crdestroy(cred);
 714}
 715EXPORT_SYMBOL_GPL(put_rpccred);
 716
 717/**
 718 * rpcauth_marshcred - Append RPC credential to end of @xdr
 719 * @task: controlling RPC task
 720 * @xdr: xdr_stream containing initial portion of RPC Call header
 721 *
 722 * On success, an appropriate verifier is added to @xdr, @xdr is
 723 * updated to point past the verifier, and zero is returned.
 724 * Otherwise, @xdr is in an undefined state and a negative errno
 725 * is returned.
 726 */
 727int rpcauth_marshcred(struct rpc_task *task, struct xdr_stream *xdr)
 728{
 729        const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops;
 730
 731        return ops->crmarshal(task, xdr);
 732}
 733
 734/**
 735 * rpcauth_wrap_req_encode - XDR encode the RPC procedure
 736 * @task: controlling RPC task
 737 * @xdr: stream where on-the-wire bytes are to be marshalled
 738 *
 739 * On success, @xdr contains the encoded and wrapped message.
 740 * Otherwise, @xdr is in an undefined state.
 741 */
 742int rpcauth_wrap_req_encode(struct rpc_task *task, struct xdr_stream *xdr)
 743{
 744        kxdreproc_t encode = task->tk_msg.rpc_proc->p_encode;
 745
 746        encode(task->tk_rqstp, xdr, task->tk_msg.rpc_argp);
 747        return 0;
 748}
 749EXPORT_SYMBOL_GPL(rpcauth_wrap_req_encode);
 750
 751/**
 752 * rpcauth_wrap_req - XDR encode and wrap the RPC procedure
 753 * @task: controlling RPC task
 754 * @xdr: stream where on-the-wire bytes are to be marshalled
 755 *
 756 * On success, @xdr contains the encoded and wrapped message,
 757 * and zero is returned. Otherwise, @xdr is in an undefined
 758 * state and a negative errno is returned.
 759 */
 760int rpcauth_wrap_req(struct rpc_task *task, struct xdr_stream *xdr)
 761{
 762        const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops;
 763
 764        return ops->crwrap_req(task, xdr);
 765}
 766
 767/**
 768 * rpcauth_checkverf - Validate verifier in RPC Reply header
 769 * @task: controlling RPC task
 770 * @xdr: xdr_stream containing RPC Reply header
 771 *
 772 * On success, @xdr is updated to point past the verifier and
 773 * zero is returned. Otherwise, @xdr is in an undefined state
 774 * and a negative errno is returned.
 775 */
 776int
 777rpcauth_checkverf(struct rpc_task *task, struct xdr_stream *xdr)
 778{
 779        const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops;
 780
 781        return ops->crvalidate(task, xdr);
 782}
 783
 784/**
 785 * rpcauth_unwrap_resp_decode - Invoke XDR decode function
 786 * @task: controlling RPC task
 787 * @xdr: stream where the Reply message resides
 788 *
 789 * Returns zero on success; otherwise a negative errno is returned.
 790 */
 791int
 792rpcauth_unwrap_resp_decode(struct rpc_task *task, struct xdr_stream *xdr)
 793{
 794        kxdrdproc_t decode = task->tk_msg.rpc_proc->p_decode;
 795
 796        return decode(task->tk_rqstp, xdr, task->tk_msg.rpc_resp);
 797}
 798EXPORT_SYMBOL_GPL(rpcauth_unwrap_resp_decode);
 799
 800/**
 801 * rpcauth_unwrap_resp - Invoke unwrap and decode function for the cred
 802 * @task: controlling RPC task
 803 * @xdr: stream where the Reply message resides
 804 *
 805 * Returns zero on success; otherwise a negative errno is returned.
 806 */
 807int
 808rpcauth_unwrap_resp(struct rpc_task *task, struct xdr_stream *xdr)
 809{
 810        const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops;
 811
 812        return ops->crunwrap_resp(task, xdr);
 813}
 814
 815bool
 816rpcauth_xmit_need_reencode(struct rpc_task *task)
 817{
 818        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 819
 820        if (!cred || !cred->cr_ops->crneed_reencode)
 821                return false;
 822        return cred->cr_ops->crneed_reencode(task);
 823}
 824
 825int
 826rpcauth_refreshcred(struct rpc_task *task)
 827{
 828        struct rpc_cred *cred;
 829        int err;
 830
 831        cred = task->tk_rqstp->rq_cred;
 832        if (cred == NULL) {
 833                err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags);
 834                if (err < 0)
 835                        goto out;
 836                cred = task->tk_rqstp->rq_cred;
 837        }
 838
 839        err = cred->cr_ops->crrefresh(task);
 840out:
 841        if (err < 0)
 842                task->tk_status = err;
 843        return err;
 844}
 845
 846void
 847rpcauth_invalcred(struct rpc_task *task)
 848{
 849        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 850
 851        if (cred)
 852                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
 853}
 854
 855int
 856rpcauth_uptodatecred(struct rpc_task *task)
 857{
 858        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 859
 860        return cred == NULL ||
 861                test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
 862}
 863
 864static struct shrinker rpc_cred_shrinker = {
 865        .count_objects = rpcauth_cache_shrink_count,
 866        .scan_objects = rpcauth_cache_shrink_scan,
 867        .seeks = DEFAULT_SEEKS,
 868};
 869
 870int __init rpcauth_init_module(void)
 871{
 872        int err;
 873
 874        err = rpc_init_authunix();
 875        if (err < 0)
 876                goto out1;
 877        err = register_shrinker(&rpc_cred_shrinker);
 878        if (err < 0)
 879                goto out2;
 880        return 0;
 881out2:
 882        rpc_destroy_authunix();
 883out1:
 884        return err;
 885}
 886
 887void rpcauth_remove_module(void)
 888{
 889        rpc_destroy_authunix();
 890        unregister_shrinker(&rpc_cred_shrinker);
 891}
 892