linux/net/xfrm/xfrm_policy.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0-only
   2/*
   3 * xfrm_policy.c
   4 *
   5 * Changes:
   6 *      Mitsuru KANDA @USAGI
   7 *      Kazunori MIYAZAWA @USAGI
   8 *      Kunihiro Ishiguro <kunihiro@ipinfusion.com>
   9 *              IPv6 support
  10 *      Kazunori MIYAZAWA @USAGI
  11 *      YOSHIFUJI Hideaki
  12 *              Split up af-specific portion
  13 *      Derek Atkins <derek@ihtfp.com>          Add the post_input processor
  14 *
  15 */
  16
  17#include <linux/err.h>
  18#include <linux/slab.h>
  19#include <linux/kmod.h>
  20#include <linux/list.h>
  21#include <linux/spinlock.h>
  22#include <linux/workqueue.h>
  23#include <linux/notifier.h>
  24#include <linux/netdevice.h>
  25#include <linux/netfilter.h>
  26#include <linux/module.h>
  27#include <linux/cache.h>
  28#include <linux/cpu.h>
  29#include <linux/audit.h>
  30#include <linux/rhashtable.h>
  31#include <linux/if_tunnel.h>
  32#include <net/dst.h>
  33#include <net/flow.h>
  34#include <net/xfrm.h>
  35#include <net/ip.h>
  36#if IS_ENABLED(CONFIG_IPV6_MIP6)
  37#include <net/mip6.h>
  38#endif
  39#ifdef CONFIG_XFRM_STATISTICS
  40#include <net/snmp.h>
  41#endif
  42
  43#include "xfrm_hash.h"
  44
  45#define XFRM_QUEUE_TMO_MIN ((unsigned)(HZ/10))
  46#define XFRM_QUEUE_TMO_MAX ((unsigned)(60*HZ))
  47#define XFRM_MAX_QUEUE_LEN      100
  48
  49struct xfrm_flo {
  50        struct dst_entry *dst_orig;
  51        u8 flags;
  52};
  53
  54/* prefixes smaller than this are stored in lists, not trees. */
  55#define INEXACT_PREFIXLEN_IPV4  16
  56#define INEXACT_PREFIXLEN_IPV6  48
  57
  58struct xfrm_pol_inexact_node {
  59        struct rb_node node;
  60        union {
  61                xfrm_address_t addr;
  62                struct rcu_head rcu;
  63        };
  64        u8 prefixlen;
  65
  66        struct rb_root root;
  67
  68        /* the policies matching this node, can be empty list */
  69        struct hlist_head hhead;
  70};
  71
  72/* xfrm inexact policy search tree:
  73 * xfrm_pol_inexact_bin = hash(dir,type,family,if_id);
  74 *  |
  75 * +---- root_d: sorted by daddr:prefix
  76 * |                 |
  77 * |        xfrm_pol_inexact_node
  78 * |                 |
  79 * |                 +- root: sorted by saddr/prefix
  80 * |                 |              |
  81 * |                 |         xfrm_pol_inexact_node
  82 * |                 |              |
  83 * |                 |              + root: unused
  84 * |                 |              |
  85 * |                 |              + hhead: saddr:daddr policies
  86 * |                 |
  87 * |                 +- coarse policies and all any:daddr policies
  88 * |
  89 * +---- root_s: sorted by saddr:prefix
  90 * |                 |
  91 * |        xfrm_pol_inexact_node
  92 * |                 |
  93 * |                 + root: unused
  94 * |                 |
  95 * |                 + hhead: saddr:any policies
  96 * |
  97 * +---- coarse policies and all any:any policies
  98 *
  99 * Lookups return four candidate lists:
 100 * 1. any:any list from top-level xfrm_pol_inexact_bin
 101 * 2. any:daddr list from daddr tree
 102 * 3. saddr:daddr list from 2nd level daddr tree
 103 * 4. saddr:any list from saddr tree
 104 *
 105 * This result set then needs to be searched for the policy with
 106 * the lowest priority.  If two results have same prio, youngest one wins.
 107 */
 108
 109struct xfrm_pol_inexact_key {
 110        possible_net_t net;
 111        u32 if_id;
 112        u16 family;
 113        u8 dir, type;
 114};
 115
 116struct xfrm_pol_inexact_bin {
 117        struct xfrm_pol_inexact_key k;
 118        struct rhash_head head;
 119        /* list containing '*:*' policies */
 120        struct hlist_head hhead;
 121
 122        seqcount_t count;
 123        /* tree sorted by daddr/prefix */
 124        struct rb_root root_d;
 125
 126        /* tree sorted by saddr/prefix */
 127        struct rb_root root_s;
 128
 129        /* slow path below */
 130        struct list_head inexact_bins;
 131        struct rcu_head rcu;
 132};
 133
 134enum xfrm_pol_inexact_candidate_type {
 135        XFRM_POL_CAND_BOTH,
 136        XFRM_POL_CAND_SADDR,
 137        XFRM_POL_CAND_DADDR,
 138        XFRM_POL_CAND_ANY,
 139
 140        XFRM_POL_CAND_MAX,
 141};
 142
 143struct xfrm_pol_inexact_candidates {
 144        struct hlist_head *res[XFRM_POL_CAND_MAX];
 145};
 146
 147static DEFINE_SPINLOCK(xfrm_if_cb_lock);
 148static struct xfrm_if_cb const __rcu *xfrm_if_cb __read_mostly;
 149
 150static DEFINE_SPINLOCK(xfrm_policy_afinfo_lock);
 151static struct xfrm_policy_afinfo const __rcu *xfrm_policy_afinfo[AF_INET6 + 1]
 152                                                __read_mostly;
 153
 154static struct kmem_cache *xfrm_dst_cache __ro_after_init;
 155static __read_mostly seqcount_t xfrm_policy_hash_generation;
 156
 157static struct rhashtable xfrm_policy_inexact_table;
 158static const struct rhashtable_params xfrm_pol_inexact_params;
 159
 160static void xfrm_init_pmtu(struct xfrm_dst **bundle, int nr);
 161static int stale_bundle(struct dst_entry *dst);
 162static int xfrm_bundle_ok(struct xfrm_dst *xdst);
 163static void xfrm_policy_queue_process(struct timer_list *t);
 164
 165static void __xfrm_policy_link(struct xfrm_policy *pol, int dir);
 166static struct xfrm_policy *__xfrm_policy_unlink(struct xfrm_policy *pol,
 167                                                int dir);
 168
 169static struct xfrm_pol_inexact_bin *
 170xfrm_policy_inexact_lookup(struct net *net, u8 type, u16 family, u8 dir,
 171                           u32 if_id);
 172
 173static struct xfrm_pol_inexact_bin *
 174xfrm_policy_inexact_lookup_rcu(struct net *net,
 175                               u8 type, u16 family, u8 dir, u32 if_id);
 176static struct xfrm_policy *
 177xfrm_policy_insert_list(struct hlist_head *chain, struct xfrm_policy *policy,
 178                        bool excl);
 179static void xfrm_policy_insert_inexact_list(struct hlist_head *chain,
 180                                            struct xfrm_policy *policy);
 181
 182static bool
 183xfrm_policy_find_inexact_candidates(struct xfrm_pol_inexact_candidates *cand,
 184                                    struct xfrm_pol_inexact_bin *b,
 185                                    const xfrm_address_t *saddr,
 186                                    const xfrm_address_t *daddr);
 187
 188static inline bool xfrm_pol_hold_rcu(struct xfrm_policy *policy)
 189{
 190        return refcount_inc_not_zero(&policy->refcnt);
 191}
 192
 193static inline bool
 194__xfrm4_selector_match(const struct xfrm_selector *sel, const struct flowi *fl)
 195{
 196        const struct flowi4 *fl4 = &fl->u.ip4;
 197
 198        return  addr4_match(fl4->daddr, sel->daddr.a4, sel->prefixlen_d) &&
 199                addr4_match(fl4->saddr, sel->saddr.a4, sel->prefixlen_s) &&
 200                !((xfrm_flowi_dport(fl, &fl4->uli) ^ sel->dport) & sel->dport_mask) &&
 201                !((xfrm_flowi_sport(fl, &fl4->uli) ^ sel->sport) & sel->sport_mask) &&
 202                (fl4->flowi4_proto == sel->proto || !sel->proto) &&
 203                (fl4->flowi4_oif == sel->ifindex || !sel->ifindex);
 204}
 205
 206static inline bool
 207__xfrm6_selector_match(const struct xfrm_selector *sel, const struct flowi *fl)
 208{
 209        const struct flowi6 *fl6 = &fl->u.ip6;
 210
 211        return  addr_match(&fl6->daddr, &sel->daddr, sel->prefixlen_d) &&
 212                addr_match(&fl6->saddr, &sel->saddr, sel->prefixlen_s) &&
 213                !((xfrm_flowi_dport(fl, &fl6->uli) ^ sel->dport) & sel->dport_mask) &&
 214                !((xfrm_flowi_sport(fl, &fl6->uli) ^ sel->sport) & sel->sport_mask) &&
 215                (fl6->flowi6_proto == sel->proto || !sel->proto) &&
 216                (fl6->flowi6_oif == sel->ifindex || !sel->ifindex);
 217}
 218
 219bool xfrm_selector_match(const struct xfrm_selector *sel, const struct flowi *fl,
 220                         unsigned short family)
 221{
 222        switch (family) {
 223        case AF_INET:
 224                return __xfrm4_selector_match(sel, fl);
 225        case AF_INET6:
 226                return __xfrm6_selector_match(sel, fl);
 227        }
 228        return false;
 229}
 230
 231static const struct xfrm_policy_afinfo *xfrm_policy_get_afinfo(unsigned short family)
 232{
 233        const struct xfrm_policy_afinfo *afinfo;
 234
 235        if (unlikely(family >= ARRAY_SIZE(xfrm_policy_afinfo)))
 236                return NULL;
 237        rcu_read_lock();
 238        afinfo = rcu_dereference(xfrm_policy_afinfo[family]);
 239        if (unlikely(!afinfo))
 240                rcu_read_unlock();
 241        return afinfo;
 242}
 243
 244/* Called with rcu_read_lock(). */
 245static const struct xfrm_if_cb *xfrm_if_get_cb(void)
 246{
 247        return rcu_dereference(xfrm_if_cb);
 248}
 249
 250struct dst_entry *__xfrm_dst_lookup(struct net *net, int tos, int oif,
 251                                    const xfrm_address_t *saddr,
 252                                    const xfrm_address_t *daddr,
 253                                    int family, u32 mark)
 254{
 255        const struct xfrm_policy_afinfo *afinfo;
 256        struct dst_entry *dst;
 257
 258        afinfo = xfrm_policy_get_afinfo(family);
 259        if (unlikely(afinfo == NULL))
 260                return ERR_PTR(-EAFNOSUPPORT);
 261
 262        dst = afinfo->dst_lookup(net, tos, oif, saddr, daddr, mark);
 263
 264        rcu_read_unlock();
 265
 266        return dst;
 267}
 268EXPORT_SYMBOL(__xfrm_dst_lookup);
 269
 270static inline struct dst_entry *xfrm_dst_lookup(struct xfrm_state *x,
 271                                                int tos, int oif,
 272                                                xfrm_address_t *prev_saddr,
 273                                                xfrm_address_t *prev_daddr,
 274                                                int family, u32 mark)
 275{
 276        struct net *net = xs_net(x);
 277        xfrm_address_t *saddr = &x->props.saddr;
 278        xfrm_address_t *daddr = &x->id.daddr;
 279        struct dst_entry *dst;
 280
 281        if (x->type->flags & XFRM_TYPE_LOCAL_COADDR) {
 282                saddr = x->coaddr;
 283                daddr = prev_daddr;
 284        }
 285        if (x->type->flags & XFRM_TYPE_REMOTE_COADDR) {
 286                saddr = prev_saddr;
 287                daddr = x->coaddr;
 288        }
 289
 290        dst = __xfrm_dst_lookup(net, tos, oif, saddr, daddr, family, mark);
 291
 292        if (!IS_ERR(dst)) {
 293                if (prev_saddr != saddr)
 294                        memcpy(prev_saddr, saddr,  sizeof(*prev_saddr));
 295                if (prev_daddr != daddr)
 296                        memcpy(prev_daddr, daddr,  sizeof(*prev_daddr));
 297        }
 298
 299        return dst;
 300}
 301
 302static inline unsigned long make_jiffies(long secs)
 303{
 304        if (secs >= (MAX_SCHEDULE_TIMEOUT-1)/HZ)
 305                return MAX_SCHEDULE_TIMEOUT-1;
 306        else
 307                return secs*HZ;
 308}
 309
 310static void xfrm_policy_timer(struct timer_list *t)
 311{
 312        struct xfrm_policy *xp = from_timer(xp, t, timer);
 313        time64_t now = ktime_get_real_seconds();
 314        time64_t next = TIME64_MAX;
 315        int warn = 0;
 316        int dir;
 317
 318        read_lock(&xp->lock);
 319
 320        if (unlikely(xp->walk.dead))
 321                goto out;
 322
 323        dir = xfrm_policy_id2dir(xp->index);
 324
 325        if (xp->lft.hard_add_expires_seconds) {
 326                time64_t tmo = xp->lft.hard_add_expires_seconds +
 327                        xp->curlft.add_time - now;
 328                if (tmo <= 0)
 329                        goto expired;
 330                if (tmo < next)
 331                        next = tmo;
 332        }
 333        if (xp->lft.hard_use_expires_seconds) {
 334                time64_t tmo = xp->lft.hard_use_expires_seconds +
 335                        (xp->curlft.use_time ? : xp->curlft.add_time) - now;
 336                if (tmo <= 0)
 337                        goto expired;
 338                if (tmo < next)
 339                        next = tmo;
 340        }
 341        if (xp->lft.soft_add_expires_seconds) {
 342                time64_t tmo = xp->lft.soft_add_expires_seconds +
 343                        xp->curlft.add_time - now;
 344                if (tmo <= 0) {
 345                        warn = 1;
 346                        tmo = XFRM_KM_TIMEOUT;
 347                }
 348                if (tmo < next)
 349                        next = tmo;
 350        }
 351        if (xp->lft.soft_use_expires_seconds) {
 352                time64_t tmo = xp->lft.soft_use_expires_seconds +
 353                        (xp->curlft.use_time ? : xp->curlft.add_time) - now;
 354                if (tmo <= 0) {
 355                        warn = 1;
 356                        tmo = XFRM_KM_TIMEOUT;
 357                }
 358                if (tmo < next)
 359                        next = tmo;
 360        }
 361
 362        if (warn)
 363                km_policy_expired(xp, dir, 0, 0);
 364        if (next != TIME64_MAX &&
 365            !mod_timer(&xp->timer, jiffies + make_jiffies(next)))
 366                xfrm_pol_hold(xp);
 367
 368out:
 369        read_unlock(&xp->lock);
 370        xfrm_pol_put(xp);
 371        return;
 372
 373expired:
 374        read_unlock(&xp->lock);
 375        if (!xfrm_policy_delete(xp, dir))
 376                km_policy_expired(xp, dir, 1, 0);
 377        xfrm_pol_put(xp);
 378}
 379
 380/* Allocate xfrm_policy. Not used here, it is supposed to be used by pfkeyv2
 381 * SPD calls.
 382 */
 383
 384struct xfrm_policy *xfrm_policy_alloc(struct net *net, gfp_t gfp)
 385{
 386        struct xfrm_policy *policy;
 387
 388        policy = kzalloc(sizeof(struct xfrm_policy), gfp);
 389
 390        if (policy) {
 391                write_pnet(&policy->xp_net, net);
 392                INIT_LIST_HEAD(&policy->walk.all);
 393                INIT_HLIST_NODE(&policy->bydst_inexact_list);
 394                INIT_HLIST_NODE(&policy->bydst);
 395                INIT_HLIST_NODE(&policy->byidx);
 396                rwlock_init(&policy->lock);
 397                refcount_set(&policy->refcnt, 1);
 398                skb_queue_head_init(&policy->polq.hold_queue);
 399                timer_setup(&policy->timer, xfrm_policy_timer, 0);
 400                timer_setup(&policy->polq.hold_timer,
 401                            xfrm_policy_queue_process, 0);
 402        }
 403        return policy;
 404}
 405EXPORT_SYMBOL(xfrm_policy_alloc);
 406
 407static void xfrm_policy_destroy_rcu(struct rcu_head *head)
 408{
 409        struct xfrm_policy *policy = container_of(head, struct xfrm_policy, rcu);
 410
 411        security_xfrm_policy_free(policy->security);
 412        kfree(policy);
 413}
 414
 415/* Destroy xfrm_policy: descendant resources must be released to this moment. */
 416
 417void xfrm_policy_destroy(struct xfrm_policy *policy)
 418{
 419        BUG_ON(!policy->walk.dead);
 420
 421        if (del_timer(&policy->timer) || del_timer(&policy->polq.hold_timer))
 422                BUG();
 423
 424        call_rcu(&policy->rcu, xfrm_policy_destroy_rcu);
 425}
 426EXPORT_SYMBOL(xfrm_policy_destroy);
 427
 428/* Rule must be locked. Release descendant resources, announce
 429 * entry dead. The rule must be unlinked from lists to the moment.
 430 */
 431
 432static void xfrm_policy_kill(struct xfrm_policy *policy)
 433{
 434        policy->walk.dead = 1;
 435
 436        atomic_inc(&policy->genid);
 437
 438        if (del_timer(&policy->polq.hold_timer))
 439                xfrm_pol_put(policy);
 440        skb_queue_purge(&policy->polq.hold_queue);
 441
 442        if (del_timer(&policy->timer))
 443                xfrm_pol_put(policy);
 444
 445        xfrm_pol_put(policy);
 446}
 447
 448static unsigned int xfrm_policy_hashmax __read_mostly = 1 * 1024 * 1024;
 449
 450static inline unsigned int idx_hash(struct net *net, u32 index)
 451{
 452        return __idx_hash(index, net->xfrm.policy_idx_hmask);
 453}
 454
 455/* calculate policy hash thresholds */
 456static void __get_hash_thresh(struct net *net,
 457                              unsigned short family, int dir,
 458                              u8 *dbits, u8 *sbits)
 459{
 460        switch (family) {
 461        case AF_INET:
 462                *dbits = net->xfrm.policy_bydst[dir].dbits4;
 463                *sbits = net->xfrm.policy_bydst[dir].sbits4;
 464                break;
 465
 466        case AF_INET6:
 467                *dbits = net->xfrm.policy_bydst[dir].dbits6;
 468                *sbits = net->xfrm.policy_bydst[dir].sbits6;
 469                break;
 470
 471        default:
 472                *dbits = 0;
 473                *sbits = 0;
 474        }
 475}
 476
 477static struct hlist_head *policy_hash_bysel(struct net *net,
 478                                            const struct xfrm_selector *sel,
 479                                            unsigned short family, int dir)
 480{
 481        unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
 482        unsigned int hash;
 483        u8 dbits;
 484        u8 sbits;
 485
 486        __get_hash_thresh(net, family, dir, &dbits, &sbits);
 487        hash = __sel_hash(sel, family, hmask, dbits, sbits);
 488
 489        if (hash == hmask + 1)
 490                return NULL;
 491
 492        return rcu_dereference_check(net->xfrm.policy_bydst[dir].table,
 493                     lockdep_is_held(&net->xfrm.xfrm_policy_lock)) + hash;
 494}
 495
 496static struct hlist_head *policy_hash_direct(struct net *net,
 497                                             const xfrm_address_t *daddr,
 498                                             const xfrm_address_t *saddr,
 499                                             unsigned short family, int dir)
 500{
 501        unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
 502        unsigned int hash;
 503        u8 dbits;
 504        u8 sbits;
 505
 506        __get_hash_thresh(net, family, dir, &dbits, &sbits);
 507        hash = __addr_hash(daddr, saddr, family, hmask, dbits, sbits);
 508
 509        return rcu_dereference_check(net->xfrm.policy_bydst[dir].table,
 510                     lockdep_is_held(&net->xfrm.xfrm_policy_lock)) + hash;
 511}
 512
 513static void xfrm_dst_hash_transfer(struct net *net,
 514                                   struct hlist_head *list,
 515                                   struct hlist_head *ndsttable,
 516                                   unsigned int nhashmask,
 517                                   int dir)
 518{
 519        struct hlist_node *tmp, *entry0 = NULL;
 520        struct xfrm_policy *pol;
 521        unsigned int h0 = 0;
 522        u8 dbits;
 523        u8 sbits;
 524
 525redo:
 526        hlist_for_each_entry_safe(pol, tmp, list, bydst) {
 527                unsigned int h;
 528
 529                __get_hash_thresh(net, pol->family, dir, &dbits, &sbits);
 530                h = __addr_hash(&pol->selector.daddr, &pol->selector.saddr,
 531                                pol->family, nhashmask, dbits, sbits);
 532                if (!entry0) {
 533                        hlist_del_rcu(&pol->bydst);
 534                        hlist_add_head_rcu(&pol->bydst, ndsttable + h);
 535                        h0 = h;
 536                } else {
 537                        if (h != h0)
 538                                continue;
 539                        hlist_del_rcu(&pol->bydst);
 540                        hlist_add_behind_rcu(&pol->bydst, entry0);
 541                }
 542                entry0 = &pol->bydst;
 543        }
 544        if (!hlist_empty(list)) {
 545                entry0 = NULL;
 546                goto redo;
 547        }
 548}
 549
 550static void xfrm_idx_hash_transfer(struct hlist_head *list,
 551                                   struct hlist_head *nidxtable,
 552                                   unsigned int nhashmask)
 553{
 554        struct hlist_node *tmp;
 555        struct xfrm_policy *pol;
 556
 557        hlist_for_each_entry_safe(pol, tmp, list, byidx) {
 558                unsigned int h;
 559
 560                h = __idx_hash(pol->index, nhashmask);
 561                hlist_add_head(&pol->byidx, nidxtable+h);
 562        }
 563}
 564
 565static unsigned long xfrm_new_hash_mask(unsigned int old_hmask)
 566{
 567        return ((old_hmask + 1) << 1) - 1;
 568}
 569
 570static void xfrm_bydst_resize(struct net *net, int dir)
 571{
 572        unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
 573        unsigned int nhashmask = xfrm_new_hash_mask(hmask);
 574        unsigned int nsize = (nhashmask + 1) * sizeof(struct hlist_head);
 575        struct hlist_head *ndst = xfrm_hash_alloc(nsize);
 576        struct hlist_head *odst;
 577        int i;
 578
 579        if (!ndst)
 580                return;
 581
 582        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
 583        write_seqcount_begin(&xfrm_policy_hash_generation);
 584
 585        odst = rcu_dereference_protected(net->xfrm.policy_bydst[dir].table,
 586                                lockdep_is_held(&net->xfrm.xfrm_policy_lock));
 587
 588        for (i = hmask; i >= 0; i--)
 589                xfrm_dst_hash_transfer(net, odst + i, ndst, nhashmask, dir);
 590
 591        rcu_assign_pointer(net->xfrm.policy_bydst[dir].table, ndst);
 592        net->xfrm.policy_bydst[dir].hmask = nhashmask;
 593
 594        write_seqcount_end(&xfrm_policy_hash_generation);
 595        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
 596
 597        synchronize_rcu();
 598
 599        xfrm_hash_free(odst, (hmask + 1) * sizeof(struct hlist_head));
 600}
 601
 602static void xfrm_byidx_resize(struct net *net, int total)
 603{
 604        unsigned int hmask = net->xfrm.policy_idx_hmask;
 605        unsigned int nhashmask = xfrm_new_hash_mask(hmask);
 606        unsigned int nsize = (nhashmask + 1) * sizeof(struct hlist_head);
 607        struct hlist_head *oidx = net->xfrm.policy_byidx;
 608        struct hlist_head *nidx = xfrm_hash_alloc(nsize);
 609        int i;
 610
 611        if (!nidx)
 612                return;
 613
 614        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
 615
 616        for (i = hmask; i >= 0; i--)
 617                xfrm_idx_hash_transfer(oidx + i, nidx, nhashmask);
 618
 619        net->xfrm.policy_byidx = nidx;
 620        net->xfrm.policy_idx_hmask = nhashmask;
 621
 622        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
 623
 624        xfrm_hash_free(oidx, (hmask + 1) * sizeof(struct hlist_head));
 625}
 626
 627static inline int xfrm_bydst_should_resize(struct net *net, int dir, int *total)
 628{
 629        unsigned int cnt = net->xfrm.policy_count[dir];
 630        unsigned int hmask = net->xfrm.policy_bydst[dir].hmask;
 631
 632        if (total)
 633                *total += cnt;
 634
 635        if ((hmask + 1) < xfrm_policy_hashmax &&
 636            cnt > hmask)
 637                return 1;
 638
 639        return 0;
 640}
 641
 642static inline int xfrm_byidx_should_resize(struct net *net, int total)
 643{
 644        unsigned int hmask = net->xfrm.policy_idx_hmask;
 645
 646        if ((hmask + 1) < xfrm_policy_hashmax &&
 647            total > hmask)
 648                return 1;
 649
 650        return 0;
 651}
 652
 653void xfrm_spd_getinfo(struct net *net, struct xfrmk_spdinfo *si)
 654{
 655        si->incnt = net->xfrm.policy_count[XFRM_POLICY_IN];
 656        si->outcnt = net->xfrm.policy_count[XFRM_POLICY_OUT];
 657        si->fwdcnt = net->xfrm.policy_count[XFRM_POLICY_FWD];
 658        si->inscnt = net->xfrm.policy_count[XFRM_POLICY_IN+XFRM_POLICY_MAX];
 659        si->outscnt = net->xfrm.policy_count[XFRM_POLICY_OUT+XFRM_POLICY_MAX];
 660        si->fwdscnt = net->xfrm.policy_count[XFRM_POLICY_FWD+XFRM_POLICY_MAX];
 661        si->spdhcnt = net->xfrm.policy_idx_hmask;
 662        si->spdhmcnt = xfrm_policy_hashmax;
 663}
 664EXPORT_SYMBOL(xfrm_spd_getinfo);
 665
 666static DEFINE_MUTEX(hash_resize_mutex);
 667static void xfrm_hash_resize(struct work_struct *work)
 668{
 669        struct net *net = container_of(work, struct net, xfrm.policy_hash_work);
 670        int dir, total;
 671
 672        mutex_lock(&hash_resize_mutex);
 673
 674        total = 0;
 675        for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
 676                if (xfrm_bydst_should_resize(net, dir, &total))
 677                        xfrm_bydst_resize(net, dir);
 678        }
 679        if (xfrm_byidx_should_resize(net, total))
 680                xfrm_byidx_resize(net, total);
 681
 682        mutex_unlock(&hash_resize_mutex);
 683}
 684
 685/* Make sure *pol can be inserted into fastbin.
 686 * Useful to check that later insert requests will be sucessful
 687 * (provided xfrm_policy_lock is held throughout).
 688 */
 689static struct xfrm_pol_inexact_bin *
 690xfrm_policy_inexact_alloc_bin(const struct xfrm_policy *pol, u8 dir)
 691{
 692        struct xfrm_pol_inexact_bin *bin, *prev;
 693        struct xfrm_pol_inexact_key k = {
 694                .family = pol->family,
 695                .type = pol->type,
 696                .dir = dir,
 697                .if_id = pol->if_id,
 698        };
 699        struct net *net = xp_net(pol);
 700
 701        lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
 702
 703        write_pnet(&k.net, net);
 704        bin = rhashtable_lookup_fast(&xfrm_policy_inexact_table, &k,
 705                                     xfrm_pol_inexact_params);
 706        if (bin)
 707                return bin;
 708
 709        bin = kzalloc(sizeof(*bin), GFP_ATOMIC);
 710        if (!bin)
 711                return NULL;
 712
 713        bin->k = k;
 714        INIT_HLIST_HEAD(&bin->hhead);
 715        bin->root_d = RB_ROOT;
 716        bin->root_s = RB_ROOT;
 717        seqcount_init(&bin->count);
 718
 719        prev = rhashtable_lookup_get_insert_key(&xfrm_policy_inexact_table,
 720                                                &bin->k, &bin->head,
 721                                                xfrm_pol_inexact_params);
 722        if (!prev) {
 723                list_add(&bin->inexact_bins, &net->xfrm.inexact_bins);
 724                return bin;
 725        }
 726
 727        kfree(bin);
 728
 729        return IS_ERR(prev) ? NULL : prev;
 730}
 731
 732static bool xfrm_pol_inexact_addr_use_any_list(const xfrm_address_t *addr,
 733                                               int family, u8 prefixlen)
 734{
 735        if (xfrm_addr_any(addr, family))
 736                return true;
 737
 738        if (family == AF_INET6 && prefixlen < INEXACT_PREFIXLEN_IPV6)
 739                return true;
 740
 741        if (family == AF_INET && prefixlen < INEXACT_PREFIXLEN_IPV4)
 742                return true;
 743
 744        return false;
 745}
 746
 747static bool
 748xfrm_policy_inexact_insert_use_any_list(const struct xfrm_policy *policy)
 749{
 750        const xfrm_address_t *addr;
 751        bool saddr_any, daddr_any;
 752        u8 prefixlen;
 753
 754        addr = &policy->selector.saddr;
 755        prefixlen = policy->selector.prefixlen_s;
 756
 757        saddr_any = xfrm_pol_inexact_addr_use_any_list(addr,
 758                                                       policy->family,
 759                                                       prefixlen);
 760        addr = &policy->selector.daddr;
 761        prefixlen = policy->selector.prefixlen_d;
 762        daddr_any = xfrm_pol_inexact_addr_use_any_list(addr,
 763                                                       policy->family,
 764                                                       prefixlen);
 765        return saddr_any && daddr_any;
 766}
 767
 768static void xfrm_pol_inexact_node_init(struct xfrm_pol_inexact_node *node,
 769                                       const xfrm_address_t *addr, u8 prefixlen)
 770{
 771        node->addr = *addr;
 772        node->prefixlen = prefixlen;
 773}
 774
 775static struct xfrm_pol_inexact_node *
 776xfrm_pol_inexact_node_alloc(const xfrm_address_t *addr, u8 prefixlen)
 777{
 778        struct xfrm_pol_inexact_node *node;
 779
 780        node = kzalloc(sizeof(*node), GFP_ATOMIC);
 781        if (node)
 782                xfrm_pol_inexact_node_init(node, addr, prefixlen);
 783
 784        return node;
 785}
 786
 787static int xfrm_policy_addr_delta(const xfrm_address_t *a,
 788                                  const xfrm_address_t *b,
 789                                  u8 prefixlen, u16 family)
 790{
 791        unsigned int pdw, pbi;
 792        int delta = 0;
 793
 794        switch (family) {
 795        case AF_INET:
 796                if (sizeof(long) == 4 && prefixlen == 0)
 797                        return ntohl(a->a4) - ntohl(b->a4);
 798                return (ntohl(a->a4) & ((~0UL << (32 - prefixlen)))) -
 799                       (ntohl(b->a4) & ((~0UL << (32 - prefixlen))));
 800        case AF_INET6:
 801                pdw = prefixlen >> 5;
 802                pbi = prefixlen & 0x1f;
 803
 804                if (pdw) {
 805                        delta = memcmp(a->a6, b->a6, pdw << 2);
 806                        if (delta)
 807                                return delta;
 808                }
 809                if (pbi) {
 810                        u32 mask = ~0u << (32 - pbi);
 811
 812                        delta = (ntohl(a->a6[pdw]) & mask) -
 813                                (ntohl(b->a6[pdw]) & mask);
 814                }
 815                break;
 816        default:
 817                break;
 818        }
 819
 820        return delta;
 821}
 822
 823static void xfrm_policy_inexact_list_reinsert(struct net *net,
 824                                              struct xfrm_pol_inexact_node *n,
 825                                              u16 family)
 826{
 827        unsigned int matched_s, matched_d;
 828        struct xfrm_policy *policy, *p;
 829
 830        matched_s = 0;
 831        matched_d = 0;
 832
 833        list_for_each_entry_reverse(policy, &net->xfrm.policy_all, walk.all) {
 834                struct hlist_node *newpos = NULL;
 835                bool matches_s, matches_d;
 836
 837                if (!policy->bydst_reinsert)
 838                        continue;
 839
 840                WARN_ON_ONCE(policy->family != family);
 841
 842                policy->bydst_reinsert = false;
 843                hlist_for_each_entry(p, &n->hhead, bydst) {
 844                        if (policy->priority > p->priority)
 845                                newpos = &p->bydst;
 846                        else if (policy->priority == p->priority &&
 847                                 policy->pos > p->pos)
 848                                newpos = &p->bydst;
 849                        else
 850                                break;
 851                }
 852
 853                if (newpos)
 854                        hlist_add_behind_rcu(&policy->bydst, newpos);
 855                else
 856                        hlist_add_head_rcu(&policy->bydst, &n->hhead);
 857
 858                /* paranoia checks follow.
 859                 * Check that the reinserted policy matches at least
 860                 * saddr or daddr for current node prefix.
 861                 *
 862                 * Matching both is fine, matching saddr in one policy
 863                 * (but not daddr) and then matching only daddr in another
 864                 * is a bug.
 865                 */
 866                matches_s = xfrm_policy_addr_delta(&policy->selector.saddr,
 867                                                   &n->addr,
 868                                                   n->prefixlen,
 869                                                   family) == 0;
 870                matches_d = xfrm_policy_addr_delta(&policy->selector.daddr,
 871                                                   &n->addr,
 872                                                   n->prefixlen,
 873                                                   family) == 0;
 874                if (matches_s && matches_d)
 875                        continue;
 876
 877                WARN_ON_ONCE(!matches_s && !matches_d);
 878                if (matches_s)
 879                        matched_s++;
 880                if (matches_d)
 881                        matched_d++;
 882                WARN_ON_ONCE(matched_s && matched_d);
 883        }
 884}
 885
 886static void xfrm_policy_inexact_node_reinsert(struct net *net,
 887                                              struct xfrm_pol_inexact_node *n,
 888                                              struct rb_root *new,
 889                                              u16 family)
 890{
 891        struct xfrm_pol_inexact_node *node;
 892        struct rb_node **p, *parent;
 893
 894        /* we should not have another subtree here */
 895        WARN_ON_ONCE(!RB_EMPTY_ROOT(&n->root));
 896restart:
 897        parent = NULL;
 898        p = &new->rb_node;
 899        while (*p) {
 900                u8 prefixlen;
 901                int delta;
 902
 903                parent = *p;
 904                node = rb_entry(*p, struct xfrm_pol_inexact_node, node);
 905
 906                prefixlen = min(node->prefixlen, n->prefixlen);
 907
 908                delta = xfrm_policy_addr_delta(&n->addr, &node->addr,
 909                                               prefixlen, family);
 910                if (delta < 0) {
 911                        p = &parent->rb_left;
 912                } else if (delta > 0) {
 913                        p = &parent->rb_right;
 914                } else {
 915                        bool same_prefixlen = node->prefixlen == n->prefixlen;
 916                        struct xfrm_policy *tmp;
 917
 918                        hlist_for_each_entry(tmp, &n->hhead, bydst) {
 919                                tmp->bydst_reinsert = true;
 920                                hlist_del_rcu(&tmp->bydst);
 921                        }
 922
 923                        node->prefixlen = prefixlen;
 924
 925                        xfrm_policy_inexact_list_reinsert(net, node, family);
 926
 927                        if (same_prefixlen) {
 928                                kfree_rcu(n, rcu);
 929                                return;
 930                        }
 931
 932                        rb_erase(*p, new);
 933                        kfree_rcu(n, rcu);
 934                        n = node;
 935                        goto restart;
 936                }
 937        }
 938
 939        rb_link_node_rcu(&n->node, parent, p);
 940        rb_insert_color(&n->node, new);
 941}
 942
 943/* merge nodes v and n */
 944static void xfrm_policy_inexact_node_merge(struct net *net,
 945                                           struct xfrm_pol_inexact_node *v,
 946                                           struct xfrm_pol_inexact_node *n,
 947                                           u16 family)
 948{
 949        struct xfrm_pol_inexact_node *node;
 950        struct xfrm_policy *tmp;
 951        struct rb_node *rnode;
 952
 953        /* To-be-merged node v has a subtree.
 954         *
 955         * Dismantle it and insert its nodes to n->root.
 956         */
 957        while ((rnode = rb_first(&v->root)) != NULL) {
 958                node = rb_entry(rnode, struct xfrm_pol_inexact_node, node);
 959                rb_erase(&node->node, &v->root);
 960                xfrm_policy_inexact_node_reinsert(net, node, &n->root,
 961                                                  family);
 962        }
 963
 964        hlist_for_each_entry(tmp, &v->hhead, bydst) {
 965                tmp->bydst_reinsert = true;
 966                hlist_del_rcu(&tmp->bydst);
 967        }
 968
 969        xfrm_policy_inexact_list_reinsert(net, n, family);
 970}
 971
 972static struct xfrm_pol_inexact_node *
 973xfrm_policy_inexact_insert_node(struct net *net,
 974                                struct rb_root *root,
 975                                xfrm_address_t *addr,
 976                                u16 family, u8 prefixlen, u8 dir)
 977{
 978        struct xfrm_pol_inexact_node *cached = NULL;
 979        struct rb_node **p, *parent = NULL;
 980        struct xfrm_pol_inexact_node *node;
 981
 982        p = &root->rb_node;
 983        while (*p) {
 984                int delta;
 985
 986                parent = *p;
 987                node = rb_entry(*p, struct xfrm_pol_inexact_node, node);
 988
 989                delta = xfrm_policy_addr_delta(addr, &node->addr,
 990                                               node->prefixlen,
 991                                               family);
 992                if (delta == 0 && prefixlen >= node->prefixlen) {
 993                        WARN_ON_ONCE(cached); /* ipsec policies got lost */
 994                        return node;
 995                }
 996
 997                if (delta < 0)
 998                        p = &parent->rb_left;
 999                else
1000                        p = &parent->rb_right;
1001
1002                if (prefixlen < node->prefixlen) {
1003                        delta = xfrm_policy_addr_delta(addr, &node->addr,
1004                                                       prefixlen,
1005                                                       family);
1006                        if (delta)
1007                                continue;
1008
1009                        /* This node is a subnet of the new prefix. It needs
1010                         * to be removed and re-inserted with the smaller
1011                         * prefix and all nodes that are now also covered
1012                         * by the reduced prefixlen.
1013                         */
1014                        rb_erase(&node->node, root);
1015
1016                        if (!cached) {
1017                                xfrm_pol_inexact_node_init(node, addr,
1018                                                           prefixlen);
1019                                cached = node;
1020                        } else {
1021                                /* This node also falls within the new
1022                                 * prefixlen. Merge the to-be-reinserted
1023                                 * node and this one.
1024                                 */
1025                                xfrm_policy_inexact_node_merge(net, node,
1026                                                               cached, family);
1027                                kfree_rcu(node, rcu);
1028                        }
1029
1030                        /* restart */
1031                        p = &root->rb_node;
1032                        parent = NULL;
1033                }
1034        }
1035
1036        node = cached;
1037        if (!node) {
1038                node = xfrm_pol_inexact_node_alloc(addr, prefixlen);
1039                if (!node)
1040                        return NULL;
1041        }
1042
1043        rb_link_node_rcu(&node->node, parent, p);
1044        rb_insert_color(&node->node, root);
1045
1046        return node;
1047}
1048
1049static void xfrm_policy_inexact_gc_tree(struct rb_root *r, bool rm)
1050{
1051        struct xfrm_pol_inexact_node *node;
1052        struct rb_node *rn = rb_first(r);
1053
1054        while (rn) {
1055                node = rb_entry(rn, struct xfrm_pol_inexact_node, node);
1056
1057                xfrm_policy_inexact_gc_tree(&node->root, rm);
1058                rn = rb_next(rn);
1059
1060                if (!hlist_empty(&node->hhead) || !RB_EMPTY_ROOT(&node->root)) {
1061                        WARN_ON_ONCE(rm);
1062                        continue;
1063                }
1064
1065                rb_erase(&node->node, r);
1066                kfree_rcu(node, rcu);
1067        }
1068}
1069
1070static void __xfrm_policy_inexact_prune_bin(struct xfrm_pol_inexact_bin *b, bool net_exit)
1071{
1072        write_seqcount_begin(&b->count);
1073        xfrm_policy_inexact_gc_tree(&b->root_d, net_exit);
1074        xfrm_policy_inexact_gc_tree(&b->root_s, net_exit);
1075        write_seqcount_end(&b->count);
1076
1077        if (!RB_EMPTY_ROOT(&b->root_d) || !RB_EMPTY_ROOT(&b->root_s) ||
1078            !hlist_empty(&b->hhead)) {
1079                WARN_ON_ONCE(net_exit);
1080                return;
1081        }
1082
1083        if (rhashtable_remove_fast(&xfrm_policy_inexact_table, &b->head,
1084                                   xfrm_pol_inexact_params) == 0) {
1085                list_del(&b->inexact_bins);
1086                kfree_rcu(b, rcu);
1087        }
1088}
1089
1090static void xfrm_policy_inexact_prune_bin(struct xfrm_pol_inexact_bin *b)
1091{
1092        struct net *net = read_pnet(&b->k.net);
1093
1094        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1095        __xfrm_policy_inexact_prune_bin(b, false);
1096        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1097}
1098
1099static void __xfrm_policy_inexact_flush(struct net *net)
1100{
1101        struct xfrm_pol_inexact_bin *bin, *t;
1102
1103        lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
1104
1105        list_for_each_entry_safe(bin, t, &net->xfrm.inexact_bins, inexact_bins)
1106                __xfrm_policy_inexact_prune_bin(bin, false);
1107}
1108
1109static struct hlist_head *
1110xfrm_policy_inexact_alloc_chain(struct xfrm_pol_inexact_bin *bin,
1111                                struct xfrm_policy *policy, u8 dir)
1112{
1113        struct xfrm_pol_inexact_node *n;
1114        struct net *net;
1115
1116        net = xp_net(policy);
1117        lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
1118
1119        if (xfrm_policy_inexact_insert_use_any_list(policy))
1120                return &bin->hhead;
1121
1122        if (xfrm_pol_inexact_addr_use_any_list(&policy->selector.daddr,
1123                                               policy->family,
1124                                               policy->selector.prefixlen_d)) {
1125                write_seqcount_begin(&bin->count);
1126                n = xfrm_policy_inexact_insert_node(net,
1127                                                    &bin->root_s,
1128                                                    &policy->selector.saddr,
1129                                                    policy->family,
1130                                                    policy->selector.prefixlen_s,
1131                                                    dir);
1132                write_seqcount_end(&bin->count);
1133                if (!n)
1134                        return NULL;
1135
1136                return &n->hhead;
1137        }
1138
1139        /* daddr is fixed */
1140        write_seqcount_begin(&bin->count);
1141        n = xfrm_policy_inexact_insert_node(net,
1142                                            &bin->root_d,
1143                                            &policy->selector.daddr,
1144                                            policy->family,
1145                                            policy->selector.prefixlen_d, dir);
1146        write_seqcount_end(&bin->count);
1147        if (!n)
1148                return NULL;
1149
1150        /* saddr is wildcard */
1151        if (xfrm_pol_inexact_addr_use_any_list(&policy->selector.saddr,
1152                                               policy->family,
1153                                               policy->selector.prefixlen_s))
1154                return &n->hhead;
1155
1156        write_seqcount_begin(&bin->count);
1157        n = xfrm_policy_inexact_insert_node(net,
1158                                            &n->root,
1159                                            &policy->selector.saddr,
1160                                            policy->family,
1161                                            policy->selector.prefixlen_s, dir);
1162        write_seqcount_end(&bin->count);
1163        if (!n)
1164                return NULL;
1165
1166        return &n->hhead;
1167}
1168
1169static struct xfrm_policy *
1170xfrm_policy_inexact_insert(struct xfrm_policy *policy, u8 dir, int excl)
1171{
1172        struct xfrm_pol_inexact_bin *bin;
1173        struct xfrm_policy *delpol;
1174        struct hlist_head *chain;
1175        struct net *net;
1176
1177        bin = xfrm_policy_inexact_alloc_bin(policy, dir);
1178        if (!bin)
1179                return ERR_PTR(-ENOMEM);
1180
1181        net = xp_net(policy);
1182        lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
1183
1184        chain = xfrm_policy_inexact_alloc_chain(bin, policy, dir);
1185        if (!chain) {
1186                __xfrm_policy_inexact_prune_bin(bin, false);
1187                return ERR_PTR(-ENOMEM);
1188        }
1189
1190        delpol = xfrm_policy_insert_list(chain, policy, excl);
1191        if (delpol && excl) {
1192                __xfrm_policy_inexact_prune_bin(bin, false);
1193                return ERR_PTR(-EEXIST);
1194        }
1195
1196        chain = &net->xfrm.policy_inexact[dir];
1197        xfrm_policy_insert_inexact_list(chain, policy);
1198
1199        if (delpol)
1200                __xfrm_policy_inexact_prune_bin(bin, false);
1201
1202        return delpol;
1203}
1204
1205static void xfrm_hash_rebuild(struct work_struct *work)
1206{
1207        struct net *net = container_of(work, struct net,
1208                                       xfrm.policy_hthresh.work);
1209        unsigned int hmask;
1210        struct xfrm_policy *pol;
1211        struct xfrm_policy *policy;
1212        struct hlist_head *chain;
1213        struct hlist_head *odst;
1214        struct hlist_node *newpos;
1215        int i;
1216        int dir;
1217        unsigned seq;
1218        u8 lbits4, rbits4, lbits6, rbits6;
1219
1220        mutex_lock(&hash_resize_mutex);
1221
1222        /* read selector prefixlen thresholds */
1223        do {
1224                seq = read_seqbegin(&net->xfrm.policy_hthresh.lock);
1225
1226                lbits4 = net->xfrm.policy_hthresh.lbits4;
1227                rbits4 = net->xfrm.policy_hthresh.rbits4;
1228                lbits6 = net->xfrm.policy_hthresh.lbits6;
1229                rbits6 = net->xfrm.policy_hthresh.rbits6;
1230        } while (read_seqretry(&net->xfrm.policy_hthresh.lock, seq));
1231
1232        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1233        write_seqcount_begin(&xfrm_policy_hash_generation);
1234
1235        /* make sure that we can insert the indirect policies again before
1236         * we start with destructive action.
1237         */
1238        list_for_each_entry(policy, &net->xfrm.policy_all, walk.all) {
1239                struct xfrm_pol_inexact_bin *bin;
1240                u8 dbits, sbits;
1241
1242                dir = xfrm_policy_id2dir(policy->index);
1243                if (policy->walk.dead || dir >= XFRM_POLICY_MAX)
1244                        continue;
1245
1246                if ((dir & XFRM_POLICY_MASK) == XFRM_POLICY_OUT) {
1247                        if (policy->family == AF_INET) {
1248                                dbits = rbits4;
1249                                sbits = lbits4;
1250                        } else {
1251                                dbits = rbits6;
1252                                sbits = lbits6;
1253                        }
1254                } else {
1255                        if (policy->family == AF_INET) {
1256                                dbits = lbits4;
1257                                sbits = rbits4;
1258                        } else {
1259                                dbits = lbits6;
1260                                sbits = rbits6;
1261                        }
1262                }
1263
1264                if (policy->selector.prefixlen_d < dbits ||
1265                    policy->selector.prefixlen_s < sbits)
1266                        continue;
1267
1268                bin = xfrm_policy_inexact_alloc_bin(policy, dir);
1269                if (!bin)
1270                        goto out_unlock;
1271
1272                if (!xfrm_policy_inexact_alloc_chain(bin, policy, dir))
1273                        goto out_unlock;
1274        }
1275
1276        /* reset the bydst and inexact table in all directions */
1277        for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
1278                struct hlist_node *n;
1279
1280                hlist_for_each_entry_safe(policy, n,
1281                                          &net->xfrm.policy_inexact[dir],
1282                                          bydst_inexact_list) {
1283                        hlist_del_rcu(&policy->bydst);
1284                        hlist_del_init(&policy->bydst_inexact_list);
1285                }
1286
1287                hmask = net->xfrm.policy_bydst[dir].hmask;
1288                odst = net->xfrm.policy_bydst[dir].table;
1289                for (i = hmask; i >= 0; i--) {
1290                        hlist_for_each_entry_safe(policy, n, odst + i, bydst)
1291                                hlist_del_rcu(&policy->bydst);
1292                }
1293                if ((dir & XFRM_POLICY_MASK) == XFRM_POLICY_OUT) {
1294                        /* dir out => dst = remote, src = local */
1295                        net->xfrm.policy_bydst[dir].dbits4 = rbits4;
1296                        net->xfrm.policy_bydst[dir].sbits4 = lbits4;
1297                        net->xfrm.policy_bydst[dir].dbits6 = rbits6;
1298                        net->xfrm.policy_bydst[dir].sbits6 = lbits6;
1299                } else {
1300                        /* dir in/fwd => dst = local, src = remote */
1301                        net->xfrm.policy_bydst[dir].dbits4 = lbits4;
1302                        net->xfrm.policy_bydst[dir].sbits4 = rbits4;
1303                        net->xfrm.policy_bydst[dir].dbits6 = lbits6;
1304                        net->xfrm.policy_bydst[dir].sbits6 = rbits6;
1305                }
1306        }
1307
1308        /* re-insert all policies by order of creation */
1309        list_for_each_entry_reverse(policy, &net->xfrm.policy_all, walk.all) {
1310                if (policy->walk.dead)
1311                        continue;
1312                dir = xfrm_policy_id2dir(policy->index);
1313                if (dir >= XFRM_POLICY_MAX) {
1314                        /* skip socket policies */
1315                        continue;
1316                }
1317                newpos = NULL;
1318                chain = policy_hash_bysel(net, &policy->selector,
1319                                          policy->family, dir);
1320
1321                if (!chain) {
1322                        void *p = xfrm_policy_inexact_insert(policy, dir, 0);
1323
1324                        WARN_ONCE(IS_ERR(p), "reinsert: %ld\n", PTR_ERR(p));
1325                        continue;
1326                }
1327
1328                hlist_for_each_entry(pol, chain, bydst) {
1329                        if (policy->priority >= pol->priority)
1330                                newpos = &pol->bydst;
1331                        else
1332                                break;
1333                }
1334                if (newpos)
1335                        hlist_add_behind_rcu(&policy->bydst, newpos);
1336                else
1337                        hlist_add_head_rcu(&policy->bydst, chain);
1338        }
1339
1340out_unlock:
1341        __xfrm_policy_inexact_flush(net);
1342        write_seqcount_end(&xfrm_policy_hash_generation);
1343        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1344
1345        mutex_unlock(&hash_resize_mutex);
1346}
1347
1348void xfrm_policy_hash_rebuild(struct net *net)
1349{
1350        schedule_work(&net->xfrm.policy_hthresh.work);
1351}
1352EXPORT_SYMBOL(xfrm_policy_hash_rebuild);
1353
1354/* Generate new index... KAME seems to generate them ordered by cost
1355 * of an absolute inpredictability of ordering of rules. This will not pass. */
1356static u32 xfrm_gen_index(struct net *net, int dir, u32 index)
1357{
1358        static u32 idx_generator;
1359
1360        for (;;) {
1361                struct hlist_head *list;
1362                struct xfrm_policy *p;
1363                u32 idx;
1364                int found;
1365
1366                if (!index) {
1367                        idx = (idx_generator | dir);
1368                        idx_generator += 8;
1369                } else {
1370                        idx = index;
1371                        index = 0;
1372                }
1373
1374                if (idx == 0)
1375                        idx = 8;
1376                list = net->xfrm.policy_byidx + idx_hash(net, idx);
1377                found = 0;
1378                hlist_for_each_entry(p, list, byidx) {
1379                        if (p->index == idx) {
1380                                found = 1;
1381                                break;
1382                        }
1383                }
1384                if (!found)
1385                        return idx;
1386        }
1387}
1388
1389static inline int selector_cmp(struct xfrm_selector *s1, struct xfrm_selector *s2)
1390{
1391        u32 *p1 = (u32 *) s1;
1392        u32 *p2 = (u32 *) s2;
1393        int len = sizeof(struct xfrm_selector) / sizeof(u32);
1394        int i;
1395
1396        for (i = 0; i < len; i++) {
1397                if (p1[i] != p2[i])
1398                        return 1;
1399        }
1400
1401        return 0;
1402}
1403
1404static void xfrm_policy_requeue(struct xfrm_policy *old,
1405                                struct xfrm_policy *new)
1406{
1407        struct xfrm_policy_queue *pq = &old->polq;
1408        struct sk_buff_head list;
1409
1410        if (skb_queue_empty(&pq->hold_queue))
1411                return;
1412
1413        __skb_queue_head_init(&list);
1414
1415        spin_lock_bh(&pq->hold_queue.lock);
1416        skb_queue_splice_init(&pq->hold_queue, &list);
1417        if (del_timer(&pq->hold_timer))
1418                xfrm_pol_put(old);
1419        spin_unlock_bh(&pq->hold_queue.lock);
1420
1421        pq = &new->polq;
1422
1423        spin_lock_bh(&pq->hold_queue.lock);
1424        skb_queue_splice(&list, &pq->hold_queue);
1425        pq->timeout = XFRM_QUEUE_TMO_MIN;
1426        if (!mod_timer(&pq->hold_timer, jiffies))
1427                xfrm_pol_hold(new);
1428        spin_unlock_bh(&pq->hold_queue.lock);
1429}
1430
1431static bool xfrm_policy_mark_match(struct xfrm_policy *policy,
1432                                   struct xfrm_policy *pol)
1433{
1434        u32 mark = policy->mark.v & policy->mark.m;
1435
1436        if (policy->mark.v == pol->mark.v && policy->mark.m == pol->mark.m)
1437                return true;
1438
1439        if ((mark & pol->mark.m) == pol->mark.v &&
1440            policy->priority == pol->priority)
1441                return true;
1442
1443        return false;
1444}
1445
1446static u32 xfrm_pol_bin_key(const void *data, u32 len, u32 seed)
1447{
1448        const struct xfrm_pol_inexact_key *k = data;
1449        u32 a = k->type << 24 | k->dir << 16 | k->family;
1450
1451        return jhash_3words(a, k->if_id, net_hash_mix(read_pnet(&k->net)),
1452                            seed);
1453}
1454
1455static u32 xfrm_pol_bin_obj(const void *data, u32 len, u32 seed)
1456{
1457        const struct xfrm_pol_inexact_bin *b = data;
1458
1459        return xfrm_pol_bin_key(&b->k, 0, seed);
1460}
1461
1462static int xfrm_pol_bin_cmp(struct rhashtable_compare_arg *arg,
1463                            const void *ptr)
1464{
1465        const struct xfrm_pol_inexact_key *key = arg->key;
1466        const struct xfrm_pol_inexact_bin *b = ptr;
1467        int ret;
1468
1469        if (!net_eq(read_pnet(&b->k.net), read_pnet(&key->net)))
1470                return -1;
1471
1472        ret = b->k.dir ^ key->dir;
1473        if (ret)
1474                return ret;
1475
1476        ret = b->k.type ^ key->type;
1477        if (ret)
1478                return ret;
1479
1480        ret = b->k.family ^ key->family;
1481        if (ret)
1482                return ret;
1483
1484        return b->k.if_id ^ key->if_id;
1485}
1486
1487static const struct rhashtable_params xfrm_pol_inexact_params = {
1488        .head_offset            = offsetof(struct xfrm_pol_inexact_bin, head),
1489        .hashfn                 = xfrm_pol_bin_key,
1490        .obj_hashfn             = xfrm_pol_bin_obj,
1491        .obj_cmpfn              = xfrm_pol_bin_cmp,
1492        .automatic_shrinking    = true,
1493};
1494
1495static void xfrm_policy_insert_inexact_list(struct hlist_head *chain,
1496                                            struct xfrm_policy *policy)
1497{
1498        struct xfrm_policy *pol, *delpol = NULL;
1499        struct hlist_node *newpos = NULL;
1500        int i = 0;
1501
1502        hlist_for_each_entry(pol, chain, bydst_inexact_list) {
1503                if (pol->type == policy->type &&
1504                    pol->if_id == policy->if_id &&
1505                    !selector_cmp(&pol->selector, &policy->selector) &&
1506                    xfrm_policy_mark_match(policy, pol) &&
1507                    xfrm_sec_ctx_match(pol->security, policy->security) &&
1508                    !WARN_ON(delpol)) {
1509                        delpol = pol;
1510                        if (policy->priority > pol->priority)
1511                                continue;
1512                } else if (policy->priority >= pol->priority) {
1513                        newpos = &pol->bydst_inexact_list;
1514                        continue;
1515                }
1516                if (delpol)
1517                        break;
1518        }
1519
1520        if (newpos)
1521                hlist_add_behind_rcu(&policy->bydst_inexact_list, newpos);
1522        else
1523                hlist_add_head_rcu(&policy->bydst_inexact_list, chain);
1524
1525        hlist_for_each_entry(pol, chain, bydst_inexact_list) {
1526                pol->pos = i;
1527                i++;
1528        }
1529}
1530
1531static struct xfrm_policy *xfrm_policy_insert_list(struct hlist_head *chain,
1532                                                   struct xfrm_policy *policy,
1533                                                   bool excl)
1534{
1535        struct xfrm_policy *pol, *newpos = NULL, *delpol = NULL;
1536
1537        hlist_for_each_entry(pol, chain, bydst) {
1538                if (pol->type == policy->type &&
1539                    pol->if_id == policy->if_id &&
1540                    !selector_cmp(&pol->selector, &policy->selector) &&
1541                    xfrm_policy_mark_match(policy, pol) &&
1542                    xfrm_sec_ctx_match(pol->security, policy->security) &&
1543                    !WARN_ON(delpol)) {
1544                        if (excl)
1545                                return ERR_PTR(-EEXIST);
1546                        delpol = pol;
1547                        if (policy->priority > pol->priority)
1548                                continue;
1549                } else if (policy->priority >= pol->priority) {
1550                        newpos = pol;
1551                        continue;
1552                }
1553                if (delpol)
1554                        break;
1555        }
1556
1557        if (newpos)
1558                hlist_add_behind_rcu(&policy->bydst, &newpos->bydst);
1559        else
1560                hlist_add_head_rcu(&policy->bydst, chain);
1561
1562        return delpol;
1563}
1564
1565int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl)
1566{
1567        struct net *net = xp_net(policy);
1568        struct xfrm_policy *delpol;
1569        struct hlist_head *chain;
1570
1571        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1572        chain = policy_hash_bysel(net, &policy->selector, policy->family, dir);
1573        if (chain)
1574                delpol = xfrm_policy_insert_list(chain, policy, excl);
1575        else
1576                delpol = xfrm_policy_inexact_insert(policy, dir, excl);
1577
1578        if (IS_ERR(delpol)) {
1579                spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1580                return PTR_ERR(delpol);
1581        }
1582
1583        __xfrm_policy_link(policy, dir);
1584
1585        /* After previous checking, family can either be AF_INET or AF_INET6 */
1586        if (policy->family == AF_INET)
1587                rt_genid_bump_ipv4(net);
1588        else
1589                rt_genid_bump_ipv6(net);
1590
1591        if (delpol) {
1592                xfrm_policy_requeue(delpol, policy);
1593                __xfrm_policy_unlink(delpol, dir);
1594        }
1595        policy->index = delpol ? delpol->index : xfrm_gen_index(net, dir, policy->index);
1596        hlist_add_head(&policy->byidx, net->xfrm.policy_byidx+idx_hash(net, policy->index));
1597        policy->curlft.add_time = ktime_get_real_seconds();
1598        policy->curlft.use_time = 0;
1599        if (!mod_timer(&policy->timer, jiffies + HZ))
1600                xfrm_pol_hold(policy);
1601        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1602
1603        if (delpol)
1604                xfrm_policy_kill(delpol);
1605        else if (xfrm_bydst_should_resize(net, dir, NULL))
1606                schedule_work(&net->xfrm.policy_hash_work);
1607
1608        return 0;
1609}
1610EXPORT_SYMBOL(xfrm_policy_insert);
1611
1612static struct xfrm_policy *
1613__xfrm_policy_bysel_ctx(struct hlist_head *chain, u32 mark, u32 if_id,
1614                        u8 type, int dir,
1615                        struct xfrm_selector *sel,
1616                        struct xfrm_sec_ctx *ctx)
1617{
1618        struct xfrm_policy *pol;
1619
1620        if (!chain)
1621                return NULL;
1622
1623        hlist_for_each_entry(pol, chain, bydst) {
1624                if (pol->type == type &&
1625                    pol->if_id == if_id &&
1626                    (mark & pol->mark.m) == pol->mark.v &&
1627                    !selector_cmp(sel, &pol->selector) &&
1628                    xfrm_sec_ctx_match(ctx, pol->security))
1629                        return pol;
1630        }
1631
1632        return NULL;
1633}
1634
1635struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u32 mark, u32 if_id,
1636                                          u8 type, int dir,
1637                                          struct xfrm_selector *sel,
1638                                          struct xfrm_sec_ctx *ctx, int delete,
1639                                          int *err)
1640{
1641        struct xfrm_pol_inexact_bin *bin = NULL;
1642        struct xfrm_policy *pol, *ret = NULL;
1643        struct hlist_head *chain;
1644
1645        *err = 0;
1646        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1647        chain = policy_hash_bysel(net, sel, sel->family, dir);
1648        if (!chain) {
1649                struct xfrm_pol_inexact_candidates cand;
1650                int i;
1651
1652                bin = xfrm_policy_inexact_lookup(net, type,
1653                                                 sel->family, dir, if_id);
1654                if (!bin) {
1655                        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1656                        return NULL;
1657                }
1658
1659                if (!xfrm_policy_find_inexact_candidates(&cand, bin,
1660                                                         &sel->saddr,
1661                                                         &sel->daddr)) {
1662                        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1663                        return NULL;
1664                }
1665
1666                pol = NULL;
1667                for (i = 0; i < ARRAY_SIZE(cand.res); i++) {
1668                        struct xfrm_policy *tmp;
1669
1670                        tmp = __xfrm_policy_bysel_ctx(cand.res[i], mark,
1671                                                      if_id, type, dir,
1672                                                      sel, ctx);
1673                        if (!tmp)
1674                                continue;
1675
1676                        if (!pol || tmp->pos < pol->pos)
1677                                pol = tmp;
1678                }
1679        } else {
1680                pol = __xfrm_policy_bysel_ctx(chain, mark, if_id, type, dir,
1681                                              sel, ctx);
1682        }
1683
1684        if (pol) {
1685                xfrm_pol_hold(pol);
1686                if (delete) {
1687                        *err = security_xfrm_policy_delete(pol->security);
1688                        if (*err) {
1689                                spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1690                                return pol;
1691                        }
1692                        __xfrm_policy_unlink(pol, dir);
1693                }
1694                ret = pol;
1695        }
1696        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1697
1698        if (ret && delete)
1699                xfrm_policy_kill(ret);
1700        if (bin && delete)
1701                xfrm_policy_inexact_prune_bin(bin);
1702        return ret;
1703}
1704EXPORT_SYMBOL(xfrm_policy_bysel_ctx);
1705
1706struct xfrm_policy *xfrm_policy_byid(struct net *net, u32 mark, u32 if_id,
1707                                     u8 type, int dir, u32 id, int delete,
1708                                     int *err)
1709{
1710        struct xfrm_policy *pol, *ret;
1711        struct hlist_head *chain;
1712
1713        *err = -ENOENT;
1714        if (xfrm_policy_id2dir(id) != dir)
1715                return NULL;
1716
1717        *err = 0;
1718        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1719        chain = net->xfrm.policy_byidx + idx_hash(net, id);
1720        ret = NULL;
1721        hlist_for_each_entry(pol, chain, byidx) {
1722                if (pol->type == type && pol->index == id &&
1723                    pol->if_id == if_id &&
1724                    (mark & pol->mark.m) == pol->mark.v) {
1725                        xfrm_pol_hold(pol);
1726                        if (delete) {
1727                                *err = security_xfrm_policy_delete(
1728                                                                pol->security);
1729                                if (*err) {
1730                                        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1731                                        return pol;
1732                                }
1733                                __xfrm_policy_unlink(pol, dir);
1734                        }
1735                        ret = pol;
1736                        break;
1737                }
1738        }
1739        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1740
1741        if (ret && delete)
1742                xfrm_policy_kill(ret);
1743        return ret;
1744}
1745EXPORT_SYMBOL(xfrm_policy_byid);
1746
1747#ifdef CONFIG_SECURITY_NETWORK_XFRM
1748static inline int
1749xfrm_policy_flush_secctx_check(struct net *net, u8 type, bool task_valid)
1750{
1751        struct xfrm_policy *pol;
1752        int err = 0;
1753
1754        list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) {
1755                if (pol->walk.dead ||
1756                    xfrm_policy_id2dir(pol->index) >= XFRM_POLICY_MAX ||
1757                    pol->type != type)
1758                        continue;
1759
1760                err = security_xfrm_policy_delete(pol->security);
1761                if (err) {
1762                        xfrm_audit_policy_delete(pol, 0, task_valid);
1763                        return err;
1764                }
1765        }
1766        return err;
1767}
1768#else
1769static inline int
1770xfrm_policy_flush_secctx_check(struct net *net, u8 type, bool task_valid)
1771{
1772        return 0;
1773}
1774#endif
1775
1776int xfrm_policy_flush(struct net *net, u8 type, bool task_valid)
1777{
1778        int dir, err = 0, cnt = 0;
1779        struct xfrm_policy *pol;
1780
1781        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1782
1783        err = xfrm_policy_flush_secctx_check(net, type, task_valid);
1784        if (err)
1785                goto out;
1786
1787again:
1788        list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) {
1789                dir = xfrm_policy_id2dir(pol->index);
1790                if (pol->walk.dead ||
1791                    dir >= XFRM_POLICY_MAX ||
1792                    pol->type != type)
1793                        continue;
1794
1795                __xfrm_policy_unlink(pol, dir);
1796                spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1797                cnt++;
1798                xfrm_audit_policy_delete(pol, 1, task_valid);
1799                xfrm_policy_kill(pol);
1800                spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1801                goto again;
1802        }
1803        if (cnt)
1804                __xfrm_policy_inexact_flush(net);
1805        else
1806                err = -ESRCH;
1807out:
1808        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1809        return err;
1810}
1811EXPORT_SYMBOL(xfrm_policy_flush);
1812
1813int xfrm_policy_walk(struct net *net, struct xfrm_policy_walk *walk,
1814                     int (*func)(struct xfrm_policy *, int, int, void*),
1815                     void *data)
1816{
1817        struct xfrm_policy *pol;
1818        struct xfrm_policy_walk_entry *x;
1819        int error = 0;
1820
1821        if (walk->type >= XFRM_POLICY_TYPE_MAX &&
1822            walk->type != XFRM_POLICY_TYPE_ANY)
1823                return -EINVAL;
1824
1825        if (list_empty(&walk->walk.all) && walk->seq != 0)
1826                return 0;
1827
1828        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
1829        if (list_empty(&walk->walk.all))
1830                x = list_first_entry(&net->xfrm.policy_all, struct xfrm_policy_walk_entry, all);
1831        else
1832                x = list_first_entry(&walk->walk.all,
1833                                     struct xfrm_policy_walk_entry, all);
1834
1835        list_for_each_entry_from(x, &net->xfrm.policy_all, all) {
1836                if (x->dead)
1837                        continue;
1838                pol = container_of(x, struct xfrm_policy, walk);
1839                if (walk->type != XFRM_POLICY_TYPE_ANY &&
1840                    walk->type != pol->type)
1841                        continue;
1842                error = func(pol, xfrm_policy_id2dir(pol->index),
1843                             walk->seq, data);
1844                if (error) {
1845                        list_move_tail(&walk->walk.all, &x->all);
1846                        goto out;
1847                }
1848                walk->seq++;
1849        }
1850        if (walk->seq == 0) {
1851                error = -ENOENT;
1852                goto out;
1853        }
1854        list_del_init(&walk->walk.all);
1855out:
1856        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1857        return error;
1858}
1859EXPORT_SYMBOL(xfrm_policy_walk);
1860
1861void xfrm_policy_walk_init(struct xfrm_policy_walk *walk, u8 type)
1862{
1863        INIT_LIST_HEAD(&walk->walk.all);
1864        walk->walk.dead = 1;
1865        walk->type = type;
1866        walk->seq = 0;
1867}
1868EXPORT_SYMBOL(xfrm_policy_walk_init);
1869
1870void xfrm_policy_walk_done(struct xfrm_policy_walk *walk, struct net *net)
1871{
1872        if (list_empty(&walk->walk.all))
1873                return;
1874
1875        spin_lock_bh(&net->xfrm.xfrm_policy_lock); /*FIXME where is net? */
1876        list_del(&walk->walk.all);
1877        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
1878}
1879EXPORT_SYMBOL(xfrm_policy_walk_done);
1880
1881/*
1882 * Find policy to apply to this flow.
1883 *
1884 * Returns 0 if policy found, else an -errno.
1885 */
1886static int xfrm_policy_match(const struct xfrm_policy *pol,
1887                             const struct flowi *fl,
1888                             u8 type, u16 family, int dir, u32 if_id)
1889{
1890        const struct xfrm_selector *sel = &pol->selector;
1891        int ret = -ESRCH;
1892        bool match;
1893
1894        if (pol->family != family ||
1895            pol->if_id != if_id ||
1896            (fl->flowi_mark & pol->mark.m) != pol->mark.v ||
1897            pol->type != type)
1898                return ret;
1899
1900        match = xfrm_selector_match(sel, fl, family);
1901        if (match)
1902                ret = security_xfrm_policy_lookup(pol->security, fl->flowi_secid,
1903                                                  dir);
1904        return ret;
1905}
1906
1907static struct xfrm_pol_inexact_node *
1908xfrm_policy_lookup_inexact_addr(const struct rb_root *r,
1909                                seqcount_t *count,
1910                                const xfrm_address_t *addr, u16 family)
1911{
1912        const struct rb_node *parent;
1913        int seq;
1914
1915again:
1916        seq = read_seqcount_begin(count);
1917
1918        parent = rcu_dereference_raw(r->rb_node);
1919        while (parent) {
1920                struct xfrm_pol_inexact_node *node;
1921                int delta;
1922
1923                node = rb_entry(parent, struct xfrm_pol_inexact_node, node);
1924
1925                delta = xfrm_policy_addr_delta(addr, &node->addr,
1926                                               node->prefixlen, family);
1927                if (delta < 0) {
1928                        parent = rcu_dereference_raw(parent->rb_left);
1929                        continue;
1930                } else if (delta > 0) {
1931                        parent = rcu_dereference_raw(parent->rb_right);
1932                        continue;
1933                }
1934
1935                return node;
1936        }
1937
1938        if (read_seqcount_retry(count, seq))
1939                goto again;
1940
1941        return NULL;
1942}
1943
1944static bool
1945xfrm_policy_find_inexact_candidates(struct xfrm_pol_inexact_candidates *cand,
1946                                    struct xfrm_pol_inexact_bin *b,
1947                                    const xfrm_address_t *saddr,
1948                                    const xfrm_address_t *daddr)
1949{
1950        struct xfrm_pol_inexact_node *n;
1951        u16 family;
1952
1953        if (!b)
1954                return false;
1955
1956        family = b->k.family;
1957        memset(cand, 0, sizeof(*cand));
1958        cand->res[XFRM_POL_CAND_ANY] = &b->hhead;
1959
1960        n = xfrm_policy_lookup_inexact_addr(&b->root_d, &b->count, daddr,
1961                                            family);
1962        if (n) {
1963                cand->res[XFRM_POL_CAND_DADDR] = &n->hhead;
1964                n = xfrm_policy_lookup_inexact_addr(&n->root, &b->count, saddr,
1965                                                    family);
1966                if (n)
1967                        cand->res[XFRM_POL_CAND_BOTH] = &n->hhead;
1968        }
1969
1970        n = xfrm_policy_lookup_inexact_addr(&b->root_s, &b->count, saddr,
1971                                            family);
1972        if (n)
1973                cand->res[XFRM_POL_CAND_SADDR] = &n->hhead;
1974
1975        return true;
1976}
1977
1978static struct xfrm_pol_inexact_bin *
1979xfrm_policy_inexact_lookup_rcu(struct net *net, u8 type, u16 family,
1980                               u8 dir, u32 if_id)
1981{
1982        struct xfrm_pol_inexact_key k = {
1983                .family = family,
1984                .type = type,
1985                .dir = dir,
1986                .if_id = if_id,
1987        };
1988
1989        write_pnet(&k.net, net);
1990
1991        return rhashtable_lookup(&xfrm_policy_inexact_table, &k,
1992                                 xfrm_pol_inexact_params);
1993}
1994
1995static struct xfrm_pol_inexact_bin *
1996xfrm_policy_inexact_lookup(struct net *net, u8 type, u16 family,
1997                           u8 dir, u32 if_id)
1998{
1999        struct xfrm_pol_inexact_bin *bin;
2000
2001        lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
2002
2003        rcu_read_lock();
2004        bin = xfrm_policy_inexact_lookup_rcu(net, type, family, dir, if_id);
2005        rcu_read_unlock();
2006
2007        return bin;
2008}
2009
2010static struct xfrm_policy *
2011__xfrm_policy_eval_candidates(struct hlist_head *chain,
2012                              struct xfrm_policy *prefer,
2013                              const struct flowi *fl,
2014                              u8 type, u16 family, int dir, u32 if_id)
2015{
2016        u32 priority = prefer ? prefer->priority : ~0u;
2017        struct xfrm_policy *pol;
2018
2019        if (!chain)
2020                return NULL;
2021
2022        hlist_for_each_entry_rcu(pol, chain, bydst) {
2023                int err;
2024
2025                if (pol->priority > priority)
2026                        break;
2027
2028                err = xfrm_policy_match(pol, fl, type, family, dir, if_id);
2029                if (err) {
2030                        if (err != -ESRCH)
2031                                return ERR_PTR(err);
2032
2033                        continue;
2034                }
2035
2036                if (prefer) {
2037                        /* matches.  Is it older than *prefer? */
2038                        if (pol->priority == priority &&
2039                            prefer->pos < pol->pos)
2040                                return prefer;
2041                }
2042
2043                return pol;
2044        }
2045
2046        return NULL;
2047}
2048
2049static struct xfrm_policy *
2050xfrm_policy_eval_candidates(struct xfrm_pol_inexact_candidates *cand,
2051                            struct xfrm_policy *prefer,
2052                            const struct flowi *fl,
2053                            u8 type, u16 family, int dir, u32 if_id)
2054{
2055        struct xfrm_policy *tmp;
2056        int i;
2057
2058        for (i = 0; i < ARRAY_SIZE(cand->res); i++) {
2059                tmp = __xfrm_policy_eval_candidates(cand->res[i],
2060                                                    prefer,
2061                                                    fl, type, family, dir,
2062                                                    if_id);
2063                if (!tmp)
2064                        continue;
2065
2066                if (IS_ERR(tmp))
2067                        return tmp;
2068                prefer = tmp;
2069        }
2070
2071        return prefer;
2072}
2073
2074static struct xfrm_policy *xfrm_policy_lookup_bytype(struct net *net, u8 type,
2075                                                     const struct flowi *fl,
2076                                                     u16 family, u8 dir,
2077                                                     u32 if_id)
2078{
2079        struct xfrm_pol_inexact_candidates cand;
2080        const xfrm_address_t *daddr, *saddr;
2081        struct xfrm_pol_inexact_bin *bin;
2082        struct xfrm_policy *pol, *ret;
2083        struct hlist_head *chain;
2084        unsigned int sequence;
2085        int err;
2086
2087        daddr = xfrm_flowi_daddr(fl, family);
2088        saddr = xfrm_flowi_saddr(fl, family);
2089        if (unlikely(!daddr || !saddr))
2090                return NULL;
2091
2092        rcu_read_lock();
2093 retry:
2094        do {
2095                sequence = read_seqcount_begin(&xfrm_policy_hash_generation);
2096                chain = policy_hash_direct(net, daddr, saddr, family, dir);
2097        } while (read_seqcount_retry(&xfrm_policy_hash_generation, sequence));
2098
2099        ret = NULL;
2100        hlist_for_each_entry_rcu(pol, chain, bydst) {
2101                err = xfrm_policy_match(pol, fl, type, family, dir, if_id);
2102                if (err) {
2103                        if (err == -ESRCH)
2104                                continue;
2105                        else {
2106                                ret = ERR_PTR(err);
2107                                goto fail;
2108                        }
2109                } else {
2110                        ret = pol;
2111                        break;
2112                }
2113        }
2114        bin = xfrm_policy_inexact_lookup_rcu(net, type, family, dir, if_id);
2115        if (!bin || !xfrm_policy_find_inexact_candidates(&cand, bin, saddr,
2116                                                         daddr))
2117                goto skip_inexact;
2118
2119        pol = xfrm_policy_eval_candidates(&cand, ret, fl, type,
2120                                          family, dir, if_id);
2121        if (pol) {
2122                ret = pol;
2123                if (IS_ERR(pol))
2124                        goto fail;
2125        }
2126
2127skip_inexact:
2128        if (read_seqcount_retry(&xfrm_policy_hash_generation, sequence))
2129                goto retry;
2130
2131        if (ret && !xfrm_pol_hold_rcu(ret))
2132                goto retry;
2133fail:
2134        rcu_read_unlock();
2135
2136        return ret;
2137}
2138
2139static struct xfrm_policy *xfrm_policy_lookup(struct net *net,
2140                                              const struct flowi *fl,
2141                                              u16 family, u8 dir, u32 if_id)
2142{
2143#ifdef CONFIG_XFRM_SUB_POLICY
2144        struct xfrm_policy *pol;
2145
2146        pol = xfrm_policy_lookup_bytype(net, XFRM_POLICY_TYPE_SUB, fl, family,
2147                                        dir, if_id);
2148        if (pol != NULL)
2149                return pol;
2150#endif
2151        return xfrm_policy_lookup_bytype(net, XFRM_POLICY_TYPE_MAIN, fl, family,
2152                                         dir, if_id);
2153}
2154
2155static struct xfrm_policy *xfrm_sk_policy_lookup(const struct sock *sk, int dir,
2156                                                 const struct flowi *fl,
2157                                                 u16 family, u32 if_id)
2158{
2159        struct xfrm_policy *pol;
2160
2161        rcu_read_lock();
2162 again:
2163        pol = rcu_dereference(sk->sk_policy[dir]);
2164        if (pol != NULL) {
2165                bool match;
2166                int err = 0;
2167
2168                if (pol->family != family) {
2169                        pol = NULL;
2170                        goto out;
2171                }
2172
2173                match = xfrm_selector_match(&pol->selector, fl, family);
2174                if (match) {
2175                        if ((sk->sk_mark & pol->mark.m) != pol->mark.v ||
2176                            pol->if_id != if_id) {
2177                                pol = NULL;
2178                                goto out;
2179                        }
2180                        err = security_xfrm_policy_lookup(pol->security,
2181                                                      fl->flowi_secid,
2182                                                      dir);
2183                        if (!err) {
2184                                if (!xfrm_pol_hold_rcu(pol))
2185                                        goto again;
2186                        } else if (err == -ESRCH) {
2187                                pol = NULL;
2188                        } else {
2189                                pol = ERR_PTR(err);
2190                        }
2191                } else
2192                        pol = NULL;
2193        }
2194out:
2195        rcu_read_unlock();
2196        return pol;
2197}
2198
2199static void __xfrm_policy_link(struct xfrm_policy *pol, int dir)
2200{
2201        struct net *net = xp_net(pol);
2202
2203        list_add(&pol->walk.all, &net->xfrm.policy_all);
2204        net->xfrm.policy_count[dir]++;
2205        xfrm_pol_hold(pol);
2206}
2207
2208static struct xfrm_policy *__xfrm_policy_unlink(struct xfrm_policy *pol,
2209                                                int dir)
2210{
2211        struct net *net = xp_net(pol);
2212
2213        if (list_empty(&pol->walk.all))
2214                return NULL;
2215
2216        /* Socket policies are not hashed. */
2217        if (!hlist_unhashed(&pol->bydst)) {
2218                hlist_del_rcu(&pol->bydst);
2219                hlist_del_init(&pol->bydst_inexact_list);
2220                hlist_del(&pol->byidx);
2221        }
2222
2223        list_del_init(&pol->walk.all);
2224        net->xfrm.policy_count[dir]--;
2225
2226        return pol;
2227}
2228
2229static void xfrm_sk_policy_link(struct xfrm_policy *pol, int dir)
2230{
2231        __xfrm_policy_link(pol, XFRM_POLICY_MAX + dir);
2232}
2233
2234static void xfrm_sk_policy_unlink(struct xfrm_policy *pol, int dir)
2235{
2236        __xfrm_policy_unlink(pol, XFRM_POLICY_MAX + dir);
2237}
2238
2239int xfrm_policy_delete(struct xfrm_policy *pol, int dir)
2240{
2241        struct net *net = xp_net(pol);
2242
2243        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
2244        pol = __xfrm_policy_unlink(pol, dir);
2245        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
2246        if (pol) {
2247                xfrm_policy_kill(pol);
2248                return 0;
2249        }
2250        return -ENOENT;
2251}
2252EXPORT_SYMBOL(xfrm_policy_delete);
2253
2254int xfrm_sk_policy_insert(struct sock *sk, int dir, struct xfrm_policy *pol)
2255{
2256        struct net *net = sock_net(sk);
2257        struct xfrm_policy *old_pol;
2258
2259#ifdef CONFIG_XFRM_SUB_POLICY
2260        if (pol && pol->type != XFRM_POLICY_TYPE_MAIN)
2261                return -EINVAL;
2262#endif
2263
2264        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
2265        old_pol = rcu_dereference_protected(sk->sk_policy[dir],
2266                                lockdep_is_held(&net->xfrm.xfrm_policy_lock));
2267        if (pol) {
2268                pol->curlft.add_time = ktime_get_real_seconds();
2269                pol->index = xfrm_gen_index(net, XFRM_POLICY_MAX+dir, 0);
2270                xfrm_sk_policy_link(pol, dir);
2271        }
2272        rcu_assign_pointer(sk->sk_policy[dir], pol);
2273        if (old_pol) {
2274                if (pol)
2275                        xfrm_policy_requeue(old_pol, pol);
2276
2277                /* Unlinking succeeds always. This is the only function
2278                 * allowed to delete or replace socket policy.
2279                 */
2280                xfrm_sk_policy_unlink(old_pol, dir);
2281        }
2282        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
2283
2284        if (old_pol) {
2285                xfrm_policy_kill(old_pol);
2286        }
2287        return 0;
2288}
2289
2290static struct xfrm_policy *clone_policy(const struct xfrm_policy *old, int dir)
2291{
2292        struct xfrm_policy *newp = xfrm_policy_alloc(xp_net(old), GFP_ATOMIC);
2293        struct net *net = xp_net(old);
2294
2295        if (newp) {
2296                newp->selector = old->selector;
2297                if (security_xfrm_policy_clone(old->security,
2298                                               &newp->security)) {
2299                        kfree(newp);
2300                        return NULL;  /* ENOMEM */
2301                }
2302                newp->lft = old->lft;
2303                newp->curlft = old->curlft;
2304                newp->mark = old->mark;
2305                newp->if_id = old->if_id;
2306                newp->action = old->action;
2307                newp->flags = old->flags;
2308                newp->xfrm_nr = old->xfrm_nr;
2309                newp->index = old->index;
2310                newp->type = old->type;
2311                newp->family = old->family;
2312                memcpy(newp->xfrm_vec, old->xfrm_vec,
2313                       newp->xfrm_nr*sizeof(struct xfrm_tmpl));
2314                spin_lock_bh(&net->xfrm.xfrm_policy_lock);
2315                xfrm_sk_policy_link(newp, dir);
2316                spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
2317                xfrm_pol_put(newp);
2318        }
2319        return newp;
2320}
2321
2322int __xfrm_sk_clone_policy(struct sock *sk, const struct sock *osk)
2323{
2324        const struct xfrm_policy *p;
2325        struct xfrm_policy *np;
2326        int i, ret = 0;
2327
2328        rcu_read_lock();
2329        for (i = 0; i < 2; i++) {
2330                p = rcu_dereference(osk->sk_policy[i]);
2331                if (p) {
2332                        np = clone_policy(p, i);
2333                        if (unlikely(!np)) {
2334                                ret = -ENOMEM;
2335                                break;
2336                        }
2337                        rcu_assign_pointer(sk->sk_policy[i], np);
2338                }
2339        }
2340        rcu_read_unlock();
2341        return ret;
2342}
2343
2344static int
2345xfrm_get_saddr(struct net *net, int oif, xfrm_address_t *local,
2346               xfrm_address_t *remote, unsigned short family, u32 mark)
2347{
2348        int err;
2349        const struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family);
2350
2351        if (unlikely(afinfo == NULL))
2352                return -EINVAL;
2353        err = afinfo->get_saddr(net, oif, local, remote, mark);
2354        rcu_read_unlock();
2355        return err;
2356}
2357
2358/* Resolve list of templates for the flow, given policy. */
2359
2360static int
2361xfrm_tmpl_resolve_one(struct xfrm_policy *policy, const struct flowi *fl,
2362                      struct xfrm_state **xfrm, unsigned short family)
2363{
2364        struct net *net = xp_net(policy);
2365        int nx;
2366        int i, error;
2367        xfrm_address_t *daddr = xfrm_flowi_daddr(fl, family);
2368        xfrm_address_t *saddr = xfrm_flowi_saddr(fl, family);
2369        xfrm_address_t tmp;
2370
2371        for (nx = 0, i = 0; i < policy->xfrm_nr; i++) {
2372                struct xfrm_state *x;
2373                xfrm_address_t *remote = daddr;
2374                xfrm_address_t *local  = saddr;
2375                struct xfrm_tmpl *tmpl = &policy->xfrm_vec[i];
2376
2377                if (tmpl->mode == XFRM_MODE_TUNNEL ||
2378                    tmpl->mode == XFRM_MODE_BEET) {
2379                        remote = &tmpl->id.daddr;
2380                        local = &tmpl->saddr;
2381                        if (xfrm_addr_any(local, tmpl->encap_family)) {
2382                                error = xfrm_get_saddr(net, fl->flowi_oif,
2383                                                       &tmp, remote,
2384                                                       tmpl->encap_family, 0);
2385                                if (error)
2386                                        goto fail;
2387                                local = &tmp;
2388                        }
2389                }
2390
2391                x = xfrm_state_find(remote, local, fl, tmpl, policy, &error,
2392                                    family, policy->if_id);
2393
2394                if (x && x->km.state == XFRM_STATE_VALID) {
2395                        xfrm[nx++] = x;
2396                        daddr = remote;
2397                        saddr = local;
2398                        continue;
2399                }
2400                if (x) {
2401                        error = (x->km.state == XFRM_STATE_ERROR ?
2402                                 -EINVAL : -EAGAIN);
2403                        xfrm_state_put(x);
2404                } else if (error == -ESRCH) {
2405                        error = -EAGAIN;
2406                }
2407
2408                if (!tmpl->optional)
2409                        goto fail;
2410        }
2411        return nx;
2412
2413fail:
2414        for (nx--; nx >= 0; nx--)
2415                xfrm_state_put(xfrm[nx]);
2416        return error;
2417}
2418
2419static int
2420xfrm_tmpl_resolve(struct xfrm_policy **pols, int npols, const struct flowi *fl,
2421                  struct xfrm_state **xfrm, unsigned short family)
2422{
2423        struct xfrm_state *tp[XFRM_MAX_DEPTH];
2424        struct xfrm_state **tpp = (npols > 1) ? tp : xfrm;
2425        int cnx = 0;
2426        int error;
2427        int ret;
2428        int i;
2429
2430        for (i = 0; i < npols; i++) {
2431                if (cnx + pols[i]->xfrm_nr >= XFRM_MAX_DEPTH) {
2432                        error = -ENOBUFS;
2433                        goto fail;
2434                }
2435
2436                ret = xfrm_tmpl_resolve_one(pols[i], fl, &tpp[cnx], family);
2437                if (ret < 0) {
2438                        error = ret;
2439                        goto fail;
2440                } else
2441                        cnx += ret;
2442        }
2443
2444        /* found states are sorted for outbound processing */
2445        if (npols > 1)
2446                xfrm_state_sort(xfrm, tpp, cnx, family);
2447
2448        return cnx;
2449
2450 fail:
2451        for (cnx--; cnx >= 0; cnx--)
2452                xfrm_state_put(tpp[cnx]);
2453        return error;
2454
2455}
2456
2457static int xfrm_get_tos(const struct flowi *fl, int family)
2458{
2459        if (family == AF_INET)
2460                return IPTOS_RT_MASK & fl->u.ip4.flowi4_tos;
2461
2462        return 0;
2463}
2464
2465static inline struct xfrm_dst *xfrm_alloc_dst(struct net *net, int family)
2466{
2467        const struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family);
2468        struct dst_ops *dst_ops;
2469        struct xfrm_dst *xdst;
2470
2471        if (!afinfo)
2472                return ERR_PTR(-EINVAL);
2473
2474        switch (family) {
2475        case AF_INET:
2476                dst_ops = &net->xfrm.xfrm4_dst_ops;
2477                break;
2478#if IS_ENABLED(CONFIG_IPV6)
2479        case AF_INET6:
2480                dst_ops = &net->xfrm.xfrm6_dst_ops;
2481                break;
2482#endif
2483        default:
2484                BUG();
2485        }
2486        xdst = dst_alloc(dst_ops, NULL, 1, DST_OBSOLETE_NONE, 0);
2487
2488        if (likely(xdst)) {
2489                struct dst_entry *dst = &xdst->u.dst;
2490
2491                memset(dst + 1, 0, sizeof(*xdst) - sizeof(*dst));
2492        } else
2493                xdst = ERR_PTR(-ENOBUFS);
2494
2495        rcu_read_unlock();
2496
2497        return xdst;
2498}
2499
2500static void xfrm_init_path(struct xfrm_dst *path, struct dst_entry *dst,
2501                           int nfheader_len)
2502{
2503        if (dst->ops->family == AF_INET6) {
2504                struct rt6_info *rt = (struct rt6_info *)dst;
2505                path->path_cookie = rt6_get_cookie(rt);
2506                path->u.rt6.rt6i_nfheader_len = nfheader_len;
2507        }
2508}
2509
2510static inline int xfrm_fill_dst(struct xfrm_dst *xdst, struct net_device *dev,
2511                                const struct flowi *fl)
2512{
2513        const struct xfrm_policy_afinfo *afinfo =
2514                xfrm_policy_get_afinfo(xdst->u.dst.ops->family);
2515        int err;
2516
2517        if (!afinfo)
2518                return -EINVAL;
2519
2520        err = afinfo->fill_dst(xdst, dev, fl);
2521
2522        rcu_read_unlock();
2523
2524        return err;
2525}
2526
2527
2528/* Allocate chain of dst_entry's, attach known xfrm's, calculate
2529 * all the metrics... Shortly, bundle a bundle.
2530 */
2531
2532static struct dst_entry *xfrm_bundle_create(struct xfrm_policy *policy,
2533                                            struct xfrm_state **xfrm,
2534                                            struct xfrm_dst **bundle,
2535                                            int nx,
2536                                            const struct flowi *fl,
2537                                            struct dst_entry *dst)
2538{
2539        const struct xfrm_state_afinfo *afinfo;
2540        const struct xfrm_mode *inner_mode;
2541        struct net *net = xp_net(policy);
2542        unsigned long now = jiffies;
2543        struct net_device *dev;
2544        struct xfrm_dst *xdst_prev = NULL;
2545        struct xfrm_dst *xdst0 = NULL;
2546        int i = 0;
2547        int err;
2548        int header_len = 0;
2549        int nfheader_len = 0;
2550        int trailer_len = 0;
2551        int tos;
2552        int family = policy->selector.family;
2553        xfrm_address_t saddr, daddr;
2554
2555        xfrm_flowi_addr_get(fl, &saddr, &daddr, family);
2556
2557        tos = xfrm_get_tos(fl, family);
2558
2559        dst_hold(dst);
2560
2561        for (; i < nx; i++) {
2562                struct xfrm_dst *xdst = xfrm_alloc_dst(net, family);
2563                struct dst_entry *dst1 = &xdst->u.dst;
2564
2565                err = PTR_ERR(xdst);
2566                if (IS_ERR(xdst)) {
2567                        dst_release(dst);
2568                        goto put_states;
2569                }
2570
2571                bundle[i] = xdst;
2572                if (!xdst_prev)
2573                        xdst0 = xdst;
2574                else
2575                        /* Ref count is taken during xfrm_alloc_dst()
2576                         * No need to do dst_clone() on dst1
2577                         */
2578                        xfrm_dst_set_child(xdst_prev, &xdst->u.dst);
2579
2580                if (xfrm[i]->sel.family == AF_UNSPEC) {
2581                        inner_mode = xfrm_ip2inner_mode(xfrm[i],
2582                                                        xfrm_af2proto(family));
2583                        if (!inner_mode) {
2584                                err = -EAFNOSUPPORT;
2585                                dst_release(dst);
2586                                goto put_states;
2587                        }
2588                } else
2589                        inner_mode = &xfrm[i]->inner_mode;
2590
2591                xdst->route = dst;
2592                dst_copy_metrics(dst1, dst);
2593
2594                if (xfrm[i]->props.mode != XFRM_MODE_TRANSPORT) {
2595                        __u32 mark = 0;
2596
2597                        if (xfrm[i]->props.smark.v || xfrm[i]->props.smark.m)
2598                                mark = xfrm_smark_get(fl->flowi_mark, xfrm[i]);
2599
2600                        family = xfrm[i]->props.family;
2601                        dst = xfrm_dst_lookup(xfrm[i], tos, fl->flowi_oif,
2602                                              &saddr, &daddr, family, mark);
2603                        err = PTR_ERR(dst);
2604                        if (IS_ERR(dst))
2605                                goto put_states;
2606                } else
2607                        dst_hold(dst);
2608
2609                dst1->xfrm = xfrm[i];
2610                xdst->xfrm_genid = xfrm[i]->genid;
2611
2612                dst1->obsolete = DST_OBSOLETE_FORCE_CHK;
2613                dst1->flags |= DST_HOST;
2614                dst1->lastuse = now;
2615
2616                dst1->input = dst_discard;
2617
2618                rcu_read_lock();
2619                afinfo = xfrm_state_afinfo_get_rcu(inner_mode->family);
2620                if (likely(afinfo))
2621                        dst1->output = afinfo->output;
2622                else
2623                        dst1->output = dst_discard_out;
2624                rcu_read_unlock();
2625
2626                xdst_prev = xdst;
2627
2628                header_len += xfrm[i]->props.header_len;
2629                if (xfrm[i]->type->flags & XFRM_TYPE_NON_FRAGMENT)
2630                        nfheader_len += xfrm[i]->props.header_len;
2631                trailer_len += xfrm[i]->props.trailer_len;
2632        }
2633
2634        xfrm_dst_set_child(xdst_prev, dst);
2635        xdst0->path = dst;
2636
2637        err = -ENODEV;
2638        dev = dst->dev;
2639        if (!dev)
2640                goto free_dst;
2641
2642        xfrm_init_path(xdst0, dst, nfheader_len);
2643        xfrm_init_pmtu(bundle, nx);
2644
2645        for (xdst_prev = xdst0; xdst_prev != (struct xfrm_dst *)dst;
2646             xdst_prev = (struct xfrm_dst *) xfrm_dst_child(&xdst_prev->u.dst)) {
2647                err = xfrm_fill_dst(xdst_prev, dev, fl);
2648                if (err)
2649                        goto free_dst;
2650
2651                xdst_prev->u.dst.header_len = header_len;
2652                xdst_prev->u.dst.trailer_len = trailer_len;
2653                header_len -= xdst_prev->u.dst.xfrm->props.header_len;
2654                trailer_len -= xdst_prev->u.dst.xfrm->props.trailer_len;
2655        }
2656
2657        return &xdst0->u.dst;
2658
2659put_states:
2660        for (; i < nx; i++)
2661                xfrm_state_put(xfrm[i]);
2662free_dst:
2663        if (xdst0)
2664                dst_release_immediate(&xdst0->u.dst);
2665
2666        return ERR_PTR(err);
2667}
2668
2669static int xfrm_expand_policies(const struct flowi *fl, u16 family,
2670                                struct xfrm_policy **pols,
2671                                int *num_pols, int *num_xfrms)
2672{
2673        int i;
2674
2675        if (*num_pols == 0 || !pols[0]) {
2676                *num_pols = 0;
2677                *num_xfrms = 0;
2678                return 0;
2679        }
2680        if (IS_ERR(pols[0]))
2681                return PTR_ERR(pols[0]);
2682
2683        *num_xfrms = pols[0]->xfrm_nr;
2684
2685#ifdef CONFIG_XFRM_SUB_POLICY
2686        if (pols[0] && pols[0]->action == XFRM_POLICY_ALLOW &&
2687            pols[0]->type != XFRM_POLICY_TYPE_MAIN) {
2688                pols[1] = xfrm_policy_lookup_bytype(xp_net(pols[0]),
2689                                                    XFRM_POLICY_TYPE_MAIN,
2690                                                    fl, family,
2691                                                    XFRM_POLICY_OUT,
2692                                                    pols[0]->if_id);
2693                if (pols[1]) {
2694                        if (IS_ERR(pols[1])) {
2695                                xfrm_pols_put(pols, *num_pols);
2696                                return PTR_ERR(pols[1]);
2697                        }
2698                        (*num_pols)++;
2699                        (*num_xfrms) += pols[1]->xfrm_nr;
2700                }
2701        }
2702#endif
2703        for (i = 0; i < *num_pols; i++) {
2704                if (pols[i]->action != XFRM_POLICY_ALLOW) {
2705                        *num_xfrms = -1;
2706                        break;
2707                }
2708        }
2709
2710        return 0;
2711
2712}
2713
2714static struct xfrm_dst *
2715xfrm_resolve_and_create_bundle(struct xfrm_policy **pols, int num_pols,
2716                               const struct flowi *fl, u16 family,
2717                               struct dst_entry *dst_orig)
2718{
2719        struct net *net = xp_net(pols[0]);
2720        struct xfrm_state *xfrm[XFRM_MAX_DEPTH];
2721        struct xfrm_dst *bundle[XFRM_MAX_DEPTH];
2722        struct xfrm_dst *xdst;
2723        struct dst_entry *dst;
2724        int err;
2725
2726        /* Try to instantiate a bundle */
2727        err = xfrm_tmpl_resolve(pols, num_pols, fl, xfrm, family);
2728        if (err <= 0) {
2729                if (err == 0)
2730                        return NULL;
2731
2732                if (err != -EAGAIN)
2733                        XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTPOLERROR);
2734                return ERR_PTR(err);
2735        }
2736
2737        dst = xfrm_bundle_create(pols[0], xfrm, bundle, err, fl, dst_orig);
2738        if (IS_ERR(dst)) {
2739                XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTBUNDLEGENERROR);
2740                return ERR_CAST(dst);
2741        }
2742
2743        xdst = (struct xfrm_dst *)dst;
2744        xdst->num_xfrms = err;
2745        xdst->num_pols = num_pols;
2746        memcpy(xdst->pols, pols, sizeof(struct xfrm_policy *) * num_pols);
2747        xdst->policy_genid = atomic_read(&pols[0]->genid);
2748
2749        return xdst;
2750}
2751
2752static void xfrm_policy_queue_process(struct timer_list *t)
2753{
2754        struct sk_buff *skb;
2755        struct sock *sk;
2756        struct dst_entry *dst;
2757        struct xfrm_policy *pol = from_timer(pol, t, polq.hold_timer);
2758        struct net *net = xp_net(pol);
2759        struct xfrm_policy_queue *pq = &pol->polq;
2760        struct flowi fl;
2761        struct sk_buff_head list;
2762
2763        spin_lock(&pq->hold_queue.lock);
2764        skb = skb_peek(&pq->hold_queue);
2765        if (!skb) {
2766                spin_unlock(&pq->hold_queue.lock);
2767                goto out;
2768        }
2769        dst = skb_dst(skb);
2770        sk = skb->sk;
2771        xfrm_decode_session(skb, &fl, dst->ops->family);
2772        spin_unlock(&pq->hold_queue.lock);
2773
2774        dst_hold(xfrm_dst_path(dst));
2775        dst = xfrm_lookup(net, xfrm_dst_path(dst), &fl, sk, XFRM_LOOKUP_QUEUE);
2776        if (IS_ERR(dst))
2777                goto purge_queue;
2778
2779        if (dst->flags & DST_XFRM_QUEUE) {
2780                dst_release(dst);
2781
2782                if (pq->timeout >= XFRM_QUEUE_TMO_MAX)
2783                        goto purge_queue;
2784
2785                pq->timeout = pq->timeout << 1;
2786                if (!mod_timer(&pq->hold_timer, jiffies + pq->timeout))
2787                        xfrm_pol_hold(pol);
2788                goto out;
2789        }
2790
2791        dst_release(dst);
2792
2793        __skb_queue_head_init(&list);
2794
2795        spin_lock(&pq->hold_queue.lock);
2796        pq->timeout = 0;
2797        skb_queue_splice_init(&pq->hold_queue, &list);
2798        spin_unlock(&pq->hold_queue.lock);
2799
2800        while (!skb_queue_empty(&list)) {
2801                skb = __skb_dequeue(&list);
2802
2803                xfrm_decode_session(skb, &fl, skb_dst(skb)->ops->family);
2804                dst_hold(xfrm_dst_path(skb_dst(skb)));
2805                dst = xfrm_lookup(net, xfrm_dst_path(skb_dst(skb)), &fl, skb->sk, 0);
2806                if (IS_ERR(dst)) {
2807                        kfree_skb(skb);
2808                        continue;
2809                }
2810
2811                nf_reset(skb);
2812                skb_dst_drop(skb);
2813                skb_dst_set(skb, dst);
2814
2815                dst_output(net, skb->sk, skb);
2816        }
2817
2818out:
2819        xfrm_pol_put(pol);
2820        return;
2821
2822purge_queue:
2823        pq->timeout = 0;
2824        skb_queue_purge(&pq->hold_queue);
2825        xfrm_pol_put(pol);
2826}
2827
2828static int xdst_queue_output(struct net *net, struct sock *sk, struct sk_buff *skb)
2829{
2830        unsigned long sched_next;
2831        struct dst_entry *dst = skb_dst(skb);
2832        struct xfrm_dst *xdst = (struct xfrm_dst *) dst;
2833        struct xfrm_policy *pol = xdst->pols[0];
2834        struct xfrm_policy_queue *pq = &pol->polq;
2835
2836        if (unlikely(skb_fclone_busy(sk, skb))) {
2837                kfree_skb(skb);
2838                return 0;
2839        }
2840
2841        if (pq->hold_queue.qlen > XFRM_MAX_QUEUE_LEN) {
2842                kfree_skb(skb);
2843                return -EAGAIN;
2844        }
2845
2846        skb_dst_force(skb);
2847
2848        spin_lock_bh(&pq->hold_queue.lock);
2849
2850        if (!pq->timeout)
2851                pq->timeout = XFRM_QUEUE_TMO_MIN;
2852
2853        sched_next = jiffies + pq->timeout;
2854
2855        if (del_timer(&pq->hold_timer)) {
2856                if (time_before(pq->hold_timer.expires, sched_next))
2857                        sched_next = pq->hold_timer.expires;
2858                xfrm_pol_put(pol);
2859        }
2860
2861        __skb_queue_tail(&pq->hold_queue, skb);
2862        if (!mod_timer(&pq->hold_timer, sched_next))
2863                xfrm_pol_hold(pol);
2864
2865        spin_unlock_bh(&pq->hold_queue.lock);
2866
2867        return 0;
2868}
2869
2870static struct xfrm_dst *xfrm_create_dummy_bundle(struct net *net,
2871                                                 struct xfrm_flo *xflo,
2872                                                 const struct flowi *fl,
2873                                                 int num_xfrms,
2874                                                 u16 family)
2875{
2876        int err;
2877        struct net_device *dev;
2878        struct dst_entry *dst;
2879        struct dst_entry *dst1;
2880        struct xfrm_dst *xdst;
2881
2882        xdst = xfrm_alloc_dst(net, family);
2883        if (IS_ERR(xdst))
2884                return xdst;
2885
2886        if (!(xflo->flags & XFRM_LOOKUP_QUEUE) ||
2887            net->xfrm.sysctl_larval_drop ||
2888            num_xfrms <= 0)
2889                return xdst;
2890
2891        dst = xflo->dst_orig;
2892        dst1 = &xdst->u.dst;
2893        dst_hold(dst);
2894        xdst->route = dst;
2895
2896        dst_copy_metrics(dst1, dst);
2897
2898        dst1->obsolete = DST_OBSOLETE_FORCE_CHK;
2899        dst1->flags |= DST_HOST | DST_XFRM_QUEUE;
2900        dst1->lastuse = jiffies;
2901
2902        dst1->input = dst_discard;
2903        dst1->output = xdst_queue_output;
2904
2905        dst_hold(dst);
2906        xfrm_dst_set_child(xdst, dst);
2907        xdst->path = dst;
2908
2909        xfrm_init_path((struct xfrm_dst *)dst1, dst, 0);
2910
2911        err = -ENODEV;
2912        dev = dst->dev;
2913        if (!dev)
2914                goto free_dst;
2915
2916        err = xfrm_fill_dst(xdst, dev, fl);
2917        if (err)
2918                goto free_dst;
2919
2920out:
2921        return xdst;
2922
2923free_dst:
2924        dst_release(dst1);
2925        xdst = ERR_PTR(err);
2926        goto out;
2927}
2928
2929static struct xfrm_dst *xfrm_bundle_lookup(struct net *net,
2930                                           const struct flowi *fl,
2931                                           u16 family, u8 dir,
2932                                           struct xfrm_flo *xflo, u32 if_id)
2933{
2934        struct xfrm_policy *pols[XFRM_POLICY_TYPE_MAX];
2935        int num_pols = 0, num_xfrms = 0, err;
2936        struct xfrm_dst *xdst;
2937
2938        /* Resolve policies to use if we couldn't get them from
2939         * previous cache entry */
2940        num_pols = 1;
2941        pols[0] = xfrm_policy_lookup(net, fl, family, dir, if_id);
2942        err = xfrm_expand_policies(fl, family, pols,
2943                                           &num_pols, &num_xfrms);
2944        if (err < 0)
2945                goto inc_error;
2946        if (num_pols == 0)
2947                return NULL;
2948        if (num_xfrms <= 0)
2949                goto make_dummy_bundle;
2950
2951        xdst = xfrm_resolve_and_create_bundle(pols, num_pols, fl, family,
2952                                              xflo->dst_orig);
2953        if (IS_ERR(xdst)) {
2954                err = PTR_ERR(xdst);
2955                if (err == -EREMOTE) {
2956                        xfrm_pols_put(pols, num_pols);
2957                        return NULL;
2958                }
2959
2960                if (err != -EAGAIN)
2961                        goto error;
2962                goto make_dummy_bundle;
2963        } else if (xdst == NULL) {
2964                num_xfrms = 0;
2965                goto make_dummy_bundle;
2966        }
2967
2968        return xdst;
2969
2970make_dummy_bundle:
2971        /* We found policies, but there's no bundles to instantiate:
2972         * either because the policy blocks, has no transformations or
2973         * we could not build template (no xfrm_states).*/
2974        xdst = xfrm_create_dummy_bundle(net, xflo, fl, num_xfrms, family);
2975        if (IS_ERR(xdst)) {
2976                xfrm_pols_put(pols, num_pols);
2977                return ERR_CAST(xdst);
2978        }
2979        xdst->num_pols = num_pols;
2980        xdst->num_xfrms = num_xfrms;
2981        memcpy(xdst->pols, pols, sizeof(struct xfrm_policy *) * num_pols);
2982
2983        return xdst;
2984
2985inc_error:
2986        XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTPOLERROR);
2987error:
2988        xfrm_pols_put(pols, num_pols);
2989        return ERR_PTR(err);
2990}
2991
2992static struct dst_entry *make_blackhole(struct net *net, u16 family,
2993                                        struct dst_entry *dst_orig)
2994{
2995        const struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family);
2996        struct dst_entry *ret;
2997
2998        if (!afinfo) {
2999                dst_release(dst_orig);
3000                return ERR_PTR(-EINVAL);
3001        } else {
3002                ret = afinfo->blackhole_route(net, dst_orig);
3003        }
3004        rcu_read_unlock();
3005
3006        return ret;
3007}
3008
3009/* Finds/creates a bundle for given flow and if_id
3010 *
3011 * At the moment we eat a raw IP route. Mostly to speed up lookups
3012 * on interfaces with disabled IPsec.
3013 *
3014 * xfrm_lookup uses an if_id of 0 by default, and is provided for
3015 * compatibility
3016 */
3017struct dst_entry *xfrm_lookup_with_ifid(struct net *net,
3018                                        struct dst_entry *dst_orig,
3019                                        const struct flowi *fl,
3020                                        const struct sock *sk,
3021                                        int flags, u32 if_id)
3022{
3023        struct xfrm_policy *pols[XFRM_POLICY_TYPE_MAX];
3024        struct xfrm_dst *xdst;
3025        struct dst_entry *dst, *route;
3026        u16 family = dst_orig->ops->family;
3027        u8 dir = XFRM_POLICY_OUT;
3028        int i, err, num_pols, num_xfrms = 0, drop_pols = 0;
3029
3030        dst = NULL;
3031        xdst = NULL;
3032        route = NULL;
3033
3034        sk = sk_const_to_full_sk(sk);
3035        if (sk && sk->sk_policy[XFRM_POLICY_OUT]) {
3036                num_pols = 1;
3037                pols[0] = xfrm_sk_policy_lookup(sk, XFRM_POLICY_OUT, fl, family,
3038                                                if_id);
3039                err = xfrm_expand_policies(fl, family, pols,
3040                                           &num_pols, &num_xfrms);
3041                if (err < 0)
3042                        goto dropdst;
3043
3044                if (num_pols) {
3045                        if (num_xfrms <= 0) {
3046                                drop_pols = num_pols;
3047                                goto no_transform;
3048                        }
3049
3050                        xdst = xfrm_resolve_and_create_bundle(
3051                                        pols, num_pols, fl,
3052                                        family, dst_orig);
3053
3054                        if (IS_ERR(xdst)) {
3055                                xfrm_pols_put(pols, num_pols);
3056                                err = PTR_ERR(xdst);
3057                                if (err == -EREMOTE)
3058                                        goto nopol;
3059
3060                                goto dropdst;
3061                        } else if (xdst == NULL) {
3062                                num_xfrms = 0;
3063                                drop_pols = num_pols;
3064                                goto no_transform;
3065                        }
3066
3067                        route = xdst->route;
3068                }
3069        }
3070
3071        if (xdst == NULL) {
3072                struct xfrm_flo xflo;
3073
3074                xflo.dst_orig = dst_orig;
3075                xflo.flags = flags;
3076
3077                /* To accelerate a bit...  */
3078                if ((dst_orig->flags & DST_NOXFRM) ||
3079                    !net->xfrm.policy_count[XFRM_POLICY_OUT])
3080                        goto nopol;
3081
3082                xdst = xfrm_bundle_lookup(net, fl, family, dir, &xflo, if_id);
3083                if (xdst == NULL)
3084                        goto nopol;
3085                if (IS_ERR(xdst)) {
3086                        err = PTR_ERR(xdst);
3087                        goto dropdst;
3088                }
3089
3090                num_pols = xdst->num_pols;
3091                num_xfrms = xdst->num_xfrms;
3092                memcpy(pols, xdst->pols, sizeof(struct xfrm_policy *) * num_pols);
3093                route = xdst->route;
3094        }
3095
3096        dst = &xdst->u.dst;
3097        if (route == NULL && num_xfrms > 0) {
3098                /* The only case when xfrm_bundle_lookup() returns a
3099                 * bundle with null route, is when the template could
3100                 * not be resolved. It means policies are there, but
3101                 * bundle could not be created, since we don't yet
3102                 * have the xfrm_state's. We need to wait for KM to
3103                 * negotiate new SA's or bail out with error.*/
3104                if (net->xfrm.sysctl_larval_drop) {
3105                        XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTNOSTATES);
3106                        err = -EREMOTE;
3107                        goto error;
3108                }
3109
3110                err = -EAGAIN;
3111
3112                XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTNOSTATES);
3113                goto error;
3114        }
3115
3116no_transform:
3117        if (num_pols == 0)
3118                goto nopol;
3119
3120        if ((flags & XFRM_LOOKUP_ICMP) &&
3121            !(pols[0]->flags & XFRM_POLICY_ICMP)) {
3122                err = -ENOENT;
3123                goto error;
3124        }
3125
3126        for (i = 0; i < num_pols; i++)
3127                pols[i]->curlft.use_time = ktime_get_real_seconds();
3128
3129        if (num_xfrms < 0) {
3130                /* Prohibit the flow */
3131                XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTPOLBLOCK);
3132                err = -EPERM;
3133                goto error;
3134        } else if (num_xfrms > 0) {
3135                /* Flow transformed */
3136                dst_release(dst_orig);
3137        } else {
3138                /* Flow passes untransformed */
3139                dst_release(dst);
3140                dst = dst_orig;
3141        }
3142ok:
3143        xfrm_pols_put(pols, drop_pols);
3144        if (dst && dst->xfrm &&
3145            dst->xfrm->props.mode == XFRM_MODE_TUNNEL)
3146                dst->flags |= DST_XFRM_TUNNEL;
3147        return dst;
3148
3149nopol:
3150        if (!(flags & XFRM_LOOKUP_ICMP)) {
3151                dst = dst_orig;
3152                goto ok;
3153        }
3154        err = -ENOENT;
3155error:
3156        dst_release(dst);
3157dropdst:
3158        if (!(flags & XFRM_LOOKUP_KEEP_DST_REF))
3159                dst_release(dst_orig);
3160        xfrm_pols_put(pols, drop_pols);
3161        return ERR_PTR(err);
3162}
3163EXPORT_SYMBOL(xfrm_lookup_with_ifid);
3164
3165/* Main function: finds/creates a bundle for given flow.
3166 *
3167 * At the moment we eat a raw IP route. Mostly to speed up lookups
3168 * on interfaces with disabled IPsec.
3169 */
3170struct dst_entry *xfrm_lookup(struct net *net, struct dst_entry *dst_orig,
3171                              const struct flowi *fl, const struct sock *sk,
3172                              int flags)
3173{
3174        return xfrm_lookup_with_ifid(net, dst_orig, fl, sk, flags, 0);
3175}
3176EXPORT_SYMBOL(xfrm_lookup);
3177
3178/* Callers of xfrm_lookup_route() must ensure a call to dst_output().
3179 * Otherwise we may send out blackholed packets.
3180 */
3181struct dst_entry *xfrm_lookup_route(struct net *net, struct dst_entry *dst_orig,
3182                                    const struct flowi *fl,
3183                                    const struct sock *sk, int flags)
3184{
3185        struct dst_entry *dst = xfrm_lookup(net, dst_orig, fl, sk,
3186                                            flags | XFRM_LOOKUP_QUEUE |
3187                                            XFRM_LOOKUP_KEEP_DST_REF);
3188
3189        if (IS_ERR(dst) && PTR_ERR(dst) == -EREMOTE)
3190                return make_blackhole(net, dst_orig->ops->family, dst_orig);
3191
3192        if (IS_ERR(dst))
3193                dst_release(dst_orig);
3194
3195        return dst;
3196}
3197EXPORT_SYMBOL(xfrm_lookup_route);
3198
3199static inline int
3200xfrm_secpath_reject(int idx, struct sk_buff *skb, const struct flowi *fl)
3201{
3202        struct sec_path *sp = skb_sec_path(skb);
3203        struct xfrm_state *x;
3204
3205        if (!sp || idx < 0 || idx >= sp->len)
3206                return 0;
3207        x = sp->xvec[idx];
3208        if (!x->type->reject)
3209                return 0;
3210        return x->type->reject(x, skb, fl);
3211}
3212
3213/* When skb is transformed back to its "native" form, we have to
3214 * check policy restrictions. At the moment we make this in maximally
3215 * stupid way. Shame on me. :-) Of course, connected sockets must
3216 * have policy cached at them.
3217 */
3218
3219static inline int
3220xfrm_state_ok(const struct xfrm_tmpl *tmpl, const struct xfrm_state *x,
3221              unsigned short family)
3222{
3223        if (xfrm_state_kern(x))
3224                return tmpl->optional && !xfrm_state_addr_cmp(tmpl, x, tmpl->encap_family);
3225        return  x->id.proto == tmpl->id.proto &&
3226                (x->id.spi == tmpl->id.spi || !tmpl->id.spi) &&
3227                (x->props.reqid == tmpl->reqid || !tmpl->reqid) &&
3228                x->props.mode == tmpl->mode &&
3229                (tmpl->allalgs || (tmpl->aalgos & (1<<x->props.aalgo)) ||
3230                 !(xfrm_id_proto_match(tmpl->id.proto, IPSEC_PROTO_ANY))) &&
3231                !(x->props.mode != XFRM_MODE_TRANSPORT &&
3232                  xfrm_state_addr_cmp(tmpl, x, family));
3233}
3234
3235/*
3236 * 0 or more than 0 is returned when validation is succeeded (either bypass
3237 * because of optional transport mode, or next index of the mathced secpath
3238 * state with the template.
3239 * -1 is returned when no matching template is found.
3240 * Otherwise "-2 - errored_index" is returned.
3241 */
3242static inline int
3243xfrm_policy_ok(const struct xfrm_tmpl *tmpl, const struct sec_path *sp, int start,
3244               unsigned short family)
3245{
3246        int idx = start;
3247
3248        if (tmpl->optional) {
3249                if (tmpl->mode == XFRM_MODE_TRANSPORT)
3250                        return start;
3251        } else
3252                start = -1;
3253        for (; idx < sp->len; idx++) {
3254                if (xfrm_state_ok(tmpl, sp->xvec[idx], family))
3255                        return ++idx;
3256                if (sp->xvec[idx]->props.mode != XFRM_MODE_TRANSPORT) {
3257                        if (start == -1)
3258                                start = -2-idx;
3259                        break;
3260                }
3261        }
3262        return start;
3263}
3264
3265static void
3266decode_session4(struct sk_buff *skb, struct flowi *fl, bool reverse)
3267{
3268        const struct iphdr *iph = ip_hdr(skb);
3269        int ihl = iph->ihl;
3270        u8 *xprth = skb_network_header(skb) + ihl * 4;
3271        struct flowi4 *fl4 = &fl->u.ip4;
3272        int oif = 0;
3273
3274        if (skb_dst(skb) && skb_dst(skb)->dev)
3275                oif = skb_dst(skb)->dev->ifindex;
3276
3277        memset(fl4, 0, sizeof(struct flowi4));
3278        fl4->flowi4_mark = skb->mark;
3279        fl4->flowi4_oif = reverse ? skb->skb_iif : oif;
3280
3281        fl4->flowi4_proto = iph->protocol;
3282        fl4->daddr = reverse ? iph->saddr : iph->daddr;
3283        fl4->saddr = reverse ? iph->daddr : iph->saddr;
3284        fl4->flowi4_tos = iph->tos;
3285
3286        if (!ip_is_fragment(iph)) {
3287                switch (iph->protocol) {
3288                case IPPROTO_UDP:
3289                case IPPROTO_UDPLITE:
3290                case IPPROTO_TCP:
3291                case IPPROTO_SCTP:
3292                case IPPROTO_DCCP:
3293                        if (xprth + 4 < skb->data ||
3294                            pskb_may_pull(skb, xprth + 4 - skb->data)) {
3295                                __be16 *ports;
3296
3297                                xprth = skb_network_header(skb) + ihl * 4;
3298                                ports = (__be16 *)xprth;
3299
3300                                fl4->fl4_sport = ports[!!reverse];
3301                                fl4->fl4_dport = ports[!reverse];
3302                        }
3303                        break;
3304                case IPPROTO_ICMP:
3305                        if (xprth + 2 < skb->data ||
3306                            pskb_may_pull(skb, xprth + 2 - skb->data)) {
3307                                u8 *icmp;
3308
3309                                xprth = skb_network_header(skb) + ihl * 4;
3310                                icmp = xprth;
3311
3312                                fl4->fl4_icmp_type = icmp[0];
3313                                fl4->fl4_icmp_code = icmp[1];
3314                        }
3315                        break;
3316                case IPPROTO_ESP:
3317                        if (xprth + 4 < skb->data ||
3318                            pskb_may_pull(skb, xprth + 4 - skb->data)) {
3319                                __be32 *ehdr;
3320
3321                                xprth = skb_network_header(skb) + ihl * 4;
3322                                ehdr = (__be32 *)xprth;
3323
3324                                fl4->fl4_ipsec_spi = ehdr[0];
3325                        }
3326                        break;
3327                case IPPROTO_AH:
3328                        if (xprth + 8 < skb->data ||
3329                            pskb_may_pull(skb, xprth + 8 - skb->data)) {
3330                                __be32 *ah_hdr;
3331
3332                                xprth = skb_network_header(skb) + ihl * 4;
3333                                ah_hdr = (__be32 *)xprth;
3334
3335                                fl4->fl4_ipsec_spi = ah_hdr[1];
3336                        }
3337                        break;
3338                case IPPROTO_COMP:
3339                        if (xprth + 4 < skb->data ||
3340                            pskb_may_pull(skb, xprth + 4 - skb->data)) {
3341                                __be16 *ipcomp_hdr;
3342
3343                                xprth = skb_network_header(skb) + ihl * 4;
3344                                ipcomp_hdr = (__be16 *)xprth;
3345
3346                                fl4->fl4_ipsec_spi = htonl(ntohs(ipcomp_hdr[1]));
3347                        }
3348                        break;
3349                case IPPROTO_GRE:
3350                        if (xprth + 12 < skb->data ||
3351                            pskb_may_pull(skb, xprth + 12 - skb->data)) {
3352                                __be16 *greflags;
3353                                __be32 *gre_hdr;
3354
3355                                xprth = skb_network_header(skb) + ihl * 4;
3356                                greflags = (__be16 *)xprth;
3357                                gre_hdr = (__be32 *)xprth;
3358
3359                                if (greflags[0] & GRE_KEY) {
3360                                        if (greflags[0] & GRE_CSUM)
3361                                                gre_hdr++;
3362                                        fl4->fl4_gre_key = gre_hdr[1];
3363                                }
3364                        }
3365                        break;
3366                default:
3367                        fl4->fl4_ipsec_spi = 0;
3368                        break;
3369                }
3370        }
3371}
3372
3373#if IS_ENABLED(CONFIG_IPV6)
3374static void
3375decode_session6(struct sk_buff *skb, struct flowi *fl, bool reverse)
3376{
3377        struct flowi6 *fl6 = &fl->u.ip6;
3378        int onlyproto = 0;
3379        const struct ipv6hdr *hdr = ipv6_hdr(skb);
3380        u32 offset = sizeof(*hdr);
3381        struct ipv6_opt_hdr *exthdr;
3382        const unsigned char *nh = skb_network_header(skb);
3383        u16 nhoff = IP6CB(skb)->nhoff;
3384        int oif = 0;
3385        u8 nexthdr;
3386
3387        if (!nhoff)
3388                nhoff = offsetof(struct ipv6hdr, nexthdr);
3389
3390        nexthdr = nh[nhoff];
3391
3392        if (skb_dst(skb) && skb_dst(skb)->dev)
3393                oif = skb_dst(skb)->dev->ifindex;
3394
3395        memset(fl6, 0, sizeof(struct flowi6));
3396        fl6->flowi6_mark = skb->mark;
3397        fl6->flowi6_oif = reverse ? skb->skb_iif : oif;
3398
3399        fl6->daddr = reverse ? hdr->saddr : hdr->daddr;
3400        fl6->saddr = reverse ? hdr->daddr : hdr->saddr;
3401
3402        while (nh + offset + sizeof(*exthdr) < skb->data ||
3403               pskb_may_pull(skb, nh + offset + sizeof(*exthdr) - skb->data)) {
3404                nh = skb_network_header(skb);
3405                exthdr = (struct ipv6_opt_hdr *)(nh + offset);
3406
3407                switch (nexthdr) {
3408                case NEXTHDR_FRAGMENT:
3409                        onlyproto = 1;
3410                        /* fall through */
3411                case NEXTHDR_ROUTING:
3412                case NEXTHDR_HOP:
3413                case NEXTHDR_DEST:
3414                        offset += ipv6_optlen(exthdr);
3415                        nexthdr = exthdr->nexthdr;
3416                        exthdr = (struct ipv6_opt_hdr *)(nh + offset);
3417                        break;
3418                case IPPROTO_UDP:
3419                case IPPROTO_UDPLITE:
3420                case IPPROTO_TCP:
3421                case IPPROTO_SCTP:
3422                case IPPROTO_DCCP:
3423                        if (!onlyproto && (nh + offset + 4 < skb->data ||
3424                             pskb_may_pull(skb, nh + offset + 4 - skb->data))) {
3425                                __be16 *ports;
3426
3427                                nh = skb_network_header(skb);
3428                                ports = (__be16 *)(nh + offset);
3429                                fl6->fl6_sport = ports[!!reverse];
3430                                fl6->fl6_dport = ports[!reverse];
3431                        }
3432                        fl6->flowi6_proto = nexthdr;
3433                        return;
3434                case IPPROTO_ICMPV6:
3435                        if (!onlyproto && (nh + offset + 2 < skb->data ||
3436                            pskb_may_pull(skb, nh + offset + 2 - skb->data))) {
3437                                u8 *icmp;
3438
3439                                nh = skb_network_header(skb);
3440                                icmp = (u8 *)(nh + offset);
3441                                fl6->fl6_icmp_type = icmp[0];
3442                                fl6->fl6_icmp_code = icmp[1];
3443                        }
3444                        fl6->flowi6_proto = nexthdr;
3445                        return;
3446#if IS_ENABLED(CONFIG_IPV6_MIP6)
3447                case IPPROTO_MH:
3448                        offset += ipv6_optlen(exthdr);
3449                        if (!onlyproto && (nh + offset + 3 < skb->data ||
3450                            pskb_may_pull(skb, nh + offset + 3 - skb->data))) {
3451                                struct ip6_mh *mh;
3452
3453                                nh = skb_network_header(skb);
3454                                mh = (struct ip6_mh *)(nh + offset);
3455                                fl6->fl6_mh_type = mh->ip6mh_type;
3456                        }
3457                        fl6->flowi6_proto = nexthdr;
3458                        return;
3459#endif
3460                /* XXX Why are there these headers? */
3461                case IPPROTO_AH:
3462                case IPPROTO_ESP:
3463                case IPPROTO_COMP:
3464                default:
3465                        fl6->fl6_ipsec_spi = 0;
3466                        fl6->flowi6_proto = nexthdr;
3467                        return;
3468                }
3469        }
3470}
3471#endif
3472
3473int __xfrm_decode_session(struct sk_buff *skb, struct flowi *fl,
3474                          unsigned int family, int reverse)
3475{
3476        switch (family) {
3477        case AF_INET:
3478                decode_session4(skb, fl, reverse);
3479                break;
3480#if IS_ENABLED(CONFIG_IPV6)
3481        case AF_INET6:
3482                decode_session6(skb, fl, reverse);
3483                break;
3484#endif
3485        default:
3486                return -EAFNOSUPPORT;
3487        }
3488
3489        return security_xfrm_decode_session(skb, &fl->flowi_secid);
3490}
3491EXPORT_SYMBOL(__xfrm_decode_session);
3492
3493static inline int secpath_has_nontransport(const struct sec_path *sp, int k, int *idxp)
3494{
3495        for (; k < sp->len; k++) {
3496                if (sp->xvec[k]->props.mode != XFRM_MODE_TRANSPORT) {
3497                        *idxp = k;
3498                        return 1;
3499                }
3500        }
3501
3502        return 0;
3503}
3504
3505int __xfrm_policy_check(struct sock *sk, int dir, struct sk_buff *skb,
3506                        unsigned short family)
3507{
3508        struct net *net = dev_net(skb->dev);
3509        struct xfrm_policy *pol;
3510        struct xfrm_policy *pols[XFRM_POLICY_TYPE_MAX];
3511        int npols = 0;
3512        int xfrm_nr;
3513        int pi;
3514        int reverse;
3515        struct flowi fl;
3516        int xerr_idx = -1;
3517        const struct xfrm_if_cb *ifcb;
3518        struct sec_path *sp;
3519        struct xfrm_if *xi;
3520        u32 if_id = 0;
3521
3522        rcu_read_lock();
3523        ifcb = xfrm_if_get_cb();
3524
3525        if (ifcb) {
3526                xi = ifcb->decode_session(skb, family);
3527                if (xi) {
3528                        if_id = xi->p.if_id;
3529                        net = xi->net;
3530                }
3531        }
3532        rcu_read_unlock();
3533
3534        reverse = dir & ~XFRM_POLICY_MASK;
3535        dir &= XFRM_POLICY_MASK;
3536
3537        if (__xfrm_decode_session(skb, &fl, family, reverse) < 0) {
3538                XFRM_INC_STATS(net, LINUX_MIB_XFRMINHDRERROR);
3539                return 0;
3540        }
3541
3542        nf_nat_decode_session(skb, &fl, family);
3543
3544        /* First, check used SA against their selectors. */
3545        sp = skb_sec_path(skb);
3546        if (sp) {
3547                int i;
3548
3549                for (i = sp->len - 1; i >= 0; i--) {
3550                        struct xfrm_state *x = sp->xvec[i];
3551                        if (!xfrm_selector_match(&x->sel, &fl, family)) {
3552                                XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATEMISMATCH);
3553                                return 0;
3554                        }
3555                }
3556        }
3557
3558        pol = NULL;
3559        sk = sk_to_full_sk(sk);
3560        if (sk && sk->sk_policy[dir]) {
3561                pol = xfrm_sk_policy_lookup(sk, dir, &fl, family, if_id);
3562                if (IS_ERR(pol)) {
3563                        XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLERROR);
3564                        return 0;
3565                }
3566        }
3567
3568        if (!pol)
3569                pol = xfrm_policy_lookup(net, &fl, family, dir, if_id);
3570
3571        if (IS_ERR(pol)) {
3572                XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLERROR);
3573                return 0;
3574        }
3575
3576        if (!pol) {
3577                if (sp && secpath_has_nontransport(sp, 0, &xerr_idx)) {
3578                        xfrm_secpath_reject(xerr_idx, skb, &fl);
3579                        XFRM_INC_STATS(net, LINUX_MIB_XFRMINNOPOLS);
3580                        return 0;
3581                }
3582                return 1;
3583        }
3584
3585        pol->curlft.use_time = ktime_get_real_seconds();
3586
3587        pols[0] = pol;
3588        npols++;
3589#ifdef CONFIG_XFRM_SUB_POLICY
3590        if (pols[0]->type != XFRM_POLICY_TYPE_MAIN) {
3591                pols[1] = xfrm_policy_lookup_bytype(net, XFRM_POLICY_TYPE_MAIN,
3592                                                    &fl, family,
3593                                                    XFRM_POLICY_IN, if_id);
3594                if (pols[1]) {
3595                        if (IS_ERR(pols[1])) {
3596                                XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLERROR);
3597                                return 0;
3598                        }
3599                        pols[1]->curlft.use_time = ktime_get_real_seconds();
3600                        npols++;
3601                }
3602        }
3603#endif
3604
3605        if (pol->action == XFRM_POLICY_ALLOW) {
3606                static struct sec_path dummy;
3607                struct xfrm_tmpl *tp[XFRM_MAX_DEPTH];
3608                struct xfrm_tmpl *stp[XFRM_MAX_DEPTH];
3609                struct xfrm_tmpl **tpp = tp;
3610                int ti = 0;
3611                int i, k;
3612
3613                sp = skb_sec_path(skb);
3614                if (!sp)
3615                        sp = &dummy;
3616
3617                for (pi = 0; pi < npols; pi++) {
3618                        if (pols[pi] != pol &&
3619                            pols[pi]->action != XFRM_POLICY_ALLOW) {
3620                                XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLBLOCK);
3621                                goto reject;
3622                        }
3623                        if (ti + pols[pi]->xfrm_nr >= XFRM_MAX_DEPTH) {
3624                                XFRM_INC_STATS(net, LINUX_MIB_XFRMINBUFFERERROR);
3625                                goto reject_error;
3626                        }
3627                        for (i = 0; i < pols[pi]->xfrm_nr; i++)
3628                                tpp[ti++] = &pols[pi]->xfrm_vec[i];
3629                }
3630                xfrm_nr = ti;
3631                if (npols > 1) {
3632                        xfrm_tmpl_sort(stp, tpp, xfrm_nr, family);
3633                        tpp = stp;
3634                }
3635
3636                /* For each tunnel xfrm, find the first matching tmpl.
3637                 * For each tmpl before that, find corresponding xfrm.
3638                 * Order is _important_. Later we will implement
3639                 * some barriers, but at the moment barriers
3640                 * are implied between each two transformations.
3641                 */
3642                for (i = xfrm_nr-1, k = 0; i >= 0; i--) {
3643                        k = xfrm_policy_ok(tpp[i], sp, k, family);
3644                        if (k < 0) {
3645                                if (k < -1)
3646                                        /* "-2 - errored_index" returned */
3647                                        xerr_idx = -(2+k);
3648                                XFRM_INC_STATS(net, LINUX_MIB_XFRMINTMPLMISMATCH);
3649                                goto reject;
3650                        }
3651                }
3652
3653                if (secpath_has_nontransport(sp, k, &xerr_idx)) {
3654                        XFRM_INC_STATS(net, LINUX_MIB_XFRMINTMPLMISMATCH);
3655                        goto reject;
3656                }
3657
3658                xfrm_pols_put(pols, npols);
3659                return 1;
3660        }
3661        XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLBLOCK);
3662
3663reject:
3664        xfrm_secpath_reject(xerr_idx, skb, &fl);
3665reject_error:
3666        xfrm_pols_put(pols, npols);
3667        return 0;
3668}
3669EXPORT_SYMBOL(__xfrm_policy_check);
3670
3671int __xfrm_route_forward(struct sk_buff *skb, unsigned short family)
3672{
3673        struct net *net = dev_net(skb->dev);
3674        struct flowi fl;
3675        struct dst_entry *dst;
3676        int res = 1;
3677
3678        if (xfrm_decode_session(skb, &fl, family) < 0) {
3679                XFRM_INC_STATS(net, LINUX_MIB_XFRMFWDHDRERROR);
3680                return 0;
3681        }
3682
3683        skb_dst_force(skb);
3684        if (!skb_dst(skb)) {
3685                XFRM_INC_STATS(net, LINUX_MIB_XFRMFWDHDRERROR);
3686                return 0;
3687        }
3688
3689        dst = xfrm_lookup(net, skb_dst(skb), &fl, NULL, XFRM_LOOKUP_QUEUE);
3690        if (IS_ERR(dst)) {
3691                res = 0;
3692                dst = NULL;
3693        }
3694        skb_dst_set(skb, dst);
3695        return res;
3696}
3697EXPORT_SYMBOL(__xfrm_route_forward);
3698
3699/* Optimize later using cookies and generation ids. */
3700
3701static struct dst_entry *xfrm_dst_check(struct dst_entry *dst, u32 cookie)
3702{
3703        /* Code (such as __xfrm4_bundle_create()) sets dst->obsolete
3704         * to DST_OBSOLETE_FORCE_CHK to force all XFRM destinations to
3705         * get validated by dst_ops->check on every use.  We do this
3706         * because when a normal route referenced by an XFRM dst is
3707         * obsoleted we do not go looking around for all parent
3708         * referencing XFRM dsts so that we can invalidate them.  It
3709         * is just too much work.  Instead we make the checks here on
3710         * every use.  For example:
3711         *
3712         *      XFRM dst A --> IPv4 dst X
3713         *
3714         * X is the "xdst->route" of A (X is also the "dst->path" of A
3715         * in this example).  If X is marked obsolete, "A" will not
3716         * notice.  That's what we are validating here via the
3717         * stale_bundle() check.
3718         *
3719         * When a dst is removed from the fib tree, DST_OBSOLETE_DEAD will
3720         * be marked on it.
3721         * This will force stale_bundle() to fail on any xdst bundle with
3722         * this dst linked in it.
3723         */
3724        if (dst->obsolete < 0 && !stale_bundle(dst))
3725                return dst;
3726
3727        return NULL;
3728}
3729
3730static int stale_bundle(struct dst_entry *dst)
3731{
3732        return !xfrm_bundle_ok((struct xfrm_dst *)dst);
3733}
3734
3735void xfrm_dst_ifdown(struct dst_entry *dst, struct net_device *dev)
3736{
3737        while ((dst = xfrm_dst_child(dst)) && dst->xfrm && dst->dev == dev) {
3738                dst->dev = dev_net(dev)->loopback_dev;
3739                dev_hold(dst->dev);
3740                dev_put(dev);
3741        }
3742}
3743EXPORT_SYMBOL(xfrm_dst_ifdown);
3744
3745static void xfrm_link_failure(struct sk_buff *skb)
3746{
3747        /* Impossible. Such dst must be popped before reaches point of failure. */
3748}
3749
3750static struct dst_entry *xfrm_negative_advice(struct dst_entry *dst)
3751{
3752        if (dst) {
3753                if (dst->obsolete) {
3754                        dst_release(dst);
3755                        dst = NULL;
3756                }
3757        }
3758        return dst;
3759}
3760
3761static void xfrm_init_pmtu(struct xfrm_dst **bundle, int nr)
3762{
3763        while (nr--) {
3764                struct xfrm_dst *xdst = bundle[nr];
3765                u32 pmtu, route_mtu_cached;
3766                struct dst_entry *dst;
3767
3768                dst = &xdst->u.dst;
3769                pmtu = dst_mtu(xfrm_dst_child(dst));
3770                xdst->child_mtu_cached = pmtu;
3771
3772                pmtu = xfrm_state_mtu(dst->xfrm, pmtu);
3773
3774                route_mtu_cached = dst_mtu(xdst->route);
3775                xdst->route_mtu_cached = route_mtu_cached;
3776
3777                if (pmtu > route_mtu_cached)
3778                        pmtu = route_mtu_cached;
3779
3780                dst_metric_set(dst, RTAX_MTU, pmtu);
3781        }
3782}
3783
3784/* Check that the bundle accepts the flow and its components are
3785 * still valid.
3786 */
3787
3788static int xfrm_bundle_ok(struct xfrm_dst *first)
3789{
3790        struct xfrm_dst *bundle[XFRM_MAX_DEPTH];
3791        struct dst_entry *dst = &first->u.dst;
3792        struct xfrm_dst *xdst;
3793        int start_from, nr;
3794        u32 mtu;
3795
3796        if (!dst_check(xfrm_dst_path(dst), ((struct xfrm_dst *)dst)->path_cookie) ||
3797            (dst->dev && !netif_running(dst->dev)))
3798                return 0;
3799
3800        if (dst->flags & DST_XFRM_QUEUE)
3801                return 1;
3802
3803        start_from = nr = 0;
3804        do {
3805                struct xfrm_dst *xdst = (struct xfrm_dst *)dst;
3806
3807                if (dst->xfrm->km.state != XFRM_STATE_VALID)
3808                        return 0;
3809                if (xdst->xfrm_genid != dst->xfrm->genid)
3810                        return 0;
3811                if (xdst->num_pols > 0 &&
3812                    xdst->policy_genid != atomic_read(&xdst->pols[0]->genid))
3813                        return 0;
3814
3815                bundle[nr++] = xdst;
3816
3817                mtu = dst_mtu(xfrm_dst_child(dst));
3818                if (xdst->child_mtu_cached != mtu) {
3819                        start_from = nr;
3820                        xdst->child_mtu_cached = mtu;
3821                }
3822
3823                if (!dst_check(xdst->route, xdst->route_cookie))
3824                        return 0;
3825                mtu = dst_mtu(xdst->route);
3826                if (xdst->route_mtu_cached != mtu) {
3827                        start_from = nr;
3828                        xdst->route_mtu_cached = mtu;
3829                }
3830
3831                dst = xfrm_dst_child(dst);
3832        } while (dst->xfrm);
3833
3834        if (likely(!start_from))
3835                return 1;
3836
3837        xdst = bundle[start_from - 1];
3838        mtu = xdst->child_mtu_cached;
3839        while (start_from--) {
3840                dst = &xdst->u.dst;
3841
3842                mtu = xfrm_state_mtu(dst->xfrm, mtu);
3843                if (mtu > xdst->route_mtu_cached)
3844                        mtu = xdst->route_mtu_cached;
3845                dst_metric_set(dst, RTAX_MTU, mtu);
3846                if (!start_from)
3847                        break;
3848
3849                xdst = bundle[start_from - 1];
3850                xdst->child_mtu_cached = mtu;
3851        }
3852
3853        return 1;
3854}
3855
3856static unsigned int xfrm_default_advmss(const struct dst_entry *dst)
3857{
3858        return dst_metric_advmss(xfrm_dst_path(dst));
3859}
3860
3861static unsigned int xfrm_mtu(const struct dst_entry *dst)
3862{
3863        unsigned int mtu = dst_metric_raw(dst, RTAX_MTU);
3864
3865        return mtu ? : dst_mtu(xfrm_dst_path(dst));
3866}
3867
3868static const void *xfrm_get_dst_nexthop(const struct dst_entry *dst,
3869                                        const void *daddr)
3870{
3871        while (dst->xfrm) {
3872                const struct xfrm_state *xfrm = dst->xfrm;
3873
3874                dst = xfrm_dst_child(dst);
3875
3876                if (xfrm->props.mode == XFRM_MODE_TRANSPORT)
3877                        continue;
3878                if (xfrm->type->flags & XFRM_TYPE_REMOTE_COADDR)
3879                        daddr = xfrm->coaddr;
3880                else if (!(xfrm->type->flags & XFRM_TYPE_LOCAL_COADDR))
3881                        daddr = &xfrm->id.daddr;
3882        }
3883        return daddr;
3884}
3885
3886static struct neighbour *xfrm_neigh_lookup(const struct dst_entry *dst,
3887                                           struct sk_buff *skb,
3888                                           const void *daddr)
3889{
3890        const struct dst_entry *path = xfrm_dst_path(dst);
3891
3892        if (!skb)
3893                daddr = xfrm_get_dst_nexthop(dst, daddr);
3894        return path->ops->neigh_lookup(path, skb, daddr);
3895}
3896
3897static void xfrm_confirm_neigh(const struct dst_entry *dst, const void *daddr)
3898{
3899        const struct dst_entry *path = xfrm_dst_path(dst);
3900
3901        daddr = xfrm_get_dst_nexthop(dst, daddr);
3902        path->ops->confirm_neigh(path, daddr);
3903}
3904
3905int xfrm_policy_register_afinfo(const struct xfrm_policy_afinfo *afinfo, int family)
3906{
3907        int err = 0;
3908
3909        if (WARN_ON(family >= ARRAY_SIZE(xfrm_policy_afinfo)))
3910                return -EAFNOSUPPORT;
3911
3912        spin_lock(&xfrm_policy_afinfo_lock);
3913        if (unlikely(xfrm_policy_afinfo[family] != NULL))
3914                err = -EEXIST;
3915        else {
3916                struct dst_ops *dst_ops = afinfo->dst_ops;
3917                if (likely(dst_ops->kmem_cachep == NULL))
3918                        dst_ops->kmem_cachep = xfrm_dst_cache;
3919                if (likely(dst_ops->check == NULL))
3920                        dst_ops->check = xfrm_dst_check;
3921                if (likely(dst_ops->default_advmss == NULL))
3922                        dst_ops->default_advmss = xfrm_default_advmss;
3923                if (likely(dst_ops->mtu == NULL))
3924                        dst_ops->mtu = xfrm_mtu;
3925                if (likely(dst_ops->negative_advice == NULL))
3926                        dst_ops->negative_advice = xfrm_negative_advice;
3927                if (likely(dst_ops->link_failure == NULL))
3928                        dst_ops->link_failure = xfrm_link_failure;
3929                if (likely(dst_ops->neigh_lookup == NULL))
3930                        dst_ops->neigh_lookup = xfrm_neigh_lookup;
3931                if (likely(!dst_ops->confirm_neigh))
3932                        dst_ops->confirm_neigh = xfrm_confirm_neigh;
3933                rcu_assign_pointer(xfrm_policy_afinfo[family], afinfo);
3934        }
3935        spin_unlock(&xfrm_policy_afinfo_lock);
3936
3937        return err;
3938}
3939EXPORT_SYMBOL(xfrm_policy_register_afinfo);
3940
3941void xfrm_policy_unregister_afinfo(const struct xfrm_policy_afinfo *afinfo)
3942{
3943        struct dst_ops *dst_ops = afinfo->dst_ops;
3944        int i;
3945
3946        for (i = 0; i < ARRAY_SIZE(xfrm_policy_afinfo); i++) {
3947                if (xfrm_policy_afinfo[i] != afinfo)
3948                        continue;
3949                RCU_INIT_POINTER(xfrm_policy_afinfo[i], NULL);
3950                break;
3951        }
3952
3953        synchronize_rcu();
3954
3955        dst_ops->kmem_cachep = NULL;
3956        dst_ops->check = NULL;
3957        dst_ops->negative_advice = NULL;
3958        dst_ops->link_failure = NULL;
3959}
3960EXPORT_SYMBOL(xfrm_policy_unregister_afinfo);
3961
3962void xfrm_if_register_cb(const struct xfrm_if_cb *ifcb)
3963{
3964        spin_lock(&xfrm_if_cb_lock);
3965        rcu_assign_pointer(xfrm_if_cb, ifcb);
3966        spin_unlock(&xfrm_if_cb_lock);
3967}
3968EXPORT_SYMBOL(xfrm_if_register_cb);
3969
3970void xfrm_if_unregister_cb(void)
3971{
3972        RCU_INIT_POINTER(xfrm_if_cb, NULL);
3973        synchronize_rcu();
3974}
3975EXPORT_SYMBOL(xfrm_if_unregister_cb);
3976
3977#ifdef CONFIG_XFRM_STATISTICS
3978static int __net_init xfrm_statistics_init(struct net *net)
3979{
3980        int rv;
3981        net->mib.xfrm_statistics = alloc_percpu(struct linux_xfrm_mib);
3982        if (!net->mib.xfrm_statistics)
3983                return -ENOMEM;
3984        rv = xfrm_proc_init(net);
3985        if (rv < 0)
3986                free_percpu(net->mib.xfrm_statistics);
3987        return rv;
3988}
3989
3990static void xfrm_statistics_fini(struct net *net)
3991{
3992        xfrm_proc_fini(net);
3993        free_percpu(net->mib.xfrm_statistics);
3994}
3995#else
3996static int __net_init xfrm_statistics_init(struct net *net)
3997{
3998        return 0;
3999}
4000
4001static void xfrm_statistics_fini(struct net *net)
4002{
4003}
4004#endif
4005
4006static int __net_init xfrm_policy_init(struct net *net)
4007{
4008        unsigned int hmask, sz;
4009        int dir, err;
4010
4011        if (net_eq(net, &init_net)) {
4012                xfrm_dst_cache = kmem_cache_create("xfrm_dst_cache",
4013                                           sizeof(struct xfrm_dst),
4014                                           0, SLAB_HWCACHE_ALIGN|SLAB_PANIC,
4015                                           NULL);
4016                err = rhashtable_init(&xfrm_policy_inexact_table,
4017                                      &xfrm_pol_inexact_params);
4018                BUG_ON(err);
4019        }
4020
4021        hmask = 8 - 1;
4022        sz = (hmask+1) * sizeof(struct hlist_head);
4023
4024        net->xfrm.policy_byidx = xfrm_hash_alloc(sz);
4025        if (!net->xfrm.policy_byidx)
4026                goto out_byidx;
4027        net->xfrm.policy_idx_hmask = hmask;
4028
4029        for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
4030                struct xfrm_policy_hash *htab;
4031
4032                net->xfrm.policy_count[dir] = 0;
4033                net->xfrm.policy_count[XFRM_POLICY_MAX + dir] = 0;
4034                INIT_HLIST_HEAD(&net->xfrm.policy_inexact[dir]);
4035
4036                htab = &net->xfrm.policy_bydst[dir];
4037                htab->table = xfrm_hash_alloc(sz);
4038                if (!htab->table)
4039                        goto out_bydst;
4040                htab->hmask = hmask;
4041                htab->dbits4 = 32;
4042                htab->sbits4 = 32;
4043                htab->dbits6 = 128;
4044                htab->sbits6 = 128;
4045        }
4046        net->xfrm.policy_hthresh.lbits4 = 32;
4047        net->xfrm.policy_hthresh.rbits4 = 32;
4048        net->xfrm.policy_hthresh.lbits6 = 128;
4049        net->xfrm.policy_hthresh.rbits6 = 128;
4050
4051        seqlock_init(&net->xfrm.policy_hthresh.lock);
4052
4053        INIT_LIST_HEAD(&net->xfrm.policy_all);
4054        INIT_LIST_HEAD(&net->xfrm.inexact_bins);
4055        INIT_WORK(&net->xfrm.policy_hash_work, xfrm_hash_resize);
4056        INIT_WORK(&net->xfrm.policy_hthresh.work, xfrm_hash_rebuild);
4057        return 0;
4058
4059out_bydst:
4060        for (dir--; dir >= 0; dir--) {
4061                struct xfrm_policy_hash *htab;
4062
4063                htab = &net->xfrm.policy_bydst[dir];
4064                xfrm_hash_free(htab->table, sz);
4065        }
4066        xfrm_hash_free(net->xfrm.policy_byidx, sz);
4067out_byidx:
4068        return -ENOMEM;
4069}
4070
4071static void xfrm_policy_fini(struct net *net)
4072{
4073        struct xfrm_pol_inexact_bin *b, *t;
4074        unsigned int sz;
4075        int dir;
4076
4077        flush_work(&net->xfrm.policy_hash_work);
4078#ifdef CONFIG_XFRM_SUB_POLICY
4079        xfrm_policy_flush(net, XFRM_POLICY_TYPE_SUB, false);
4080#endif
4081        xfrm_policy_flush(net, XFRM_POLICY_TYPE_MAIN, false);
4082
4083        WARN_ON(!list_empty(&net->xfrm.policy_all));
4084
4085        for (dir = 0; dir < XFRM_POLICY_MAX; dir++) {
4086                struct xfrm_policy_hash *htab;
4087
4088                WARN_ON(!hlist_empty(&net->xfrm.policy_inexact[dir]));
4089
4090                htab = &net->xfrm.policy_bydst[dir];
4091                sz = (htab->hmask + 1) * sizeof(struct hlist_head);
4092                WARN_ON(!hlist_empty(htab->table));
4093                xfrm_hash_free(htab->table, sz);
4094        }
4095
4096        sz = (net->xfrm.policy_idx_hmask + 1) * sizeof(struct hlist_head);
4097        WARN_ON(!hlist_empty(net->xfrm.policy_byidx));
4098        xfrm_hash_free(net->xfrm.policy_byidx, sz);
4099
4100        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
4101        list_for_each_entry_safe(b, t, &net->xfrm.inexact_bins, inexact_bins)
4102                __xfrm_policy_inexact_prune_bin(b, true);
4103        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
4104}
4105
4106static int __net_init xfrm_net_init(struct net *net)
4107{
4108        int rv;
4109
4110        /* Initialize the per-net locks here */
4111        spin_lock_init(&net->xfrm.xfrm_state_lock);
4112        spin_lock_init(&net->xfrm.xfrm_policy_lock);
4113        mutex_init(&net->xfrm.xfrm_cfg_mutex);
4114
4115        rv = xfrm_statistics_init(net);
4116        if (rv < 0)
4117                goto out_statistics;
4118        rv = xfrm_state_init(net);
4119        if (rv < 0)
4120                goto out_state;
4121        rv = xfrm_policy_init(net);
4122        if (rv < 0)
4123                goto out_policy;
4124        rv = xfrm_sysctl_init(net);
4125        if (rv < 0)
4126                goto out_sysctl;
4127
4128        return 0;
4129
4130out_sysctl:
4131        xfrm_policy_fini(net);
4132out_policy:
4133        xfrm_state_fini(net);
4134out_state:
4135        xfrm_statistics_fini(net);
4136out_statistics:
4137        return rv;
4138}
4139
4140static void __net_exit xfrm_net_exit(struct net *net)
4141{
4142        xfrm_sysctl_fini(net);
4143        xfrm_policy_fini(net);
4144        xfrm_state_fini(net);
4145        xfrm_statistics_fini(net);
4146}
4147
4148static struct pernet_operations __net_initdata xfrm_net_ops = {
4149        .init = xfrm_net_init,
4150        .exit = xfrm_net_exit,
4151};
4152
4153void __init xfrm_init(void)
4154{
4155        register_pernet_subsys(&xfrm_net_ops);
4156        xfrm_dev_init();
4157        seqcount_init(&xfrm_policy_hash_generation);
4158        xfrm_input_init();
4159
4160        RCU_INIT_POINTER(xfrm_if_cb, NULL);
4161        synchronize_rcu();
4162}
4163
4164#ifdef CONFIG_AUDITSYSCALL
4165static void xfrm_audit_common_policyinfo(struct xfrm_policy *xp,
4166                                         struct audit_buffer *audit_buf)
4167{
4168        struct xfrm_sec_ctx *ctx = xp->security;
4169        struct xfrm_selector *sel = &xp->selector;
4170
4171        if (ctx)
4172                audit_log_format(audit_buf, " sec_alg=%u sec_doi=%u sec_obj=%s",
4173                                 ctx->ctx_alg, ctx->ctx_doi, ctx->ctx_str);
4174
4175        switch (sel->family) {
4176        case AF_INET:
4177                audit_log_format(audit_buf, " src=%pI4", &sel->saddr.a4);
4178                if (sel->prefixlen_s != 32)
4179                        audit_log_format(audit_buf, " src_prefixlen=%d",
4180                                         sel->prefixlen_s);
4181                audit_log_format(audit_buf, " dst=%pI4", &sel->daddr.a4);
4182                if (sel->prefixlen_d != 32)
4183                        audit_log_format(audit_buf, " dst_prefixlen=%d",
4184                                         sel->prefixlen_d);
4185                break;
4186        case AF_INET6:
4187                audit_log_format(audit_buf, " src=%pI6", sel->saddr.a6);
4188                if (sel->prefixlen_s != 128)
4189                        audit_log_format(audit_buf, " src_prefixlen=%d",
4190                                         sel->prefixlen_s);
4191                audit_log_format(audit_buf, " dst=%pI6", sel->daddr.a6);
4192                if (sel->prefixlen_d != 128)
4193                        audit_log_format(audit_buf, " dst_prefixlen=%d",
4194                                         sel->prefixlen_d);
4195                break;
4196        }
4197}
4198
4199void xfrm_audit_policy_add(struct xfrm_policy *xp, int result, bool task_valid)
4200{
4201        struct audit_buffer *audit_buf;
4202
4203        audit_buf = xfrm_audit_start("SPD-add");
4204        if (audit_buf == NULL)
4205                return;
4206        xfrm_audit_helper_usrinfo(task_valid, audit_buf);
4207        audit_log_format(audit_buf, " res=%u", result);
4208        xfrm_audit_common_policyinfo(xp, audit_buf);
4209        audit_log_end(audit_buf);
4210}
4211EXPORT_SYMBOL_GPL(xfrm_audit_policy_add);
4212
4213void xfrm_audit_policy_delete(struct xfrm_policy *xp, int result,
4214                              bool task_valid)
4215{
4216        struct audit_buffer *audit_buf;
4217
4218        audit_buf = xfrm_audit_start("SPD-delete");
4219        if (audit_buf == NULL)
4220                return;
4221        xfrm_audit_helper_usrinfo(task_valid, audit_buf);
4222        audit_log_format(audit_buf, " res=%u", result);
4223        xfrm_audit_common_policyinfo(xp, audit_buf);
4224        audit_log_end(audit_buf);
4225}
4226EXPORT_SYMBOL_GPL(xfrm_audit_policy_delete);
4227#endif
4228
4229#ifdef CONFIG_XFRM_MIGRATE
4230static bool xfrm_migrate_selector_match(const struct xfrm_selector *sel_cmp,
4231                                        const struct xfrm_selector *sel_tgt)
4232{
4233        if (sel_cmp->proto == IPSEC_ULPROTO_ANY) {
4234                if (sel_tgt->family == sel_cmp->family &&
4235                    xfrm_addr_equal(&sel_tgt->daddr, &sel_cmp->daddr,
4236                                    sel_cmp->family) &&
4237                    xfrm_addr_equal(&sel_tgt->saddr, &sel_cmp->saddr,
4238                                    sel_cmp->family) &&
4239                    sel_tgt->prefixlen_d == sel_cmp->prefixlen_d &&
4240                    sel_tgt->prefixlen_s == sel_cmp->prefixlen_s) {
4241                        return true;
4242                }
4243        } else {
4244                if (memcmp(sel_tgt, sel_cmp, sizeof(*sel_tgt)) == 0) {
4245                        return true;
4246                }
4247        }
4248        return false;
4249}
4250
4251static struct xfrm_policy *xfrm_migrate_policy_find(const struct xfrm_selector *sel,
4252                                                    u8 dir, u8 type, struct net *net)
4253{
4254        struct xfrm_policy *pol, *ret = NULL;
4255        struct hlist_head *chain;
4256        u32 priority = ~0U;
4257
4258        spin_lock_bh(&net->xfrm.xfrm_policy_lock);
4259        chain = policy_hash_direct(net, &sel->daddr, &sel->saddr, sel->family, dir);
4260        hlist_for_each_entry(pol, chain, bydst) {
4261                if (xfrm_migrate_selector_match(sel, &pol->selector) &&
4262                    pol->type == type) {
4263                        ret = pol;
4264                        priority = ret->priority;
4265                        break;
4266                }
4267        }
4268        chain = &net->xfrm.policy_inexact[dir];
4269        hlist_for_each_entry(pol, chain, bydst_inexact_list) {
4270                if ((pol->priority >= priority) && ret)
4271                        break;
4272
4273                if (xfrm_migrate_selector_match(sel, &pol->selector) &&
4274                    pol->type == type) {
4275                        ret = pol;
4276                        break;
4277                }
4278        }
4279
4280        xfrm_pol_hold(ret);
4281
4282        spin_unlock_bh(&net->xfrm.xfrm_policy_lock);
4283
4284        return ret;
4285}
4286
4287static int migrate_tmpl_match(const struct xfrm_migrate *m, const struct xfrm_tmpl *t)
4288{
4289        int match = 0;
4290
4291        if (t->mode == m->mode && t->id.proto == m->proto &&
4292            (m->reqid == 0 || t->reqid == m->reqid)) {
4293                switch (t->mode) {
4294                case XFRM_MODE_TUNNEL:
4295                case XFRM_MODE_BEET:
4296                        if (xfrm_addr_equal(&t->id.daddr, &m->old_daddr,
4297                                            m->old_family) &&
4298                            xfrm_addr_equal(&t->saddr, &m->old_saddr,
4299                                            m->old_family)) {
4300                                match = 1;
4301                        }
4302                        break;
4303                case XFRM_MODE_TRANSPORT:
4304                        /* in case of transport mode, template does not store
4305                           any IP addresses, hence we just compare mode and
4306                           protocol */
4307                        match = 1;
4308                        break;
4309                default:
4310                        break;
4311                }
4312        }
4313        return match;
4314}
4315
4316/* update endpoint address(es) of template(s) */
4317static int xfrm_policy_migrate(struct xfrm_policy *pol,
4318                               struct xfrm_migrate *m, int num_migrate)
4319{
4320        struct xfrm_migrate *mp;
4321        int i, j, n = 0;
4322
4323        write_lock_bh(&pol->lock);
4324        if (unlikely(pol->walk.dead)) {
4325                /* target policy has been deleted */
4326                write_unlock_bh(&pol->lock);
4327                return -ENOENT;
4328        }
4329
4330        for (i = 0; i < pol->xfrm_nr; i++) {
4331                for (j = 0, mp = m; j < num_migrate; j++, mp++) {
4332                        if (!migrate_tmpl_match(mp, &pol->xfrm_vec[i]))
4333                                continue;
4334                        n++;
4335                        if (pol->xfrm_vec[i].mode != XFRM_MODE_TUNNEL &&
4336                            pol->xfrm_vec[i].mode != XFRM_MODE_BEET)
4337                                continue;
4338                        /* update endpoints */
4339                        memcpy(&pol->xfrm_vec[i].id.daddr, &mp->new_daddr,
4340                               sizeof(pol->xfrm_vec[i].id.daddr));
4341                        memcpy(&pol->xfrm_vec[i].saddr, &mp->new_saddr,
4342                               sizeof(pol->xfrm_vec[i].saddr));
4343                        pol->xfrm_vec[i].encap_family = mp->new_family;
4344                        /* flush bundles */
4345                        atomic_inc(&pol->genid);
4346                }
4347        }
4348
4349        write_unlock_bh(&pol->lock);
4350
4351        if (!n)
4352                return -ENODATA;
4353
4354        return 0;
4355}
4356
4357static int xfrm_migrate_check(const struct xfrm_migrate *m, int num_migrate)
4358{
4359        int i, j;
4360
4361        if (num_migrate < 1 || num_migrate > XFRM_MAX_DEPTH)
4362                return -EINVAL;
4363
4364        for (i = 0; i < num_migrate; i++) {
4365                if (xfrm_addr_any(&m[i].new_daddr, m[i].new_family) ||
4366                    xfrm_addr_any(&m[i].new_saddr, m[i].new_family))
4367                        return -EINVAL;
4368
4369                /* check if there is any duplicated entry */
4370                for (j = i + 1; j < num_migrate; j++) {
4371                        if (!memcmp(&m[i].old_daddr, &m[j].old_daddr,
4372                                    sizeof(m[i].old_daddr)) &&
4373                            !memcmp(&m[i].old_saddr, &m[j].old_saddr,
4374                                    sizeof(m[i].old_saddr)) &&
4375                            m[i].proto == m[j].proto &&
4376                            m[i].mode == m[j].mode &&
4377                            m[i].reqid == m[j].reqid &&
4378                            m[i].old_family == m[j].old_family)
4379                                return -EINVAL;
4380                }
4381        }
4382
4383        return 0;
4384}
4385
4386int xfrm_migrate(const struct xfrm_selector *sel, u8 dir, u8 type,
4387                 struct xfrm_migrate *m, int num_migrate,
4388                 struct xfrm_kmaddress *k, struct net *net,
4389                 struct xfrm_encap_tmpl *encap)
4390{
4391        int i, err, nx_cur = 0, nx_new = 0;
4392        struct xfrm_policy *pol = NULL;
4393        struct xfrm_state *x, *xc;
4394        struct xfrm_state *x_cur[XFRM_MAX_DEPTH];
4395        struct xfrm_state *x_new[XFRM_MAX_DEPTH];
4396        struct xfrm_migrate *mp;
4397
4398        /* Stage 0 - sanity checks */
4399        if ((err = xfrm_migrate_check(m, num_migrate)) < 0)
4400                goto out;
4401
4402        if (dir >= XFRM_POLICY_MAX) {
4403                err = -EINVAL;
4404                goto out;
4405        }
4406
4407        /* Stage 1 - find policy */
4408        if ((pol = xfrm_migrate_policy_find(sel, dir, type, net)) == NULL) {
4409                err = -ENOENT;
4410                goto out;
4411        }
4412
4413        /* Stage 2 - find and update state(s) */
4414        for (i = 0, mp = m; i < num_migrate; i++, mp++) {
4415                if ((x = xfrm_migrate_state_find(mp, net))) {
4416                        x_cur[nx_cur] = x;
4417                        nx_cur++;
4418                        xc = xfrm_state_migrate(x, mp, encap);
4419                        if (xc) {
4420                                x_new[nx_new] = xc;
4421                                nx_new++;
4422                        } else {
4423                                err = -ENODATA;
4424                                goto restore_state;
4425                        }
4426                }
4427        }
4428
4429        /* Stage 3 - update policy */
4430        if ((err = xfrm_policy_migrate(pol, m, num_migrate)) < 0)
4431                goto restore_state;
4432
4433        /* Stage 4 - delete old state(s) */
4434        if (nx_cur) {
4435                xfrm_states_put(x_cur, nx_cur);
4436                xfrm_states_delete(x_cur, nx_cur);
4437        }
4438
4439        /* Stage 5 - announce */
4440        km_migrate(sel, dir, type, m, num_migrate, k, encap);
4441
4442        xfrm_pol_put(pol);
4443
4444        return 0;
4445out:
4446        return err;
4447
4448restore_state:
4449        if (pol)
4450                xfrm_pol_put(pol);
4451        if (nx_cur)
4452                xfrm_states_put(x_cur, nx_cur);
4453        if (nx_new)
4454                xfrm_states_delete(x_new, nx_new);
4455
4456        return err;
4457}
4458EXPORT_SYMBOL(xfrm_migrate);
4459#endif
4460