Merge branch 'core-objtool-for-linus' of git://git.kernel.org/pub/scm/linux/kernel...
[linux-block.git] / drivers / net / wireguard / noise.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5
6 #include "noise.h"
7 #include "device.h"
8 #include "peer.h"
9 #include "messages.h"
10 #include "queueing.h"
11 #include "peerlookup.h"
12
13 #include <linux/rcupdate.h>
14 #include <linux/slab.h>
15 #include <linux/bitmap.h>
16 #include <linux/scatterlist.h>
17 #include <linux/highmem.h>
18 #include <crypto/algapi.h>
19
20 /* This implements Noise_IKpsk2:
21  *
22  * <- s
23  * ******
24  * -> e, es, s, ss, {t}
25  * <- e, ee, se, psk, {}
26  */
27
28 static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29 static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
30 static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31 static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32 static atomic64_t keypair_counter = ATOMIC64_INIT(0);
33
34 void __init wg_noise_init(void)
35 {
36         struct blake2s_state blake;
37
38         blake2s(handshake_init_chaining_key, handshake_name, NULL,
39                 NOISE_HASH_LEN, sizeof(handshake_name), 0);
40         blake2s_init(&blake, NOISE_HASH_LEN);
41         blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
42         blake2s_update(&blake, identifier_name, sizeof(identifier_name));
43         blake2s_final(&blake, handshake_init_hash);
44 }
45
46 /* Must hold peer->handshake.static_identity->lock */
47 void wg_noise_precompute_static_static(struct wg_peer *peer)
48 {
49         down_write(&peer->handshake.lock);
50         if (!peer->handshake.static_identity->has_identity ||
51             !curve25519(peer->handshake.precomputed_static_static,
52                         peer->handshake.static_identity->static_private,
53                         peer->handshake.remote_static))
54                 memset(peer->handshake.precomputed_static_static, 0,
55                        NOISE_PUBLIC_KEY_LEN);
56         up_write(&peer->handshake.lock);
57 }
58
59 void wg_noise_handshake_init(struct noise_handshake *handshake,
60                              struct noise_static_identity *static_identity,
61                              const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
62                              const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
63                              struct wg_peer *peer)
64 {
65         memset(handshake, 0, sizeof(*handshake));
66         init_rwsem(&handshake->lock);
67         handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
68         handshake->entry.peer = peer;
69         memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
70         if (peer_preshared_key)
71                 memcpy(handshake->preshared_key, peer_preshared_key,
72                        NOISE_SYMMETRIC_KEY_LEN);
73         handshake->static_identity = static_identity;
74         handshake->state = HANDSHAKE_ZEROED;
75         wg_noise_precompute_static_static(peer);
76 }
77
78 static void handshake_zero(struct noise_handshake *handshake)
79 {
80         memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
81         memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
82         memset(&handshake->hash, 0, NOISE_HASH_LEN);
83         memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
84         handshake->remote_index = 0;
85         handshake->state = HANDSHAKE_ZEROED;
86 }
87
88 void wg_noise_handshake_clear(struct noise_handshake *handshake)
89 {
90         wg_index_hashtable_remove(
91                         handshake->entry.peer->device->index_hashtable,
92                         &handshake->entry);
93         down_write(&handshake->lock);
94         handshake_zero(handshake);
95         up_write(&handshake->lock);
96         wg_index_hashtable_remove(
97                         handshake->entry.peer->device->index_hashtable,
98                         &handshake->entry);
99 }
100
101 static struct noise_keypair *keypair_create(struct wg_peer *peer)
102 {
103         struct noise_keypair *keypair = kzalloc(sizeof(*keypair), GFP_KERNEL);
104
105         if (unlikely(!keypair))
106                 return NULL;
107         keypair->internal_id = atomic64_inc_return(&keypair_counter);
108         keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
109         keypair->entry.peer = peer;
110         kref_init(&keypair->refcount);
111         return keypair;
112 }
113
114 static void keypair_free_rcu(struct rcu_head *rcu)
115 {
116         kzfree(container_of(rcu, struct noise_keypair, rcu));
117 }
118
119 static void keypair_free_kref(struct kref *kref)
120 {
121         struct noise_keypair *keypair =
122                 container_of(kref, struct noise_keypair, refcount);
123
124         net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
125                             keypair->entry.peer->device->dev->name,
126                             keypair->internal_id,
127                             keypair->entry.peer->internal_id);
128         wg_index_hashtable_remove(keypair->entry.peer->device->index_hashtable,
129                                   &keypair->entry);
130         call_rcu(&keypair->rcu, keypair_free_rcu);
131 }
132
133 void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
134 {
135         if (unlikely(!keypair))
136                 return;
137         if (unlikely(unreference_now))
138                 wg_index_hashtable_remove(
139                         keypair->entry.peer->device->index_hashtable,
140                         &keypair->entry);
141         kref_put(&keypair->refcount, keypair_free_kref);
142 }
143
144 struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
145 {
146         RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
147                 "Taking noise keypair reference without holding the RCU BH read lock");
148         if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
149                 return NULL;
150         return keypair;
151 }
152
153 void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
154 {
155         struct noise_keypair *old;
156
157         spin_lock_bh(&keypairs->keypair_update_lock);
158
159         /* We zero the next_keypair before zeroing the others, so that
160          * wg_noise_received_with_keypair returns early before subsequent ones
161          * are zeroed.
162          */
163         old = rcu_dereference_protected(keypairs->next_keypair,
164                 lockdep_is_held(&keypairs->keypair_update_lock));
165         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
166         wg_noise_keypair_put(old, true);
167
168         old = rcu_dereference_protected(keypairs->previous_keypair,
169                 lockdep_is_held(&keypairs->keypair_update_lock));
170         RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
171         wg_noise_keypair_put(old, true);
172
173         old = rcu_dereference_protected(keypairs->current_keypair,
174                 lockdep_is_held(&keypairs->keypair_update_lock));
175         RCU_INIT_POINTER(keypairs->current_keypair, NULL);
176         wg_noise_keypair_put(old, true);
177
178         spin_unlock_bh(&keypairs->keypair_update_lock);
179 }
180
181 void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
182 {
183         struct noise_keypair *keypair;
184
185         wg_noise_handshake_clear(&peer->handshake);
186         wg_noise_reset_last_sent_handshake(&peer->last_sent_handshake);
187
188         spin_lock_bh(&peer->keypairs.keypair_update_lock);
189         keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
190                         lockdep_is_held(&peer->keypairs.keypair_update_lock));
191         if (keypair)
192                 keypair->sending.is_valid = false;
193         keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
194                         lockdep_is_held(&peer->keypairs.keypair_update_lock));
195         if (keypair)
196                 keypair->sending.is_valid = false;
197         spin_unlock_bh(&peer->keypairs.keypair_update_lock);
198 }
199
200 static void add_new_keypair(struct noise_keypairs *keypairs,
201                             struct noise_keypair *new_keypair)
202 {
203         struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
204
205         spin_lock_bh(&keypairs->keypair_update_lock);
206         previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
207                 lockdep_is_held(&keypairs->keypair_update_lock));
208         next_keypair = rcu_dereference_protected(keypairs->next_keypair,
209                 lockdep_is_held(&keypairs->keypair_update_lock));
210         current_keypair = rcu_dereference_protected(keypairs->current_keypair,
211                 lockdep_is_held(&keypairs->keypair_update_lock));
212         if (new_keypair->i_am_the_initiator) {
213                 /* If we're the initiator, it means we've sent a handshake, and
214                  * received a confirmation response, which means this new
215                  * keypair can now be used.
216                  */
217                 if (next_keypair) {
218                         /* If there already was a next keypair pending, we
219                          * demote it to be the previous keypair, and free the
220                          * existing current. Note that this means KCI can result
221                          * in this transition. It would perhaps be more sound to
222                          * always just get rid of the unused next keypair
223                          * instead of putting it in the previous slot, but this
224                          * might be a bit less robust. Something to think about
225                          * for the future.
226                          */
227                         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
228                         rcu_assign_pointer(keypairs->previous_keypair,
229                                            next_keypair);
230                         wg_noise_keypair_put(current_keypair, true);
231                 } else /* If there wasn't an existing next keypair, we replace
232                         * the previous with the current one.
233                         */
234                         rcu_assign_pointer(keypairs->previous_keypair,
235                                            current_keypair);
236                 /* At this point we can get rid of the old previous keypair, and
237                  * set up the new keypair.
238                  */
239                 wg_noise_keypair_put(previous_keypair, true);
240                 rcu_assign_pointer(keypairs->current_keypair, new_keypair);
241         } else {
242                 /* If we're the responder, it means we can't use the new keypair
243                  * until we receive confirmation via the first data packet, so
244                  * we get rid of the existing previous one, the possibly
245                  * existing next one, and slide in the new next one.
246                  */
247                 rcu_assign_pointer(keypairs->next_keypair, new_keypair);
248                 wg_noise_keypair_put(next_keypair, true);
249                 RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
250                 wg_noise_keypair_put(previous_keypair, true);
251         }
252         spin_unlock_bh(&keypairs->keypair_update_lock);
253 }
254
255 bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
256                                     struct noise_keypair *received_keypair)
257 {
258         struct noise_keypair *old_keypair;
259         bool key_is_new;
260
261         /* We first check without taking the spinlock. */
262         key_is_new = received_keypair ==
263                      rcu_access_pointer(keypairs->next_keypair);
264         if (likely(!key_is_new))
265                 return false;
266
267         spin_lock_bh(&keypairs->keypair_update_lock);
268         /* After locking, we double check that things didn't change from
269          * beneath us.
270          */
271         if (unlikely(received_keypair !=
272                     rcu_dereference_protected(keypairs->next_keypair,
273                             lockdep_is_held(&keypairs->keypair_update_lock)))) {
274                 spin_unlock_bh(&keypairs->keypair_update_lock);
275                 return false;
276         }
277
278         /* When we've finally received the confirmation, we slide the next
279          * into the current, the current into the previous, and get rid of
280          * the old previous.
281          */
282         old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
283                 lockdep_is_held(&keypairs->keypair_update_lock));
284         rcu_assign_pointer(keypairs->previous_keypair,
285                 rcu_dereference_protected(keypairs->current_keypair,
286                         lockdep_is_held(&keypairs->keypair_update_lock)));
287         wg_noise_keypair_put(old_keypair, true);
288         rcu_assign_pointer(keypairs->current_keypair, received_keypair);
289         RCU_INIT_POINTER(keypairs->next_keypair, NULL);
290
291         spin_unlock_bh(&keypairs->keypair_update_lock);
292         return true;
293 }
294
295 /* Must hold static_identity->lock */
296 void wg_noise_set_static_identity_private_key(
297         struct noise_static_identity *static_identity,
298         const u8 private_key[NOISE_PUBLIC_KEY_LEN])
299 {
300         memcpy(static_identity->static_private, private_key,
301                NOISE_PUBLIC_KEY_LEN);
302         curve25519_clamp_secret(static_identity->static_private);
303         static_identity->has_identity = curve25519_generate_public(
304                 static_identity->static_public, private_key);
305 }
306
307 /* This is Hugo Krawczyk's HKDF:
308  *  - https://eprint.iacr.org/2010/264.pdf
309  *  - https://tools.ietf.org/html/rfc5869
310  */
311 static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
312                 size_t first_len, size_t second_len, size_t third_len,
313                 size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
314 {
315         u8 output[BLAKE2S_HASH_SIZE + 1];
316         u8 secret[BLAKE2S_HASH_SIZE];
317
318         WARN_ON(IS_ENABLED(DEBUG) &&
319                 (first_len > BLAKE2S_HASH_SIZE ||
320                  second_len > BLAKE2S_HASH_SIZE ||
321                  third_len > BLAKE2S_HASH_SIZE ||
322                  ((second_len || second_dst || third_len || third_dst) &&
323                   (!first_len || !first_dst)) ||
324                  ((third_len || third_dst) && (!second_len || !second_dst))));
325
326         /* Extract entropy from data into secret */
327         blake2s256_hmac(secret, data, chaining_key, data_len, NOISE_HASH_LEN);
328
329         if (!first_dst || !first_len)
330                 goto out;
331
332         /* Expand first key: key = secret, data = 0x1 */
333         output[0] = 1;
334         blake2s256_hmac(output, output, secret, 1, BLAKE2S_HASH_SIZE);
335         memcpy(first_dst, output, first_len);
336
337         if (!second_dst || !second_len)
338                 goto out;
339
340         /* Expand second key: key = secret, data = first-key || 0x2 */
341         output[BLAKE2S_HASH_SIZE] = 2;
342         blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
343                         BLAKE2S_HASH_SIZE);
344         memcpy(second_dst, output, second_len);
345
346         if (!third_dst || !third_len)
347                 goto out;
348
349         /* Expand third key: key = secret, data = second-key || 0x3 */
350         output[BLAKE2S_HASH_SIZE] = 3;
351         blake2s256_hmac(output, output, secret, BLAKE2S_HASH_SIZE + 1,
352                         BLAKE2S_HASH_SIZE);
353         memcpy(third_dst, output, third_len);
354
355 out:
356         /* Clear sensitive data from stack */
357         memzero_explicit(secret, BLAKE2S_HASH_SIZE);
358         memzero_explicit(output, BLAKE2S_HASH_SIZE + 1);
359 }
360
361 static void symmetric_key_init(struct noise_symmetric_key *key)
362 {
363         spin_lock_init(&key->counter.receive.lock);
364         atomic64_set(&key->counter.counter, 0);
365         memset(key->counter.receive.backtrack, 0,
366                sizeof(key->counter.receive.backtrack));
367         key->birthdate = ktime_get_coarse_boottime_ns();
368         key->is_valid = true;
369 }
370
371 static void derive_keys(struct noise_symmetric_key *first_dst,
372                         struct noise_symmetric_key *second_dst,
373                         const u8 chaining_key[NOISE_HASH_LEN])
374 {
375         kdf(first_dst->key, second_dst->key, NULL, NULL,
376             NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
377             chaining_key);
378         symmetric_key_init(first_dst);
379         symmetric_key_init(second_dst);
380 }
381
382 static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
383                                 u8 key[NOISE_SYMMETRIC_KEY_LEN],
384                                 const u8 private[NOISE_PUBLIC_KEY_LEN],
385                                 const u8 public[NOISE_PUBLIC_KEY_LEN])
386 {
387         u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
388
389         if (unlikely(!curve25519(dh_calculation, private, public)))
390                 return false;
391         kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
392             NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
393         memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
394         return true;
395 }
396
397 static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
398                                             u8 key[NOISE_SYMMETRIC_KEY_LEN],
399                                             const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
400 {
401         static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
402         if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
403                 return false;
404         kdf(chaining_key, key, NULL, precomputed, NOISE_HASH_LEN,
405             NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
406             chaining_key);
407         return true;
408 }
409
410 static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
411 {
412         struct blake2s_state blake;
413
414         blake2s_init(&blake, NOISE_HASH_LEN);
415         blake2s_update(&blake, hash, NOISE_HASH_LEN);
416         blake2s_update(&blake, src, src_len);
417         blake2s_final(&blake, hash);
418 }
419
420 static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
421                     u8 key[NOISE_SYMMETRIC_KEY_LEN],
422                     const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
423 {
424         u8 temp_hash[NOISE_HASH_LEN];
425
426         kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
427             NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
428         mix_hash(hash, temp_hash, NOISE_HASH_LEN);
429         memzero_explicit(temp_hash, NOISE_HASH_LEN);
430 }
431
432 static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
433                            u8 hash[NOISE_HASH_LEN],
434                            const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
435 {
436         memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
437         memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
438         mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
439 }
440
441 static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
442                             size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
443                             u8 hash[NOISE_HASH_LEN])
444 {
445         chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
446                                  NOISE_HASH_LEN,
447                                  0 /* Always zero for Noise_IK */, key);
448         mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
449 }
450
451 static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
452                             size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
453                             u8 hash[NOISE_HASH_LEN])
454 {
455         if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
456                                       hash, NOISE_HASH_LEN,
457                                       0 /* Always zero for Noise_IK */, key))
458                 return false;
459         mix_hash(hash, src_ciphertext, src_len);
460         return true;
461 }
462
463 static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
464                               const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
465                               u8 chaining_key[NOISE_HASH_LEN],
466                               u8 hash[NOISE_HASH_LEN])
467 {
468         if (ephemeral_dst != ephemeral_src)
469                 memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
470         mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
471         kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
472             NOISE_PUBLIC_KEY_LEN, chaining_key);
473 }
474
475 static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
476 {
477         struct timespec64 now;
478
479         ktime_get_real_ts64(&now);
480
481         /* In order to prevent some sort of infoleak from precise timers, we
482          * round down the nanoseconds part to the closest rounded-down power of
483          * two to the maximum initiations per second allowed anyway by the
484          * implementation.
485          */
486         now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
487                 rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
488
489         /* https://cr.yp.to/libtai/tai64.html */
490         *(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
491         *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
492 }
493
494 bool
495 wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
496                                      struct noise_handshake *handshake)
497 {
498         u8 timestamp[NOISE_TIMESTAMP_LEN];
499         u8 key[NOISE_SYMMETRIC_KEY_LEN];
500         bool ret = false;
501
502         /* We need to wait for crng _before_ taking any locks, since
503          * curve25519_generate_secret uses get_random_bytes_wait.
504          */
505         wait_for_random_bytes();
506
507         down_read(&handshake->static_identity->lock);
508         down_write(&handshake->lock);
509
510         if (unlikely(!handshake->static_identity->has_identity))
511                 goto out;
512
513         dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
514
515         handshake_init(handshake->chaining_key, handshake->hash,
516                        handshake->remote_static);
517
518         /* e */
519         curve25519_generate_secret(handshake->ephemeral_private);
520         if (!curve25519_generate_public(dst->unencrypted_ephemeral,
521                                         handshake->ephemeral_private))
522                 goto out;
523         message_ephemeral(dst->unencrypted_ephemeral,
524                           dst->unencrypted_ephemeral, handshake->chaining_key,
525                           handshake->hash);
526
527         /* es */
528         if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
529                     handshake->remote_static))
530                 goto out;
531
532         /* s */
533         message_encrypt(dst->encrypted_static,
534                         handshake->static_identity->static_public,
535                         NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
536
537         /* ss */
538         if (!mix_precomputed_dh(handshake->chaining_key, key,
539                                 handshake->precomputed_static_static))
540                 goto out;
541
542         /* {t} */
543         tai64n_now(timestamp);
544         message_encrypt(dst->encrypted_timestamp, timestamp,
545                         NOISE_TIMESTAMP_LEN, key, handshake->hash);
546
547         dst->sender_index = wg_index_hashtable_insert(
548                 handshake->entry.peer->device->index_hashtable,
549                 &handshake->entry);
550
551         handshake->state = HANDSHAKE_CREATED_INITIATION;
552         ret = true;
553
554 out:
555         up_write(&handshake->lock);
556         up_read(&handshake->static_identity->lock);
557         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
558         return ret;
559 }
560
561 struct wg_peer *
562 wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
563                                       struct wg_device *wg)
564 {
565         struct wg_peer *peer = NULL, *ret_peer = NULL;
566         struct noise_handshake *handshake;
567         bool replay_attack, flood_attack;
568         u8 key[NOISE_SYMMETRIC_KEY_LEN];
569         u8 chaining_key[NOISE_HASH_LEN];
570         u8 hash[NOISE_HASH_LEN];
571         u8 s[NOISE_PUBLIC_KEY_LEN];
572         u8 e[NOISE_PUBLIC_KEY_LEN];
573         u8 t[NOISE_TIMESTAMP_LEN];
574         u64 initiation_consumption;
575
576         down_read(&wg->static_identity.lock);
577         if (unlikely(!wg->static_identity.has_identity))
578                 goto out;
579
580         handshake_init(chaining_key, hash, wg->static_identity.static_public);
581
582         /* e */
583         message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
584
585         /* es */
586         if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e))
587                 goto out;
588
589         /* s */
590         if (!message_decrypt(s, src->encrypted_static,
591                              sizeof(src->encrypted_static), key, hash))
592                 goto out;
593
594         /* Lookup which peer we're actually talking to */
595         peer = wg_pubkey_hashtable_lookup(wg->peer_hashtable, s);
596         if (!peer)
597                 goto out;
598         handshake = &peer->handshake;
599
600         /* ss */
601         if (!mix_precomputed_dh(chaining_key, key,
602                                 handshake->precomputed_static_static))
603             goto out;
604
605         /* {t} */
606         if (!message_decrypt(t, src->encrypted_timestamp,
607                              sizeof(src->encrypted_timestamp), key, hash))
608                 goto out;
609
610         down_read(&handshake->lock);
611         replay_attack = memcmp(t, handshake->latest_timestamp,
612                                NOISE_TIMESTAMP_LEN) <= 0;
613         flood_attack = (s64)handshake->last_initiation_consumption +
614                                NSEC_PER_SEC / INITIATIONS_PER_SECOND >
615                        (s64)ktime_get_coarse_boottime_ns();
616         up_read(&handshake->lock);
617         if (replay_attack || flood_attack)
618                 goto out;
619
620         /* Success! Copy everything to peer */
621         down_write(&handshake->lock);
622         memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
623         if (memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) > 0)
624                 memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
625         memcpy(handshake->hash, hash, NOISE_HASH_LEN);
626         memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
627         handshake->remote_index = src->sender_index;
628         if ((s64)(handshake->last_initiation_consumption -
629             (initiation_consumption = ktime_get_coarse_boottime_ns())) < 0)
630                 handshake->last_initiation_consumption = initiation_consumption;
631         handshake->state = HANDSHAKE_CONSUMED_INITIATION;
632         up_write(&handshake->lock);
633         ret_peer = peer;
634
635 out:
636         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
637         memzero_explicit(hash, NOISE_HASH_LEN);
638         memzero_explicit(chaining_key, NOISE_HASH_LEN);
639         up_read(&wg->static_identity.lock);
640         if (!ret_peer)
641                 wg_peer_put(peer);
642         return ret_peer;
643 }
644
645 bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
646                                         struct noise_handshake *handshake)
647 {
648         u8 key[NOISE_SYMMETRIC_KEY_LEN];
649         bool ret = false;
650
651         /* We need to wait for crng _before_ taking any locks, since
652          * curve25519_generate_secret uses get_random_bytes_wait.
653          */
654         wait_for_random_bytes();
655
656         down_read(&handshake->static_identity->lock);
657         down_write(&handshake->lock);
658
659         if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
660                 goto out;
661
662         dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
663         dst->receiver_index = handshake->remote_index;
664
665         /* e */
666         curve25519_generate_secret(handshake->ephemeral_private);
667         if (!curve25519_generate_public(dst->unencrypted_ephemeral,
668                                         handshake->ephemeral_private))
669                 goto out;
670         message_ephemeral(dst->unencrypted_ephemeral,
671                           dst->unencrypted_ephemeral, handshake->chaining_key,
672                           handshake->hash);
673
674         /* ee */
675         if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
676                     handshake->remote_ephemeral))
677                 goto out;
678
679         /* se */
680         if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
681                     handshake->remote_static))
682                 goto out;
683
684         /* psk */
685         mix_psk(handshake->chaining_key, handshake->hash, key,
686                 handshake->preshared_key);
687
688         /* {} */
689         message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
690
691         dst->sender_index = wg_index_hashtable_insert(
692                 handshake->entry.peer->device->index_hashtable,
693                 &handshake->entry);
694
695         handshake->state = HANDSHAKE_CREATED_RESPONSE;
696         ret = true;
697
698 out:
699         up_write(&handshake->lock);
700         up_read(&handshake->static_identity->lock);
701         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
702         return ret;
703 }
704
705 struct wg_peer *
706 wg_noise_handshake_consume_response(struct message_handshake_response *src,
707                                     struct wg_device *wg)
708 {
709         enum noise_handshake_state state = HANDSHAKE_ZEROED;
710         struct wg_peer *peer = NULL, *ret_peer = NULL;
711         struct noise_handshake *handshake;
712         u8 key[NOISE_SYMMETRIC_KEY_LEN];
713         u8 hash[NOISE_HASH_LEN];
714         u8 chaining_key[NOISE_HASH_LEN];
715         u8 e[NOISE_PUBLIC_KEY_LEN];
716         u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
717         u8 static_private[NOISE_PUBLIC_KEY_LEN];
718
719         down_read(&wg->static_identity.lock);
720
721         if (unlikely(!wg->static_identity.has_identity))
722                 goto out;
723
724         handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
725                 wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
726                 src->receiver_index, &peer);
727         if (unlikely(!handshake))
728                 goto out;
729
730         down_read(&handshake->lock);
731         state = handshake->state;
732         memcpy(hash, handshake->hash, NOISE_HASH_LEN);
733         memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
734         memcpy(ephemeral_private, handshake->ephemeral_private,
735                NOISE_PUBLIC_KEY_LEN);
736         up_read(&handshake->lock);
737
738         if (state != HANDSHAKE_CREATED_INITIATION)
739                 goto fail;
740
741         /* e */
742         message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash);
743
744         /* ee */
745         if (!mix_dh(chaining_key, NULL, ephemeral_private, e))
746                 goto fail;
747
748         /* se */
749         if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e))
750                 goto fail;
751
752         /* psk */
753         mix_psk(chaining_key, hash, key, handshake->preshared_key);
754
755         /* {} */
756         if (!message_decrypt(NULL, src->encrypted_nothing,
757                              sizeof(src->encrypted_nothing), key, hash))
758                 goto fail;
759
760         /* Success! Copy everything to peer */
761         down_write(&handshake->lock);
762         /* It's important to check that the state is still the same, while we
763          * have an exclusive lock.
764          */
765         if (handshake->state != state) {
766                 up_write(&handshake->lock);
767                 goto fail;
768         }
769         memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
770         memcpy(handshake->hash, hash, NOISE_HASH_LEN);
771         memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
772         handshake->remote_index = src->sender_index;
773         handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
774         up_write(&handshake->lock);
775         ret_peer = peer;
776         goto out;
777
778 fail:
779         wg_peer_put(peer);
780 out:
781         memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
782         memzero_explicit(hash, NOISE_HASH_LEN);
783         memzero_explicit(chaining_key, NOISE_HASH_LEN);
784         memzero_explicit(ephemeral_private, NOISE_PUBLIC_KEY_LEN);
785         memzero_explicit(static_private, NOISE_PUBLIC_KEY_LEN);
786         up_read(&wg->static_identity.lock);
787         return ret_peer;
788 }
789
790 bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
791                                       struct noise_keypairs *keypairs)
792 {
793         struct noise_keypair *new_keypair;
794         bool ret = false;
795
796         down_write(&handshake->lock);
797         if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
798             handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
799                 goto out;
800
801         new_keypair = keypair_create(handshake->entry.peer);
802         if (!new_keypair)
803                 goto out;
804         new_keypair->i_am_the_initiator = handshake->state ==
805                                           HANDSHAKE_CONSUMED_RESPONSE;
806         new_keypair->remote_index = handshake->remote_index;
807
808         if (new_keypair->i_am_the_initiator)
809                 derive_keys(&new_keypair->sending, &new_keypair->receiving,
810                             handshake->chaining_key);
811         else
812                 derive_keys(&new_keypair->receiving, &new_keypair->sending,
813                             handshake->chaining_key);
814
815         handshake_zero(handshake);
816         rcu_read_lock_bh();
817         if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
818                                            handshake)->is_dead))) {
819                 add_new_keypair(keypairs, new_keypair);
820                 net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
821                                     handshake->entry.peer->device->dev->name,
822                                     new_keypair->internal_id,
823                                     handshake->entry.peer->internal_id);
824                 ret = wg_index_hashtable_replace(
825                         handshake->entry.peer->device->index_hashtable,
826                         &handshake->entry, &new_keypair->entry);
827         } else {
828                 kzfree(new_keypair);
829         }
830         rcu_read_unlock_bh();
831
832 out:
833         up_write(&handshake->lock);
834         return ret;
835 }