linux/drivers/net/wireguard/noise.c
<<
>>
Prefs
   1// SPDX-License-Identifier: GPL-2.0
   2/*
   3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
   4 */
   5
   6#include "noise.h"
   7#include "device.h"
   8#include "peer.h"
   9#include "messages.h"
  10#include "queueing.h"
  11#include "peerlookup.h"
  12
  13#include <linux/rcupdate.h>
  14#include <linux/slab.h>
  15#include <linux/bitmap.h>
  16#include <linux/scatterlist.h>
  17#include <linux/highmem.h>
  18#include <crypto/algapi.h>
  19
  20/* This implements Noise_IKpsk2:
  21 *
  22 * <- s
  23 * ******
  24 * -> e, es, s, ss, {t}
  25 * <- e, ee, se, psk, {}
  26 */
  27
  28static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
  29static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
  30static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
  31static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
  32static atomic64_t keypair_counter = ATOMIC64_INIT(0);
  33
  34void __init wg_noise_init(void)
  35{
  36        struct blake2s_state blake;
  37
  38        blake2s(handshake_init_chaining_key, handshake_name, NULL,
  39                NOISE_HASH_LEN, sizeof(handshake_name), 0);
  40        blake2s_init(&blake, NOISE_HASH_LEN);
  41        blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
  42        blake2s_update(&blake, identifier_name, sizeof(identifier_name));
  43        blake2s_final(&blake, handshake_init_hash);
  44}
  45
  46/* Must hold peer->handshake.static_identity->lock */
  47void wg_noise_precompute_static_static(struct wg_peer *peer)
  48{
  49        down_write(&peer->handshake.lock);
  50        if (!peer->handshake.static_identity->has_identity ||
  51            !curve25519(peer->handshake.precomputed_static_static,
  52                        peer->handshake.static_identity->static_private,
  53                        peer->handshake.remote_static))
  54                memset(peer->handshake.precomputed_static_static, 0,
  55                       NOISE_PUBLIC_KEY_LEN);
  56        up_write(&peer->handshake.lock);
  57}
  58
  59void wg_noise_handshake_init(struct noise_handshake *handshake,
  60                             struct noise_static_identity *static_identity,
  61                             const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
  62                             const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
  63                             struct wg_peer *peer)
  64{
  65        memset(handshake, 0, sizeof(*handshake));
  66        init_rwsem(&handshake->lock);
  67        handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
  68        handshake->entry.peer = peer;
  69        memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
  70        if (peer_preshared_key)
  71                memcpy(handshake->preshared_key, peer_preshared_key,
  72                       NOISE_SYMMETRIC_KEY_LEN);
  73        handshake->static_identity = static_identity;
  74        handshake->state = HANDSHAKE_ZEROED;
  75        wg_noise_precompute_static_static(peer);
  76}
  77
  78static void handshake_zero(struct noise_handshake *handshake)
  79{
  80        memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
  81        memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
  82        memset(&handshake->hash, 0, NOISE_HASH_LEN);
  83        memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
  84        handshake->remote_index = 0;
  85        handshake->state = HANDSHAKE_ZEROED;
  86}
  87
  88void wg_noise_handshake_clear(struct noise_handshake *handshake)
  89{
  90        down_write(&handshake->lock);
  91        wg_index_hashtable_remove(
  92                        handshake->entry.peer->device->index_hashtable,
  93                        &handshake->entry);
  94        handshake_zero(handshake);
  95        up_write(&handshake->lock);
  96}
  97
  98static struct noise_keypair *keypair_create(struct wg_peer *peer)
  99{
 100        struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
 101
 102        if (unlikely(!keypair))
 103                return NULL;
 104        spin_lock_init(&keypair->receiving_counter.lock);
 105        keypair->internal_id = atomic64_inc_return(&keypair_counter);
 106        keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
 107        keypair->entry.peer = peer;
 108        kref_init(&keypair->refcount);
 109        return keypair;
 110}
 111
 112static void keypair_free_rcu(struct rcu_head *rcu)
 113{
 114        kfree_sensitive(container_of(rcu, struct noise_keypair, rcu));
 115}
 116
 117static void keypair_free_kref(struct kref *kref)
 118{
 119        struct noise_keypair *keypair =
 120                container_of(kref, struct noise_keypair, refcount);
 121
 122        net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
 123                            keypair->entry.peer->device->dev->name,
 124                            keypair->internal_id,
 125                            keypair->entry.peer->internal_id);
 126        wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
 127                                  &keypair->entry);
 128        call_rcu(&keypair->rcu, keypair_free_rcu);
 129}
 130
 131void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
 132{
 133        if (unlikely(!keypair))
 134                return;
 135        if (unlikely(unreference_now))
 136                wg_index_hashtable_remove(
 137                        keypair->entry.peer->device->index_hashtable,
 138                        &keypair->entry);
 139        kref_put(&keypair->refcount, keypair_free_kref);
 140}
 141
 142struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
 143{
 144        RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
 145                "Taking noise keypair reference without holding the RCU BH read lock");
 146        if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
 147                return NULL;
 148        return keypair;
 149}
 150
 151void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
 152{
 153        struct noise_keypair *old;
 154
 155        spin_lock_bh(&keypairs->keypair_update_lock);
 156
 157        /* We zero the next_keypair before zeroing the others, so that
 158         * wg_noise_received_with_keypair returns early before subsequent ones
 159         * are zeroed.
 160         */
 161        old = rcu_dereference_protected(keypairs->next_keypair,
 162                lockdep_is_held(&keypairs->keypair_update_lock));
 163        RCU_INIT_POINTER(keypairs->next_keypair, NULL);
 164        wg_noise_keypair_put(old, true);
 165
 166        old = rcu_dereference_protected(keypairs->previous_keypair,
 167                lockdep_is_held(&keypairs->keypair_update_lock));
 168        RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
 169        wg_noise_keypair_put(old, true);
 170
 171        old = rcu_dereference_protected(keypairs->current_keypair,
 172                lockdep_is_held(&keypairs->keypair_update_lock));
 173        RCU_INIT_POINTER(keypairs->current_keypair, NULL);
 174        wg_noise_keypair_put(old, true);
 175
 176        spin_unlock_bh(&keypairs->keypair_update_lock);
 177}
 178
 179void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
 180{
 181        struct noise_keypair *keypair;
 182
 183        wg_noise_handshake_clear(&peer->handshake);
 184        wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
 185
 186        spin_lock_bh(&peer->keypairs.keypair_update_lock);
 187        keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
 188                        lockdep_is_held(&peer->keypairs.keypair_update_lock));
 189        if (keypair)
 190                keypair->sending.is_valid = false;
 191        keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
 192                        lockdep_is_held(&peer->keypairs.keypair_update_lock));
 193        if (keypair)
 194                keypair->sending.is_valid = false;
 195        spin_unlock_bh(&peer->keypairs.keypair_update_lock);
 196}
 197
 198static void add_new_keypair(struct noise_keypairs *keypairs,
 199                            struct noise_keypair *new_keypair)
 200{
 201        struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
 202
 203        spin_lock_bh(&keypairs->keypair_update_lock);
 204        previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
 205                lockdep_is_held(&keypairs->keypair_update_lock));
 206        next_keypair = rcu_dereference_protected(keypairs->next_keypair,
 207                lockdep_is_held(&keypairs->keypair_update_lock));
 208        current_keypair = rcu_dereference_protected(keypairs->current_keypair,
 209                lockdep_is_held(&keypairs->keypair_update_lock));
 210        if (new_keypair->i_am_the_initiator) {
 211                /* If we're the initiator, it means we've sent a handshake, and
 212                 * received a confirmation response, which means this new
 213                 * keypair can now be used.
 214                 */
 215                if (next_keypair) {
 216                        /* If there already was a next keypair pending, we
 217                         * demote it to be the previous keypair, and free the
 218                         * existing current. Note that this means KCI can result
 219                         * in this transition. It would perhaps be more sound to
 220                         * always just get rid of the unused next keypair
 221                         * instead of putting it in the previous slot, but this
 222                         * might be a bit less robust. Something to think about
 223                         * for the future.
 224                         */
 225                        RCU_INIT_POINTER(keypairs->next_keypair, NULL);
 226                        rcu_assign_pointer(keypairs->previous_keypair,
 227                                           next_keypair);
 228                        wg_noise_keypair_put(current_keypair, true);
 229                } else /* If there wasn't an existing next keypair, we replace
 230                        * the previous with the current one.
 231                        */
 232                        rcu_assign_pointer(keypairs->previous_keypair,
 233                                           current_keypair);
 234                /* At this point we can get rid of the old previous keypair, and
 235                 * set up the new keypair.
 236                 */
 237                wg_noise_keypair_put(previous_keypair, true);
 238                rcu_assign_pointer(keypairs->current_keypair, new_keypair);
 239        } else {
 240                /* If we're the responder, it means we can't use the new keypair
 241                 * until we receive confirmation via the first data packet, so
 242                 * we get rid of the existing previous one, the possibly
 243                 * existing next one, and slide in the new next one.
 244                 */
 245                rcu_assign_pointer(keypairs->next_keypair, new_keypair);
 246                wg_noise_keypair_put(next_keypair, true);
 247                RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
 248                wg_noise_keypair_put(previous_keypair, true);
 249        }
 250        spin_unlock_bh(&keypairs->keypair_update_lock);
 251}
 252
 253bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
 254                                    struct noise_keypair *received_keypair)
 255{
 256        struct noise_keypair *old_keypair;
 257        bool key_is_new;
 258
 259        /* We first check without taking the spinlock. */
 260        key_is_new = received_keypair ==
 261                     rcu_access_pointer(keypairs->next_keypair);
 262        if (likely(!key_is_new))
 263                return false;
 264
 265        spin_lock_bh(&keypairs->keypair_update_lock);
 266        /* After locking, we double check that things didn't change from
 267         * beneath us.
 268         */
 269        if (unlikely(received_keypair !=
 270                    rcu_dereference_protected(keypairs->next_keypair,
 271                            lockdep_is_held(&keypairs->keypair_update_lock)))) {
 272                spin_unlock_bh(&keypairs->keypair_update_lock);
 273                return false;
 274        }
 275
 276        /* When we've finally received the confirmation, we slide the next
 277         * into the current, the current into the previous, and get rid of
 278         * the old previous.
 279         */
 280        old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
 281                lockdep_is_held(&keypairs->keypair_update_lock));
 282        rcu_assign_pointer(keypairs->previous_keypair,
 283                rcu_dereference_protected(keypairs->current_keypair,
 284                        lockdep_is_held(&keypairs->keypair_update_lock)));
 285        wg_noise_keypair_put(old_keypair, true);
 286        rcu_assign_pointer(keypairs->current_keypair, received_keypair);
 287        RCU_INIT_POINTER(keypairs->next_keypair, NULL);
 288
 289        spin_unlock_bh(&keypairs->keypair_update_lock);
 290        return true;
 291}
 292
 293/* Must hold static_identity->lock */
 294void wg_noise_set_static_identity_private_key(
 295        struct noise_static_identity *static_identity,
 296        const u8 private_key[NOISE_PUBLIC_KEY_LEN])
 297{
 298        memcpy(static_identity->static_private, private_key,
 299               NOISE_PUBLIC_KEY_LEN);
 300        curve25519_clamp_secret(static_identity->static_private);
 301        static_identity->has_identity = curve25519_generate_public(
 302                static_identity->static_public, private_key);
 303}
 304
 305/* This is Hugo Krawczyk's HKDF:
 306 *  - https://eprint.iacr.org/2010/264.pdf
 307 *  - https://tools.ietf.org/html/rfc5869
 308 */
 309static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
 310                size_t first_len, size_t second_len, size_t third_len,
 311                size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
 312{
 313        u8 output[BLAKE2S_HASH_SIZE + 1];
 314        u8 secret[BLAKE2S_HASH_SIZE];
 315
 316        WARN_ON(IS_ENABLED(DEBUG) &&
 317                (first_len > BLAKE2S_HASH_SIZE ||
 318                 second_len > BLAKE2S_HASH_SIZE ||
 319                 third_len > BLAKE2S_HASH_SIZE ||
 320                 ((second_len || second_dst || third_len || third_dst) &&
 321                  (!first_len || !first_dst)) ||
 322                 ((third_len || third_dst) && (!second_len || !second_dst))));
 323
 324        /* Extract entropy from data into secret */
 325        blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
 326
 327        if (!first_dst || !first_len)
 328                goto out;
 329
 330        /* Expand first key: key = secret, data = 0x1 */
 331        output[0] = 1;
 332        blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
 333        memcpy(first_dst, output, first_len);
 334
 335        if (!second_dst || !second_len)
 336                goto out;
 337
 338        /* Expand second key: key = secret, data = first-key || 0x2 */
 339        output[BLAKE2S_HASH_SIZE] = 2;
 340        blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
 341                        BLAKE2S_HASH_SIZE);
 342        memcpy(second_dst, output, second_len);
 343
 344        if (!third_dst || !third_len)
 345                goto out;
 346
 347        /* Expand third key: key = secret, data = second-key || 0x3 */
 348        output[BLAKE2S_HASH_SIZE] = 3;
 349        blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
 350                        BLAKE2S_HASH_SIZE);
 351        memcpy(third_dst, output, third_len);
 352
 353out:
 354        /* Clear sensitive data from stack */
 355        memzero_explicit(secret, BLAKE2S_HASH_SIZE);
 356        memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
 357}
 358
 359static void derive_keys(struct noise_symmetric_key *first_dst,
 360                        struct noise_symmetric_key *second_dst,
 361                        const u8 chaining_key[NOISE_HASH_LEN])
 362{
 363        u64 birthdate = ktime_get_coarse_boottime_ns();
 364        kdf(first_dst->key, second_dst->key, NULL, NULL,
 365            NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
 366            chaining_key);
 367        first_dst->birthdate = second_dst->birthdate = birthdate;
 368        first_dst->is_valid = second_dst->is_valid = true;
 369}
 370
 371static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
 372                                u8 key[NOISE_SYMMETRIC_KEY_LEN],
 373                                const u8 private[NOISE_PUBLIC_KEY_LEN],
 374                                const u8 public[NOISE_PUBLIC_KEY_LEN])
 375{
 376        u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
 377
 378        if (unlikely(!curve25519(dh_calculation, private, public)))
 379                return false;
 380        kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
 381            NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
 382        memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
 383        return true;
 384}
 385
 386static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
 387                                            u8 key[NOISE_SYMMETRIC_KEY_LEN],
 388                                            const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
 389{
 390        static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
 391        if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
 392                return false;
 393        kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
 394            NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
 395            chaining_key);
 396        return true;
 397}
 398
 399static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
 400{
 401        struct blake2s_state blake;
 402
 403        blake2s_init(&blake, NOISE_HASH_LEN);
 404        blake2s_update(&blake, hash, NOISE_HASH_LEN);
 405        blake2s_update(&blake, src, src_len);
 406        blake2s_final(&blake, hash);
 407}
 408
 409static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
 410                    u8 key[NOISE_SYMMETRIC_KEY_LEN],
 411                    const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
 412{
 413        u8 temp_hash[NOISE_HASH_LEN];
 414
 415        kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
 416            NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
 417        mix_hash(hash, temp_hash, NOISE_HASH_LEN);
 418        memzero_explicit(temp_hash, NOISE_HASH_LEN);
 419}
 420
 421static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
 422                           u8 hash[NOISE_HASH_LEN],
 423                           const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
 424{
 425        memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
 426        memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
 427        mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
 428}
 429
 430static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
 431                            size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
 432                            u8 hash[NOISE_HASH_LEN])
 433{
 434        chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
 435                                 NOISE_HASH_LEN,
 436                                 0 /* Always zero for Noise_IK */, key);
 437        mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
 438}
 439
 440static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
 441                            size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
 442                            u8 hash[NOISE_HASH_LEN])
 443{
 444        if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
 445                                      hash, NOISE_HASH_LEN,
 446                                      0 /* Always zero for Noise_IK */, key))
 447                return false;
 448        mix_hash(hash, src_ciphertext, src_len);
 449        return true;
 450}
 451
 452static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
 453                              const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
 454                              u8 chaining_key[NOISE_HASH_LEN],
 455                              u8 hash[NOISE_HASH_LEN])
 456{
 457        if (ephemeral_dst != ephemeral_src)
 458                memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
 459        mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
 460        kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
 461            NOISE_PUBLIC_KEY_LEN, chaining_key);
 462}
 463
 464static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
 465{
 466        struct timespec64 now;
 467
 468        ktime_get_real_ts64(&now);
 469
 470        /* In order to prevent some sort of infoleak from precise timers, we
 471         * round down the nanoseconds part to the closest rounded-down power of
 472         * two to the maximum initiations per second allowed anyway by the
 473         * implementation.
 474         */
 475        now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
 476                rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
 477
 478        /* https://cr.yp.to/libtai/tai64.html */
 479        *(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
 480        *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
 481}
 482
 483bool
 484wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
 485                                     struct noise_handshake *handshake)
 486{
 487        u8 timestamp[NOISE_TIMESTAMP_LEN];
 488        u8 key[NOISE_SYMMETRIC_KEY_LEN];
 489        bool ret = false;
 490
 491        /* We need to wait for crng _before_ taking any locks, since
 492         * curve25519_generate_secret uses get_random_bytes_wait.
 493         */
 494        wait_for_random_bytes();
 495
 496        down_read(&handshake->static_identity->lock);
 497        down_write(&handshake->lock);
 498
 499        if (unlikely(!handshake->static_identity->has_identity))
 500                goto out;
 501
 502        dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
 503
 504        handshake_init(handshake->chaining_key, handshake->hash,
 505                       handshake->remote_static);
 506
 507        /* e */
 508        curve25519_generate_secret(handshake->ephemeral_private);
 509        if (!curve25519_generate_public(dst->unencrypted_ephemeral,
 510                                        handshake->ephemeral_private))
 511                goto out;
 512        message_ephemeral(dst->unencrypted_ephemeral,
 513                          dst->unencrypted_ephemeral, handshake->chaining_key,
 514                          handshake->hash);
 515
 516        /* es */
 517        if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
 518                    handshake->remote_static))
 519                goto out;
 520
 521        /* s */
 522        message_encrypt(dst->encrypted_static,
 523                        handshake->static_identity->static_public,
 524                        NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
 525
 526        /* ss */
 527        if (!mix_precomputed_dh(handshake->chaining_key, key,
 528                                handshake->precomputed_static_static))
 529                goto out;
 530
 531        /* {t} */
 532        tai64n_now(timestamp);
 533        message_encrypt(dst->encrypted_timestamp, timestamp,
 534                        NOISE_TIMESTAMP_LEN, key, handshake->hash);
 535
 536        dst->sender_index = wg_index_hashtable_insert(
 537                handshake->entry.peer->device->index_hashtable,
 538                &handshake->entry);
 539
 540        handshake->state = HANDSHAKE_CREATED_INITIATION;
 541        ret = true;
 542
 543out:
 544        up_write(&handshake->lock);
 545        up_read(&handshake->static_identity->lock);
 546        memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
 547        return ret;
 548}
 549
 550struct wg_peer *
 551wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
 552                                      struct wg_device *wg)
 553{
 554        struct wg_peer *peer = NULL, *ret_peer = NULL;
 555        struct noise_handshake *handshake;
 556        bool replay_attack, flood_attack;
 557        u8 key[NOISE_SYMMETRIC_KEY_LEN];
 558        u8 chaining_key[NOISE_HASH_LEN];
 559        u8 hash[NOISE_HASH_LEN];
 560        u8 s[NOISE_PUBLIC_KEY_LEN];
 561        u8 e[NOISE_PUBLIC_KEY_LEN];
 562        u8 t[NOISE_TIMESTAMP_LEN];
 563        u64 initiation_consumption;
 564
 565        down_read(&wg->static_identity.lock);
 566        if (unlikely(!wg->static_identity.has_identity))
 567                goto out;
 568
 569        handshake_init(chaining_key, hash, wg->static_identity.static_public);
 570
 571        /* e */
 572        message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
 573
 574        /* es */
 575        if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
 576                goto out;
 577
 578        /* s */
 579        if (!message_decrypt(s, src->encrypted_static,
 580                             sizeof(src->encrypted_static), key, hash))
 581                goto out;
 582
 583        /* Lookup which peer we're actually talking to */
 584        peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
 585        if (!peer)
 586                goto out;
 587        handshake = &peer->handshake;
 588
 589        /* ss */
 590        if (!mix_precomputed_dh(chaining_key, key,
 591                                handshake->precomputed_static_static))
 592            goto out;
 593
 594        /* {t} */
 595        if (!message_decrypt(t, src->encrypted_timestamp,
 596                             sizeof(src->encrypted_timestamp), key, hash))
 597                goto out;
 598
 599        down_read(&handshake->lock);
 600        replay_attack = memcmp(t, handshake->latest_timestamp,
 601                               NOISE_TIMESTAMP_LEN) <= 0;
 602        flood_attack = (s64)handshake->last_initiation_consumption +
 603                               NSEC_PER_SEC / INITIATIONS_PER_SECOND >
 604                       (s64)ktime_get_coarse_boottime_ns();
 605        up_read(&handshake->lock);
 606        if (replay_attack || flood_attack)
 607                goto out;
 608
 609        /* Success! Copy everything to peer */
 610        down_write(&handshake->lock);
 611        memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
 612        if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
 613                memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
 614        memcpy(handshake->hash, hash, NOISE_HASH_LEN);
 615        memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
 616        handshake->remote_index = src->sender_index;
 617        initiation_consumption = ktime_get_coarse_boottime_ns();
 618        if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
 619                handshake->last_initiation_consumption = initiation_consumption;
 620        handshake->state = HANDSHAKE_CONSUMED_INITIATION;
 621        up_write(&handshake->lock);
 622        ret_peer = peer;
 623
 624out:
 625        memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
 626        memzero_explicit(hash, NOISE_HASH_LEN);
 627        memzero_explicit(chaining_key, NOISE_HASH_LEN);
 628        up_read(&wg->static_identity.lock);
 629        if (!ret_peer)
 630                wg_peer_put(peer);
 631        return ret_peer;
 632}
 633
 634bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
 635                                        struct noise_handshake *handshake)
 636{
 637        u8 key[NOISE_SYMMETRIC_KEY_LEN];
 638        bool ret = false;
 639
 640        /* We need to wait for crng _before_ taking any locks, since
 641         * curve25519_generate_secret uses get_random_bytes_wait.
 642         */
 643        wait_for_random_bytes();
 644
 645        down_read(&handshake->static_identity->lock);
 646        down_write(&handshake->lock);
 647
 648        if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
 649                goto out;
 650
 651        dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
 652        dst->receiver_index = handshake->remote_index;
 653
 654        /* e */
 655        curve25519_generate_secret(handshake->ephemeral_private);
 656        if (!curve25519_generate_public(dst->unencrypted_ephemeral,
 657                                        handshake->ephemeral_private))
 658                goto out;
 659        message_ephemeral(dst->unencrypted_ephemeral,
 660                          dst->unencrypted_ephemeral, handshake->chaining_key,
 661                          handshake->hash);
 662
 663        /* ee */
 664        if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
 665                    handshake->remote_ephemeral))
 666                goto out;
 667
 668        /* se */
 669        if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
 670                    handshake->remote_static))
 671                goto out;
 672
 673        /* psk */
 674        mix_psk(handshake->chaining_key, handshake->hash, key,
 675                handshake->preshared_key);
 676
 677        /* {} */
 678        message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
 679
 680        dst->sender_index = wg_index_hashtable_insert(
 681                handshake->entry.peer->device->index_hashtable,
 682                &handshake->entry);
 683
 684        handshake->state = HANDSHAKE_CREATED_RESPONSE;
 685        ret = true;
 686
 687out:
 688        up_write(&handshake->lock);
 689        up_read(&handshake->static_identity->lock);
 690        memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
 691        return ret;
 692}
 693
 694struct wg_peer *
 695wg_noise_handshake_consume_response(struct message_handshake_response *src,
 696                                    struct wg_device *wg)
 697{
 698        enum noise_handshake_state state = HANDSHAKE_ZEROED;
 699        struct wg_peer *peer = NULL, *ret_peer = NULL;
 700        struct noise_handshake *handshake;
 701        u8 key[NOISE_SYMMETRIC_KEY_LEN];
 702        u8 hash[NOISE_HASH_LEN];
 703        u8 chaining_key[NOISE_HASH_LEN];
 704        u8 e[NOISE_PUBLIC_KEY_LEN];
 705        u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
 706        u8 static_private[NOISE_PUBLIC_KEY_LEN];
 707        u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
 708
 709        down_read(&wg->static_identity.lock);
 710
 711        if (unlikely(!wg->static_identity.has_identity))
 712                goto out;
 713
 714        handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
 715                wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
 716                src->receiver_index, &peer);
 717        if (unlikely(!handshake))
 718                goto out;
 719
 720        down_read(&handshake->lock);
 721        state = handshake->state;
 722        memcpy(hash, handshake->hash, NOISE_HASH_LEN);
 723        memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
 724        memcpy(ephemeral_private, handshake->ephemeral_private,
 725               NOISE_PUBLIC_KEY_LEN);
 726        memcpy(preshared_key, handshake->preshared_key,
 727               NOISE_SYMMETRIC_KEY_LEN);
 728        up_read(&handshake->lock);
 729
 730        if (state != HANDSHAKE_CREATED_INITIATION)
 731                goto fail;
 732
 733        /* e */
 734        message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
 735
 736        /* ee */
 737        if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
 738                goto fail;
 739
 740        /* se */
 741        if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
 742                goto fail;
 743
 744        /* psk */
 745        mix_psk(chaining_key, hash, key, preshared_key);
 746
 747        /* {} */
 748        if (!message_decrypt(NULL, src->encrypted_nothing,
 749                             sizeof(src->encrypted_nothing), key, hash))
 750                goto fail;
 751
 752        /* Success! Copy everything to peer */
 753        down_write(&handshake->lock);
 754        /* It's important to check that the state is still the same, while we
 755         * have an exclusive lock.
 756         */
 757        if (handshake->state != state) {
 758                up_write(&handshake->lock);
 759                goto fail;
 760        }
 761        memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
 762        memcpy(handshake->hash, hash, NOISE_HASH_LEN);
 763        memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
 764        handshake->remote_index = src->sender_index;
 765        handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
 766        up_write(&handshake->lock);
 767        ret_peer = peer;
 768        goto out;
 769
 770fail:
 771        wg_peer_put(peer);
 772out:
 773        memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
 774        memzero_explicit(hash, NOISE_HASH_LEN);
 775        memzero_explicit(chaining_key, NOISE_HASH_LEN);
 776        memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
 777        memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
 778        memzero_explicit(preshared_key, NOISE_SYMMETRIC_KEY_LEN);
 779        up_read(&wg->static_identity.lock);
 780        return ret_peer;
 781}
 782
 783bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
 784                                      struct noise_keypairs *keypairs)
 785{
 786        struct noise_keypair *new_keypair;
 787        bool ret = false;
 788
 789        down_write(&handshake->lock);
 790        if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
 791            handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
 792                goto out;
 793
 794        new_keypair = keypair_create(handshake->entry.peer);
 795        if (!new_keypair)
 796                goto out;
 797        new_keypair->i_am_the_initiator = handshake->state ==
 798                                          HANDSHAKE_CONSUMED_RESPONSE;
 799        new_keypair->remote_index = handshake->remote_index;
 800
 801        if (new_keypair->i_am_the_initiator)
 802                derive_keys(&new_keypair->sending, &new_keypair->receiving,
 803                            handshake->chaining_key);
 804        else
 805                derive_keys(&new_keypair->receiving, &new_keypair->sending,
 806                            handshake->chaining_key);
 807
 808        handshake_zero(handshake);
 809        rcu_read_lock_bh();
 810        if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
 811                                           handshake)->is_dead))) {
 812                add_new_keypair(keypairs, new_keypair);
 813                net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
 814                                    handshake->entry.peer->device->dev->name,
 815                                    new_keypair->internal_id,
 816                                    handshake->entry.peer->internal_id);
 817                ret = wg_index_hashtable_replace(
 818                        handshake->entry.peer->device->index_hashtable,
 819                        &handshake->entry, &new_keypair->entry);
 820        } else {
 821                kfree_sensitive(new_keypair);
 822        }
 823        rcu_read_unlock_bh();
 824
 825out:
 826        up_write(&handshake->lock);
 827        return ret;
 828}
 829